Compare commits

...

9 Commits

Author SHA1 Message Date
CodingOnStar
48cba768b1 test: enhance text generation and workflow callbacks tests
- Added comprehensive tests for the useTextGeneration hook, covering handleSend and handleStop functionalities, including file processing and error handling.
- Improved workflow callbacks tests to handle edge cases and ensure graceful handling of missing nodes and undefined tracing.
- Introduced new test cases for validating prompt variables and processing files in the text generation flow.
2026-02-06 18:03:41 +08:00
CodingOnStar
a985eb8725 Merge remote-tracking branch 'origin/main' into refactor/share 2026-02-06 17:32:21 +08:00
CodingOnStar
15d25f8876 feat: enhance text generation result components and add tests
- Removed unused ESLint suppressions from the configuration.
- Introduced new Result and Header components for better organization of text generation results.
- Added comprehensive tests for the new components and hooks to ensure functionality and reliability.
- Updated the useTextGeneration hook to streamline state management and improve performance.
2026-02-06 17:31:42 +08:00
CodingOnStar
cd4a4ed770 refactor: restructure text generation component and introduce new hooks
- Simplified the TextGeneration component by removing unused imports and state management.
- Introduced new hooks for managing app configuration and batch tasks.
- Added HeaderSection and ResultPanel components for better UI organization.
- Implemented tests for new components and hooks to ensure functionality.
- Updated types for improved type safety and clarity.
2026-02-06 16:51:01 +08:00
zyssyz123
2c9430313d fix: redis for api token (#31861)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
2026-02-06 16:25:27 +08:00
QuantumGhost
552ee369b2 chore: update deploy branches for deploy-hitl.yaml (#32051) 2026-02-06 16:14:05 +08:00
Stephen Zhou
d5b9a7b2f8 test: only remove text coverage in CI (#32043) 2026-02-06 16:12:28 +08:00
NeatGuyCoding
c2a3f459c7 fix(api): return proper HTTP 204 status code in DELETE endpoints (#32012)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-06 15:32:52 +08:00
QuantumGhost
4971e11734 perf: use batch delete method instead of single delete (#32036)
Co-authored-by: fatelei <fatelei@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: FFXN <lizy@dify.ai>
2026-02-06 15:12:32 +08:00
52 changed files with 5481 additions and 1178 deletions

View File

@@ -4,7 +4,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "feat/hitl"
- "build/feat/hitl"
types:
- completed
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'feat/hitl'
github.event.workflow_run.head_branch == 'build/feat/hitl'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1

View File

@@ -122,7 +122,7 @@ These commands assume you start from the repository root.
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).

View File

@@ -1155,6 +1155,16 @@ class CeleryScheduleTasksConfig(BaseSettings):
default=0,
)
# API token last_used_at batch update
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
description="Enable periodic batch update of API token last_used_at timestamps",
default=True,
)
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
description="Interval in minutes for batch updating API token last_used_at (default 30)",
default=30,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",

View File

@@ -10,6 +10,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
@@ -131,6 +132,11 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -55,6 +55,7 @@ from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
# Register models for flask_restx to avoid dict type issues in Swagger
@@ -820,6 +821,11 @@ class DatasetApiDeleteApi(Resource):
if key is None:
console_ns.abort(404, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource):
TagService.delete_tag(tag_id)
return 204
return "", 204
@console_ns.route("/tag-bindings/create")

View File

@@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return 204
return "", 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
@@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
return 204
return "", 204
@service_api_ns.route("/datasets/tags/binding")
@@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource):
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return 204
return "", 204
@service_api_ns.route("/datasets/tags/unbinding")
@@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return 204
return "", 204
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")

View File

@@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return 204
return "", 204

View File

@@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return 204
return "", 204
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")

View File

@@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
return 204
return "", 204
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return 204
return "", 204
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")

View File

@@ -1,27 +1,24 @@
import logging
import time
from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate, ParamSpec, TypeVar, cast
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
@@ -296,7 +293,14 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
Validate and get API token with Redis caching.
This function uses a two-tier approach:
1. First checks Redis cache for the token
2. If not cached, queries database and caches the result
The last_used_at field is updated asynchronously via Celery task
to avoid blocking the request.
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
@@ -308,29 +312,18 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
result = session.execute(update_stmt)
api_token = session.scalar(stmt)
# Try to get token from cache first
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token validation served from cache for scope: %s", scope)
# Record usage in Redis for later batch update (no Celery task per request)
record_token_usage(auth_token, scope)
return cast(ApiToken, cached_token)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
if not api_token:
raise Unauthorized("Access token is invalid")
return api_token
# Cache miss - use Redis lock for single-flight mode
# This ensures only one request queries DB for the same token concurrently
return fetch_token_with_single_flight(auth_token, scope)
class DatasetApiResource(Resource):

View File

@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@@ -184,6 +184,14 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
}
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
imports.append("schedule.update_api_token_last_used_task")
beat_schedule["batch_update_api_token_last_used"] = {
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@@ -0,0 +1,114 @@
"""
Scheduled task to batch-update API token last_used_at timestamps.
Instead of updating the database on every request, token usage is recorded
in Redis as lightweight SET keys (api_token_active:{scope}:{token}).
This task runs periodically (default every 30 minutes) to flush those
records into the database in a single batch operation.
"""
import logging
import time
from datetime import datetime
import click
from sqlalchemy import update
from sqlalchemy.orm import Session
import app
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import ApiToken
from services.api_token_service import ACTIVE_TOKEN_KEY_PREFIX
logger = logging.getLogger(__name__)
@app.celery.task(queue="api_token")
def batch_update_api_token_last_used():
"""
Batch update last_used_at for all recently active API tokens.
Scans Redis for api_token_active:* keys, parses the token and scope
from each key, and performs a batch database update.
"""
click.echo(click.style("batch_update_api_token_last_used: start.", fg="green"))
start_at = time.perf_counter()
updated_count = 0
scanned_count = 0
try:
# Collect all active token keys and their values (the actual usage timestamps)
token_entries: list[tuple[str, str | None, datetime]] = [] # (token, scope, usage_time)
keys_to_delete: list[str | bytes] = []
for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200):
if isinstance(key, bytes):
key = key.decode("utf-8")
scanned_count += 1
# Read the value (ISO timestamp recorded at actual request time)
value = redis_client.get(key)
if not value:
keys_to_delete.append(key)
continue
if isinstance(value, bytes):
value = value.decode("utf-8")
try:
usage_time = datetime.fromisoformat(value)
except (ValueError, TypeError):
logger.warning("Invalid timestamp in key %s: %s", key, value)
keys_to_delete.append(key)
continue
# Parse token info from key: api_token_active:{scope}:{token}
suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX) :]
parts = suffix.split(":", 1)
if len(parts) == 2:
scope_str, token = parts
scope = None if scope_str == "None" else scope_str
token_entries.append((token, scope, usage_time))
keys_to_delete.append(key)
if not token_entries:
click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow"))
# Still clean up any invalid keys
if keys_to_delete:
redis_client.delete(*keys_to_delete)
return
# Update each token in its own short transaction to avoid long transactions
for token, scope, usage_time in token_entries:
with Session(db.engine, expire_on_commit=False) as session, session.begin():
stmt = (
update(ApiToken)
.where(
ApiToken.token == token,
ApiToken.type == scope,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < usage_time)),
)
.values(last_used_at=usage_time)
)
result = session.execute(stmt)
rowcount = getattr(result, "rowcount", 0)
if rowcount > 0:
updated_count += 1
# Delete processed keys from Redis
if keys_to_delete:
redis_client.delete(*keys_to_delete)
except Exception:
logger.exception("batch_update_api_token_last_used failed")
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"batch_update_api_token_last_used: done. "
f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s",
fg="green",
)
)

View File

@@ -0,0 +1,330 @@
"""
API Token Service
Handles all API token caching, validation, and usage recording.
Includes Redis cache operations, database queries, and single-flight concurrency control.
"""
import logging
from datetime import datetime
from typing import Any
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs.datetime_utils import naive_utc_now
from models.model import ApiToken
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# Pydantic DTO
# ---------------------------------------------------------------------
class CachedApiToken(BaseModel):
"""
Pydantic model for cached API token data.
This is NOT a SQLAlchemy model instance, but a plain Pydantic model
that mimics the ApiToken model interface for read-only access.
"""
id: str
app_id: str | None
tenant_id: str | None
type: str
token: str
last_used_at: datetime | None
created_at: datetime | None
def __repr__(self) -> str:
return f"<CachedApiToken id={self.id} type={self.type}>"
# ---------------------------------------------------------------------
# Cache configuration
# ---------------------------------------------------------------------
CACHE_KEY_PREFIX = "api_token"
CACHE_TTL_SECONDS = 600 # 10 minutes
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens
ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:"
# ---------------------------------------------------------------------
# Cache class
# ---------------------------------------------------------------------
class ApiTokenCache:
"""
Redis cache wrapper for API tokens.
Handles serialization, deserialization, and cache invalidation.
"""
@staticmethod
def make_active_key(token: str, scope: str | None = None) -> str:
"""Generate Redis key for recording token usage."""
return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}"
@staticmethod
def _make_tenant_index_key(tenant_id: str) -> str:
"""Generate Redis key for tenant token index."""
return f"tenant_tokens:{tenant_id}"
@staticmethod
def _make_cache_key(token: str, scope: str | None = None) -> str:
"""Generate cache key for the given token and scope."""
scope_str = scope or "any"
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
@staticmethod
def _serialize_token(api_token: Any) -> bytes:
"""Serialize ApiToken object to JSON bytes."""
if isinstance(api_token, CachedApiToken):
return api_token.model_dump_json().encode("utf-8")
cached = CachedApiToken(
id=str(api_token.id),
app_id=str(api_token.app_id) if api_token.app_id else None,
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
type=api_token.type,
token=api_token.token,
last_used_at=api_token.last_used_at,
created_at=api_token.created_at,
)
return cached.model_dump_json().encode("utf-8")
@staticmethod
def _deserialize_token(cached_data: bytes | str) -> Any:
"""Deserialize JSON bytes/string back to a CachedApiToken Pydantic model."""
if cached_data in {b"null", "null"}:
return None
try:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode("utf-8")
return CachedApiToken.model_validate_json(cached_data)
except (ValueError, Exception) as e:
logger.warning("Failed to deserialize token from cache: %s", e)
return None
@staticmethod
@redis_fallback(default_return=None)
def get(token: str, scope: str | None) -> Any | None:
"""Get API token from cache."""
cache_key = ApiTokenCache._make_cache_key(token, scope)
cached_data = redis_client.get(cache_key)
if cached_data is None:
logger.debug("Cache miss for token key: %s", cache_key)
return None
logger.debug("Cache hit for token key: %s", cache_key)
return ApiTokenCache._deserialize_token(cached_data)
@staticmethod
def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None:
"""Add cache key to tenant index for efficient invalidation."""
if not tenant_id:
return
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
redis_client.sadd(index_key, cache_key)
redis_client.expire(index_key, CACHE_TTL_SECONDS + 60)
except Exception as e:
logger.warning("Failed to update tenant index: %s", e)
@staticmethod
def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None:
"""Remove cache key from tenant index."""
if not tenant_id:
return
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
redis_client.srem(index_key, cache_key)
except Exception as e:
logger.warning("Failed to remove from tenant index: %s", e)
@staticmethod
@redis_fallback(default_return=False)
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
"""Set API token in cache."""
cache_key = ApiTokenCache._make_cache_key(token, scope)
if api_token is None:
cached_value = b"null"
ttl = CACHE_NULL_TTL_SECONDS
else:
cached_value = ApiTokenCache._serialize_token(api_token)
try:
redis_client.setex(cache_key, ttl, cached_value)
if api_token is not None and hasattr(api_token, "tenant_id"):
ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key)
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
return True
except Exception as e:
logger.warning("Failed to cache token: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def delete(token: str, scope: str | None = None) -> bool:
"""Delete API token from cache."""
if scope is None:
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
try:
keys_to_delete = list(redis_client.scan_iter(match=pattern))
if keys_to_delete:
redis_client.delete(*keys_to_delete)
logger.info("Deleted %d cache entries for token", len(keys_to_delete))
return True
except Exception as e:
logger.warning("Failed to delete token cache with pattern: %s", e)
return False
else:
cache_key = ApiTokenCache._make_cache_key(token, scope)
try:
tenant_id = None
try:
cached_data = redis_client.get(cache_key)
if cached_data and cached_data != b"null":
cached_token = ApiTokenCache._deserialize_token(cached_data)
if cached_token:
tenant_id = cached_token.tenant_id
except Exception as e:
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
redis_client.delete(cache_key)
if tenant_id:
ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key)
logger.info("Deleted cache for key: %s", cache_key)
return True
except Exception as e:
logger.warning("Failed to delete token cache: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def invalidate_by_tenant(tenant_id: str) -> bool:
"""Invalidate all API token caches for a specific tenant via tenant index."""
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
cache_keys = redis_client.smembers(index_key)
if cache_keys:
deleted_count = 0
for cache_key in cache_keys:
if isinstance(cache_key, bytes):
cache_key = cache_key.decode("utf-8")
redis_client.delete(cache_key)
deleted_count += 1
redis_client.delete(index_key)
logger.info(
"Invalidated %d token cache entries for tenant: %s",
deleted_count,
tenant_id,
)
else:
logger.info(
"No tenant index found for %s, relying on TTL expiration",
tenant_id,
)
return True
except Exception as e:
logger.warning("Failed to invalidate tenant token cache: %s", e)
return False
# ---------------------------------------------------------------------
# Token usage recording (for batch update)
# ---------------------------------------------------------------------
def record_token_usage(auth_token: str, scope: str | None) -> None:
"""
Record token usage in Redis for later batch update by a scheduled job.
Instead of dispatching a Celery task per request, we simply SET a key in Redis.
A Celery Beat scheduled task will periodically scan these keys and batch-update
last_used_at in the database.
"""
try:
key = ApiTokenCache.make_active_key(auth_token, scope)
redis_client.set(key, naive_utc_now().isoformat(), ex=3600)
except Exception as e:
logger.warning("Failed to record token usage: %s", e)
# ---------------------------------------------------------------------
# Database query + single-flight
# ---------------------------------------------------------------------
def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
"""
Query API token from database and cache the result.
Raises Unauthorized if token is invalid.
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
record_token_usage(auth_token, scope)
return api_token
def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any:
"""
Fetch token from DB with single-flight pattern using Redis lock.
Ensures only one concurrent request queries the database for the same token.
Falls back to direct query if lock acquisition fails.
"""
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5)
try:
if lock.acquire(blocking=True):
try:
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token cached by concurrent request, using cached version")
return cached_token
return query_token_from_db(auth_token, scope)
finally:
lock.release()
else:
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
return query_token_from_db(auth_token, scope)
except Unauthorized:
raise
except Exception as e:
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
return query_token_from_db(auth_token, scope)

View File

@@ -14,6 +14,9 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
# Batch size for database operations to keep transactions short
BATCH_SIZE = 1000
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
@@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not doc_form:
raise ValueError("doc_form is required")
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
storage_keys_to_delete: list[str] = []
index_node_ids: list[str] = []
segment_ids: list[str] = []
total_image_upload_file_ids: list[str] = []
try:
# ============ Step 1: Query segment and file data (short read-only transaction) ============
with session_factory.create_session() as session:
# Get segments info
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
# Collect image file IDs from segment content
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
total_image_upload_file_ids.extend(image_upload_file_ids)
# Query storage keys for image files
if total_image_upload_file_ids:
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
).all()
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
# Query storage keys for document files
if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
if index_node_ids:
try:
# Fetch dataset in a fresh session to avoid DetachedInstanceError
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
else:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
dataset_id,
document_ids,
len(index_node_ids),
)
)
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = (
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
)
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
except Exception:
logger.exception("Cleaned documents when documents deleted failed")
logger.exception(
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
if total_image_upload_file_ids:
failed_batches = 0
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
session.execute(stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
i,
i + len(batch),
dataset_id,
)
if failed_batches > 0:
logger.warning(
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
failed_batches,
total_batches,
dataset_id,
)
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
if segment_ids:
failed_batches = 0
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(segment_ids), BATCH_SIZE):
batch = segment_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
session.execute(segment_delete_stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
i,
i + len(batch),
dataset_id,
document_ids,
)
if failed_batches > 0:
logger.warning(
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
failed_batches,
total_batches,
document_ids,
)
# ============ Step 6: Delete document-associated files (separate short transaction) ============
if file_ids:
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
session.commit()
except Exception:
logger.exception(
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
dataset_id,
file_ids,
)
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
storage_delete_failures = 0
for storage_key in storage_keys_to_delete:
try:
storage.delete(storage_key)
except Exception:
storage_delete_failures += 1
logger.exception("Failed to delete file from storage, key: %s", storage_key)
if storage_delete_failures > 0:
logger.warning(
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
storage_delete_failures,
len(storage_keys_to_delete),
dataset_id,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
f"storage_files: {len(storage_keys_to_delete)}",
fg="green",
)
)
except Exception:
logger.exception(
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -67,8 +68,14 @@ def delete_segment_from_index_task(
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
session.delete(binding)
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
for i in range(0, len(segment_attachment_bind_ids), 1000):
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
)
session.execute(segment_attachment_bind_delete_stmt)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit()

View File

@@ -28,7 +28,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
@@ -68,7 +68,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
session.commit()
return
loader = NotionExtractor(
@@ -85,7 +84,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:

View File

@@ -48,6 +48,7 @@ from models.workflow import (
WorkflowArchiveLog,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.api_token_service import ApiTokenCache
logger = logging.getLogger(__name__)
@@ -134,6 +135,12 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(session, api_token_id: str):
# Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
if token_obj:
# Invalidate cache before deletion
ApiTokenCache.delete(token_obj.token, token_obj.type)
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(

View File

@@ -0,0 +1,375 @@
"""
Integration tests for API Token Cache with Redis.
These tests require:
- Redis server running
- Test database configured
"""
import time
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from models.model import ApiToken
from services.api_token_service import ApiTokenCache, CachedApiToken
class TestApiTokenCacheRedisIntegration:
"""Integration tests with real Redis."""
def setup_method(self):
"""Setup test fixtures and clean Redis."""
self.test_token = "test-integration-token-123"
self.test_scope = "app"
self.cache_key = f"api_token:{self.test_scope}:{self.test_token}"
# Clean up any existing test data
self._cleanup()
def teardown_method(self):
"""Cleanup test data from Redis."""
self._cleanup()
def _cleanup(self):
"""Remove test data from Redis."""
try:
redis_client.delete(self.cache_key)
redis_client.delete(ApiTokenCache._make_tenant_index_key("test-tenant-id"))
redis_client.delete(ApiTokenCache.make_active_key(self.test_token, self.test_scope))
except Exception:
pass # Ignore cleanup errors
def test_cache_set_and_get_with_real_redis(self):
"""Test cache set and get operations with real Redis."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id-123"
mock_token.app_id = "test-app-456"
mock_token.tenant_id = "test-tenant-789"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = datetime.now()
mock_token.created_at = datetime.now() - timedelta(days=30)
# Set in cache
result = ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
assert result is True
# Verify in Redis
cached_data = redis_client.get(self.cache_key)
assert cached_data is not None
# Get from cache
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
assert cached_token is not None
assert isinstance(cached_token, CachedApiToken)
assert cached_token.id == "test-id-123"
assert cached_token.app_id == "test-app-456"
assert cached_token.tenant_id == "test-tenant-789"
assert cached_token.type == "app"
assert cached_token.token == self.test_token
def test_cache_ttl_with_real_redis(self):
"""Test cache TTL is set correctly."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
ttl = redis_client.ttl(self.cache_key)
assert 595 <= ttl <= 600 # Should be around 600 seconds (10 minutes)
def test_cache_null_value_for_invalid_token(self):
"""Test caching null value for invalid tokens."""
result = ApiTokenCache.set(self.test_token, self.test_scope, None)
assert result is True
cached_data = redis_client.get(self.cache_key)
assert cached_data == b"null"
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
assert cached_token is None
ttl = redis_client.ttl(self.cache_key)
assert 55 <= ttl <= 60
def test_cache_delete_with_real_redis(self):
"""Test cache deletion with real Redis."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
assert redis_client.exists(self.cache_key) == 1
result = ApiTokenCache.delete(self.test_token, self.test_scope)
assert result is True
assert redis_client.exists(self.cache_key) == 0
def test_tenant_index_creation(self):
"""Test tenant index is created when caching token."""
from unittest.mock import MagicMock
tenant_id = "test-tenant-id"
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = tenant_id
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
assert redis_client.exists(index_key) == 1
members = redis_client.smembers(index_key)
cache_keys = [m.decode("utf-8") if isinstance(m, bytes) else m for m in members]
assert self.cache_key in cache_keys
def test_invalidate_by_tenant_via_index(self):
"""Test tenant-wide cache invalidation using index (fast path)."""
from unittest.mock import MagicMock
tenant_id = "test-tenant-id"
for i in range(3):
token_value = f"test-token-{i}"
mock_token = MagicMock()
mock_token.id = f"test-id-{i}"
mock_token.app_id = "test-app"
mock_token.tenant_id = tenant_id
mock_token.type = "app"
mock_token.token = token_value
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(token_value, "app", mock_token)
for i in range(3):
key = f"api_token:app:test-token-{i}"
assert redis_client.exists(key) == 1
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
assert result is True
for i in range(3):
key = f"api_token:app:test-token-{i}"
assert redis_client.exists(key) == 0
assert redis_client.exists(ApiTokenCache._make_tenant_index_key(tenant_id)) == 0
def test_concurrent_cache_access(self):
"""Test concurrent cache access doesn't cause issues."""
import concurrent.futures
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
def get_from_cache():
return ApiTokenCache.get(self.test_token, self.test_scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(get_from_cache) for _ in range(50)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
assert len(results) == 50
assert all(r is not None for r in results)
assert all(isinstance(r, CachedApiToken) for r in results)
class TestTokenUsageRecording:
"""Tests for recording token usage in Redis (batch update approach)."""
def setup_method(self):
self.test_token = "test-usage-token"
self.test_scope = "app"
self.active_key = ApiTokenCache.make_active_key(self.test_token, self.test_scope)
def teardown_method(self):
try:
redis_client.delete(self.active_key)
except Exception:
pass
def test_record_token_usage_sets_redis_key(self):
"""Test that record_token_usage writes an active key to Redis."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
# Key should exist
assert redis_client.exists(self.active_key) == 1
# Value should be an ISO timestamp
value = redis_client.get(self.active_key)
if isinstance(value, bytes):
value = value.decode("utf-8")
datetime.fromisoformat(value) # Should not raise
def test_record_token_usage_has_ttl(self):
"""Test that active keys have a TTL as safety net."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
ttl = redis_client.ttl(self.active_key)
assert 3595 <= ttl <= 3600 # ~1 hour
def test_record_token_usage_overwrites(self):
"""Test that repeated calls overwrite the same key (no accumulation)."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
first_value = redis_client.get(self.active_key)
time.sleep(0.01) # Tiny delay so timestamp differs
record_token_usage(self.test_token, self.test_scope)
second_value = redis_client.get(self.active_key)
# Key count should still be 1 (overwritten, not accumulated)
assert redis_client.exists(self.active_key) == 1
class TestEndToEndCacheFlow:
"""End-to-end integration test for complete cache flow."""
@pytest.mark.usefixtures("db_session")
def test_complete_flow_cache_miss_then_hit(self, db_session):
"""
Test complete flow:
1. First request (cache miss) -> query DB -> cache result
2. Second request (cache hit) -> return from cache
3. Verify Redis state
"""
test_token_value = "test-e2e-token"
test_scope = "app"
test_token = ApiToken()
test_token.id = "test-e2e-id"
test_token.token = test_token_value
test_token.type = test_scope
test_token.app_id = "test-app"
test_token.tenant_id = "test-tenant"
test_token.last_used_at = None
test_token.created_at = datetime.now()
db_session.add(test_token)
db_session.commit()
try:
# Step 1: Cache miss - set token in cache
ApiTokenCache.set(test_token_value, test_scope, test_token)
cache_key = f"api_token:{test_scope}:{test_token_value}"
assert redis_client.exists(cache_key) == 1
# Step 2: Cache hit - get from cache
cached_token = ApiTokenCache.get(test_token_value, test_scope)
assert cached_token is not None
assert cached_token.id == test_token.id
assert cached_token.token == test_token_value
# Step 3: Verify tenant index
index_key = ApiTokenCache._make_tenant_index_key(test_token.tenant_id)
assert redis_client.exists(index_key) == 1
assert cache_key.encode() in redis_client.smembers(index_key)
# Step 4: Delete and verify cleanup
ApiTokenCache.delete(test_token_value, test_scope)
assert redis_client.exists(cache_key) == 0
assert cache_key.encode() not in redis_client.smembers(index_key)
finally:
db_session.delete(test_token)
db_session.commit()
redis_client.delete(f"api_token:{test_scope}:{test_token_value}")
redis_client.delete(ApiTokenCache._make_tenant_index_key(test_token.tenant_id))
def test_high_concurrency_simulation(self):
"""Simulate high concurrency access to cache."""
import concurrent.futures
from unittest.mock import MagicMock
test_token_value = "test-concurrent-token"
test_scope = "app"
mock_token = MagicMock()
mock_token.id = "concurrent-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = test_scope
mock_token.token = test_token_value
mock_token.last_used_at = datetime.now()
mock_token.created_at = datetime.now()
ApiTokenCache.set(test_token_value, test_scope, mock_token)
try:
def read_cache():
return ApiTokenCache.get(test_token_value, test_scope)
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(read_cache) for _ in range(100)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
elapsed = time.time() - start_time
assert len(results) == 100
assert all(r is not None for r in results)
assert elapsed < 1.0, f"Too slow: {elapsed}s for 100 cache reads"
finally:
ApiTokenCache.delete(test_token_value, test_scope)
redis_client.delete(ApiTokenCache._make_tenant_index_key(mock_token.tenant_id))
class TestRedisFailover:
"""Test behavior when Redis is unavailable."""
@patch("services.api_token_service.redis_client")
def test_graceful_degradation_when_redis_fails(self, mock_redis):
"""Test system degrades gracefully when Redis is unavailable."""
from redis import RedisError
mock_redis.get.side_effect = RedisError("Connection failed")
mock_redis.setex.side_effect = RedisError("Connection failed")
result_get = ApiTokenCache.get("test-token", "app")
assert result_get is None
result_set = ApiTokenCache.set("test-token", "app", None)
assert result_set is False

View File

@@ -132,6 +132,8 @@ class TestCelerySSLConfiguration:
mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
with patch("extensions.ext_celery.dify_config", mock_config):
from dify_app import DifyApp

View File

@@ -0,0 +1,250 @@
"""
Unit tests for API Token Cache module.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock, patch
from services.api_token_service import (
CACHE_KEY_PREFIX,
CACHE_NULL_TTL_SECONDS,
CACHE_TTL_SECONDS,
ApiTokenCache,
CachedApiToken,
)
class TestApiTokenCache:
"""Test cases for ApiTokenCache class."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_token = MagicMock()
self.mock_token.id = "test-token-id-123"
self.mock_token.app_id = "test-app-id-456"
self.mock_token.tenant_id = "test-tenant-id-789"
self.mock_token.type = "app"
self.mock_token.token = "test-token-value-abc"
self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
def test_make_cache_key(self):
"""Test cache key generation."""
# Test with scope
key = ApiTokenCache._make_cache_key("my-token", "app")
assert key == f"{CACHE_KEY_PREFIX}:app:my-token"
# Test without scope
key = ApiTokenCache._make_cache_key("my-token", None)
assert key == f"{CACHE_KEY_PREFIX}:any:my-token"
def test_serialize_token(self):
"""Test token serialization."""
serialized = ApiTokenCache._serialize_token(self.mock_token)
data = json.loads(serialized)
assert data["id"] == "test-token-id-123"
assert data["app_id"] == "test-app-id-456"
assert data["tenant_id"] == "test-tenant-id-789"
assert data["type"] == "app"
assert data["token"] == "test-token-value-abc"
assert data["last_used_at"] == "2026-02-03T10:00:00"
assert data["created_at"] == "2026-01-01T00:00:00"
def test_serialize_token_with_nulls(self):
"""Test token serialization with None values."""
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = None
mock_token.tenant_id = None
mock_token.type = "dataset"
mock_token.token = "test-token"
mock_token.last_used_at = None
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
serialized = ApiTokenCache._serialize_token(mock_token)
data = json.loads(serialized)
assert data["app_id"] is None
assert data["tenant_id"] is None
assert data["last_used_at"] is None
def test_deserialize_token(self):
"""Test token deserialization."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
)
result = ApiTokenCache._deserialize_token(cached_data)
assert isinstance(result, CachedApiToken)
assert result.id == "test-id"
assert result.app_id == "test-app"
assert result.tenant_id == "test-tenant"
assert result.type == "app"
assert result.token == "test-token"
assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0)
assert result.created_at == datetime(2026, 1, 1, 0, 0, 0)
def test_deserialize_null_token(self):
"""Test deserialization of null token (cached miss)."""
result = ApiTokenCache._deserialize_token("null")
assert result is None
def test_deserialize_invalid_json(self):
"""Test deserialization with invalid JSON."""
result = ApiTokenCache._deserialize_token("invalid-json{")
assert result is None
@patch("services.api_token_service.redis_client")
def test_get_cache_hit(self, mock_redis):
"""Test cache hit scenario."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
).encode("utf-8")
mock_redis.get.return_value = cached_data
result = ApiTokenCache.get("test-token", "app")
assert result is not None
assert isinstance(result, CachedApiToken)
assert result.app_id == "test-app"
mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("services.api_token_service.redis_client")
def test_get_cache_miss(self, mock_redis):
"""Test cache miss scenario."""
mock_redis.get.return_value = None
result = ApiTokenCache.get("test-token", "app")
assert result is None
mock_redis.get.assert_called_once()
@patch("services.api_token_service.redis_client")
def test_set_valid_token(self, mock_redis):
"""Test setting a valid token in cache."""
result = ApiTokenCache.set("test-token", "app", self.mock_token)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token"
assert args[1] == CACHE_TTL_SECONDS
@patch("services.api_token_service.redis_client")
def test_set_null_token(self, mock_redis):
"""Test setting a null token (cache penetration prevention)."""
result = ApiTokenCache.set("invalid-token", "app", None)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token"
assert args[1] == CACHE_NULL_TTL_SECONDS
assert args[2] == b"null"
@patch("services.api_token_service.redis_client")
def test_delete_with_scope(self, mock_redis):
"""Test deleting token cache with specific scope."""
result = ApiTokenCache.delete("test-token", "app")
assert result is True
mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("services.api_token_service.redis_client")
def test_delete_without_scope(self, mock_redis):
"""Test deleting token cache without scope (delete all)."""
# Mock scan_iter to return an iterator of keys
mock_redis.scan_iter.return_value = iter(
[
b"api_token:app:test-token",
b"api_token:dataset:test-token",
]
)
result = ApiTokenCache.delete("test-token", None)
assert result is True
# Verify scan_iter was called with the correct pattern
mock_redis.scan_iter.assert_called_once()
call_args = mock_redis.scan_iter.call_args
assert call_args[1]["match"] == f"{CACHE_KEY_PREFIX}:*:test-token"
# Verify delete was called with all matched keys
mock_redis.delete.assert_called_once_with(
b"api_token:app:test-token",
b"api_token:dataset:test-token",
)
@patch("services.api_token_service.redis_client")
def test_redis_fallback_on_exception(self, mock_redis):
"""Test Redis fallback when Redis is unavailable."""
from redis import RedisError
mock_redis.get.side_effect = RedisError("Connection failed")
result = ApiTokenCache.get("test-token", "app")
# Should return None (fallback) instead of raising exception
assert result is None
class TestApiTokenCacheIntegration:
"""Integration test scenarios."""
@patch("services.api_token_service.redis_client")
def test_full_cache_lifecycle(self, mock_redis):
"""Test complete cache lifecycle: set -> get -> delete."""
# Setup mock token
mock_token = MagicMock()
mock_token.id = "id-123"
mock_token.app_id = "app-456"
mock_token.tenant_id = "tenant-789"
mock_token.type = "app"
mock_token.token = "token-abc"
mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
# 1. Set token in cache
ApiTokenCache.set("token-abc", "app", mock_token)
assert mock_redis.setex.called
# 2. Simulate cache hit
cached_data = ApiTokenCache._serialize_token(mock_token)
mock_redis.get.return_value = cached_data # bytes from model_dump_json().encode()
retrieved = ApiTokenCache.get("token-abc", "app")
assert retrieved is not None
assert isinstance(retrieved, CachedApiToken)
# 3. Delete from cache
ApiTokenCache.delete("token-abc", "app")
assert mock_redis.delete.called
@patch("services.api_token_service.redis_client")
def test_cache_penetration_prevention(self, mock_redis):
"""Test that non-existent tokens are cached as null."""
# Set null token (cache miss)
ApiTokenCache.set("non-existent-token", "app", None)
args = mock_redis.setex.call_args[0]
assert args[2] == b"null"
assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values

View File

@@ -114,6 +114,21 @@ def mock_db_session():
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager to auto-commit on exit
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def _begin_exit_side_effect(*args, **kwargs):
# session.begin().__exit__() should commit if no exception
if args[0] is None: # No exception
session.commit()
begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.begin.return_value = begin_cm
# Mock create_session() context manager
cm = MagicMock()
cm.__enter__.return_value = session

View File

@@ -0,0 +1,128 @@
import type { SavedMessage } from '@/models/debug'
import type { SiteInfo } from '@/models/share'
import { fireEvent, render, screen } from '@testing-library/react'
import { AccessMode } from '@/models/access-control'
import HeaderSection from './header-section'
// Mock menu-dropdown (sibling with external deps)
vi.mock('../menu-dropdown', () => ({
default: ({ hideLogout, data }: { hideLogout: boolean, data: SiteInfo }) => (
<div data-testid="menu-dropdown" data-hide-logout={String(hideLogout)}>{data.title}</div>
),
}))
const baseSiteInfo: SiteInfo = {
title: 'Test App',
icon_type: 'emoji',
icon: '🤖',
icon_background: '#eee',
icon_url: '',
description: 'A description',
default_language: 'en-US',
prompt_public: false,
copyright: '',
privacy_policy: '',
custom_disclaimer: '',
show_workflow_steps: false,
use_icon_as_answer_icon: false,
chat_color_theme: '',
}
const defaultProps = {
isPC: true,
isInstalledApp: false,
isWorkflow: false,
siteInfo: baseSiteInfo,
accessMode: AccessMode.PUBLIC,
savedMessages: [] as SavedMessage[],
currentTab: 'create',
onTabChange: vi.fn(),
}
describe('HeaderSection', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Basic rendering
describe('Rendering', () => {
it('should render app title and description', () => {
render(<HeaderSection {...defaultProps} />)
// MenuDropdown mock also renders title, so use getAllByText
expect(screen.getAllByText('Test App').length).toBeGreaterThanOrEqual(1)
expect(screen.getByText('A description')).toBeInTheDocument()
})
it('should not render description when empty', () => {
render(<HeaderSection {...defaultProps} siteInfo={{ ...baseSiteInfo, description: '' }} />)
expect(screen.queryByText('A description')).not.toBeInTheDocument()
})
})
// Tab rendering
describe('Tabs', () => {
it('should render create and batch tabs', () => {
render(<HeaderSection {...defaultProps} />)
expect(screen.getByText(/share\.generation\.tabs\.create/)).toBeInTheDocument()
expect(screen.getByText(/share\.generation\.tabs\.batch/)).toBeInTheDocument()
})
it('should render saved tab when not workflow', () => {
render(<HeaderSection {...defaultProps} isWorkflow={false} />)
expect(screen.getByText(/share\.generation\.tabs\.saved/)).toBeInTheDocument()
})
it('should hide saved tab when isWorkflow is true', () => {
render(<HeaderSection {...defaultProps} isWorkflow />)
expect(screen.queryByText(/share\.generation\.tabs\.saved/)).not.toBeInTheDocument()
})
it('should show badge count for saved messages', () => {
const messages: SavedMessage[] = [
{ id: '1', answer: 'a' } as SavedMessage,
{ id: '2', answer: 'b' } as SavedMessage,
]
render(<HeaderSection {...defaultProps} savedMessages={messages} />)
expect(screen.getByText('2')).toBeInTheDocument()
})
})
// Menu dropdown
describe('MenuDropdown', () => {
it('should pass hideLogout=true when accessMode is PUBLIC', () => {
render(<HeaderSection {...defaultProps} accessMode={AccessMode.PUBLIC} />)
expect(screen.getByTestId('menu-dropdown')).toHaveAttribute('data-hide-logout', 'true')
})
it('should pass hideLogout=true when isInstalledApp', () => {
render(<HeaderSection {...defaultProps} isInstalledApp={true} accessMode={AccessMode.SPECIFIC_GROUPS_MEMBERS} />)
expect(screen.getByTestId('menu-dropdown')).toHaveAttribute('data-hide-logout', 'true')
})
it('should pass hideLogout=false when not installed and accessMode is not PUBLIC', () => {
render(<HeaderSection {...defaultProps} isInstalledApp={false} accessMode={AccessMode.SPECIFIC_GROUPS_MEMBERS} />)
expect(screen.getByTestId('menu-dropdown')).toHaveAttribute('data-hide-logout', 'false')
})
})
// Tab change callback
describe('Interaction', () => {
it('should call onTabChange when a tab is clicked', () => {
const onTabChange = vi.fn()
render(<HeaderSection {...defaultProps} onTabChange={onTabChange} />)
fireEvent.click(screen.getByText(/share\.generation\.tabs\.batch/))
expect(onTabChange).toHaveBeenCalledWith('batch')
})
})
})

View File

@@ -0,0 +1,78 @@
import type { FC } from 'react'
import type { SavedMessage } from '@/models/debug'
import type { SiteInfo } from '@/models/share'
import { RiBookmark3Line } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import AppIcon from '@/app/components/base/app-icon'
import Badge from '@/app/components/base/badge'
import TabHeader from '@/app/components/base/tab-header'
import { appDefaultIconBackground } from '@/config'
import { AccessMode } from '@/models/access-control'
import { cn } from '@/utils/classnames'
import MenuDropdown from '../menu-dropdown'
type HeaderSectionProps = {
isPC: boolean
isInstalledApp: boolean
isWorkflow: boolean
siteInfo: SiteInfo
accessMode: AccessMode
savedMessages: SavedMessage[]
currentTab: string
onTabChange: (tab: string) => void
}
const HeaderSection: FC<HeaderSectionProps> = ({
isPC,
isInstalledApp,
isWorkflow,
siteInfo,
accessMode,
savedMessages,
currentTab,
onTabChange,
}) => {
const { t } = useTranslation()
const tabItems = [
{ id: 'create', name: t('generation.tabs.create', { ns: 'share' }) },
{ id: 'batch', name: t('generation.tabs.batch', { ns: 'share' }) },
...(!isWorkflow
? [{
id: 'saved',
name: t('generation.tabs.saved', { ns: 'share' }),
isRight: true,
icon: <RiBookmark3Line className="h-4 w-4" />,
extra: savedMessages.length > 0
? <Badge className="ml-1">{savedMessages.length}</Badge>
: null,
}]
: []),
]
return (
<div className={cn('shrink-0 space-y-4 border-b border-divider-subtle', isPC ? 'bg-components-panel-bg p-8 pb-0' : 'p-4 pb-0')}>
<div className="flex items-center gap-3">
<AppIcon
size={isPC ? 'large' : 'small'}
iconType={siteInfo.icon_type}
icon={siteInfo.icon}
background={siteInfo.icon_background || appDefaultIconBackground}
imageUrl={siteInfo.icon_url}
/>
<div className="system-md-semibold grow truncate text-text-secondary">{siteInfo.title}</div>
<MenuDropdown hideLogout={isInstalledApp || accessMode === AccessMode.PUBLIC} data={siteInfo} />
</div>
{siteInfo.description && (
<div className="system-xs-regular text-text-tertiary">{siteInfo.description}</div>
)}
<TabHeader
items={tabItems}
value={currentTab}
onChange={onTabChange}
/>
</div>
)
}
export default HeaderSection

View File

@@ -0,0 +1,96 @@
import { render, screen } from '@testing-library/react'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { defaultSystemFeatures } from '@/types/feature'
import PoweredBy from './powered-by'
// Helper to override branding in system features while keeping other defaults
const setBranding = (branding: Partial<typeof defaultSystemFeatures.branding>) => {
useGlobalPublicStore.setState({
systemFeatures: {
...defaultSystemFeatures,
branding: { ...defaultSystemFeatures.branding, ...branding },
},
})
}
describe('PoweredBy', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Renders default Dify logo
describe('Default rendering', () => {
it('should render powered-by text', () => {
render(<PoweredBy isPC={true} resultExisted={false} customConfig={null} />)
expect(screen.getByText(/share\.chat\.poweredBy/)).toBeInTheDocument()
})
})
// Branding logo
describe('Custom branding', () => {
it('should render workspace logo when branding is enabled', () => {
setBranding({ enabled: true, workspace_logo: 'https://example.com/logo.png' })
render(<PoweredBy isPC={true} resultExisted={false} customConfig={null} />)
const img = screen.getByAltText('logo')
expect(img).toHaveAttribute('src', 'https://example.com/logo.png')
})
it('should render custom logo from customConfig', () => {
render(
<PoweredBy
isPC={true}
resultExisted={false}
customConfig={{ replace_webapp_logo: 'https://custom.com/logo.png' }}
/>,
)
const img = screen.getByAltText('logo')
expect(img).toHaveAttribute('src', 'https://custom.com/logo.png')
})
it('should prefer branding logo over custom config logo', () => {
setBranding({ enabled: true, workspace_logo: 'https://brand.com/logo.png' })
render(
<PoweredBy
isPC={true}
resultExisted={false}
customConfig={{ replace_webapp_logo: 'https://custom.com/logo.png' }}
/>,
)
const img = screen.getByAltText('logo')
expect(img).toHaveAttribute('src', 'https://brand.com/logo.png')
})
})
// Hidden when remove_webapp_brand
describe('Visibility', () => {
it('should return null when remove_webapp_brand is truthy', () => {
const { container } = render(
<PoweredBy
isPC={true}
resultExisted={false}
customConfig={{ remove_webapp_brand: true }}
/>,
)
expect(container.innerHTML).toBe('')
})
it('should render when remove_webapp_brand is falsy', () => {
const { container } = render(
<PoweredBy
isPC={true}
resultExisted={false}
customConfig={{ remove_webapp_brand: false }}
/>,
)
expect(container.innerHTML).not.toBe('')
})
})
})

View File

@@ -0,0 +1,40 @@
import type { FC } from 'react'
import type { CustomConfigValueType } from '@/models/share'
import { useTranslation } from 'react-i18next'
import DifyLogo from '@/app/components/base/logo/dify-logo'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { cn } from '@/utils/classnames'
type PoweredByProps = {
isPC: boolean
resultExisted: boolean
customConfig: Record<string, CustomConfigValueType> | null
}
const PoweredBy: FC<PoweredByProps> = ({ isPC, resultExisted, customConfig }) => {
const { t } = useTranslation()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
if (customConfig?.remove_webapp_brand)
return null
const brandingLogo = systemFeatures.branding.enabled ? systemFeatures.branding.workspace_logo : undefined
const customLogo = customConfig?.replace_webapp_logo
const logoSrc = brandingLogo || (typeof customLogo === 'string' ? customLogo : undefined)
return (
<div className={cn(
'flex shrink-0 items-center gap-1.5 bg-components-panel-bg py-3',
isPC ? 'px-8' : 'px-4',
!isPC && resultExisted && 'rounded-b-2xl border-b-[0.5px] border-divider-regular',
)}
>
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
{logoSrc
? <img src={logoSrc} alt="logo" className="block h-5 w-auto" />
: <DifyLogo size="small" />}
</div>
)
}
export default PoweredBy

View File

@@ -0,0 +1,157 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ResultPanel from './result-panel'
// Mock ResDownload (sibling dep with CSV logic)
vi.mock('../run-batch/res-download', () => ({
default: ({ values }: { values: Record<string, string>[] }) => (
<button data-testid="res-download">
Download (
{values.length}
)
</button>
),
}))
const defaultProps = {
isPC: true,
isShowResultPanel: false,
isCallBatchAPI: false,
totalTasks: 0,
successCount: 0,
failedCount: 0,
noPendingTask: true,
exportRes: [] as Record<string, string>[],
onRetryFailed: vi.fn(),
}
describe('ResultPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Renders children
describe('Rendering', () => {
it('should render children content', () => {
render(
<ResultPanel {...defaultProps}>
<div>Result content</div>
</ResultPanel>,
)
expect(screen.getByText('Result content')).toBeInTheDocument()
})
})
// Batch header
describe('Batch mode header', () => {
it('should show execution count when isCallBatchAPI is true', () => {
render(
<ResultPanel {...defaultProps} isCallBatchAPI={true} totalTasks={5}>
<div />
</ResultPanel>,
)
expect(screen.getByText(/share\.generation\.executions/)).toBeInTheDocument()
})
it('should not show execution header when not in batch mode', () => {
render(
<ResultPanel {...defaultProps} isCallBatchAPI={false}>
<div />
</ResultPanel>,
)
expect(screen.queryByText(/share\.generation\.executions/)).not.toBeInTheDocument()
})
it('should show download button when there are successful tasks', () => {
render(
<ResultPanel {...defaultProps} isCallBatchAPI={true} successCount={3} exportRes={[{ a: 'b' }]}>
<div />
</ResultPanel>,
)
expect(screen.getByTestId('res-download')).toBeInTheDocument()
})
it('should not show download button when no successful tasks', () => {
render(
<ResultPanel {...defaultProps} isCallBatchAPI={true} successCount={0}>
<div />
</ResultPanel>,
)
expect(screen.queryByTestId('res-download')).not.toBeInTheDocument()
})
})
// Loading indicator for pending tasks
describe('Pending tasks', () => {
it('should show loading area when there are pending tasks', () => {
const { container } = render(
<ResultPanel {...defaultProps} noPendingTask={false}>
<div />
</ResultPanel>,
)
expect(container.querySelector('.mt-4')).toBeInTheDocument()
})
it('should not show loading when all tasks are done', () => {
const { container } = render(
<ResultPanel {...defaultProps} noPendingTask={true}>
<div />
</ResultPanel>,
)
expect(container.querySelector('.mt-4')).not.toBeInTheDocument()
})
})
// Failed tasks retry bar
describe('Failed tasks retry', () => {
it('should show retry bar when batch has failed tasks', () => {
render(
<ResultPanel {...defaultProps} isCallBatchAPI={true} failedCount={2}>
<div />
</ResultPanel>,
)
expect(screen.getByText(/share\.generation\.batchFailed\.info/)).toBeInTheDocument()
expect(screen.getByText(/share\.generation\.batchFailed\.retry/)).toBeInTheDocument()
})
it('should call onRetryFailed when retry is clicked', () => {
const onRetry = vi.fn()
render(
<ResultPanel {...defaultProps} isCallBatchAPI={true} failedCount={1} onRetryFailed={onRetry}>
<div />
</ResultPanel>,
)
fireEvent.click(screen.getByText(/share\.generation\.batchFailed\.retry/))
expect(onRetry).toHaveBeenCalledTimes(1)
})
it('should not show retry bar when no failed tasks', () => {
render(
<ResultPanel {...defaultProps} isCallBatchAPI={true} failedCount={0}>
<div />
</ResultPanel>,
)
expect(screen.queryByText(/share\.generation\.batchFailed\.retry/)).not.toBeInTheDocument()
})
it('should not show retry bar when not in batch mode even with failed count', () => {
render(
<ResultPanel {...defaultProps} isCallBatchAPI={false} failedCount={3}>
<div />
</ResultPanel>,
)
expect(screen.queryByText(/share\.generation\.batchFailed\.retry/)).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,94 @@
import type { FC, ReactNode } from 'react'
import { RiErrorWarningFill } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import Loading from '@/app/components/base/loading'
import { cn } from '@/utils/classnames'
import ResDownload from '../run-batch/res-download'
type ResultPanelProps = {
isPC: boolean
isShowResultPanel: boolean
isCallBatchAPI: boolean
totalTasks: number
successCount: number
failedCount: number
noPendingTask: boolean
exportRes: Record<string, string>[]
onRetryFailed: () => void
children: ReactNode
}
const ResultPanel: FC<ResultPanelProps> = ({
isPC,
isShowResultPanel,
isCallBatchAPI,
totalTasks,
successCount,
failedCount,
noPendingTask,
exportRes,
onRetryFailed,
children,
}) => {
const { t } = useTranslation()
return (
<div
className={cn(
'relative flex h-full flex-col',
!isPC && 'h-[calc(100vh_-_36px)] rounded-t-2xl shadow-lg backdrop-blur-sm',
!isPC
? isShowResultPanel
? 'bg-background-default-burn'
: 'border-t-[0.5px] border-divider-regular bg-components-panel-bg'
: 'bg-chatbot-bg',
)}
>
{isCallBatchAPI && (
<div className={cn(
'flex shrink-0 items-center justify-between px-14 pb-2 pt-9',
!isPC && 'px-4 pb-1 pt-3',
)}
>
<div className="system-md-semibold-uppercase text-text-primary">
{t('generation.executions', { ns: 'share', num: totalTasks })}
</div>
{successCount > 0 && (
<ResDownload isMobile={!isPC} values={exportRes} />
)}
</div>
)}
<div className={cn(
'flex h-0 grow flex-col overflow-y-auto',
isPC && 'px-14 py-8',
isPC && isCallBatchAPI && 'pt-0',
!isPC && 'p-0 pb-2',
)}
>
{children}
{!noPendingTask && (
<div className="mt-4">
<Loading type="area" />
</div>
)}
</div>
{isCallBatchAPI && failedCount > 0 && (
<div className="absolute bottom-6 left-1/2 z-10 flex -translate-x-1/2 items-center gap-2 rounded-xl border border-components-panel-border bg-components-panel-bg-blur p-3 shadow-lg backdrop-blur-sm">
<RiErrorWarningFill className="h-4 w-4 text-text-destructive" />
<div className="system-sm-medium text-text-secondary">
{t('generation.batchFailed.info', { ns: 'share', num: failedCount })}
</div>
<div className="h-3.5 w-px bg-divider-regular"></div>
<div
onClick={onRetryFailed}
className="system-sm-semibold-uppercase cursor-pointer text-text-accent"
>
{t('generation.batchFailed.retry', { ns: 'share' })}
</div>
</div>
)}
</div>
)
}
export default ResultPanel

View File

@@ -0,0 +1,210 @@
import type { ChatConfig } from '@/app/components/base/chat/types'
import type { Locale } from '@/i18n-config/language'
import type { SiteInfo } from '@/models/share'
import { renderHook } from '@testing-library/react'
import { useWebAppStore } from '@/context/web-app-context'
import { PromptMode } from '@/models/debug'
import { useAppConfig } from './use-app-config'
// Mock changeLanguage side-effect
const mockChangeLanguage = vi.fn()
vi.mock('@/i18n-config/client', () => ({
changeLanguage: (...args: unknown[]) => mockChangeLanguage(...args),
}))
const baseSiteInfo: SiteInfo = {
title: 'My App',
icon_type: 'emoji',
icon: '🤖',
icon_background: '#fff',
icon_url: '',
description: 'A test app',
default_language: 'en-US' as Locale,
prompt_public: false,
copyright: '',
privacy_policy: '',
custom_disclaimer: '',
show_workflow_steps: false,
use_icon_as_answer_icon: false,
chat_color_theme: '',
}
const baseAppParams = {
user_input_form: [
{ 'text-input': { label: 'Name', variable: 'name', required: true, default: '', max_length: 100, hide: false } },
],
more_like_this: { enabled: true },
text_to_speech: { enabled: false },
file_upload: {
allowed_file_upload_methods: ['local_file'],
allowed_file_types: [],
max_length: 10,
number_limits: 3,
},
system_parameters: {
audio_file_size_limit: 50,
file_size_limit: 15,
image_file_size_limit: 10,
video_file_size_limit: 100,
workflow_file_upload_limit: 10,
},
opening_statement: '',
pre_prompt: '',
prompt_type: PromptMode.simple,
suggested_questions_after_answer: { enabled: false },
speech_to_text: { enabled: false },
retriever_resource: { enabled: false },
sensitive_word_avoidance: { enabled: false },
agent_mode: { enabled: false, tools: [] },
dataset_configs: { datasets: { datasets: [] }, retrieval_model: 'single' },
} as unknown as ChatConfig
describe('useAppConfig', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Default state when store has no data
describe('Default state', () => {
it('should return not-ready state when store is empty', () => {
useWebAppStore.setState({ appInfo: null, appParams: null })
const { result } = renderHook(() => useAppConfig())
expect(result.current.appId).toBe('')
expect(result.current.siteInfo).toBeNull()
expect(result.current.promptConfig).toBeNull()
expect(result.current.isReady).toBe(false)
})
})
// Deriving config from store data
describe('Config derivation', () => {
it('should derive appId and siteInfo from appInfo', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-123', site: baseSiteInfo, custom_config: null },
appParams: baseAppParams,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.appId).toBe('app-123')
expect(result.current.siteInfo?.title).toBe('My App')
})
it('should derive promptConfig with prompt_variables from user_input_form', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: null },
appParams: baseAppParams,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.promptConfig).not.toBeNull()
expect(result.current.promptConfig!.prompt_variables).toHaveLength(1)
expect(result.current.promptConfig!.prompt_variables[0].key).toBe('name')
})
it('should derive moreLikeThisConfig and textToSpeechConfig from appParams', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: null },
appParams: baseAppParams,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.moreLikeThisConfig).toEqual({ enabled: true })
expect(result.current.textToSpeechConfig).toEqual({ enabled: false })
})
it('should derive visionConfig from file_upload and system_parameters', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: null },
appParams: baseAppParams,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.visionConfig.transfer_methods).toEqual(['local_file'])
expect(result.current.visionConfig.image_file_size_limit).toBe(10)
})
it('should return default visionConfig when appParams is null', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: null },
appParams: null,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.visionConfig.enabled).toBe(false)
expect(result.current.visionConfig.number_limits).toBe(2)
})
it('should return customConfig from appInfo', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: { remove_webapp_brand: true } },
appParams: baseAppParams,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.customConfig).toEqual({ remove_webapp_brand: true })
})
})
// Readiness condition
describe('isReady', () => {
it('should be true when appId, siteInfo and promptConfig are all present', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: null },
appParams: baseAppParams,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.isReady).toBe(true)
})
it('should be false when appParams is missing (no promptConfig)', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: null },
appParams: null,
})
const { result } = renderHook(() => useAppConfig())
expect(result.current.isReady).toBe(false)
})
it('should be false when appInfo is missing (no appId or siteInfo)', () => {
useWebAppStore.setState({ appInfo: null, appParams: baseAppParams })
const { result } = renderHook(() => useAppConfig())
expect(result.current.isReady).toBe(false)
})
})
// Language sync side-effect
describe('Language sync', () => {
it('should call changeLanguage when siteInfo has default_language', () => {
useWebAppStore.setState({
appInfo: { app_id: 'app-1', site: baseSiteInfo, custom_config: null },
appParams: baseAppParams,
})
renderHook(() => useAppConfig())
expect(mockChangeLanguage).toHaveBeenCalledWith('en-US')
})
it('should not call changeLanguage when siteInfo is null', () => {
useWebAppStore.setState({ appInfo: null, appParams: null })
renderHook(() => useAppConfig())
expect(mockChangeLanguage).not.toHaveBeenCalled()
})
})
})

View File

@@ -0,0 +1,84 @@
import type { AccessMode } from '@/models/access-control'
import type {
MoreLikeThisConfig,
PromptConfig,
TextToSpeechConfig,
} from '@/models/debug'
import type { CustomConfigValueType, SiteInfo } from '@/models/share'
import type { VisionSettings } from '@/types/app'
import { useEffect, useMemo } from 'react'
import { useWebAppStore } from '@/context/web-app-context'
import { changeLanguage } from '@/i18n-config/client'
import { Resolution, TransferMethod } from '@/types/app'
import { userInputsFormToPromptVariables } from '@/utils/model-config'
const DEFAULT_VISION_CONFIG: VisionSettings = {
enabled: false,
number_limits: 2,
detail: Resolution.low,
transfer_methods: [TransferMethod.local_file],
}
export type AppConfig = {
appId: string
siteInfo: SiteInfo | null
customConfig: Record<string, CustomConfigValueType> | null
promptConfig: PromptConfig | null
moreLikeThisConfig: MoreLikeThisConfig | null
textToSpeechConfig: TextToSpeechConfig | null
visionConfig: VisionSettings
accessMode: AccessMode
isReady: boolean
}
export function useAppConfig(): AppConfig {
const appData = useWebAppStore(s => s.appInfo)
const appParams = useWebAppStore(s => s.appParams)
const accessMode = useWebAppStore(s => s.webAppAccessMode)
const appId = appData?.app_id ?? ''
const siteInfo = (appData?.site as SiteInfo) ?? null
const customConfig = appData?.custom_config ?? null
const promptConfig = useMemo<PromptConfig | null>(() => {
if (!appParams)
return null
const prompt_variables = userInputsFormToPromptVariables(appParams.user_input_form)
return { prompt_template: '', prompt_variables } as PromptConfig
}, [appParams])
const moreLikeThisConfig: MoreLikeThisConfig | null = appParams?.more_like_this ?? null
const textToSpeechConfig: TextToSpeechConfig | null = appParams?.text_to_speech ?? null
const visionConfig = useMemo<VisionSettings>(() => {
if (!appParams)
return DEFAULT_VISION_CONFIG
const { file_upload, system_parameters } = appParams
return {
...file_upload,
transfer_methods: file_upload?.allowed_file_upload_methods || file_upload?.allowed_upload_methods || [],
image_file_size_limit: system_parameters?.image_file_size_limit,
fileUploadConfig: system_parameters,
} as unknown as VisionSettings
}, [appParams])
// Sync language when site info changes
useEffect(() => {
if (siteInfo?.default_language)
changeLanguage(siteInfo.default_language)
}, [siteInfo?.default_language])
const isReady = !!(appId && siteInfo && promptConfig)
return {
appId,
siteInfo,
customConfig,
promptConfig,
moreLikeThisConfig,
textToSpeechConfig,
visionConfig,
accessMode,
isReady,
}
}

View File

@@ -0,0 +1,299 @@
import type { PromptConfig } from '@/models/debug'
import { act, renderHook } from '@testing-library/react'
import { TaskStatus } from '../types'
import { useBatchTasks } from './use-batch-tasks'
vi.mock('@/app/components/base/toast', () => ({
default: { notify: vi.fn() },
}))
const createPromptConfig = (overrides?: Partial<PromptConfig>): PromptConfig => ({
prompt_template: '',
prompt_variables: [
{ key: 'name', name: 'Name', type: 'string', required: true, max_length: 100 },
{ key: 'age', name: 'Age', type: 'string', required: false, max_length: 10 },
] as PromptConfig['prompt_variables'],
...overrides,
})
// Build a valid CSV data matrix: [header, ...rows]
const buildCsvData = (rows: string[][]): string[][] => [
['Name', 'Age'],
...rows,
]
describe('useBatchTasks', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Initial state
describe('Initial state', () => {
it('should start with empty task list and batch mode off', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
expect(result.current.isCallBatchAPI).toBe(false)
expect(result.current.allTaskList).toEqual([])
expect(result.current.noPendingTask).toBe(true)
expect(result.current.allTasksRun).toBe(true)
})
})
// Batch validation via startBatchRun
describe('startBatchRun validation', () => {
it('should reject empty data', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
let ok = false
act(() => {
ok = result.current.startBatchRun([])
})
expect(ok).toBe(false)
expect(result.current.isCallBatchAPI).toBe(false)
})
it('should reject data with mismatched header', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const data = [['Wrong', 'Header'], ['a', 'b']]
let ok = false
act(() => {
ok = result.current.startBatchRun(data)
})
expect(ok).toBe(false)
})
it('should reject data with no payload rows (header only)', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const data = [['Name', 'Age']]
let ok = false
act(() => {
ok = result.current.startBatchRun(data)
})
expect(ok).toBe(false)
})
it('should reject when required field is empty', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const data = buildCsvData([['', '25']])
let ok = false
act(() => {
ok = result.current.startBatchRun(data)
})
expect(ok).toBe(false)
})
it('should reject when required field exceeds max_length', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const longName = 'a'.repeat(101)
const data = buildCsvData([[longName, '25']])
let ok = false
act(() => {
ok = result.current.startBatchRun(data)
})
expect(ok).toBe(false)
})
})
// Successful batch run
describe('startBatchRun success', () => {
it('should create tasks and enable batch mode', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const data = buildCsvData([['Alice', '30'], ['Bob', '25']])
let ok = false
act(() => {
ok = result.current.startBatchRun(data)
})
expect(ok).toBe(true)
expect(result.current.isCallBatchAPI).toBe(true)
expect(result.current.allTaskList).toHaveLength(2)
expect(result.current.allTaskList[0].params.inputs.name).toBe('Alice')
expect(result.current.allTaskList[1].params.inputs.name).toBe('Bob')
})
it('should set first tasks to running status (limited by BATCH_CONCURRENCY)', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const data = buildCsvData([['Alice', '30'], ['Bob', '25']])
act(() => {
result.current.startBatchRun(data)
})
// Both should be running since 2 < BATCH_CONCURRENCY (5)
expect(result.current.allTaskList[0].status).toBe(TaskStatus.running)
expect(result.current.allTaskList[1].status).toBe(TaskStatus.running)
})
it('should set excess tasks to pending when exceeding BATCH_CONCURRENCY', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
// Create 7 tasks (BATCH_CONCURRENCY=5, so 2 should be pending)
const rows = Array.from({ length: 7 }, (_, i) => [`User${i}`, `${20 + i}`])
const data = buildCsvData(rows)
act(() => {
result.current.startBatchRun(data)
})
const running = result.current.allTaskList.filter(t => t.status === TaskStatus.running)
const pending = result.current.allTaskList.filter(t => t.status === TaskStatus.pending)
expect(running).toHaveLength(5)
expect(pending).toHaveLength(2)
})
})
// Task completion handling
describe('handleCompleted', () => {
it('should mark task as completed on success', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
act(() => {
result.current.startBatchRun(buildCsvData([['Alice', '30']]))
})
act(() => {
result.current.handleCompleted('result text', 1, true)
})
expect(result.current.allTaskList[0].status).toBe(TaskStatus.completed)
})
it('should mark task as failed on failure', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
act(() => {
result.current.startBatchRun(buildCsvData([['Alice', '30']]))
})
act(() => {
result.current.handleCompleted('', 1, false)
})
expect(result.current.allTaskList[0].status).toBe(TaskStatus.failed)
expect(result.current.allFailedTaskList).toHaveLength(1)
})
it('should promote pending tasks to running when group completes', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
// 7 tasks: first 5 running, last 2 pending
const rows = Array.from({ length: 7 }, (_, i) => [`User${i}`, `${20 + i}`])
act(() => {
result.current.startBatchRun(buildCsvData(rows))
})
// Complete all 5 running tasks
for (let i = 1; i <= 5; i++) {
act(() => {
result.current.handleCompleted(`res${i}`, i, true)
})
}
// Tasks 6 and 7 should now be running
expect(result.current.allTaskList[5].status).toBe(TaskStatus.running)
expect(result.current.allTaskList[6].status).toBe(TaskStatus.running)
})
})
// Derived task lists
describe('Derived lists', () => {
it('should compute showTaskList excluding pending tasks', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const rows = Array.from({ length: 7 }, (_, i) => [`User${i}`, `${i}`])
act(() => {
result.current.startBatchRun(buildCsvData(rows))
})
expect(result.current.showTaskList).toHaveLength(5) // 5 running
expect(result.current.noPendingTask).toBe(false)
})
it('should compute allTasksRun when all tasks completed or failed', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
act(() => {
result.current.startBatchRun(buildCsvData([['Alice', '30'], ['Bob', '25']]))
})
expect(result.current.allTasksRun).toBe(false)
act(() => {
result.current.handleCompleted('res1', 1, true)
})
act(() => {
result.current.handleCompleted('', 2, false)
})
expect(result.current.allTasksRun).toBe(true)
expect(result.current.allSuccessTaskList).toHaveLength(1)
expect(result.current.allFailedTaskList).toHaveLength(1)
})
})
// Clear state
describe('clearBatchState', () => {
it('should reset batch mode and task list', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
act(() => {
result.current.startBatchRun(buildCsvData([['Alice', '30']]))
})
expect(result.current.isCallBatchAPI).toBe(true)
act(() => {
result.current.clearBatchState()
})
expect(result.current.isCallBatchAPI).toBe(false)
expect(result.current.allTaskList).toEqual([])
})
})
// Export results
describe('exportRes', () => {
it('should format export data with variable names as keys', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
act(() => {
result.current.startBatchRun(buildCsvData([['Alice', '30']]))
})
act(() => {
result.current.handleCompleted('Generated text', 1, true)
})
const exported = result.current.exportRes
expect(exported).toHaveLength(1)
expect(exported[0].Name).toBe('Alice')
expect(exported[0].Age).toBe('30')
})
it('should use empty string for missing optional inputs', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
act(() => {
result.current.startBatchRun(buildCsvData([['Alice', '']]))
})
act(() => {
result.current.handleCompleted('res', 1, true)
})
expect(result.current.exportRes[0].Age).toBe('')
})
})
// Retry failed tasks
describe('handleRetryAllFailedTask', () => {
it('should update controlRetry timestamp', () => {
const { result } = renderHook(() => useBatchTasks(createPromptConfig()))
const before = result.current.controlRetry
act(() => {
result.current.handleRetryAllFailedTask()
})
expect(result.current.controlRetry).toBeGreaterThan(before)
})
})
})

View File

@@ -0,0 +1,219 @@
import type { TFunction } from 'i18next'
import type { Task } from '../types'
import type { PromptConfig } from '@/models/debug'
import { useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import { BATCH_CONCURRENCY } from '@/config'
import { TaskStatus } from '../types'
function validateBatchData(
data: string[][],
promptVariables: PromptConfig['prompt_variables'],
t: TFunction,
): string | null {
if (!data?.length)
return t('generation.errorMsg.empty', { ns: 'share' })
// Validate header matches prompt variables
const header = data[0]
if (promptVariables.some((v, i) => v.name !== header[i]))
return t('generation.errorMsg.fileStructNotMatch', { ns: 'share' })
const rows = data.slice(1)
if (!rows.length)
return t('generation.errorMsg.atLeastOne', { ns: 'share' })
// Detect non-consecutive empty lines (empty rows in the middle of data)
const emptyIndexes = rows
.map((row, i) => row.every(c => c === '') ? i : -1)
.filter(i => i >= 0)
if (emptyIndexes.length > 0) {
let prev = emptyIndexes[0] - 1
for (const idx of emptyIndexes) {
if (prev + 1 !== idx)
return t('generation.errorMsg.emptyLine', { ns: 'share', rowIndex: prev + 2 })
prev = idx
}
}
// Remove trailing empty rows and re-check
const nonEmptyRows = rows.filter(row => !row.every(c => c === ''))
if (!nonEmptyRows.length)
return t('generation.errorMsg.atLeastOne', { ns: 'share' })
// Validate individual row values
for (let r = 0; r < nonEmptyRows.length; r++) {
const row = nonEmptyRows[r]
for (let v = 0; v < promptVariables.length; v++) {
const varItem = promptVariables[v]
if (varItem.type === 'string' && varItem.max_length && row[v].length > varItem.max_length) {
return t('generation.errorMsg.moreThanMaxLengthLine', {
ns: 'share',
rowIndex: r + 2,
varName: varItem.name,
maxLength: varItem.max_length,
})
}
if (varItem.required && row[v].trim() === '') {
return t('generation.errorMsg.invalidLine', {
ns: 'share',
rowIndex: r + 2,
varName: varItem.name,
})
}
}
}
return null
}
export function useBatchTasks(promptConfig: PromptConfig | null) {
const { t } = useTranslation()
const [isCallBatchAPI, setIsCallBatchAPI] = useState(false)
const [controlRetry, setControlRetry] = useState(0)
// Task list with ref for accessing latest value in async callbacks
const [allTaskList, doSetAllTaskList] = useState<Task[]>([])
const allTaskListRef = useRef<Task[]>([])
const setAllTaskList = useCallback((tasks: Task[]) => {
doSetAllTaskList(tasks)
allTaskListRef.current = tasks
}, [])
// Batch completion results stored in ref (no re-render needed on each update)
const batchCompletionResRef = useRef<Record<string, string>>({})
const currGroupNumRef = useRef(0)
// Derived task lists
const pendingTaskList = allTaskList.filter(task => task.status === TaskStatus.pending)
const noPendingTask = pendingTaskList.length === 0
const showTaskList = allTaskList.filter(task => task.status !== TaskStatus.pending)
const allSuccessTaskList = allTaskList.filter(task => task.status === TaskStatus.completed)
const allFailedTaskList = allTaskList.filter(task => task.status === TaskStatus.failed)
const allTasksFinished = allTaskList.every(task => task.status === TaskStatus.completed)
const allTasksRun = allTaskList.every(task =>
[TaskStatus.completed, TaskStatus.failed].includes(task.status),
)
// Export-ready results for CSV download
const exportRes = allTaskList.map((task) => {
const completionRes = batchCompletionResRef.current
const res: Record<string, string> = {}
const { inputs } = task.params
promptConfig?.prompt_variables.forEach((v) => {
res[v.name] = inputs[v.key] ?? ''
})
let result = completionRes[task.id]
if (typeof result === 'object')
result = JSON.stringify(result)
res[t('generation.completionResult', { ns: 'share' })] = result
return res
})
// Clear batch state (used when switching to single-run mode)
const clearBatchState = useCallback(() => {
setIsCallBatchAPI(false)
setAllTaskList([])
}, [setAllTaskList])
// Attempt to start a batch run. Returns true on success, false on validation failure.
const startBatchRun = useCallback((data: string[][]): boolean => {
const error = validateBatchData(data, promptConfig?.prompt_variables ?? [], t)
if (error) {
Toast.notify({ type: 'error', message: error })
return false
}
if (!allTasksFinished) {
Toast.notify({ type: 'info', message: t('errorMessage.waitForBatchResponse', { ns: 'appDebug' }) })
return false
}
const payloadData = data.filter(row => !row.every(c => c === '')).slice(1)
const varLen = promptConfig?.prompt_variables.length ?? 0
const tasks: Task[] = payloadData.map((item, i) => {
const inputs: Record<string, string | undefined> = {}
if (varLen > 0) {
item.slice(0, varLen).forEach((input, index) => {
const varSchema = promptConfig?.prompt_variables[index]
const key = varSchema?.key as string
if (!input)
inputs[key] = (varSchema?.type === 'string' || varSchema?.type === 'paragraph') ? '' : undefined
else
inputs[key] = input
})
}
return {
id: i + 1,
status: i < BATCH_CONCURRENCY ? TaskStatus.running : TaskStatus.pending,
params: { inputs },
}
})
setAllTaskList(tasks)
currGroupNumRef.current = 0
batchCompletionResRef.current = {}
setIsCallBatchAPI(true)
return true
}, [allTasksFinished, promptConfig?.prompt_variables, setAllTaskList, t])
// Callback invoked when a single task completes; manages group concurrency.
const handleCompleted = useCallback((completionRes: string, taskId?: number, isSuccess?: boolean) => {
const latestTasks = allTaskListRef.current
const latestCompletionRes = batchCompletionResRef.current
const pending = latestTasks.filter(task => task.status === TaskStatus.pending)
const doneCount = 1 + latestTasks.filter(task =>
[TaskStatus.completed, TaskStatus.failed].includes(task.status),
).length
const shouldAddNextGroup
= currGroupNumRef.current !== doneCount
&& pending.length > 0
&& (doneCount % BATCH_CONCURRENCY === 0 || latestTasks.length - doneCount < BATCH_CONCURRENCY)
if (shouldAddNextGroup)
currGroupNumRef.current = doneCount
const nextPendingIds = shouldAddNextGroup
? pending.slice(0, BATCH_CONCURRENCY).map(t => t.id)
: []
const updatedTasks = latestTasks.map((item) => {
if (item.id === taskId)
return { ...item, status: isSuccess ? TaskStatus.completed : TaskStatus.failed }
if (shouldAddNextGroup && nextPendingIds.includes(item.id))
return { ...item, status: TaskStatus.running }
return item
})
setAllTaskList(updatedTasks)
if (taskId) {
batchCompletionResRef.current = {
...latestCompletionRes,
[`${taskId}`]: completionRes,
}
}
}, [setAllTaskList])
const handleRetryAllFailedTask = useCallback(() => {
setControlRetry(Date.now())
}, [])
return {
isCallBatchAPI,
controlRetry,
allTaskList,
showTaskList,
noPendingTask,
allSuccessTaskList,
allFailedTaskList,
allTasksRun,
exportRes,
clearBatchState,
startBatchRun,
handleCompleted,
handleRetryAllFailedTask,
}
}

View File

@@ -0,0 +1,141 @@
import type { ReactNode } from 'react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { renderHook, waitFor } from '@testing-library/react'
import * as React from 'react'
import { AppSourceType } from '@/service/share'
import {
useInvalidateSavedMessages,
useRemoveMessageMutation,
useSavedMessages,
useSaveMessageMutation,
} from './use-saved-messages'
// Mock service layer (preserve enum exports)
vi.mock('@/service/share', async (importOriginal) => {
const actual = await importOriginal<Record<string, unknown>>()
return {
...actual,
fetchSavedMessage: vi.fn(),
saveMessage: vi.fn(),
removeMessage: vi.fn(),
}
})
vi.mock('@/app/components/base/toast', () => ({
default: { notify: vi.fn() },
}))
// Get mocked functions for assertion
const shareModule = await import('@/service/share')
const mockFetchSavedMessage = shareModule.fetchSavedMessage as ReturnType<typeof vi.fn>
const mockSaveMessage = shareModule.saveMessage as ReturnType<typeof vi.fn>
const mockRemoveMessage = shareModule.removeMessage as ReturnType<typeof vi.fn>
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: { queries: { retry: false }, mutations: { retry: false } },
})
return ({ children }: { children: ReactNode }) =>
React.createElement(QueryClientProvider, { client: queryClient }, children)
}
const APP_SOURCE = AppSourceType.webApp
const APP_ID = 'test-app-id'
describe('useSavedMessages', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Fetching saved messages
describe('Query behavior', () => {
it('should fetch saved messages when enabled and appId present', async () => {
mockFetchSavedMessage.mockResolvedValue({ data: [{ id: 'm1', answer: 'Hello' }] })
const { result } = renderHook(
() => useSavedMessages(APP_SOURCE, APP_ID, true),
{ wrapper: createWrapper() },
)
await waitFor(() => expect(result.current.isSuccess).toBe(true))
expect(result.current.data).toEqual([{ id: 'm1', answer: 'Hello' }])
expect(mockFetchSavedMessage).toHaveBeenCalledWith(APP_SOURCE, APP_ID)
})
it('should not fetch when disabled', () => {
const { result } = renderHook(
() => useSavedMessages(APP_SOURCE, APP_ID, false),
{ wrapper: createWrapper() },
)
expect(result.current.fetchStatus).toBe('idle')
expect(mockFetchSavedMessage).not.toHaveBeenCalled()
})
it('should not fetch when appId is empty', () => {
const { result } = renderHook(
() => useSavedMessages(APP_SOURCE, '', true),
{ wrapper: createWrapper() },
)
expect(result.current.fetchStatus).toBe('idle')
expect(mockFetchSavedMessage).not.toHaveBeenCalled()
})
})
})
describe('useSaveMessageMutation', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should call saveMessage service on mutate', async () => {
mockSaveMessage.mockResolvedValue({})
const { result } = renderHook(
() => useSaveMessageMutation(APP_SOURCE, APP_ID),
{ wrapper: createWrapper() },
)
result.current.mutate('msg-1')
await waitFor(() => expect(result.current.isSuccess).toBe(true))
expect(mockSaveMessage).toHaveBeenCalledWith('msg-1', APP_SOURCE, APP_ID)
})
})
describe('useRemoveMessageMutation', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should call removeMessage service on mutate', async () => {
mockRemoveMessage.mockResolvedValue({})
const { result } = renderHook(
() => useRemoveMessageMutation(APP_SOURCE, APP_ID),
{ wrapper: createWrapper() },
)
result.current.mutate('msg-2')
await waitFor(() => expect(result.current.isSuccess).toBe(true))
expect(mockRemoveMessage).toHaveBeenCalledWith('msg-2', APP_SOURCE, APP_ID)
})
})
describe('useInvalidateSavedMessages', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should return a callable invalidation function', () => {
const { result } = renderHook(
() => useInvalidateSavedMessages(APP_SOURCE, APP_ID),
{ wrapper: createWrapper() },
)
expect(typeof result.current).toBe('function')
expect(() => result.current()).not.toThrow()
})
})

View File

@@ -0,0 +1,79 @@
import type { SavedMessage } from '@/models/debug'
import type { AppSourceType } from '@/service/share'
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import {
fetchSavedMessage,
removeMessage,
saveMessage,
} from '@/service/share'
const NAME_SPACE = 'text-generation'
export const savedMessagesQueryKeys = {
all: (appSourceType: AppSourceType, appId: string) =>
[NAME_SPACE, 'savedMessages', appSourceType, appId] as const,
}
export function useSavedMessages(
appSourceType: AppSourceType,
appId: string,
enabled = true,
) {
return useQuery<SavedMessage[]>({
queryKey: savedMessagesQueryKeys.all(appSourceType, appId),
queryFn: async () => {
const res = await fetchSavedMessage(appSourceType, appId) as { data: SavedMessage[] }
return res.data
},
enabled: enabled && !!appId,
})
}
export function useInvalidateSavedMessages(
appSourceType: AppSourceType,
appId: string,
) {
const queryClient = useQueryClient()
return () => {
queryClient.invalidateQueries({
queryKey: savedMessagesQueryKeys.all(appSourceType, appId),
})
}
}
export function useSaveMessageMutation(
appSourceType: AppSourceType,
appId: string,
) {
const { t } = useTranslation()
const invalidate = useInvalidateSavedMessages(appSourceType, appId)
return useMutation({
mutationFn: (messageId: string) =>
saveMessage(messageId, appSourceType, appId),
onSuccess: () => {
Toast.notify({ type: 'success', message: t('api.saved', { ns: 'common' }) })
invalidate()
},
})
}
export function useRemoveMessageMutation(
appSourceType: AppSourceType,
appId: string,
) {
const { t } = useTranslation()
const invalidate = useInvalidateSavedMessages(appSourceType, appId)
return useMutation({
mutationFn: (messageId: string) =>
removeMessage(messageId, appSourceType, appId),
onSuccess: () => {
Toast.notify({ type: 'success', message: t('api.remove', { ns: 'common' }) })
invalidate()
},
})
}

View File

@@ -1,65 +1,29 @@
'use client'
import type { FC } from 'react'
import type {
MoreLikeThisConfig,
PromptConfig,
SavedMessage,
TextToSpeechConfig,
} from '@/models/debug'
import type { InputValueTypes, Task } from './types'
import type { InstalledApp } from '@/models/explore'
import type { SiteInfo } from '@/models/share'
import type { VisionFile, VisionSettings } from '@/types/app'
import {
RiBookmark3Line,
RiErrorWarningFill,
} from '@remixicon/react'
import type { VisionFile } from '@/types/app'
import { useBoolean } from 'ahooks'
import { useSearchParams } from 'next/navigation'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import SavedItems from '@/app/components/app/text-generate/saved-items'
import AppIcon from '@/app/components/base/app-icon'
import Badge from '@/app/components/base/badge'
import Loading from '@/app/components/base/loading'
import DifyLogo from '@/app/components/base/logo/dify-logo'
import Toast from '@/app/components/base/toast'
import Res from '@/app/components/share/text-generation/result'
import RunOnce from '@/app/components/share/text-generation/run-once'
import { appDefaultIconBackground, BATCH_CONCURRENCY } from '@/config'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useWebAppStore } from '@/context/web-app-context'
import { useAppFavicon } from '@/hooks/use-app-favicon'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import useDocumentTitle from '@/hooks/use-document-title'
import { changeLanguage } from '@/i18n-config/client'
import { AccessMode } from '@/models/access-control'
import { AppSourceType, fetchSavedMessage as doFetchSavedMessage, removeMessage, saveMessage } from '@/service/share'
import { Resolution, TransferMethod } from '@/types/app'
import { AppSourceType } from '@/service/share'
import { cn } from '@/utils/classnames'
import { userInputsFormToPromptVariables } from '@/utils/model-config'
import TabHeader from '../../base/tab-header'
import MenuDropdown from './menu-dropdown'
import HeaderSection from './components/header-section'
import PoweredBy from './components/powered-by'
import ResultPanel from './components/result-panel'
import { useAppConfig } from './hooks/use-app-config'
import { useBatchTasks } from './hooks/use-batch-tasks'
import { useRemoveMessageMutation, useSavedMessages, useSaveMessageMutation } from './hooks/use-saved-messages'
import RunBatch from './run-batch'
import ResDownload from './run-batch/res-download'
const GROUP_SIZE = BATCH_CONCURRENCY // to avoid RPM(Request per minute) limit. The group task finished then the next group.
enum TaskStatus {
pending = 'pending',
running = 'running',
completed = 'completed',
failed = 'failed',
}
type TaskParam = {
inputs: Record<string, any>
}
type Task = {
id: number
status: TaskStatus
params: TaskParam
}
import { TaskStatus } from './types'
export type IMainProps = {
isInstalledApp?: boolean
@@ -71,9 +35,6 @@ const TextGeneration: FC<IMainProps> = ({
isInstalledApp = false,
isWorkflow = false,
}) => {
const { notify } = Toast
const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp
const { t } = useTranslation()
const media = useBreakpoints()
const isPC = media === MediaType.pc
@@ -82,325 +43,81 @@ const TextGeneration: FC<IMainProps> = ({
const mode = searchParams.get('mode') || 'create'
const [currentTab, setCurrentTab] = useState<string>(['create', 'batch'].includes(mode) ? mode : 'create')
// Notice this situation isCallBatchAPI but not in batch tab
const [isCallBatchAPI, setIsCallBatchAPI] = useState(false)
const isInBatchTab = currentTab === 'batch'
const [inputs, doSetInputs] = useState<Record<string, any>>({})
// App configuration derived from store
const {
appId,
siteInfo,
customConfig,
promptConfig,
moreLikeThisConfig,
textToSpeechConfig,
visionConfig,
accessMode,
isReady,
} = useAppConfig()
const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp
// Saved messages (React Query)
const { data: savedMessages = [] } = useSavedMessages(appSourceType, appId, !isWorkflow)
const saveMutation = useSaveMessageMutation(appSourceType, appId)
const removeMutation = useRemoveMessageMutation(appSourceType, appId)
// Batch task management
const {
isCallBatchAPI,
controlRetry,
allTaskList,
showTaskList,
noPendingTask,
allSuccessTaskList,
allFailedTaskList,
allTasksRun,
exportRes,
clearBatchState,
startBatchRun,
handleCompleted,
handleRetryAllFailedTask,
} = useBatchTasks(promptConfig)
// Input state with ref for accessing latest value in async callbacks
const [inputs, doSetInputs] = useState<Record<string, InputValueTypes>>({})
const inputsRef = useRef(inputs)
const setInputs = useCallback((newInputs: Record<string, any>) => {
const setInputs = useCallback((newInputs: Record<string, InputValueTypes>) => {
doSetInputs(newInputs)
inputsRef.current = newInputs
}, [])
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const [appId, setAppId] = useState<string>('')
const [siteInfo, setSiteInfo] = useState<SiteInfo | null>(null)
const [customConfig, setCustomConfig] = useState<Record<string, any> | null>(null)
const [promptConfig, setPromptConfig] = useState<PromptConfig | null>(null)
const [moreLikeThisConfig, setMoreLikeThisConfig] = useState<MoreLikeThisConfig | null>(null)
const [textToSpeechConfig, setTextToSpeechConfig] = useState<TextToSpeechConfig | null>(null)
// save message
const [savedMessages, setSavedMessages] = useState<SavedMessage[]>([])
const fetchSavedMessage = useCallback(async () => {
if (!appId)
return
const res: any = await doFetchSavedMessage(appSourceType, appId)
setSavedMessages(res.data)
}, [appSourceType, appId])
const handleSaveMessage = async (messageId: string) => {
await saveMessage(messageId, appSourceType, appId)
notify({ type: 'success', message: t('api.saved', { ns: 'common' }) })
fetchSavedMessage()
}
const handleRemoveSavedMessage = async (messageId: string) => {
await removeMessage(messageId, appSourceType, appId)
notify({ type: 'success', message: t('api.remove', { ns: 'common' }) })
fetchSavedMessage()
}
// send message task
// Send control signals
const [controlSend, setControlSend] = useState(0)
const [controlStopResponding, setControlStopResponding] = useState(0)
const [visionConfig, setVisionConfig] = useState<VisionSettings>({
enabled: false,
number_limits: 2,
detail: Resolution.low,
transfer_methods: [TransferMethod.local_file],
})
const [completionFiles, setCompletionFiles] = useState<VisionFile[]>([])
const [runControl, setRunControl] = useState<{ onStop: () => Promise<void> | void, isStopping: boolean } | null>(null)
useEffect(() => {
if (isCallBatchAPI)
setRunControl(null)
}, [isCallBatchAPI])
// Result panel visibility
const [isShowResultPanel, { setTrue: doShowResultPanel, setFalse: hideResultPanel }] = useBoolean(false)
const showResultPanel = useCallback(() => {
// Delay to avoid useClickAway closing the panel immediately
setTimeout(doShowResultPanel, 0)
}, [doShowResultPanel])
const [resultExisted, setResultExisted] = useState(false)
const handleSend = () => {
setIsCallBatchAPI(false)
const handleSend = useCallback(() => {
clearBatchState()
setControlSend(Date.now())
// eslint-disable-next-line ts/no-use-before-define
setAllTaskList([]) // clear batch task running status
// eslint-disable-next-line ts/no-use-before-define
showResultPanel()
}
}, [clearBatchState, showResultPanel])
const [controlRetry, setControlRetry] = useState(0)
const handleRetryAllFailedTask = () => {
setControlRetry(Date.now())
}
const [allTaskList, doSetAllTaskList] = useState<Task[]>([])
const allTaskListRef = useRef<Task[]>([])
const getLatestTaskList = () => allTaskListRef.current
const setAllTaskList = (taskList: Task[]) => {
doSetAllTaskList(taskList)
allTaskListRef.current = taskList
}
const pendingTaskList = allTaskList.filter(task => task.status === TaskStatus.pending)
const noPendingTask = pendingTaskList.length === 0
const showTaskList = allTaskList.filter(task => task.status !== TaskStatus.pending)
const currGroupNumRef = useRef(0)
const setCurrGroupNum = (num: number) => {
currGroupNumRef.current = num
}
const getCurrGroupNum = () => {
return currGroupNumRef.current
}
const allSuccessTaskList = allTaskList.filter(task => task.status === TaskStatus.completed)
const allFailedTaskList = allTaskList.filter(task => task.status === TaskStatus.failed)
const allTasksFinished = allTaskList.every(task => task.status === TaskStatus.completed)
const allTasksRun = allTaskList.every(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status))
const batchCompletionResRef = useRef<Record<string, string>>({})
const setBatchCompletionRes = (res: Record<string, string>) => {
batchCompletionResRef.current = res
}
const getBatchCompletionRes = () => batchCompletionResRef.current
const exportRes = allTaskList.map((task) => {
const batchCompletionResLatest = getBatchCompletionRes()
const res: Record<string, string> = {}
const { inputs } = task.params
promptConfig?.prompt_variables.forEach((v) => {
res[v.name] = inputs[v.key]
})
let result = batchCompletionResLatest[task.id]
// task might return multiple fields, should marshal object to string
if (typeof batchCompletionResLatest[task.id] === 'object')
result = JSON.stringify(result)
res[t('generation.completionResult', { ns: 'share' })] = result
return res
})
const checkBatchInputs = (data: string[][]) => {
if (!data || data.length === 0) {
notify({ type: 'error', message: t('generation.errorMsg.empty', { ns: 'share' }) })
return false
}
const headerData = data[0]
let isMapVarName = true
promptConfig?.prompt_variables.forEach((item, index) => {
if (!isMapVarName)
return
if (item.name !== headerData[index])
isMapVarName = false
})
if (!isMapVarName) {
notify({ type: 'error', message: t('generation.errorMsg.fileStructNotMatch', { ns: 'share' }) })
return false
}
let payloadData = data.slice(1)
if (payloadData.length === 0) {
notify({ type: 'error', message: t('generation.errorMsg.atLeastOne', { ns: 'share' }) })
return false
}
// check middle empty line
const allEmptyLineIndexes = payloadData.filter(item => item.every(i => i === '')).map(item => payloadData.indexOf(item))
if (allEmptyLineIndexes.length > 0) {
let hasMiddleEmptyLine = false
let startIndex = allEmptyLineIndexes[0] - 1
allEmptyLineIndexes.forEach((index) => {
if (hasMiddleEmptyLine)
return
if (startIndex + 1 !== index) {
hasMiddleEmptyLine = true
return
}
startIndex++
})
if (hasMiddleEmptyLine) {
notify({ type: 'error', message: t('generation.errorMsg.emptyLine', { ns: 'share', rowIndex: startIndex + 2 }) })
return false
}
}
// check row format
payloadData = payloadData.filter(item => !item.every(i => i === ''))
// after remove empty rows in the end, checked again
if (payloadData.length === 0) {
notify({ type: 'error', message: t('generation.errorMsg.atLeastOne', { ns: 'share' }) })
return false
}
let errorRowIndex = 0
let requiredVarName = ''
let moreThanMaxLengthVarName = ''
let maxLength = 0
payloadData.forEach((item, index) => {
if (errorRowIndex !== 0)
return
promptConfig?.prompt_variables.forEach((varItem, varIndex) => {
if (errorRowIndex !== 0)
return
if (varItem.type === 'string' && varItem.max_length) {
if (item[varIndex].length > varItem.max_length) {
moreThanMaxLengthVarName = varItem.name
maxLength = varItem.max_length
errorRowIndex = index + 1
return
}
}
if (!varItem.required)
return
if (item[varIndex].trim() === '') {
requiredVarName = varItem.name
errorRowIndex = index + 1
}
})
})
if (errorRowIndex !== 0) {
if (requiredVarName)
notify({ type: 'error', message: t('generation.errorMsg.invalidLine', { ns: 'share', rowIndex: errorRowIndex + 1, varName: requiredVarName }) })
if (moreThanMaxLengthVarName)
notify({ type: 'error', message: t('generation.errorMsg.moreThanMaxLengthLine', { ns: 'share', rowIndex: errorRowIndex + 1, varName: moreThanMaxLengthVarName, maxLength }) })
return false
}
return true
}
const handleRunBatch = (data: string[][]) => {
if (!checkBatchInputs(data))
const handleRunBatch = useCallback((data: string[][]) => {
if (!startBatchRun(data))
return
if (!allTasksFinished) {
notify({ type: 'info', message: t('errorMessage.waitForBatchResponse', { ns: 'appDebug' }) })
return
}
const payloadData = data.filter(item => !item.every(i => i === '')).slice(1)
const varLen = promptConfig?.prompt_variables.length || 0
setIsCallBatchAPI(true)
const allTaskList: Task[] = payloadData.map((item, i) => {
const inputs: Record<string, any> = {}
if (varLen > 0) {
item.slice(0, varLen).forEach((input, index) => {
const varSchema = promptConfig?.prompt_variables[index]
inputs[varSchema?.key as string] = input
if (!input) {
if (varSchema?.type === 'string' || varSchema?.type === 'paragraph')
inputs[varSchema?.key as string] = ''
else
inputs[varSchema?.key as string] = undefined
}
})
}
return {
id: i + 1,
status: i < GROUP_SIZE ? TaskStatus.running : TaskStatus.pending,
params: {
inputs,
},
}
})
setAllTaskList(allTaskList)
setCurrGroupNum(0)
setRunControl(null)
setControlSend(Date.now())
// clear run once task status
setControlStopResponding(Date.now())
// eslint-disable-next-line ts/no-use-before-define
showResultPanel()
}
const handleCompleted = (completionRes: string, taskId?: number, isSuccess?: boolean) => {
const allTaskListLatest = getLatestTaskList()
const batchCompletionResLatest = getBatchCompletionRes()
const pendingTaskList = allTaskListLatest.filter(task => task.status === TaskStatus.pending)
const runTasksCount = 1 + allTaskListLatest.filter(task => [TaskStatus.completed, TaskStatus.failed].includes(task.status)).length
const needToAddNextGroupTask = (getCurrGroupNum() !== runTasksCount) && pendingTaskList.length > 0 && (runTasksCount % GROUP_SIZE === 0 || (allTaskListLatest.length - runTasksCount < GROUP_SIZE))
// avoid add many task at the same time
if (needToAddNextGroupTask)
setCurrGroupNum(runTasksCount)
}, [startBatchRun, showResultPanel])
const nextPendingTaskIds = needToAddNextGroupTask ? pendingTaskList.slice(0, GROUP_SIZE).map(item => item.id) : []
const newAllTaskList = allTaskListLatest.map((item) => {
if (item.id === taskId) {
return {
...item,
status: isSuccess ? TaskStatus.completed : TaskStatus.failed,
}
}
if (needToAddNextGroupTask && nextPendingTaskIds.includes(item.id)) {
return {
...item,
status: TaskStatus.running,
}
}
return item
})
setAllTaskList(newAllTaskList)
if (taskId) {
setBatchCompletionRes({
...batchCompletionResLatest,
[`${taskId}`]: completionRes,
})
}
}
const appData = useWebAppStore(s => s.appInfo)
const appParams = useWebAppStore(s => s.appParams)
const accessMode = useWebAppStore(s => s.webAppAccessMode)
useEffect(() => {
(async () => {
if (!appData || !appParams)
return
if (!isWorkflow)
fetchSavedMessage()
const { app_id: appId, site: siteInfo, custom_config } = appData
setAppId(appId)
setSiteInfo(siteInfo as SiteInfo)
setCustomConfig(custom_config)
await changeLanguage(siteInfo.default_language)
const { user_input_form, more_like_this, file_upload, text_to_speech }: any = appParams
setVisionConfig({
// legacy of image upload compatible
...file_upload,
transfer_methods: file_upload?.allowed_file_upload_methods || file_upload?.allowed_upload_methods,
// legacy of image upload compatible
image_file_size_limit: appParams?.system_parameters.image_file_size_limit,
fileUploadConfig: appParams?.system_parameters,
} as any)
const prompt_variables = userInputsFormToPromptVariables(user_input_form)
setPromptConfig({
prompt_template: '', // placeholder for future
prompt_variables,
} as PromptConfig)
setMoreLikeThisConfig(more_like_this)
setTextToSpeechConfig(text_to_speech)
})()
}, [appData, appParams, fetchSavedMessage, isWorkflow])
// Can Use metadata(https://beta.nextjs.org/docs/api-reference/metadata) to set title. But it only works in server side client.
useDocumentTitle(siteInfo?.title || t('generation.title', { ns: 'share' }))
useAppFavicon({
enable: !isInstalledApp,
icon_type: siteInfo?.icon_type,
@@ -409,15 +126,6 @@ const TextGeneration: FC<IMainProps> = ({
icon_url: siteInfo?.icon_url,
})
const [isShowResultPanel, { setTrue: doShowResultPanel, setFalse: hideResultPanel }] = useBoolean(false)
const showResultPanel = () => {
// fix: useClickAway hideResSidebar will close sidebar
setTimeout(() => {
doShowResultPanel()
}, 0)
}
const [resultExisted, setResultExisted] = useState(false)
const renderRes = (task?: Task) => (
<Res
key={task?.id}
@@ -425,7 +133,7 @@ const TextGeneration: FC<IMainProps> = ({
isCallBatchAPI={isCallBatchAPI}
isPC={isPC}
isMobile={!isPC}
appSourceType={isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp}
appSourceType={appSourceType}
appId={appId}
isError={task?.status === TaskStatus.failed}
promptConfig={promptConfig}
@@ -435,7 +143,7 @@ const TextGeneration: FC<IMainProps> = ({
controlRetry={task?.status === TaskStatus.failed ? controlRetry : 0}
controlStopResponding={controlStopResponding}
onShowRes={showResultPanel}
handleSaveMessage={handleSaveMessage}
handleSaveMessage={id => saveMutation.mutate(id)}
taskId={task?.id}
onCompleted={handleCompleted}
visionConfig={visionConfig}
@@ -448,69 +156,14 @@ const TextGeneration: FC<IMainProps> = ({
/>
)
const renderBatchRes = () => {
return (showTaskList.map(task => renderRes(task)))
}
const renderResWrap = (
<div
className={cn(
'relative flex h-full flex-col',
!isPC && 'h-[calc(100vh_-_36px)] rounded-t-2xl shadow-lg backdrop-blur-sm',
!isPC
? isShowResultPanel
? 'bg-background-default-burn'
: 'border-t-[0.5px] border-divider-regular bg-components-panel-bg'
: 'bg-chatbot-bg',
)}
>
{isCallBatchAPI && (
<div className={cn(
'flex shrink-0 items-center justify-between px-14 pb-2 pt-9',
!isPC && 'px-4 pb-1 pt-3',
)}
>
<div className="system-md-semibold-uppercase text-text-primary">{t('generation.executions', { ns: 'share', num: allTaskList.length })}</div>
{allSuccessTaskList.length > 0 && (
<ResDownload
isMobile={!isPC}
values={exportRes}
/>
)}
</div>
)}
<div className={cn(
'flex h-0 grow flex-col overflow-y-auto',
isPC && 'px-14 py-8',
isPC && isCallBatchAPI && 'pt-0',
!isPC && 'p-0 pb-2',
)}
>
{!isCallBatchAPI ? renderRes() : renderBatchRes()}
{!noPendingTask && (
<div className="mt-4">
<Loading type="area" />
</div>
)}
</div>
{isCallBatchAPI && allFailedTaskList.length > 0 && (
<div className="absolute bottom-6 left-1/2 z-10 flex -translate-x-1/2 items-center gap-2 rounded-xl border border-components-panel-border bg-components-panel-bg-blur p-3 shadow-lg backdrop-blur-sm">
<RiErrorWarningFill className="h-4 w-4 text-text-destructive" />
<div className="system-sm-medium text-text-secondary">{t('generation.batchFailed.info', { ns: 'share', num: allFailedTaskList.length })}</div>
<div className="h-3.5 w-px bg-divider-regular"></div>
<div onClick={handleRetryAllFailedTask} className="system-sm-semibold-uppercase cursor-pointer text-text-accent">{t('generation.batchFailed.retry', { ns: 'share' })}</div>
</div>
)}
</div>
)
if (!appId || !siteInfo || !promptConfig) {
if (!isReady) {
return (
<div className="flex h-screen items-center">
<Loading type="app" />
</div>
)
}
return (
<div className={cn(
'bg-background-default-burn',
@@ -519,54 +172,24 @@ const TextGeneration: FC<IMainProps> = ({
isInstalledApp ? 'h-full rounded-2xl shadow-md' : 'h-screen',
)}
>
{/* Left */}
{/* Left panel */}
<div className={cn(
'relative flex h-full shrink-0 flex-col',
isPC ? 'w-[600px] max-w-[50%]' : resultExisted ? 'h-[calc(100%_-_64px)]' : '',
isInstalledApp && 'rounded-l-2xl',
)}
>
{/* header */}
<div className={cn('shrink-0 space-y-4 border-b border-divider-subtle', isPC ? 'bg-components-panel-bg p-8 pb-0' : 'p-4 pb-0')}>
<div className="flex items-center gap-3">
<AppIcon
size={isPC ? 'large' : 'small'}
iconType={siteInfo.icon_type}
icon={siteInfo.icon}
background={siteInfo.icon_background || appDefaultIconBackground}
imageUrl={siteInfo.icon_url}
/>
<div className="system-md-semibold grow truncate text-text-secondary">{siteInfo.title}</div>
<MenuDropdown hideLogout={isInstalledApp || accessMode === AccessMode.PUBLIC} data={siteInfo} />
</div>
{siteInfo.description && (
<div className="system-xs-regular text-text-tertiary">{siteInfo.description}</div>
)}
<TabHeader
items={[
{ id: 'create', name: t('generation.tabs.create', { ns: 'share' }) },
{ id: 'batch', name: t('generation.tabs.batch', { ns: 'share' }) },
...(!isWorkflow
? [{
id: 'saved',
name: t('generation.tabs.saved', { ns: 'share' }),
isRight: true,
icon: <RiBookmark3Line className="h-4 w-4" />,
extra: savedMessages.length > 0
? (
<Badge className="ml-1">
{savedMessages.length}
</Badge>
)
: null,
}]
: []),
]}
value={currentTab}
onChange={setCurrentTab}
/>
</div>
{/* form */}
<HeaderSection
isPC={isPC}
isInstalledApp={isInstalledApp}
isWorkflow={isWorkflow}
siteInfo={siteInfo!}
accessMode={accessMode}
savedMessages={savedMessages}
currentTab={currentTab}
onTabChange={setCurrentTab}
/>
{/* Form content */}
<div className={cn(
'h-0 grow overflow-y-auto bg-components-panel-bg',
isPC ? 'px-8' : 'px-4',
@@ -575,20 +198,20 @@ const TextGeneration: FC<IMainProps> = ({
>
<div className={cn(currentTab === 'create' ? 'block' : 'hidden')}>
<RunOnce
siteInfo={siteInfo}
siteInfo={siteInfo!}
inputs={inputs}
inputsRef={inputsRef}
onInputsChange={setInputs}
promptConfig={promptConfig}
promptConfig={promptConfig!}
onSend={handleSend}
visionConfig={visionConfig}
onVisionFilesChange={setCompletionFiles}
runControl={runControl}
/>
</div>
<div className={cn(isInBatchTab ? 'block' : 'hidden')}>
<div className={cn(currentTab === 'batch' ? 'block' : 'hidden')}>
<RunBatch
vars={promptConfig.prompt_variables}
vars={promptConfig!.prompt_variables}
onSend={handleRunBatch}
isAllFinished={allTasksRun}
/>
@@ -598,31 +221,15 @@ const TextGeneration: FC<IMainProps> = ({
className={cn(isPC ? 'mt-6' : 'mt-4')}
isShowTextToSpeech={textToSpeechConfig?.enabled}
list={savedMessages}
onRemove={handleRemoveSavedMessage}
onRemove={id => removeMutation.mutate(id)}
onStartCreateContent={() => setCurrentTab('create')}
/>
)}
</div>
{/* powered by */}
{!customConfig?.remove_webapp_brand && (
<div className={cn(
'flex shrink-0 items-center gap-1.5 bg-components-panel-bg py-3',
isPC ? 'px-8' : 'px-4',
!isPC && resultExisted && 'rounded-b-2xl border-b-[0.5px] border-divider-regular',
)}
>
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
{
systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo
? <img src={systemFeatures.branding.workspace_logo} alt="logo" className="block h-5 w-auto" />
: customConfig?.replace_webapp_logo
? <img src={`${customConfig?.replace_webapp_logo}`} alt="logo" className="block h-5 w-auto" />
: <DifyLogo size="small" />
}
</div>
)}
<PoweredBy isPC={isPC} resultExisted={resultExisted} customConfig={customConfig} />
</div>
{/* Result */}
{/* Right panel - Results */}
<div className={cn(
isPC
? 'h-full w-0 grow'
@@ -640,17 +247,26 @@ const TextGeneration: FC<IMainProps> = ({
? 'flex items-center justify-center p-2 pt-6'
: 'absolute left-0 top-0 z-10 flex w-full items-center justify-center px-2 pb-[57px] pt-[3px]',
)}
onClick={() => {
if (isShowResultPanel)
hideResultPanel()
else
showResultPanel()
}}
onClick={() => isShowResultPanel ? hideResultPanel() : showResultPanel()}
>
<div className="h-1 w-8 cursor-grab rounded bg-divider-solid" />
</div>
)}
{renderResWrap}
<ResultPanel
isPC={isPC}
isShowResultPanel={isShowResultPanel}
isCallBatchAPI={isCallBatchAPI}
totalTasks={allTaskList.length}
successCount={allSuccessTaskList.length}
failedCount={allFailedTaskList.length}
noPendingTask={noPendingTask}
exportRes={exportRes}
onRetryFailed={handleRetryAllFailedTask}
>
{!isCallBatchAPI
? renderRes()
: showTaskList.map(task => renderRes(task))}
</ResultPanel>
</div>
</div>
)

View File

@@ -24,7 +24,7 @@ const Header: FC<IResultHeaderProps> = ({
}) => {
const { t } = useTranslation()
return (
<div className="flex w-full items-center justify-between ">
<div className="flex w-full items-center justify-between">
<div className="text-2xl font-normal leading-4 text-gray-800">{t('generation.resultTitle', { ns: 'share' })}</div>
<div className="flex items-center space-x-2">
<Button
@@ -50,7 +50,7 @@ const Header: FC<IResultHeaderProps> = ({
rating: null,
})
}}
className="flex h-7 w-7 cursor-pointer items-center justify-center rounded-md border border-primary-200 bg-primary-100 !text-primary-600 hover:border-primary-300 hover:bg-primary-200"
className="flex h-7 w-7 cursor-pointer items-center justify-center rounded-md border border-primary-200 bg-primary-100 !text-primary-600 hover:border-primary-300 hover:bg-primary-200"
>
<HandThumbUpIcon width={16} height={16} />
</div>
@@ -67,7 +67,7 @@ const Header: FC<IResultHeaderProps> = ({
rating: null,
})
}}
className="flex h-7 w-7 cursor-pointer items-center justify-center rounded-md border border-red-200 bg-red-100 !text-red-600 hover:border-red-300 hover:bg-red-200"
className="flex h-7 w-7 cursor-pointer items-center justify-center rounded-md border border-red-200 bg-red-100 !text-red-600 hover:border-red-300 hover:bg-red-200"
>
<HandThumbDownIcon width={16} height={16} />
</div>

View File

@@ -0,0 +1,904 @@
import type { UseTextGenerationProps } from './use-text-generation'
import { act, renderHook } from '@testing-library/react'
import Toast from '@/app/components/base/toast'
import {
AppSourceType,
sendCompletionMessage,
sendWorkflowMessage,
stopChatMessageResponding,
stopWorkflowMessage,
} from '@/service/share'
import { TransferMethod } from '@/types/app'
import { sleep } from '@/utils'
import { useTextGeneration } from './use-text-generation'
// Mock external services
vi.mock('@/service/share', async (importOriginal) => {
const actual = await importOriginal<Record<string, unknown>>()
return {
...actual,
sendCompletionMessage: vi.fn(),
sendWorkflowMessage: vi.fn(() => Promise.resolve()),
stopChatMessageResponding: vi.fn(() => Promise.resolve()),
stopWorkflowMessage: vi.fn(() => Promise.resolve()),
updateFeedback: vi.fn(() => Promise.resolve()),
}
})
vi.mock('@/app/components/base/toast', () => ({
default: { notify: vi.fn() },
}))
vi.mock('@/utils', () => ({
sleep: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/app/components/base/file-uploader/utils', () => ({
getProcessedFiles: vi.fn((files: unknown[]) => files),
getFilesInLogs: vi.fn(() => []),
}))
vi.mock('@/utils/model-config', () => ({
formatBooleanInputs: vi.fn((_vars: unknown, inputs: unknown) => inputs),
}))
// Extracted parameter types for typed mock implementations
type CompletionBody = Parameters<typeof sendCompletionMessage>[0]
type CompletionCbs = Parameters<typeof sendCompletionMessage>[1]
type WorkflowBody = Parameters<typeof sendWorkflowMessage>[0]
type WorkflowCbs = Parameters<typeof sendWorkflowMessage>[1]
// Factory for default hook props
function createProps(overrides: Partial<UseTextGenerationProps> = {}): UseTextGenerationProps {
return {
isWorkflow: false,
isCallBatchAPI: false,
isPC: true,
appSourceType: AppSourceType.webApp,
appId: 'app-1',
promptConfig: { prompt_template: '', prompt_variables: [] },
inputs: {},
onShowRes: vi.fn(),
onCompleted: vi.fn(),
visionConfig: { enabled: false } as UseTextGenerationProps['visionConfig'],
completionFiles: [],
onRunStart: vi.fn(),
...overrides,
}
}
describe('useTextGeneration', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Initial state
describe('initial state', () => {
it('should return correct default values', () => {
const { result } = renderHook(() => useTextGeneration(createProps()))
expect(result.current.isResponding).toBe(false)
expect(result.current.completionRes).toBe('')
expect(result.current.workflowProcessData).toBeUndefined()
expect(result.current.messageId).toBeNull()
expect(result.current.feedback).toEqual({ rating: null })
expect(result.current.isStopping).toBe(false)
expect(result.current.currentTaskId).toBeNull()
expect(result.current.controlClearMoreLikeThis).toBe(0)
})
it('should expose handler functions', () => {
const { result } = renderHook(() => useTextGeneration(createProps()))
expect(typeof result.current.handleSend).toBe('function')
expect(typeof result.current.handleStop).toBe('function')
expect(typeof result.current.handleFeedback).toBe('function')
})
})
// Feedback
describe('handleFeedback', () => {
it('should call updateFeedback API and update state', async () => {
const { updateFeedback } = await import('@/service/share')
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleFeedback({ rating: 'like' })
})
expect(updateFeedback).toHaveBeenCalledWith(
expect.objectContaining({ body: { rating: 'like', content: undefined } }),
AppSourceType.webApp,
'app-1',
)
expect(result.current.feedback).toEqual({ rating: 'like' })
})
})
// Stop
describe('handleStop', () => {
it('should do nothing when no currentTaskId', async () => {
const { stopChatMessageResponding } = await import('@/service/share')
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleStop()
})
expect(stopChatMessageResponding).not.toHaveBeenCalled()
})
it('should call stopWorkflowMessage for workflow mode', async () => {
const { stopWorkflowMessage, sendWorkflowMessage } = await import('@/service/share')
const props = createProps({ isWorkflow: true })
const { result } = renderHook(() => useTextGeneration(props))
// Trigger a send to set currentTaskId (mock will set it via callbacks)
// Instead, we test that handleStop guards against empty taskId
await act(async () => {
await result.current.handleStop()
})
// No task to stop
expect(stopWorkflowMessage).not.toHaveBeenCalled()
expect(sendWorkflowMessage).toBeDefined()
})
})
// Send - validation
describe('handleSend - validation', () => {
it('should show toast when called while responding', async () => {
const { sendCompletionMessage } = await import('@/service/share')
const { result } = renderHook(() => useTextGeneration(createProps({ controlSend: 1 })))
// First send sets isResponding true
// Second send should show warning
await act(async () => {
await result.current.handleSend()
})
expect(sendCompletionMessage).toHaveBeenCalled()
})
it('should validate required prompt variables', async () => {
const Toast = (await import('@/app/components/base/toast')).default
const props = createProps({
promptConfig: {
prompt_template: '',
prompt_variables: [
{ key: 'name', name: 'Name', type: 'string', required: true },
] as UseTextGenerationProps['promptConfig'] extends infer T ? T extends { prompt_variables: infer V } ? V : never : never,
},
inputs: {}, // missing required 'name'
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error' }),
)
})
it('should pass validation for batch API mode', async () => {
const { sendCompletionMessage } = await import('@/service/share')
const props = createProps({ isCallBatchAPI: true })
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
// Batch mode skips validation - should call send
expect(sendCompletionMessage).toHaveBeenCalled()
})
})
// Send - API calls
describe('handleSend - API', () => {
it('should call sendCompletionMessage for non-workflow mode', async () => {
const { sendCompletionMessage } = await import('@/service/share')
const props = createProps({ isWorkflow: false })
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(sendCompletionMessage).toHaveBeenCalledWith(
expect.objectContaining({ inputs: {} }),
expect.objectContaining({
onData: expect.any(Function),
onCompleted: expect.any(Function),
onError: expect.any(Function),
}),
AppSourceType.webApp,
'app-1',
)
})
it('should call sendWorkflowMessage for workflow mode', async () => {
const { sendWorkflowMessage } = await import('@/service/share')
const props = createProps({ isWorkflow: true })
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(sendWorkflowMessage).toHaveBeenCalledWith(
expect.objectContaining({ inputs: {} }),
expect.objectContaining({
onWorkflowStarted: expect.any(Function),
onNodeStarted: expect.any(Function),
onWorkflowFinished: expect.any(Function),
}),
AppSourceType.webApp,
'app-1',
)
})
it('should call onShowRes and onRunStart on mobile', async () => {
const onShowRes = vi.fn()
const onRunStart = vi.fn()
const props = createProps({ isPC: false, onShowRes, onRunStart })
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(onShowRes).toHaveBeenCalled()
expect(onRunStart).toHaveBeenCalled()
})
})
// Effects
describe('effects', () => {
it('should trigger send when controlSend changes', async () => {
const { sendCompletionMessage } = await import('@/service/share')
const { result, rerender } = renderHook(
(props: UseTextGenerationProps) => useTextGeneration(props),
{ initialProps: createProps({ controlSend: 0 }) },
)
// Change controlSend to trigger the effect
await act(async () => {
rerender(createProps({ controlSend: Date.now() }))
})
expect(sendCompletionMessage).toHaveBeenCalled()
expect(result.current.controlClearMoreLikeThis).toBeGreaterThan(0)
})
it('should trigger send when controlRetry changes', async () => {
const { sendCompletionMessage } = await import('@/service/share')
await act(async () => {
renderHook(() => useTextGeneration(createProps({ controlRetry: Date.now() })))
})
expect(sendCompletionMessage).toHaveBeenCalled()
})
it('should sync run control with parent via onRunControlChange', () => {
const onRunControlChange = vi.fn()
renderHook(() => useTextGeneration(createProps({ onRunControlChange })))
// Initially not responding, so should pass null
expect(onRunControlChange).toHaveBeenCalledWith(null)
})
})
// handleStop with active task
describe('handleStop - with active task', () => {
it('should call stopWorkflowMessage for workflow', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendWorkflowMessage).mockImplementationOnce(
(async (_data: WorkflowBody, callbacks: WorkflowCbs) => {
callbacks.onWorkflowStarted({ workflow_run_id: 'run-1', task_id: 'task-1' } as never)
}) as unknown as typeof sendWorkflowMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps({ isWorkflow: true })))
await act(async () => {
await result.current.handleSend()
})
await act(async () => {
await result.current.handleStop()
})
expect(stopWorkflowMessage).toHaveBeenCalledWith('app-1', 'task-1', AppSourceType.webApp, 'app-1')
})
it('should call stopChatMessageResponding for completion', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { onData, getAbortController }: CompletionCbs) => {
getAbortController?.(new AbortController())
onData('chunk', true, { messageId: 'msg-1', taskId: 'task-1' })
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
await act(async () => {
await result.current.handleStop()
})
expect(stopChatMessageResponding).toHaveBeenCalledWith('app-1', 'task-1', AppSourceType.webApp, 'app-1')
})
it('should handle stop API errors gracefully', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendWorkflowMessage).mockImplementationOnce(
(async (_data: WorkflowBody, callbacks: WorkflowCbs) => {
callbacks.onWorkflowStarted({ workflow_run_id: 'run-1', task_id: 'task-1' } as never)
}) as unknown as typeof sendWorkflowMessage,
)
vi.mocked(stopWorkflowMessage).mockRejectedValueOnce(new Error('Network error'))
const { result } = renderHook(() => useTextGeneration(createProps({ isWorkflow: true })))
await act(async () => {
await result.current.handleSend()
})
await act(async () => {
await result.current.handleStop()
})
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error', message: 'Network error' }),
)
expect(result.current.isStopping).toBe(false)
})
})
// File processing in handleSend
describe('handleSend - file processing', () => {
it('should process file-type and file-list prompt variables', async () => {
const fileValue = { name: 'doc.pdf', size: 100 }
const fileListValue = [{ name: 'a.pdf' }, { name: 'b.pdf' }]
const props = createProps({
promptConfig: {
prompt_template: '',
prompt_variables: [
{ key: 'doc', name: 'Document', type: 'file', required: false },
{ key: 'docs', name: 'Documents', type: 'file-list', required: false },
] as UseTextGenerationProps['promptConfig'] extends infer T ? T extends { prompt_variables: infer V } ? V : never : never,
},
inputs: { doc: fileValue, docs: fileListValue } as unknown as Record<string, UseTextGenerationProps['inputs'][string]>,
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(sendCompletionMessage).toHaveBeenCalledWith(
expect.objectContaining({ inputs: expect.objectContaining({ doc: expect.anything(), docs: expect.anything() }) }),
expect.anything(),
expect.anything(),
expect.anything(),
)
})
it('should include vision files when vision is enabled', async () => {
const props = createProps({
visionConfig: { enabled: true, number_limits: 2, detail: 'low', transfer_methods: [] } as UseTextGenerationProps['visionConfig'],
completionFiles: [
{ transfer_method: TransferMethod.local_file, url: 'http://local', upload_file_id: 'f1' },
{ transfer_method: TransferMethod.remote_url, url: 'http://remote' },
] as UseTextGenerationProps['completionFiles'],
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(sendCompletionMessage).toHaveBeenCalledWith(
expect.objectContaining({
files: expect.arrayContaining([
expect.objectContaining({ transfer_method: TransferMethod.local_file, url: '' }),
expect.objectContaining({ transfer_method: TransferMethod.remote_url, url: 'http://remote' }),
]),
}),
expect.anything(),
expect.anything(),
expect.anything(),
)
})
})
// Validation edge cases
describe('handleSend - validation edge cases', () => {
it('should block when files are uploading and no prompt variables', async () => {
const props = createProps({
promptConfig: { prompt_template: '', prompt_variables: [] },
completionFiles: [
{ transfer_method: TransferMethod.local_file, url: '' },
] as UseTextGenerationProps['completionFiles'],
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'info' }),
)
expect(sendCompletionMessage).not.toHaveBeenCalled()
})
it('should skip boolean/checkbox vars in required check', async () => {
const props = createProps({
promptConfig: {
prompt_template: '',
prompt_variables: [
{ key: 'flag', name: 'Flag', type: 'boolean', required: true },
{ key: 'check', name: 'Check', type: 'checkbox', required: true },
] as UseTextGenerationProps['promptConfig'] extends infer T ? T extends { prompt_variables: infer V } ? V : never : never,
},
inputs: {},
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
// Should pass validation - boolean/checkbox are skipped
expect(sendCompletionMessage).toHaveBeenCalled()
})
it('should stop checking after first empty required var', async () => {
const props = createProps({
promptConfig: {
prompt_template: '',
prompt_variables: [
{ key: 'first', name: 'First', type: 'string', required: true },
{ key: 'second', name: 'Second', type: 'string', required: true },
] as UseTextGenerationProps['promptConfig'] extends infer T ? T extends { prompt_variables: infer V } ? V : never : never,
},
inputs: { second: 'value' },
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
// Error should mention 'First', not 'Second'
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error' }),
)
})
it('should block when files uploading after vars pass', async () => {
const props = createProps({
promptConfig: {
prompt_template: '',
prompt_variables: [
{ key: 'name', name: 'Name', type: 'string', required: true },
] as UseTextGenerationProps['promptConfig'] extends infer T ? T extends { prompt_variables: infer V } ? V : never : never,
},
inputs: { name: 'Alice' },
completionFiles: [
{ transfer_method: TransferMethod.local_file, url: '' },
] as UseTextGenerationProps['completionFiles'],
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'info' }),
)
expect(sendCompletionMessage).not.toHaveBeenCalled()
})
})
// sendCompletionMessage callbacks
describe('sendCompletionMessage callbacks', () => {
it('should accumulate text and track task/message via onData', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { onData }: CompletionCbs) => {
onData('Hello ', true, { messageId: 'msg-1', taskId: 'task-1' })
onData('World', false, { messageId: 'msg-1', taskId: 'task-1' })
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
expect(result.current.completionRes).toBe('Hello World')
expect(result.current.currentTaskId).toBe('task-1')
})
it('should finalize state via onCompleted', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
const onCompleted = vi.fn()
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, callbacks: CompletionCbs) => {
callbacks.onData('result', true, { messageId: 'msg-1', taskId: 'task-1' })
callbacks.onCompleted()
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps({ onCompleted })))
await act(async () => {
await result.current.handleSend()
})
expect(result.current.isResponding).toBe(false)
expect(result.current.messageId).toBe('msg-1')
expect(onCompleted).toHaveBeenCalledWith('result', undefined, true)
})
it('should replace text via onMessageReplace', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { onData, onMessageReplace }: CompletionCbs) => {
onData('old text', true, { messageId: 'msg-1', taskId: 'task-1' })
onMessageReplace!({ answer: 'replaced text' } as never)
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
expect(result.current.completionRes).toBe('replaced text')
})
it('should handle error via onError', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
const onCompleted = vi.fn()
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { onError }: CompletionCbs) => {
onError('test error')
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps({ onCompleted })))
await act(async () => {
await result.current.handleSend()
})
expect(result.current.isResponding).toBe(false)
expect(onCompleted).toHaveBeenCalledWith('', undefined, false)
})
it('should store abort controller via getAbortController', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
const abortController = new AbortController()
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { getAbortController }: CompletionCbs) => {
getAbortController?.(abortController)
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
// Verify abort controller is stored by triggering stop
expect(result.current.isResponding).toBe(true)
})
it('should show timeout warning when onCompleted fires after timeout', async () => {
// Default sleep mock resolves immediately, so timeout fires
let capturedCallbacks: CompletionCbs | null = null
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, callbacks: CompletionCbs) => {
capturedCallbacks = callbacks
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
// Timeout has fired (sleep resolved immediately, isEndRef still false)
await act(async () => {
capturedCallbacks!.onCompleted()
})
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'warning' }),
)
})
it('should show timeout warning when onError fires after timeout', async () => {
let capturedCallbacks: CompletionCbs | null = null
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, callbacks: CompletionCbs) => {
capturedCallbacks = callbacks
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
await act(async () => {
capturedCallbacks!.onError('test error')
})
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'warning' }),
)
})
})
// sendWorkflowMessage error handling
describe('sendWorkflowMessage error', () => {
it('should handle workflow API rejection', async () => {
vi.mocked(sendWorkflowMessage).mockRejectedValueOnce(new Error('API error'))
const { result } = renderHook(() => useTextGeneration(createProps({ isWorkflow: true })))
await act(async () => {
await result.current.handleSend()
// Wait for the catch handler to process
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(result.current.isResponding).toBe(false)
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error', message: 'API error' }),
)
})
})
// controlStopResponding effect
describe('effects - controlStopResponding', () => {
it('should abort and reset state when controlStopResponding changes', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { onData, getAbortController }: CompletionCbs) => {
getAbortController?.(new AbortController())
onData('chunk', true, { messageId: 'msg-1', taskId: 'task-1' })
}) as unknown as typeof sendCompletionMessage,
)
const { result, rerender } = renderHook(
(props: UseTextGenerationProps) => useTextGeneration(props),
{ initialProps: createProps({ controlStopResponding: 0 }) },
)
await act(async () => {
await result.current.handleSend()
})
expect(result.current.isResponding).toBe(true)
await act(async () => {
rerender(createProps({ controlStopResponding: Date.now() }))
})
expect(result.current.isResponding).toBe(false)
})
})
// onRunControlChange with active task
describe('effects - onRunControlChange with active task', () => {
it('should provide control object when responding with active task', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendWorkflowMessage).mockImplementationOnce(
(async (_data: WorkflowBody, callbacks: WorkflowCbs) => {
callbacks.onWorkflowStarted({ workflow_run_id: 'run-1', task_id: 'task-1' } as never)
}) as unknown as typeof sendWorkflowMessage,
)
const onRunControlChange = vi.fn()
const { result } = renderHook(() =>
useTextGeneration(createProps({ isWorkflow: true, onRunControlChange })),
)
await act(async () => {
await result.current.handleSend()
})
expect(onRunControlChange).toHaveBeenCalledWith(
expect.objectContaining({ onStop: expect.any(Function), isStopping: false }),
)
})
})
// Branch coverage: handleStop when already stopping
describe('handleStop - branch coverage', () => {
it('should do nothing when already stopping', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendWorkflowMessage).mockImplementationOnce(
(async (_data: WorkflowBody, callbacks: WorkflowCbs) => {
callbacks.onWorkflowStarted({ workflow_run_id: 'run-1', task_id: 'task-1' } as never)
}) as unknown as typeof sendWorkflowMessage,
)
// Make stopWorkflowMessage hang to keep isStopping=true
vi.mocked(stopWorkflowMessage).mockReturnValueOnce(new Promise(() => {}))
const { result } = renderHook(() => useTextGeneration(createProps({ isWorkflow: true })))
await act(async () => {
await result.current.handleSend()
})
// First stop sets isStopping=true
act(() => {
result.current.handleStop()
})
expect(result.current.isStopping).toBe(true)
// Second stop should be a no-op
await act(async () => {
await result.current.handleStop()
})
expect(stopWorkflowMessage).toHaveBeenCalledTimes(1)
})
})
// Branch coverage: onData with falsy/empty taskId
describe('sendCompletionMessage callbacks - branch coverage', () => {
it('should not set taskId when taskId is empty', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { onData }: CompletionCbs) => {
onData('chunk', true, { messageId: 'msg-1', taskId: '' })
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
expect(result.current.currentTaskId).toBeNull()
})
it('should not override taskId when already set', async () => {
vi.mocked(sleep).mockReturnValueOnce(new Promise(() => {}))
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, { onData }: CompletionCbs) => {
onData('a', true, { messageId: 'msg-1', taskId: 'first-task' })
onData('b', false, { messageId: 'msg-1', taskId: 'second-task' })
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps()))
await act(async () => {
await result.current.handleSend()
})
// Should keep 'first-task', not override with 'second-task'
expect(result.current.currentTaskId).toBe('first-task')
})
})
// Branch coverage: promptConfig null
describe('handleSend - promptConfig null', () => {
it('should handle null promptConfig gracefully', async () => {
const props = createProps({ promptConfig: null })
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(sendCompletionMessage).toHaveBeenCalled()
})
})
// Branch coverage: onCompleted before timeout (isEndRef=true skips timeout)
describe('sendCompletionMessage - timeout skip branch', () => {
it('should skip timeout when onCompleted fires before timeout resolves', async () => {
// Use default sleep mock (resolves immediately) - NOT overriding to never-resolve
const onCompleted = vi.fn()
vi.mocked(sendCompletionMessage).mockImplementationOnce(
(async (_data: CompletionBody, callbacks: CompletionCbs) => {
callbacks.onData('res', true, { messageId: 'msg-1', taskId: 'task-1' })
callbacks.onCompleted()
// isEndRef.current = true now, so timeout IIFE will skip
}) as unknown as typeof sendCompletionMessage,
)
const { result } = renderHook(() => useTextGeneration(createProps({ onCompleted })))
await act(async () => {
await result.current.handleSend()
})
expect(result.current.isResponding).toBe(false)
// onCompleted should be called once (from callback), not twice (timeout skipped)
expect(onCompleted).toHaveBeenCalledTimes(1)
expect(onCompleted).toHaveBeenCalledWith('res', undefined, true)
})
})
// Branch coverage: workflow error with non-Error object
describe('sendWorkflowMessage - non-Error rejection', () => {
it('should handle non-Error rejection via String()', async () => {
vi.mocked(sendWorkflowMessage).mockRejectedValueOnce('string error')
const { result } = renderHook(() => useTextGeneration(createProps({ isWorkflow: true })))
await act(async () => {
await result.current.handleSend()
await new Promise(resolve => setTimeout(resolve, 0))
})
expect(result.current.isResponding).toBe(false)
expect(Toast.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error', message: 'string error' }),
)
})
})
// Branch coverage: hasUploadingFiles false branch
describe('handleSend - file upload branch', () => {
it('should proceed when files have upload_file_id (not uploading)', async () => {
const props = createProps({
completionFiles: [
{ transfer_method: TransferMethod.local_file, url: 'http://file', upload_file_id: 'f1' },
] as UseTextGenerationProps['completionFiles'],
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(sendCompletionMessage).toHaveBeenCalled()
})
it('should proceed when files use remote_url transfer method', async () => {
const props = createProps({
completionFiles: [
{ transfer_method: TransferMethod.remote_url, url: 'http://remote' },
] as UseTextGenerationProps['completionFiles'],
})
const { result } = renderHook(() => useTextGeneration(props))
await act(async () => {
await result.current.handleSend()
})
expect(sendCompletionMessage).toHaveBeenCalled()
})
})
})

View File

@@ -0,0 +1,357 @@
import type { InputValueTypes } from '../../types'
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { PromptConfig } from '@/models/debug'
import type { AppSourceType } from '@/service/share'
import type { VisionFile, VisionSettings } from '@/types/app'
import { useBoolean } from 'ahooks'
import { t } from 'i18next'
import { useCallback, useEffect, useRef, useState } from 'react'
import {
getProcessedFiles,
} from '@/app/components/base/file-uploader/utils'
import Toast from '@/app/components/base/toast'
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
import {
sendCompletionMessage,
sendWorkflowMessage,
stopChatMessageResponding,
stopWorkflowMessage,
updateFeedback,
} from '@/service/share'
import { TransferMethod } from '@/types/app'
import { sleep } from '@/utils'
import { formatBooleanInputs } from '@/utils/model-config'
import { createWorkflowCallbacks } from './workflow-callbacks'
export type UseTextGenerationProps = {
isWorkflow: boolean
isCallBatchAPI: boolean
isPC: boolean
appSourceType: AppSourceType
appId?: string
promptConfig: PromptConfig | null
inputs: Record<string, InputValueTypes>
controlSend?: number
controlRetry?: number
controlStopResponding?: number
onShowRes: () => void
taskId?: number
onCompleted: (completionRes: string, taskId?: number, success?: boolean) => void
visionConfig: VisionSettings
completionFiles: VisionFile[]
onRunStart: () => void
onRunControlChange?: (control: { onStop: () => Promise<void> | void, isStopping: boolean } | null) => void
}
function hasUploadingFiles(files: VisionFile[]): boolean {
return files.some(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)
}
function processFileInputs(
processedInputs: Record<string, string | number | boolean | object>,
promptVariables: PromptConfig['prompt_variables'],
) {
promptVariables.forEach((variable) => {
const value = processedInputs[variable.key]
if (variable.type === 'file' && value && typeof value === 'object' && !Array.isArray(value))
processedInputs[variable.key] = getProcessedFiles([value as FileEntity])[0]
else if (variable.type === 'file-list' && Array.isArray(value) && value.length > 0)
processedInputs[variable.key] = getProcessedFiles(value as FileEntity[])
})
}
function prepareVisionFiles(files: VisionFile[]): VisionFile[] {
return files.map(item =>
item.transfer_method === TransferMethod.local_file ? { ...item, url: '' } : item,
)
}
export function useTextGeneration(props: UseTextGenerationProps) {
const {
isWorkflow,
isCallBatchAPI,
isPC,
appSourceType,
appId,
promptConfig,
inputs,
controlSend,
controlRetry,
controlStopResponding,
onShowRes,
taskId,
onCompleted,
visionConfig,
completionFiles,
onRunStart,
onRunControlChange,
} = props
const { notify } = Toast
const [isResponding, { setTrue: setRespondingTrue, setFalse: setRespondingFalse }] = useBoolean(false)
const [completionRes, doSetCompletionRes] = useState('')
const completionResRef = useRef('')
const setCompletionRes = (res: string) => {
completionResRef.current = res
doSetCompletionRes(res)
}
const getCompletionRes = () => completionResRef.current
const [workflowProcessData, doSetWorkflowProcessData] = useState<WorkflowProcess>()
const workflowProcessDataRef = useRef<WorkflowProcess | undefined>(undefined)
const setWorkflowProcessData = (data: WorkflowProcess) => {
workflowProcessDataRef.current = data
doSetWorkflowProcessData(data)
}
const getWorkflowProcessData = () => workflowProcessDataRef.current
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null)
const [isStopping, setIsStopping] = useState(false)
const abortControllerRef = useRef<AbortController | null>(null)
const isEndRef = useRef(false)
const isTimeoutRef = useRef(false)
const tempMessageIdRef = useRef('')
const resetRunState = useCallback(() => {
setCurrentTaskId(null) // eslint-disable-line react-hooks-extra/no-direct-set-state-in-use-effect
setIsStopping(false) // eslint-disable-line react-hooks-extra/no-direct-set-state-in-use-effect
abortControllerRef.current = null
onRunControlChange?.(null)
}, [onRunControlChange])
const [messageId, setMessageId] = useState<string | null>(null)
const [feedback, setFeedback] = useState<FeedbackType>({ rating: null })
const [controlClearMoreLikeThis, setControlClearMoreLikeThis] = useState(0)
const handleFeedback = async (fb: FeedbackType) => {
await updateFeedback(
{ url: `/messages/${messageId}/feedbacks`, body: { rating: fb.rating, content: fb.content } },
appSourceType,
appId,
)
setFeedback(fb)
}
const handleStop = useCallback(async () => {
if (!currentTaskId || isStopping)
return
setIsStopping(true)
try {
if (isWorkflow)
await stopWorkflowMessage(appId!, currentTaskId, appSourceType, appId || '')
else
await stopChatMessageResponding(appId!, currentTaskId, appSourceType, appId || '')
abortControllerRef.current?.abort()
}
catch (error) {
notify({ type: 'error', message: error instanceof Error ? error.message : String(error) })
}
finally {
setIsStopping(false)
}
}, [appId, currentTaskId, appSourceType, isStopping, isWorkflow, notify])
const checkCanSend = (): boolean => {
if (isCallBatchAPI)
return true
const promptVariables = promptConfig?.prompt_variables
if (!promptVariables?.length) {
if (hasUploadingFiles(completionFiles)) {
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
return false
}
return true
}
let hasEmptyInput = ''
const requiredVars = promptVariables?.filter(({ key, name, required, type }) => {
if (type === 'boolean' || type === 'checkbox')
return false
return (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null)
}) || []
requiredVars.forEach(({ key, name }) => {
if (hasEmptyInput)
return
if (!inputs[key])
hasEmptyInput = name
})
if (hasEmptyInput) {
notify({ type: 'error', message: t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: hasEmptyInput }) })
return false
}
if (hasUploadingFiles(completionFiles)) {
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
return false
}
return !hasEmptyInput
}
const handleSend = async () => {
if (isResponding) {
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
return
}
if (!checkCanSend())
return
const definedInputs = Object.fromEntries(
Object.entries(inputs).filter(([, v]) => v !== undefined),
) as Record<string, string | number | boolean | object>
const processedInputs = { ...formatBooleanInputs(promptConfig?.prompt_variables, definedInputs) }
processFileInputs(processedInputs, promptConfig?.prompt_variables ?? [])
const data: { inputs: Record<string, string | number | boolean | object>, files?: VisionFile[] } = { inputs: processedInputs }
if (visionConfig.enabled && completionFiles?.length > 0)
data.files = prepareVisionFiles(completionFiles)
setMessageId(null)
setFeedback({ rating: null })
setCompletionRes('')
resetRunState()
isEndRef.current = false
isTimeoutRef.current = false
tempMessageIdRef.current = ''
if (!isPC) {
onShowRes()
onRunStart()
}
setRespondingTrue()
;(async () => {
await sleep(TEXT_GENERATION_TIMEOUT_MS)
if (!isEndRef.current) {
setRespondingFalse()
onCompleted(getCompletionRes(), taskId, false)
resetRunState()
isTimeoutRef.current = true
}
})()
if (isWorkflow) {
const callbacks = createWorkflowCallbacks({
getProcessData: getWorkflowProcessData,
setProcessData: setWorkflowProcessData,
setCurrentTaskId,
setIsStopping,
getCompletionRes,
setCompletionRes,
setRespondingFalse,
resetRunState,
setMessageId,
isTimeoutRef,
isEndRef,
tempMessageIdRef,
taskId,
onCompleted,
notify,
t,
requestData: data,
})
sendWorkflowMessage(data, callbacks, appSourceType, appId).catch((error) => {
setRespondingFalse()
resetRunState()
notify({ type: 'error', message: error instanceof Error ? error.message : String(error) })
})
}
else {
let res: string[] = []
sendCompletionMessage(data, {
onData: (chunk: string, _isFirstMessage: boolean, { messageId: msgId, taskId: tId }) => {
tempMessageIdRef.current = msgId
if (tId && typeof tId === 'string' && tId.trim() !== '')
setCurrentTaskId(prev => prev ?? tId)
res.push(chunk)
setCompletionRes(res.join(''))
},
onCompleted: () => {
if (isTimeoutRef.current) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
setRespondingFalse()
resetRunState()
setMessageId(tempMessageIdRef.current)
onCompleted(getCompletionRes(), taskId, true)
isEndRef.current = true
},
onMessageReplace: (messageReplace) => {
res = [messageReplace.answer]
setCompletionRes(res.join(''))
},
onError() {
if (isTimeoutRef.current) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEndRef.current = true
},
getAbortController: (abortController) => {
abortControllerRef.current = abortController
},
}, appSourceType, appId)
}
}
useEffect(() => {
const abortCurrentRequest = () => {
abortControllerRef.current?.abort()
}
if (controlStopResponding) {
abortCurrentRequest()
setRespondingFalse()
resetRunState()
}
return abortCurrentRequest
}, [controlStopResponding, resetRunState, setRespondingFalse])
useEffect(() => {
if (!onRunControlChange)
return
if (isResponding && currentTaskId)
onRunControlChange({ onStop: handleStop, isStopping })
else
onRunControlChange(null)
}, [currentTaskId, handleStop, isResponding, isStopping, onRunControlChange])
useEffect(() => {
if (controlSend) {
handleSend()
setControlClearMoreLikeThis(Date.now()) // eslint-disable-line react-hooks-extra/no-direct-set-state-in-use-effect
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [controlSend])
useEffect(() => {
if (controlRetry)
handleSend()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [controlRetry])
return {
isResponding,
completionRes,
workflowProcessData,
messageId,
feedback,
isStopping,
currentTaskId,
controlClearMoreLikeThis,
handleSend,
handleStop,
handleFeedback,
}
}

View File

@@ -0,0 +1,597 @@
import type { WorkflowCallbackDeps } from './workflow-callbacks'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { NodeTracing } from '@/types/workflow'
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
import { createWorkflowCallbacks } from './workflow-callbacks'
vi.mock('@/app/components/base/file-uploader/utils', () => ({
getFilesInLogs: vi.fn(() => [{ name: 'file.png' }]),
}))
// Factory for a minimal NodeTracing-like object
const createTrace = (overrides: Partial<NodeTracing> = {}): NodeTracing => ({
id: 'trace-1',
index: 0,
predecessor_node_id: '',
node_id: 'node-1',
node_type: 'start',
title: 'Node',
status: NodeRunningStatus.Running,
...overrides,
} as NodeTracing)
// Factory for a base WorkflowProcess
const createProcess = (overrides: Partial<WorkflowProcess> = {}): WorkflowProcess => ({
status: WorkflowRunningStatus.Running,
tracing: [],
expand: false,
resultText: '',
...overrides,
})
// Factory for mock dependencies
function createMockDeps(overrides: Partial<WorkflowCallbackDeps> = {}): WorkflowCallbackDeps {
const process = createProcess()
return {
getProcessData: vi.fn(() => process),
setProcessData: vi.fn(),
setCurrentTaskId: vi.fn(),
setIsStopping: vi.fn(),
getCompletionRes: vi.fn(() => ''),
setCompletionRes: vi.fn(),
setRespondingFalse: vi.fn(),
resetRunState: vi.fn(),
setMessageId: vi.fn(),
isTimeoutRef: { current: false },
isEndRef: { current: false },
tempMessageIdRef: { current: '' },
onCompleted: vi.fn(),
notify: vi.fn(),
t: vi.fn((key: string) => key) as unknown as WorkflowCallbackDeps['t'],
requestData: { inputs: {} },
...overrides,
}
}
describe('createWorkflowCallbacks', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Workflow lifecycle start
describe('onWorkflowStarted', () => {
it('should initialize process data and set task id', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowStarted({ workflow_run_id: 'run-1', task_id: 'task-1' } as never)
expect(deps.tempMessageIdRef.current).toBe('run-1')
expect(deps.setCurrentTaskId).toHaveBeenCalledWith('task-1')
expect(deps.setIsStopping).toHaveBeenCalledWith(false)
expect(deps.setProcessData).toHaveBeenCalledWith(
expect.objectContaining({ status: WorkflowRunningStatus.Running, tracing: [] }),
)
})
it('should default task_id to null when not provided', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowStarted({ workflow_run_id: 'run-2' } as never)
expect(deps.setCurrentTaskId).toHaveBeenCalledWith(null)
})
})
// Shared group handlers (iteration & loop use the same logic)
describe('group handlers (iteration/loop)', () => {
it('onIterationStart should push a running trace', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
const trace = createTrace({ node_id: 'iter-node' })
cb.onIterationStart({ data: trace } as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.expand).toBe(true)
expect(produced.tracing).toHaveLength(1)
expect(produced.tracing[0].node_id).toBe('iter-node')
expect(produced.tracing[0].status).toBe(NodeRunningStatus.Running)
})
it('onLoopStart should behave identically to onIterationStart', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onLoopStart({ data: createTrace({ node_id: 'loop-node' }) } as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].node_id).toBe('loop-node')
})
it('onIterationFinish should replace trace entry', () => {
const existing = createTrace({ node_id: 'n1', execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'] })
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [existing] })),
})
const cb = createWorkflowCallbacks(deps)
const updated = createTrace({ node_id: 'n1', execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'], error: 'fail' } as NodeTracing)
cb.onIterationFinish({ data: updated } as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].expand).toBe(true) // error -> expand
})
})
// Node lifecycle
describe('onNodeStarted', () => {
it('should add a running trace for top-level nodes', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onNodeStarted({ data: createTrace({ node_id: 'top-node' }) } as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing).toHaveLength(1)
})
it('should skip nodes inside an iteration', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onNodeStarted({ data: createTrace({ iteration_id: 'iter-1' }) } as never)
expect(deps.setProcessData).not.toHaveBeenCalled()
})
it('should skip nodes inside a loop', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onNodeStarted({ data: createTrace({ loop_id: 'loop-1' }) } as never)
expect(deps.setProcessData).not.toHaveBeenCalled()
})
})
describe('onNodeFinished', () => {
it('should update existing trace entry', () => {
const trace = createTrace({ node_id: 'n1', execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'] })
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [trace] })),
})
const cb = createWorkflowCallbacks(deps)
const finished = createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
status: NodeRunningStatus.Succeeded as NodeTracing['status'],
})
cb.onNodeFinished({ data: finished } as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].status).toBe(NodeRunningStatus.Succeeded)
})
it('should skip nodes inside iteration or loop', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onNodeFinished({ data: createTrace({ iteration_id: 'i1' }) } as never)
cb.onNodeFinished({ data: createTrace({ loop_id: 'l1' }) } as never)
expect(deps.setProcessData).not.toHaveBeenCalled()
})
})
// Workflow completion
describe('onWorkflowFinished', () => {
it('should handle success with outputs', () => {
const deps = createMockDeps({ taskId: 1 })
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: 'succeeded', outputs: { result: 'hello' } },
} as never)
expect(deps.setCompletionRes).toHaveBeenCalledWith({ result: 'hello' })
expect(deps.setRespondingFalse).toHaveBeenCalled()
expect(deps.resetRunState).toHaveBeenCalled()
expect(deps.onCompleted).toHaveBeenCalledWith('', 1, true)
expect(deps.isEndRef.current).toBe(true)
})
it('should handle success with single string output and set resultText', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: 'succeeded', outputs: { text: 'response' } },
} as never)
// setProcessData called multiple times: succeeded status, then resultText
expect(deps.setProcessData).toHaveBeenCalledTimes(2)
})
it('should handle success without outputs', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: 'succeeded', outputs: null },
} as never)
expect(deps.setCompletionRes).toHaveBeenCalledWith('')
})
it('should handle stopped status', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: WorkflowRunningStatus.Stopped },
} as never)
expect(deps.onCompleted).toHaveBeenCalledWith('', undefined, false)
expect(deps.isEndRef.current).toBe(true)
})
it('should handle error status', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: 'failed', error: 'Something broke' },
} as never)
expect(deps.notify).toHaveBeenCalledWith({ type: 'error', message: 'Something broke' })
expect(deps.onCompleted).toHaveBeenCalledWith('', undefined, false)
})
it('should skip processing when timeout has already occurred', () => {
const deps = createMockDeps()
deps.isTimeoutRef.current = true
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: 'succeeded', outputs: { text: 'late' } },
} as never)
expect(deps.notify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'warning' }),
)
expect(deps.onCompleted).not.toHaveBeenCalled()
})
})
// Streaming text handlers
describe('text handlers', () => {
it('onTextChunk should append text to resultText', () => {
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ resultText: 'hello' })),
})
const cb = createWorkflowCallbacks(deps)
cb.onTextChunk({ data: { text: ' world' } } as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.resultText).toBe('hello world')
})
it('onTextReplace should replace resultText entirely', () => {
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ resultText: 'old' })),
})
const cb = createWorkflowCallbacks(deps)
cb.onTextReplace({ data: { text: 'new' } } as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.resultText).toBe('new')
})
})
// handleGroupNext with valid node_id (covers findTrace)
describe('handleGroupNext', () => {
it('should push empty details to matching group when node_id exists', () => {
const existingTrace = createTrace({
node_id: 'group-node',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
details: [[]],
} as Partial<NodeTracing>)
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [existingTrace] })),
requestData: { inputs: {}, node_id: 'group-node', execution_metadata: { parallel_id: 'p1' } },
})
const cb = createWorkflowCallbacks(deps)
cb.onIterationNext()
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].details).toHaveLength(2)
expect(produced.expand).toBe(true)
})
it('should handle no matching group gracefully', () => {
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [] })),
requestData: { inputs: {}, node_id: 'nonexistent' },
})
const cb = createWorkflowCallbacks(deps)
// Should not throw even when no matching trace is found
cb.onLoopNext()
expect(deps.setProcessData).toHaveBeenCalled()
})
})
// markNodesStopped edge cases
describe('markNodesStopped', () => {
it('should handle undefined tracing gracefully', () => {
const deps = createMockDeps({
getProcessData: vi.fn(() => ({
status: WorkflowRunningStatus.Running,
expand: false,
resultText: '',
} as unknown as WorkflowProcess)),
})
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: WorkflowRunningStatus.Stopped },
} as never)
expect(deps.setProcessData).toHaveBeenCalled()
expect(deps.onCompleted).toHaveBeenCalledWith('', undefined, false)
})
it('should recursively mark running/waiting nodes and nested structures as stopped', () => {
const nestedTrace = createTrace({ node_id: 'nested', status: NodeRunningStatus.Running })
const retryTrace = createTrace({ node_id: 'retry', status: NodeRunningStatus.Waiting })
const parallelChild = createTrace({ node_id: 'p-child', status: NodeRunningStatus.Running })
const parentTrace = createTrace({
node_id: 'parent',
status: NodeRunningStatus.Running,
details: [[nestedTrace]],
retryDetail: [retryTrace],
parallelDetail: { children: [parallelChild] },
} as Partial<NodeTracing>)
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [parentTrace] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: WorkflowRunningStatus.Stopped },
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].status).toBe(NodeRunningStatus.Stopped)
expect(produced.tracing[0].details![0][0].status).toBe(NodeRunningStatus.Stopped)
expect(produced.tracing[0].retryDetail![0].status).toBe(NodeRunningStatus.Stopped)
const parallel = produced.tracing[0].parallelDetail as { children: NodeTracing[] }
expect(parallel.children[0].status).toBe(NodeRunningStatus.Stopped)
})
it('should not change status of already succeeded nodes', () => {
const succeededTrace = createTrace({
node_id: 'done',
status: NodeRunningStatus.Succeeded,
})
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [succeededTrace] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: WorkflowRunningStatus.Stopped },
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].status).toBe(NodeRunningStatus.Succeeded)
})
it('should handle trace with no nested details/retryDetail/parallelDetail', () => {
const simpleTrace = createTrace({ node_id: 'simple', status: NodeRunningStatus.Running })
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [simpleTrace] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: WorkflowRunningStatus.Stopped },
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].status).toBe(NodeRunningStatus.Stopped)
})
})
// Branch coverage: handleGroupNext early return
describe('handleGroupNext - early return', () => {
it('should return early when requestData has no node_id', () => {
const deps = createMockDeps({
requestData: { inputs: {} }, // no node_id
})
const cb = createWorkflowCallbacks(deps)
cb.onIterationNext()
expect(deps.setProcessData).not.toHaveBeenCalled()
})
})
// Branch coverage: onNodeFinished edge cases
describe('onNodeFinished - branch coverage', () => {
it('should preserve existing extras when updating trace', () => {
const trace = createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
extras: { key: 'val' },
} as Partial<NodeTracing>)
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [trace] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onNodeFinished({
data: createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
status: NodeRunningStatus.Succeeded as NodeTracing['status'],
}),
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].extras).toEqual({ key: 'val' })
})
it('should not add extras when existing trace has no extras', () => {
const trace = createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
})
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [trace] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onNodeFinished({
data: createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
}),
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0]).not.toHaveProperty('extras')
})
it('should do nothing when trace is not found (idx === -1)', () => {
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onNodeFinished({
data: createTrace({ node_id: 'nonexistent' }),
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing).toHaveLength(0)
})
})
// Branch coverage: handleGroupFinish without error
describe('handleGroupFinish - branch coverage', () => {
it('should set expand=false when no error', () => {
const existing = createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
})
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [existing] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onLoopFinish({
data: createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
}),
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].expand).toBe(false)
})
})
// Branch coverage: handleWorkflowEnd without error
describe('handleWorkflowEnd - branch coverage', () => {
it('should not notify when no error message', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: WorkflowRunningStatus.Stopped },
} as never)
expect(deps.notify).not.toHaveBeenCalled()
})
})
// Branch coverage: findTraceIndex matching via parallel_id vs execution_metadata
describe('findTrace matching', () => {
it('should match trace via parallel_id field', () => {
const trace = createTrace({
node_id: 'n1',
parallel_id: 'p1',
} as Partial<NodeTracing>)
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [trace] })),
})
const cb = createWorkflowCallbacks(deps)
cb.onNodeFinished({
data: createTrace({
node_id: 'n1',
execution_metadata: { parallel_id: 'p1' } as NodeTracing['execution_metadata'],
status: NodeRunningStatus.Succeeded as NodeTracing['status'],
}),
} as never)
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].status).toBe(NodeRunningStatus.Succeeded)
})
it('should not match when both parallel_id fields differ', () => {
const trace = createTrace({
node_id: 'group-node',
execution_metadata: { parallel_id: 'other' } as NodeTracing['execution_metadata'],
parallel_id: 'also-other',
details: [[]],
} as Partial<NodeTracing>)
const deps = createMockDeps({
getProcessData: vi.fn(() => createProcess({ tracing: [trace] })),
requestData: { inputs: {}, node_id: 'group-node', execution_metadata: { parallel_id: 'target' } },
})
const cb = createWorkflowCallbacks(deps)
cb.onIterationNext()
// group not found, details unchanged
const produced = (deps.setProcessData as ReturnType<typeof vi.fn>).mock.calls[0][0] as WorkflowProcess
expect(produced.tracing[0].details).toHaveLength(1)
})
})
// Branch coverage: onWorkflowFinished success with multiple output keys
describe('onWorkflowFinished - output branches', () => {
it('should not set resultText when outputs have multiple keys', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: 'succeeded', outputs: { key1: 'val1', key2: 'val2' } },
} as never)
// setProcessData called once (for succeeded status), not twice (no resultText)
expect(deps.setProcessData).toHaveBeenCalledTimes(1)
})
it('should not set resultText when single key is not a string', () => {
const deps = createMockDeps()
const cb = createWorkflowCallbacks(deps)
cb.onWorkflowFinished({
data: { status: 'succeeded', outputs: { data: { nested: true } } },
} as never)
expect(deps.setProcessData).toHaveBeenCalledTimes(1)
})
})
})

View File

@@ -0,0 +1,237 @@
import type { TFunction } from 'i18next'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { VisionFile } from '@/types/app'
import type { NodeTracing, WorkflowFinishedResponse } from '@/types/workflow'
import { produce } from 'immer'
import {
getFilesInLogs,
} from '@/app/components/base/file-uploader/utils'
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
type WorkflowFinishedData = WorkflowFinishedResponse['data']
type TraceItem = WorkflowProcess['tracing'][number]
function findTraceIndex(
tracing: WorkflowProcess['tracing'],
nodeId: string,
parallelId?: string,
): number {
return tracing.findIndex(item =>
item.node_id === nodeId
&& (item.execution_metadata?.parallel_id === parallelId || item.parallel_id === parallelId),
)
}
function findTrace(
tracing: WorkflowProcess['tracing'],
nodeId: string,
parallelId?: string,
): TraceItem | undefined {
return tracing.find(item =>
item.node_id === nodeId
&& (item.execution_metadata?.parallel_id === parallelId || item.parallel_id === parallelId),
)
}
function markNodesStopped(traces?: WorkflowProcess['tracing']) {
if (!traces)
return
const mark = (trace: TraceItem) => {
if ([NodeRunningStatus.Running, NodeRunningStatus.Waiting].includes(trace.status as NodeRunningStatus))
trace.status = NodeRunningStatus.Stopped
trace.details?.forEach(group => group.forEach(mark))
trace.retryDetail?.forEach(mark)
trace.parallelDetail?.children?.forEach(mark)
}
traces.forEach(mark)
}
export type WorkflowCallbackDeps = {
getProcessData: () => WorkflowProcess | undefined
setProcessData: (data: WorkflowProcess) => void
setCurrentTaskId: (id: string | null) => void
setIsStopping: (v: boolean) => void
getCompletionRes: () => string
setCompletionRes: (res: string) => void
setRespondingFalse: () => void
resetRunState: () => void
setMessageId: (id: string | null) => void
isTimeoutRef: { current: boolean }
isEndRef: { current: boolean }
tempMessageIdRef: { current: string }
taskId?: number
onCompleted: (completionRes: string, taskId?: number, success?: boolean) => void
notify: (options: { type: 'error' | 'info' | 'success' | 'warning', message: string }) => void
t: TFunction
// The outer request data object passed to sendWorkflowMessage.
// Used by group next handlers to match traces (mirrors original closure behavior).
requestData: { inputs: Record<string, string | number | boolean | object>, files?: VisionFile[], node_id?: string, execution_metadata?: { parallel_id?: string } }
}
export function createWorkflowCallbacks(deps: WorkflowCallbackDeps) {
const {
getProcessData,
setProcessData,
setCurrentTaskId,
setIsStopping,
getCompletionRes,
setCompletionRes,
setRespondingFalse,
resetRunState,
setMessageId,
isTimeoutRef,
isEndRef,
tempMessageIdRef,
taskId,
onCompleted,
notify,
t,
requestData,
} = deps
const updateProcessData = (updater: (draft: WorkflowProcess) => void) => {
setProcessData(produce(getProcessData()!, updater))
}
const handleGroupStart = ({ data }: { data: NodeTracing }) => {
updateProcessData((draft) => {
draft.expand = true
draft.tracing!.push({ ...data, status: NodeRunningStatus.Running, expand: true })
})
}
const handleGroupNext = () => {
if (!requestData.node_id)
return
updateProcessData((draft) => {
draft.expand = true
const group = findTrace(
draft.tracing,
requestData.node_id!,
requestData.execution_metadata?.parallel_id,
)
group?.details!.push([])
})
}
const handleGroupFinish = ({ data }: { data: NodeTracing }) => {
updateProcessData((draft) => {
draft.expand = true
const idx = findTraceIndex(draft.tracing, data.node_id, data.execution_metadata?.parallel_id)
draft.tracing[idx] = { ...data, expand: !!data.error }
})
}
const handleWorkflowEnd = (status: WorkflowRunningStatus, error?: string) => {
if (error)
notify({ type: 'error', message: error })
updateProcessData((draft) => {
draft.status = status
markNodesStopped(draft.tracing)
})
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEndRef.current = true
}
return {
onWorkflowStarted: ({ workflow_run_id, task_id }: { workflow_run_id: string, task_id?: string }) => {
tempMessageIdRef.current = workflow_run_id
setCurrentTaskId(task_id || null)
setIsStopping(false)
setProcessData({
status: WorkflowRunningStatus.Running,
tracing: [],
expand: false,
resultText: '',
})
},
onIterationStart: handleGroupStart,
onIterationNext: handleGroupNext,
onIterationFinish: handleGroupFinish,
onLoopStart: handleGroupStart,
onLoopNext: handleGroupNext,
onLoopFinish: handleGroupFinish,
onNodeStarted: ({ data }: { data: NodeTracing }) => {
if (data.iteration_id || data.loop_id)
return
updateProcessData((draft) => {
draft.expand = true
draft.tracing!.push({ ...data, status: NodeRunningStatus.Running, expand: true })
})
},
onNodeFinished: ({ data }: { data: NodeTracing }) => {
if (data.iteration_id || data.loop_id)
return
updateProcessData((draft) => {
const idx = findTraceIndex(draft.tracing!, data.node_id, data.execution_metadata?.parallel_id)
if (idx > -1 && draft.tracing) {
draft.tracing[idx] = {
...(draft.tracing[idx].extras ? { extras: draft.tracing[idx].extras } : {}),
...data,
expand: !!data.error,
}
}
})
},
onWorkflowFinished: ({ data }: { data: WorkflowFinishedData }) => {
if (isTimeoutRef.current) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
if (data.status === WorkflowRunningStatus.Stopped) {
handleWorkflowEnd(WorkflowRunningStatus.Stopped)
return
}
if (data.error) {
handleWorkflowEnd(WorkflowRunningStatus.Failed, data.error)
return
}
updateProcessData((draft) => {
draft.status = WorkflowRunningStatus.Succeeded
// eslint-disable-next-line ts/no-explicit-any
draft.files = getFilesInLogs(data.outputs || []) as any[]
})
if (data.outputs) {
setCompletionRes(data.outputs)
const keys = Object.keys(data.outputs)
if (keys.length === 1 && typeof data.outputs[keys[0]] === 'string') {
updateProcessData((draft) => {
draft.resultText = data.outputs[keys[0]]
})
}
}
else {
setCompletionRes('')
}
setRespondingFalse()
resetRunState()
setMessageId(tempMessageIdRef.current)
onCompleted(getCompletionRes(), taskId, true)
isEndRef.current = true
},
onTextChunk: (params: { data: { text: string } }) => {
updateProcessData((draft) => {
draft.resultText += params.data.text
})
},
onTextReplace: (params: { data: { text: string } }) => {
updateProcessData((draft) => {
draft.resultText = params.data.text
})
},
}
}

View File

@@ -0,0 +1,245 @@
import type { IResultProps } from './index'
import { render, screen } from '@testing-library/react'
import { AppSourceType } from '@/service/share'
import Result from './index'
// Mock the custom hook to control state
const mockHandleSend = vi.fn()
const mockHandleStop = vi.fn()
const mockHandleFeedback = vi.fn()
let hookReturnValue = {
isResponding: false,
completionRes: '',
workflowProcessData: undefined as IResultProps['isWorkflow'] extends true ? object : undefined,
messageId: null as string | null,
feedback: { rating: null as string | null },
isStopping: false,
currentTaskId: null as string | null,
controlClearMoreLikeThis: 0,
handleSend: mockHandleSend,
handleStop: mockHandleStop,
handleFeedback: mockHandleFeedback,
}
vi.mock('./hooks/use-text-generation', () => ({
useTextGeneration: () => hookReturnValue,
}))
vi.mock('i18next', () => ({
t: (key: string) => key,
}))
// Mock complex external component to keep tests focused
vi.mock('@/app/components/app/text-generate/item', () => ({
default: ({ content, isWorkflow, taskId, isLoading }: {
content: string
isWorkflow: boolean
taskId?: string
isLoading: boolean
}) => (
<div
data-testid="text-generation-res"
data-content={content}
data-workflow={String(isWorkflow)}
data-task-id={taskId ?? ''}
data-loading={String(isLoading)}
/>
),
}))
vi.mock('@/app/components/share/text-generation/no-data', () => ({
default: () => <div data-testid="no-data" />,
}))
// Factory for default props
const createProps = (overrides: Partial<IResultProps> = {}): IResultProps => ({
isWorkflow: false,
isCallBatchAPI: false,
isPC: true,
isMobile: false,
appSourceType: AppSourceType.webApp,
appId: 'app-1',
isError: false,
isShowTextToSpeech: false,
promptConfig: { prompt_template: '', prompt_variables: [] },
moreLikeThisEnabled: false,
inputs: {},
onShowRes: vi.fn(),
handleSaveMessage: vi.fn(),
onCompleted: vi.fn(),
visionConfig: { enabled: false } as IResultProps['visionConfig'],
completionFiles: [],
siteInfo: null,
onRunStart: vi.fn(),
...overrides,
})
describe('Result', () => {
beforeEach(() => {
vi.clearAllMocks()
hookReturnValue = {
isResponding: false,
completionRes: '',
workflowProcessData: undefined,
messageId: null,
feedback: { rating: null },
isStopping: false,
currentTaskId: null,
controlClearMoreLikeThis: 0,
handleSend: mockHandleSend,
handleStop: mockHandleStop,
handleFeedback: mockHandleFeedback,
}
})
// Empty state rendering
describe('empty state', () => {
it('should show NoData when not batch and no completion data', () => {
render(<Result {...createProps()} />)
expect(screen.getByTestId('no-data')).toBeInTheDocument()
expect(screen.queryByTestId('text-generation-res')).not.toBeInTheDocument()
})
it('should show NoData when workflow mode has no process data', () => {
render(<Result {...createProps({ isWorkflow: true })} />)
expect(screen.getByTestId('no-data')).toBeInTheDocument()
})
})
// Loading state rendering
describe('loading state', () => {
it('should show loading spinner when responding but no data yet', () => {
hookReturnValue.isResponding = true
hookReturnValue.completionRes = ''
const { container } = render(<Result {...createProps()} />)
// Loading area renders a spinner
expect(container.querySelector('.items-center.justify-center')).toBeInTheDocument()
expect(screen.queryByTestId('no-data')).not.toBeInTheDocument()
expect(screen.queryByTestId('text-generation-res')).not.toBeInTheDocument()
})
it('should not show loading in batch mode even when responding', () => {
hookReturnValue.isResponding = true
hookReturnValue.completionRes = ''
render(<Result {...createProps({ isCallBatchAPI: true })} />)
// Batch mode skips loading state and goes to TextGenerationRes
expect(screen.getByTestId('text-generation-res')).toBeInTheDocument()
})
})
// Result rendering
describe('result rendering', () => {
it('should render TextGenerationRes when completion data exists', () => {
hookReturnValue.completionRes = 'Generated output'
render(<Result {...createProps()} />)
const res = screen.getByTestId('text-generation-res')
expect(res).toBeInTheDocument()
expect(res.dataset.content).toBe('Generated output')
})
it('should render TextGenerationRes for workflow with process data', () => {
hookReturnValue.workflowProcessData = { status: 'running', tracing: [] } as never
render(<Result {...createProps({ isWorkflow: true })} />)
const res = screen.getByTestId('text-generation-res')
expect(res.dataset.workflow).toBe('true')
})
it('should format batch taskId with leading zero for single digit', () => {
hookReturnValue.completionRes = 'batch result'
render(<Result {...createProps({ isCallBatchAPI: true, taskId: 3 })} />)
expect(screen.getByTestId('text-generation-res').dataset.taskId).toBe('03')
})
it('should format batch taskId without leading zero for double digit', () => {
hookReturnValue.completionRes = 'batch result'
render(<Result {...createProps({ isCallBatchAPI: true, taskId: 12 })} />)
expect(screen.getByTestId('text-generation-res').dataset.taskId).toBe('12')
})
it('should show loading in TextGenerationRes for batch mode while responding', () => {
hookReturnValue.isResponding = true
hookReturnValue.completionRes = ''
render(<Result {...createProps({ isCallBatchAPI: true })} />)
expect(screen.getByTestId('text-generation-res').dataset.loading).toBe('true')
})
})
// Stop button
describe('stop button', () => {
it('should show stop button when responding with active task', () => {
hookReturnValue.isResponding = true
hookReturnValue.completionRes = 'data'
hookReturnValue.currentTaskId = 'task-1'
render(<Result {...createProps()} />)
expect(screen.getByText('operation.stopResponding')).toBeInTheDocument()
})
it('should hide stop button when hideInlineStopButton is true', () => {
hookReturnValue.isResponding = true
hookReturnValue.completionRes = 'data'
hookReturnValue.currentTaskId = 'task-1'
render(<Result {...createProps({ hideInlineStopButton: true })} />)
expect(screen.queryByText('operation.stopResponding')).not.toBeInTheDocument()
})
it('should hide stop button when not responding', () => {
hookReturnValue.completionRes = 'data'
hookReturnValue.currentTaskId = 'task-1'
render(<Result {...createProps()} />)
expect(screen.queryByText('operation.stopResponding')).not.toBeInTheDocument()
})
it('should show spinner icon when stopping', () => {
hookReturnValue.isResponding = true
hookReturnValue.completionRes = 'data'
hookReturnValue.currentTaskId = 'task-1'
hookReturnValue.isStopping = true
const { container } = render(<Result {...createProps()} />)
expect(container.querySelector('.animate-spin')).toBeInTheDocument()
})
it('should align stop button to end on PC, center on mobile', () => {
hookReturnValue.isResponding = true
hookReturnValue.completionRes = 'data'
hookReturnValue.currentTaskId = 'task-1'
const { container, rerender } = render(<Result {...createProps({ isPC: true })} />)
expect(container.querySelector('.justify-end')).toBeInTheDocument()
rerender(<Result {...createProps({ isPC: false })} />)
expect(container.querySelector('.justify-center')).toBeInTheDocument()
})
})
// Memo
describe('memoization', () => {
it('should be wrapped with React.memo', () => {
expect((Result as unknown as { $$typeof: symbol }).$$typeof).toBe(Symbol.for('react.memo'))
})
})
})

View File

@@ -1,34 +1,19 @@
'use client'
import type { FC } from 'react'
import type { FeedbackType } from '@/app/components/base/chat/chat/type'
import type { WorkflowProcess } from '@/app/components/base/chat/types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { InputValueTypes } from '../types'
import type { PromptConfig } from '@/models/debug'
import type { SiteInfo } from '@/models/share'
import type { AppSourceType } from '@/service/share'
import type { VisionFile, VisionSettings } from '@/types/app'
import { RiLoader2Line } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import { t } from 'i18next'
import { produce } from 'immer'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import TextGenerationRes from '@/app/components/app/text-generate/item'
import Button from '@/app/components/base/button'
import {
getFilesInLogs,
getProcessedFiles,
} from '@/app/components/base/file-uploader/utils'
import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
import Loading from '@/app/components/base/loading'
import Toast from '@/app/components/base/toast'
import NoData from '@/app/components/share/text-generation/no-data'
import { NodeRunningStatus, WorkflowRunningStatus } from '@/app/components/workflow/types'
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
import { sendCompletionMessage, sendWorkflowMessage, stopChatMessageResponding, stopWorkflowMessage, updateFeedback } from '@/service/share'
import { TransferMethod } from '@/types/app'
import { sleep } from '@/utils'
import { formatBooleanInputs } from '@/utils/model-config'
import { useTextGeneration } from './hooks/use-text-generation'
export type IResultProps = {
isWorkflow: boolean
@@ -41,7 +26,7 @@ export type IResultProps = {
isShowTextToSpeech: boolean
promptConfig: PromptConfig | null
moreLikeThisEnabled: boolean
inputs: Record<string, any>
inputs: Record<string, InputValueTypes>
controlSend?: number
controlRetry?: number
controlStopResponding?: number
@@ -57,492 +42,61 @@ export type IResultProps = {
hideInlineStopButton?: boolean
}
const Result: FC<IResultProps> = ({
isWorkflow,
isCallBatchAPI,
isPC,
isMobile,
appSourceType,
appId,
isError,
isShowTextToSpeech,
promptConfig,
moreLikeThisEnabled,
inputs,
controlSend,
controlRetry,
controlStopResponding,
onShowRes,
handleSaveMessage,
taskId,
onCompleted,
visionConfig,
completionFiles,
siteInfo,
onRunStart,
onRunControlChange,
hideInlineStopButton = false,
}) => {
const [isResponding, { setTrue: setRespondingTrue, setFalse: setRespondingFalse }] = useBoolean(false)
const [completionRes, doSetCompletionRes] = useState<string>('')
const completionResRef = useRef<string>('')
const setCompletionRes = (res: string) => {
completionResRef.current = res
doSetCompletionRes(res)
}
const getCompletionRes = () => completionResRef.current
const [workflowProcessData, doSetWorkflowProcessData] = useState<WorkflowProcess>()
const workflowProcessDataRef = useRef<WorkflowProcess | undefined>(undefined)
const setWorkflowProcessData = (data: WorkflowProcess) => {
workflowProcessDataRef.current = data
doSetWorkflowProcessData(data)
}
const getWorkflowProcessData = () => workflowProcessDataRef.current
const [currentTaskId, setCurrentTaskId] = useState<string | null>(null)
const [isStopping, setIsStopping] = useState(false)
const abortControllerRef = useRef<AbortController | null>(null)
const resetRunState = useCallback(() => {
setCurrentTaskId(null)
setIsStopping(false)
abortControllerRef.current = null
onRunControlChange?.(null)
}, [onRunControlChange])
const Result: FC<IResultProps> = (props) => {
const {
isWorkflow,
isCallBatchAPI,
isPC,
isMobile,
appSourceType,
appId,
isError,
isShowTextToSpeech,
moreLikeThisEnabled,
handleSaveMessage,
taskId,
siteInfo,
hideInlineStopButton = false,
} = props
useEffect(() => {
const abortCurrentRequest = () => {
abortControllerRef.current?.abort()
}
const {
isResponding,
completionRes,
workflowProcessData,
messageId,
feedback,
isStopping,
currentTaskId,
controlClearMoreLikeThis,
handleSend,
handleStop,
handleFeedback,
} = useTextGeneration(props)
if (controlStopResponding) {
abortCurrentRequest()
setRespondingFalse()
resetRunState()
}
// Determine content state using a unified check
const hasData = isWorkflow ? !!workflowProcessData : !!completionRes
const isLoadingState = !isCallBatchAPI && isResponding && !hasData
const isEmptyState = !isCallBatchAPI && !hasData
return abortCurrentRequest
}, [controlStopResponding, resetRunState, setRespondingFalse])
const { notify } = Toast
const isNoData = !completionRes
const [messageId, setMessageId] = useState<string | null>(null)
const [feedback, setFeedback] = useState<FeedbackType>({
rating: null,
})
const handleFeedback = async (feedback: FeedbackType) => {
await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, appSourceType, appId)
setFeedback(feedback)
if (isLoadingState) {
return (
<div className="flex h-full w-full items-center justify-center">
<Loading type="area" />
</div>
)
}
const logError = (message: string) => {
notify({ type: 'error', message })
}
if (isEmptyState)
return <NoData />
const handleStop = useCallback(async () => {
if (!currentTaskId || isStopping)
return
setIsStopping(true)
try {
if (isWorkflow)
await stopWorkflowMessage(appId!, currentTaskId, appSourceType, appId || '')
else
await stopChatMessageResponding(appId!, currentTaskId, appSourceType, appId || '')
abortControllerRef.current?.abort()
}
catch (error) {
const message = error instanceof Error ? error.message : String(error)
notify({ type: 'error', message })
}
finally {
setIsStopping(false)
}
}, [appId, currentTaskId, appSourceType, appId, isStopping, isWorkflow, notify])
useEffect(() => {
if (!onRunControlChange)
return
if (isResponding && currentTaskId) {
onRunControlChange({
onStop: handleStop,
isStopping,
})
}
else {
onRunControlChange(null)
}
}, [currentTaskId, handleStop, isResponding, isStopping, onRunControlChange])
const checkCanSend = () => {
// batch will check outer
if (isCallBatchAPI)
return true
const prompt_variables = promptConfig?.prompt_variables
if (!prompt_variables || prompt_variables?.length === 0) {
if (completionFiles.find(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
return false
}
return true
}
let hasEmptyInput = ''
const requiredVars = prompt_variables?.filter(({ key, name, required, type }) => {
if (type === 'boolean' || type === 'checkbox')
return false // boolean/checkbox input is not required
const res = (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null)
return res
}) || [] // compatible with old version
requiredVars.forEach(({ key, name }) => {
if (hasEmptyInput)
return
if (!inputs[key])
hasEmptyInput = name
})
if (hasEmptyInput) {
logError(t('errorMessage.valueOfVarRequired', { ns: 'appDebug', key: hasEmptyInput }))
return false
}
if (completionFiles.find(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
notify({ type: 'info', message: t('errorMessage.waitForFileUpload', { ns: 'appDebug' }) })
return false
}
return !hasEmptyInput
}
const handleSend = async () => {
if (isResponding) {
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
return false
}
if (!checkCanSend())
return
// Process inputs: convert file entities to API format
const processedInputs = { ...formatBooleanInputs(promptConfig?.prompt_variables, inputs) }
promptConfig?.prompt_variables.forEach((variable) => {
const value = processedInputs[variable.key]
if (variable.type === 'file' && value && typeof value === 'object' && !Array.isArray(value)) {
// Convert single file entity to API format
processedInputs[variable.key] = getProcessedFiles([value as FileEntity])[0]
}
else if (variable.type === 'file-list' && Array.isArray(value) && value.length > 0) {
// Convert file entity array to API format
processedInputs[variable.key] = getProcessedFiles(value as FileEntity[])
}
})
const data: Record<string, any> = {
inputs: processedInputs,
}
if (visionConfig.enabled && completionFiles && completionFiles?.length > 0) {
data.files = completionFiles.map((item) => {
if (item.transfer_method === TransferMethod.local_file) {
return {
...item,
url: '',
}
}
return item
})
}
setMessageId(null)
setFeedback({
rating: null,
})
setCompletionRes('')
resetRunState()
let res: string[] = []
let tempMessageId = ''
if (!isPC) {
onShowRes()
onRunStart()
}
setRespondingTrue()
let isEnd = false
let isTimeout = false;
(async () => {
await sleep(TEXT_GENERATION_TIMEOUT_MS)
if (!isEnd) {
setRespondingFalse()
onCompleted(getCompletionRes(), taskId, false)
resetRunState()
isTimeout = true
}
})()
if (isWorkflow) {
sendWorkflowMessage(
data,
{
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
tempMessageId = workflow_run_id
setCurrentTaskId(task_id || null)
setIsStopping(false)
setWorkflowProcessData({
status: WorkflowRunningStatus.Running,
tracing: [],
expand: false,
resultText: '',
})
},
onIterationStart: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
draft.tracing!.push({
...data,
status: NodeRunningStatus.Running,
expand: true,
})
}))
},
onIterationNext: () => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const iterations = draft.tracing.find(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
iterations?.details!.push([])
}))
},
onIterationFinish: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const iterationsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
draft.tracing[iterationsIndex] = {
...data,
expand: !!data.error,
}
}))
},
onLoopStart: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
draft.tracing!.push({
...data,
status: NodeRunningStatus.Running,
expand: true,
})
}))
},
onLoopNext: () => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const loops = draft.tracing.find(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
loops?.details!.push([])
}))
},
onLoopFinish: ({ data }) => {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
const loopsIndex = draft.tracing.findIndex(item => item.node_id === data.node_id
&& (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id))!
draft.tracing[loopsIndex] = {
...data,
expand: !!data.error,
}
}))
},
onNodeStarted: ({ data }) => {
if (data.iteration_id)
return
if (data.loop_id)
return
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.expand = true
draft.tracing!.push({
...data,
status: NodeRunningStatus.Running,
expand: true,
})
}))
},
onNodeFinished: ({ data }) => {
if (data.iteration_id)
return
if (data.loop_id)
return
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
const currentIndex = draft.tracing!.findIndex(trace => trace.node_id === data.node_id
&& (trace.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || trace.parallel_id === data.execution_metadata?.parallel_id))
if (currentIndex > -1 && draft.tracing) {
draft.tracing[currentIndex] = {
...(draft.tracing[currentIndex].extras
? { extras: draft.tracing[currentIndex].extras }
: {}),
...data,
expand: !!data.error,
}
}
}))
},
onWorkflowFinished: ({ data }) => {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
const workflowStatus = data.status as WorkflowRunningStatus | undefined
const markNodesStopped = (traces?: WorkflowProcess['tracing']) => {
if (!traces)
return
const markTrace = (trace: WorkflowProcess['tracing'][number]) => {
if ([NodeRunningStatus.Running, NodeRunningStatus.Waiting].includes(trace.status as NodeRunningStatus))
trace.status = NodeRunningStatus.Stopped
trace.details?.forEach(detailGroup => detailGroup.forEach(markTrace))
trace.retryDetail?.forEach(markTrace)
trace.parallelDetail?.children?.forEach(markTrace)
}
traces.forEach(markTrace)
}
if (workflowStatus === WorkflowRunningStatus.Stopped) {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.status = WorkflowRunningStatus.Stopped
markNodesStopped(draft.tracing)
}))
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEnd = true
return
}
if (data.error) {
notify({ type: 'error', message: data.error })
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.status = WorkflowRunningStatus.Failed
markNodesStopped(draft.tracing)
}))
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEnd = true
return
}
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.status = WorkflowRunningStatus.Succeeded
draft.files = getFilesInLogs(data.outputs || []) as any[]
}))
if (!data.outputs) {
setCompletionRes('')
}
else {
setCompletionRes(data.outputs)
const isStringOutput = Object.keys(data.outputs).length === 1 && typeof data.outputs[Object.keys(data.outputs)[0]] === 'string'
if (isStringOutput) {
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.resultText = data.outputs[Object.keys(data.outputs)[0]]
}))
}
}
setRespondingFalse()
resetRunState()
setMessageId(tempMessageId)
onCompleted(getCompletionRes(), taskId, true)
isEnd = true
},
onTextChunk: (params) => {
const { data: { text } } = params
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.resultText += text
}))
},
onTextReplace: (params) => {
const { data: { text } } = params
setWorkflowProcessData(produce(getWorkflowProcessData()!, (draft) => {
draft.resultText = text
}))
},
},
appSourceType,
appId,
).catch((error) => {
setRespondingFalse()
resetRunState()
const message = error instanceof Error ? error.message : String(error)
notify({ type: 'error', message })
})
}
else {
sendCompletionMessage(data, {
onData: (data: string, _isFirstMessage: boolean, { messageId, taskId }) => {
tempMessageId = messageId
if (taskId && typeof taskId === 'string' && taskId.trim() !== '')
setCurrentTaskId(prev => prev ?? taskId)
res.push(data)
setCompletionRes(res.join(''))
},
onCompleted: () => {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
setRespondingFalse()
resetRunState()
setMessageId(tempMessageId)
onCompleted(getCompletionRes(), taskId, true)
isEnd = true
},
onMessageReplace: (messageReplace) => {
res = [messageReplace.answer]
setCompletionRes(res.join(''))
},
onError() {
if (isTimeout) {
notify({ type: 'warning', message: t('warningMessage.timeoutExceeded', { ns: 'appDebug' }) })
return
}
setRespondingFalse()
resetRunState()
onCompleted(getCompletionRes(), taskId, false)
isEnd = true
},
getAbortController: (abortController) => {
abortControllerRef.current = abortController
},
}, appSourceType, appId)
}
}
const [controlClearMoreLikeThis, setControlClearMoreLikeThis] = useState(0)
useEffect(() => {
if (controlSend) {
handleSend()
setControlClearMoreLikeThis(Date.now())
}
}, [controlSend])
useEffect(() => {
if (controlRetry)
handleSend()
}, [controlRetry])
const renderTextGenerationRes = () => (
return (
<>
{!hideInlineStopButton && isResponding && currentTaskId && (
<div className={`mb-3 flex ${isPC ? 'justify-end' : 'justify-center'}`}>
<Button
variant="secondary"
disabled={isStopping}
onClick={handleStop}
>
{
isStopping
? <RiLoader2Line className="mr-[5px] h-3.5 w-3.5 animate-spin" />
: <StopCircle className="mr-[5px] h-3.5 w-3.5" />
}
<Button variant="secondary" disabled={isStopping} onClick={handleStop}>
{isStopping
? <RiLoader2Line className="mr-[5px] h-3.5 w-3.5 animate-spin" />
: <StopCircle className="mr-[5px] h-3.5 w-3.5" />}
<span className="text-xs font-normal">{t('operation.stopResponding', { ns: 'appDebug' })}</span>
</Button>
</div>
@@ -571,37 +125,6 @@ const Result: FC<IResultProps> = ({
/>
</>
)
return (
<>
{!isCallBatchAPI && !isWorkflow && (
(isResponding && !completionRes)
? (
<div className="flex h-full w-full items-center justify-center">
<Loading type="area" />
</div>
)
: (
<>
{(isNoData)
? <NoData />
: renderTextGenerationRes()}
</>
)
)}
{!isCallBatchAPI && isWorkflow && (
(isResponding && !workflowProcessData)
? (
<div className="flex h-full w-full items-center justify-center">
<Loading type="area" />
</div>
)
: !workflowProcessData
? <NoData />
: renderTextGenerationRes()
)}
{isCallBatchAPI && renderTextGenerationRes()}
</>
)
}
export default React.memo(Result)

View File

@@ -1,5 +1,7 @@
import type { ChangeEvent, FC, FormEvent } from 'react'
import type { InputValueTypes } from '../types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { FileUploadConfigResponse } from '@/models/common'
import type { PromptConfig } from '@/models/debug'
import type { SiteInfo } from '@/models/share'
import type { VisionFile, VisionSettings } from '@/types/app'
@@ -8,7 +10,7 @@ import {
RiPlayLargeLine,
} from '@remixicon/react'
import * as React from 'react'
import { useCallback, useEffect, useState } from 'react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import { FileUploaderInAttachmentWrapper } from '@/app/components/base/file-uploader'
@@ -50,7 +52,7 @@ const RunOnce: FC<IRunOnceProps> = ({
const { t } = useTranslation()
const media = useBreakpoints()
const isPC = media === MediaType.pc
const [isInitialized, setIsInitialized] = useState(false)
const isInitializedRef = React.useRef(false)
const onClear = () => {
const newInputs: Record<string, InputValueTypes> = {}
@@ -80,15 +82,16 @@ const RunOnce: FC<IRunOnceProps> = ({
runControl?.onStop?.()
}, [isRunning, runControl])
const handleInputsChange = useCallback((newInputs: Record<string, any>) => {
const handleInputsChange = useCallback((newInputs: Record<string, InputValueTypes>) => {
onInputsChange(newInputs)
inputsRef.current = newInputs
}, [onInputsChange, inputsRef])
useEffect(() => {
if (isInitialized)
if (isInitializedRef.current)
return
const newInputs: Record<string, any> = {}
isInitializedRef.current = true
const newInputs: Record<string, InputValueTypes> = {}
promptConfig.prompt_variables.forEach((item) => {
if (item.type === 'select')
newInputs[item.key] = item.default
@@ -106,7 +109,6 @@ const RunOnce: FC<IRunOnceProps> = ({
newInputs[item.key] = undefined
})
onInputsChange(newInputs)
setIsInitialized(true)
}, [promptConfig.prompt_variables, onInputsChange])
return (
@@ -114,7 +116,7 @@ const RunOnce: FC<IRunOnceProps> = ({
<section>
{/* input form */}
<form onSubmit={onSubmit}>
{(inputs === null || inputs === undefined || Object.keys(inputs).length === 0) || !isInitialized
{Object.keys(inputs).length === 0
? null
: promptConfig.prompt_variables.filter(item => item.hide !== true).map(item => (
<div className="mt-4 w-full" key={item.key}>
@@ -169,22 +171,21 @@ const RunOnce: FC<IRunOnceProps> = ({
)}
{item.type === 'file' && (
<FileUploaderInAttachmentWrapper
value={(inputs[item.key] && typeof inputs[item.key] === 'object') ? [inputs[item.key]] : []}
value={(inputs[item.key] && typeof inputs[item.key] === 'object') ? [inputs[item.key] as FileEntity] : []}
onChange={(files) => { handleInputsChange({ ...inputsRef.current, [item.key]: files[0] }) }}
fileConfig={{
...item.config,
fileUploadConfig: (visionConfig as any).fileUploadConfig,
fileUploadConfig: (visionConfig as VisionSettings & { fileUploadConfig?: FileUploadConfigResponse }).fileUploadConfig,
}}
/>
)}
{item.type === 'file-list' && (
<FileUploaderInAttachmentWrapper
value={Array.isArray(inputs[item.key]) ? inputs[item.key] : []}
value={Array.isArray(inputs[item.key]) ? inputs[item.key] as FileEntity[] : []}
onChange={(files) => { handleInputsChange({ ...inputsRef.current, [item.key]: files }) }}
fileConfig={{
...item.config,
// eslint-disable-next-line ts/no-explicit-any
fileUploadConfig: (visionConfig as any).fileUploadConfig,
fileUploadConfig: (visionConfig as VisionSettings & { fileUploadConfig?: FileUploadConfigResponse }).fileUploadConfig,
}}
/>
)}

View File

@@ -1,5 +1,9 @@
type TaskParam = {
inputs: Record<string, string | boolean | undefined>
import type { FileEntity } from '@/app/components/base/file-uploader/types'
export type InputValueTypes = string | boolean | number | string[] | FileEntity | FileEntity[] | Record<string, unknown> | undefined
export type TaskParam = {
inputs: Record<string, string | undefined>
}
export type Task = {
@@ -14,6 +18,3 @@ export enum TaskStatus {
completed = 'completed',
failed = 'failed',
}
// eslint-disable-next-line ts/no-explicit-any
export type InputValueTypes = string | boolean | number | string[] | object | undefined | any

View File

@@ -3141,14 +3141,6 @@
"count": 1
}
},
"app/components/share/text-generation/index.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 1
},
"ts/no-explicit-any": {
"count": 8
}
},
"app/components/share/text-generation/menu-dropdown.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 1
@@ -3159,19 +3151,6 @@
"count": 1
}
},
"app/components/share/text-generation/result/header.tsx": {
"tailwindcss/no-unnecessary-whitespace": {
"count": 3
}
},
"app/components/share/text-generation/result/index.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 3
},
"ts/no-explicit-any": {
"count": 3
}
},
"app/components/share/text-generation/run-batch/csv-reader/index.spec.tsx": {
"ts/no-explicit-any": {
"count": 2
@@ -3182,14 +3161,6 @@
"count": 2
}
},
"app/components/share/text-generation/run-once/index.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 1
},
"ts/no-explicit-any": {
"count": 3
}
},
"app/components/share/utils.ts": {
"ts/no-explicit-any": {
"count": 2

View File

@@ -47,7 +47,7 @@
"i18n:check": "tsx ./scripts/check-i18n.js",
"test": "vitest run",
"test:coverage": "vitest run --coverage",
"test:ci": "vitest run --coverage --reporter vitest-tiny-reporter --silent=passed-only",
"test:ci": "vitest run --coverage --silent=passed-only",
"test:watch": "vitest --watch",
"analyze-component": "node ./scripts/analyze-component.js",
"refactor-component": "node ./scripts/refactor-component.js",
@@ -236,8 +236,7 @@
"vite": "7.3.1",
"vite-tsconfig-paths": "6.0.4",
"vitest": "4.0.17",
"vitest-canvas-mock": "1.1.3",
"vitest-tiny-reporter": "1.3.1"
"vitest-canvas-mock": "1.1.3"
},
"pnpm": {
"overrides": {

15
web/pnpm-lock.yaml generated
View File

@@ -585,9 +585,6 @@ importers:
vitest-canvas-mock:
specifier: 1.1.3
version: 1.1.3(vitest@4.0.17)
vitest-tiny-reporter:
specifier: 1.3.1
version: 1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17)
packages:
@@ -7294,12 +7291,6 @@ packages:
peerDependencies:
vitest: ^3.0.0 || ^4.0.0
vitest-tiny-reporter@1.3.1:
resolution: {integrity: sha512-9WfLruQBbxm4EqMIS0jDZmQjvMgsWgHUso9mHQWgjA6hM3tEVhjdG8wYo7ePFh1XbwEFzEo3XUQqkGoKZ/Td2Q==}
peerDependencies:
'@vitest/runner': ^2.0.0 || ^3.0.2 || ^4.0.0
vitest: ^2.0.0 || ^3.0.2 || ^4.0.0
vitest@4.0.17:
resolution: {integrity: sha512-FQMeF0DJdWY0iOnbv466n/0BudNdKj1l5jYgl5JVTwjSsZSlqyXFt/9+1sEyhR6CLowbZpV7O1sCHrzBhucKKg==}
engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0}
@@ -15351,12 +15342,6 @@ snapshots:
moo-color: 1.0.3
vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
vitest-tiny-reporter@1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17):
dependencies:
'@vitest/runner': 4.0.17
tinyrainbow: 3.0.3
vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
vitest@4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2):
dependencies:
'@vitest/expect': 4.0.17

View File

@@ -1,6 +1,8 @@
import { defineConfig, mergeConfig } from 'vitest/config'
import viteConfig from './vite.config'
const isCI = !!process.env.CI
export default mergeConfig(viteConfig, defineConfig({
test: {
environment: 'jsdom',
@@ -8,7 +10,7 @@ export default mergeConfig(viteConfig, defineConfig({
setupFiles: ['./vitest.setup.ts'],
coverage: {
provider: 'v8',
reporter: ['json', 'json-summary'],
reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'],
},
},
}))