Compare commits

..

2 Commits

Author SHA1 Message Date
autofix-ci[bot]
6e6922c4ac [autofix.ci] apply automated fixes 2026-01-13 15:17:42 +00:00
hj24
c385283356 refactor: enhance clean message task 2026-01-13 23:14:31 +08:00
30 changed files with 2800 additions and 807 deletions

View File

@@ -709,6 +709,7 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000

View File

@@ -3,6 +3,7 @@ import datetime
import json
import logging
import secrets
import time
from typing import Any
import click
@@ -46,6 +47,8 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
@@ -2168,3 +2171,79 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
@click.command("clean-expired-messages", help="Clean expired messages.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
default=21,
show_default=True,
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleteing")
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
):
"""
Clean expired messages and related data for tenants based on clean policy.
"""
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Total messages scanned: {stats['total_messages']}\n"
f" - Messages filtered: {stats['filtered_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise
click.echo(click.style("messages cleanup completed.", fg="green"))

View File

@@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@@ -200,6 +200,8 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)

View File

@@ -251,7 +251,10 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
return super().is_empty() and not self.tool_calls
if not super().is_empty() and not self.tool_calls:
return False
return True
class SystemPromptMessage(PromptMessage):

View File

@@ -1,7 +1,6 @@
import logging
from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -152,7 +151,6 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -458,7 +456,6 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -478,7 +475,6 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)

View File

@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=span_data.span_kind,
kind=trace_api.SpanKind.INTERNAL,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,

View File

@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.trace import Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,4 +34,3 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")

View File

@@ -19,7 +19,6 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -137,11 +136,13 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config.get("data", {})
# Get node type
# Get node class
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -157,12 +158,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_factory = DifyNodeFactory(
node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
# variable selector to variable mapping

View File

@@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clean_expired_messages,
clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
@@ -58,6 +59,7 @@ def init_app(app: DifyApp):
transform_datasource_credentials,
install_rag_pipeline_plugins,
clean_workflow_runs,
clean_expired_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -0,0 +1,33 @@
"""feat: add created_at id index to messages
Revision ID: 3334862ee907
Revises: 905527cc8fd3
Create Date: 2026-01-12 17:29:44.846544
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3334862ee907'
down_revision = '905527cc8fd3'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_index('message_created_at_id_idx')
# ### end Alembic commands ###

View File

@@ -968,6 +968,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))

View File

@@ -1,90 +1,62 @@
import datetime
import logging
import time
import click
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.feature_service import FeatureService
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
@app.celery.task(queue="retention")
def clean_messages():
click.echo(click.style("Start clean messages.", fg="green"))
start_at = time.perf_counter()
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
)
while True:
try:
# Main query with join and filter
messages = (
db.session.query(Message)
.where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
)
"""
Clean expired messages based on clean policy.
except SQLAlchemyError:
raise
if not messages:
break
for message in messages:
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(app.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == CloudPlan.SANDBOX:
# clean related message
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
This task uses MessagesCleanService to efficiently clean messages in batches.
The behavior depends on BILLING_ENABLED configuration:
- BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period)
- BILLING_ENABLED=False: delete all messages within the time range
"""
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
# Create policy based on billing configuration
policy = create_message_clean_policy(
graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
)
# Create and run the cleanup service
service = MessagesCleanService.from_days(
policy=policy,
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Total messages scanned: {stats['total_messages']}\n"
f" - Messages filtered: {stats['filtered_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise

View File

@@ -0,0 +1,216 @@
import datetime
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
@dataclass
class SimpleMessage:
id: str
app_id: str
created_at: datetime.datetime
class MessagesCleanPolicy(ABC):
"""
Abstract base class for message cleanup policies.
A policy determines which messages from a batch should be deleted.
"""
@abstractmethod
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages and return IDs of messages that should be deleted.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
...
class BillingDisabledPolicy(MessagesCleanPolicy):
"""
Policy for community or enterpriseedition (billing disabled).
No special filter logic, just return all message ids.
"""
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
return [msg.id for msg in messages]
class BillingSandboxPolicy(MessagesCleanPolicy):
"""
Policy for sandbox plan tenants in cloud edition (billing enabled).
Filters messages based on sandbox plan expiration rules:
- Skip tenants in the whitelist
- Only delete messages from sandbox plan tenants
- Respect grace period after subscription expiration
- Safe default: if tenant mapping or plan is missing, do NOT delete
"""
def __init__(
self,
plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
graceful_period_days: int = 21,
tenant_whitelist: Sequence[str] | None = None,
current_timestamp: int | None = None,
) -> None:
self._graceful_period_days = graceful_period_days
self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
self._plan_provider = plan_provider
self._current_timestamp = current_timestamp
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages based on sandbox plan expiration rules.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
if not messages or not app_to_tenant:
return []
# Get unique tenant_ids and fetch subscription plans
tenant_ids = list(set(app_to_tenant.values()))
tenant_plans = self._plan_provider(tenant_ids)
if not tenant_plans:
return []
# Apply sandbox deletion rules
return self._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
)
def _filter_expired_sandbox_messages(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
tenant_plans: dict[str, SubscriptionPlan],
) -> list[str]:
"""
Filter messages that should be deleted based on sandbox plan expiration.
A message should be deleted if:
1. It belongs to a sandbox tenant AND
2. Either:
a) The tenant has no previous subscription (expiration_date == -1), OR
b) The subscription expired more than graceful_period_days ago
Args:
messages: List of message objects with id and app_id attributes
app_to_tenant: Mapping from app_id to tenant_id
tenant_plans: Mapping from tenant_id to subscription plan info
Returns:
List of message IDs that should be deleted
"""
current_timestamp = self._current_timestamp
if current_timestamp is None:
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
sandbox_message_ids: list[str] = []
graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
for msg in messages:
# Get tenant_id for this message's app
tenant_id = app_to_tenant.get(msg.app_id)
if not tenant_id:
continue
# Skip tenant messages in whitelist
if tenant_id in self._tenant_whitelist:
continue
# Get subscription plan for this tenant
tenant_plan = tenant_plans.get(tenant_id)
if not tenant_plan:
continue
plan = str(tenant_plan["plan"])
expiration_date = int(tenant_plan["expiration_date"])
# Only process sandbox plans
if plan != CloudPlan.SANDBOX:
continue
# Case 1: No previous subscription (-1 means never had a paid subscription)
if expiration_date == -1:
sandbox_message_ids.append(msg.id)
continue
# Case 2: Subscription expired beyond grace period
if current_timestamp - expiration_date > graceful_period_seconds:
sandbox_message_ids.append(msg.id)
return sandbox_message_ids
def create_message_clean_policy(
graceful_period_days: int = 21,
current_timestamp: int | None = None,
) -> MessagesCleanPolicy:
"""
Factory function to create the appropriate message clean policy.
Determines which policy to use based on BILLING_ENABLED configuration:
- If BILLING_ENABLED is True: returns BillingSandboxPolicy
- If BILLING_ENABLED is False: returns BillingDisabledPolicy
Args:
graceful_period_days: Grace period in days after subscription expiration (default: 21)
current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
"""
if not dify_config.BILLING_ENABLED:
logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
return BillingDisabledPolicy()
# Billing enabled - fetch whitelist from BillingService
tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
plan_provider = BillingService.get_plan_bulk_with_cache
logger.info(
"create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
"(graceful_period_days=%s, whitelist=%s)",
graceful_period_days,
tenant_whitelist,
)
return BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=graceful_period_days,
tenant_whitelist=tenant_whitelist,
current_timestamp=current_timestamp,
)

View File

@@ -0,0 +1,334 @@
import datetime
import logging
import random
from collections.abc import Sequence
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import (
App,
AppAnnotationHitHistory,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.messages_clean_policy import (
MessagesCleanPolicy,
SimpleMessage,
)
logger = logging.getLogger(__name__)
class MessagesCleanService:
"""
Service for cleaning expired messages based on retention policies.
Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
"""
def __init__(
self,
policy: MessagesCleanPolicy,
end_before: datetime.datetime,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
dry_run: bool = False,
) -> None:
"""
Initialize the service with cleanup parameters.
Args:
policy: The policy that determines which messages to delete
end_before: End time (exclusive) of the range
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
"""
self._policy = policy
self._end_before = end_before
self._start_from = start_from
self._batch_size = batch_size
self._dry_run = dry_run
@classmethod
def from_time_range(
cls,
policy: MessagesCleanPolicy,
start_from: datetime.datetime,
end_before: datetime.datetime,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages within a specific time range.
Time range is [start_from, end_before).
Args:
policy: The policy that determines which messages to delete
start_from: Start time (inclusive) of the range
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If start_from >= end_before or invalid parameters
"""
if start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
logger.info(
"clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
start_from,
end_before,
batch_size,
policy,
)
return cls(
policy=policy,
end_before=end_before,
start_from=start_from,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def from_days(
cls,
policy: MessagesCleanPolicy,
days: int = 30,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages older than specified days.
Args:
policy: The policy that determines which messages to delete
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If invalid parameters
"""
if days < 0:
raise ValueError(f"days ({days}) must be greater than or equal to 0")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
logger.info(
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
days,
end_before,
batch_size,
policy,
)
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
def run(self) -> dict[str, int]:
"""
Execute the message cleanup operation.
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
return self._clean_messages_by_time_range()
def _clean_messages_by_time_range(self) -> dict[str, int]:
"""
Clean messages within a time range using cursor-based pagination.
Time range is [start_from, end_before)
Steps:
1. Iterate messages using cursor pagination (by created_at, id)
2. Query app_id -> tenant_id mapping
3. Delegate to policy to determine which messages to delete
4. Batch delete messages and their relations
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
stats = {
"batches": 0,
"total_messages": 0,
"filtered_messages": 0,
"total_deleted": 0,
}
# Cursor-based pagination using (created_at, id) to avoid infinite loops
# and ensure proper ordering with time-based filtering
_cursor: tuple[datetime.datetime, str] | None = None
logger.info(
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
self._dry_run,
self._start_from,
self._end_before,
)
while True:
stats["batches"] += 1
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
msg_stmt = (
select(Message.id, Message.app_id, Message.created_at)
.where(Message.created_at < self._end_before)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
# This translates to:
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
if _cursor:
# Continuing from previous batch
msg_stmt = msg_stmt.where(
(Message.created_at > _cursor[0])
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
)
raw_messages = list(session.execute(msg_stmt).all())
messages = [
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
for msg_id, app_id, msg_created_at in raw_messages
]
# Track total messages fetched across all batches
stats["total_messages"] += len(messages)
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
break
# Update cursor to the last message's (created_at, id)
_cursor = (messages[-1].created_at, messages[-1].id)
# Step 2: Extract app_ids and query tenant_ids
app_ids = list({msg.app_id for msg in messages})
if not app_ids:
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
continue
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
apps = list(session.execute(app_stmt).all())
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
continue
# Build app_id -> tenant_id mapping
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
# Step 3: Delegate to policy to determine which messages to delete
message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
if not message_ids_to_delete:
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
continue
stats["filtered_messages"] += len(message_ids_to_delete)
# Step 4: Batch delete messages and their relations
if not self._dry_run:
with Session(db.engine, expire_on_commit=False) as session:
# Delete related records first
self._batch_delete_message_relations(session, message_ids_to_delete)
# Delete messages
delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
delete_result = cast(CursorResult, session.execute(delete_stmt))
messages_deleted = delete_result.rowcount
session.commit()
stats["total_deleted"] += messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s messages",
stats["batches"],
len(messages),
messages_deleted,
)
else:
# Log random sample of message IDs that would be deleted (up to 10)
sample_size = min(10, len(message_ids_to_delete))
sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
logger.info(
"clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
stats["batches"],
len(message_ids_to_delete),
sample_size,
)
for msg_id in sampled_ids:
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
stats["batches"],
stats["total_messages"],
stats["filtered_messages"],
stats["total_deleted"],
)
return stats
@staticmethod
def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
"""
Batch delete all related records for given message IDs.
Args:
session: Database session
message_ids: List of message IDs to delete relations for
"""
if not message_ids:
return
# Delete all related records in batch
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))

View File

@@ -2,17 +2,13 @@ from types import SimpleNamespace
import pytest
from configs import dify_config
from core.file.enums import FileType
from core.file.models import File, FileTransferMethod
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.variables import StringVariable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
@@ -100,58 +96,6 @@ class TestWorkflowEntry:
assert output_var is not None
assert output_var.value == "system_user"
def test_single_step_run_injects_code_limits(self):
"""Ensure single-step CodeNode execution configures limits."""
# Arrange
node_id = "code_node"
node_data = {
"type": "code",
"title": "Code",
"desc": None,
"variables": [],
"code_language": CodeLanguage.PYTHON3,
"code": "def main():\n return {}",
"outputs": {},
}
node_config = {"id": node_id, "data": node_data}
class StubWorkflow:
def __init__(self):
self.tenant_id = "tenant"
self.app_id = "app"
self.id = "workflow"
self.graph_dict = {"nodes": [node_config], "edges": []}
def get_node_config_by_id(self, target_id: str):
assert target_id == node_id
return node_config
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=SystemVariable.empty(), user_inputs={})
expected_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
# Act
node, _ = WorkflowEntry.single_step_run(
workflow=workflow,
node_id=node_id,
user_id="user",
user_inputs={},
variable_pool=variable_pool,
)
# Assert
assert isinstance(node, CodeNode)
assert node._limits == expected_limits
def test_mapping_user_inputs_to_variable_pool_with_env_variables(self):
"""Test mapping environment variables from user inputs to variable pool."""
# Initialize variable pool with environment variables

View File

@@ -0,0 +1,627 @@
import datetime
from unittest.mock import MagicMock, patch
import pytest
from enums.cloud_plan import CloudPlan
from services.retention.conversation.messages_clean_policy import (
BillingDisabledPolicy,
BillingSandboxPolicy,
SimpleMessage,
create_message_clean_policy,
)
from services.retention.conversation.messages_clean_service import MessagesCleanService
def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage:
"""Helper to create a SimpleMessage with a fixed created_at timestamp."""
return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1))
def make_plan_provider(tenant_plans: dict) -> MagicMock:
"""Helper to create a mock plan_provider that returns the given tenant_plans."""
provider = MagicMock()
provider.return_value = tenant_plans
return provider
class TestBillingSandboxPolicyFilterMessageIds:
"""Unit tests for BillingSandboxPolicy.filter_message_ids method."""
# Fixed timestamp for deterministic tests
CURRENT_TIMESTAMP = 1000000
GRACEFUL_PERIOD_DAYS = 8
GRACEFUL_PERIOD_SECONDS = GRACEFUL_PERIOD_DAYS * 24 * 60 * 60
def test_missing_tenant_mapping_excluded(self):
"""Test that messages with missing app-to-tenant mapping are excluded."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {} # No mapping
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_missing_tenant_plan_excluded(self):
"""Test that messages with missing tenant plan are excluded (safe default)."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {} # No plans
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_non_sandbox_plan_excluded(self):
"""Test that messages from non-sandbox plans (PROFESSIONAL/TEAM) are excluded."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.TEAM, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, # Only this one
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg3 (sandbox tenant) should be included
assert set(result) == {"msg3"}
def test_whitelist_skip(self):
"""Test that whitelisted tenants are excluded even if sandbox + expired."""
# Arrange
messages = [
make_simple_message("msg1", "app1"), # Whitelisted - excluded
make_simple_message("msg2", "app2"), # Not whitelisted - included
make_simple_message("msg3", "app3"), # Whitelisted - excluded
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
plan_provider = make_plan_provider(tenant_plans)
tenant_whitelist = ["tenant1", "tenant3"]
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
tenant_whitelist=tenant_whitelist,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg2 should be included
assert set(result) == {"msg2"}
def test_no_previous_subscription_included(self):
"""Test that messages with expiration_date=-1 (no previous subscription) are included."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all messages should be included
assert set(result) == {"msg1", "msg2"}
def test_within_grace_period_excluded(self):
"""Test that messages within grace period are excluded."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_1_day_ago = now - (1 * 24 * 60 * 60)
expired_5_days_ago = now - (5 * 24 * 60 * 60)
expired_7_days_ago = now - (7 * 24 * 60 * 60)
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS, # 8 days
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all within 8-day grace period, none should be included
assert list(result) == []
def test_exactly_at_boundary_excluded(self):
"""Test that messages exactly at grace period boundary are excluded (code uses >)."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_exactly_8_days_ago = now - self.GRACEFUL_PERIOD_SECONDS # Exactly at boundary
messages = [make_simple_message("msg1", "app1")]
app_to_tenant = {"app1": "tenant1"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - exactly at boundary (==) should be excluded (code uses >)
assert list(result) == []
def test_beyond_grace_period_included(self):
"""Test that messages beyond grace period are included."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond 8-day grace
expired_30_days_ago = now - (30 * 24 * 60 * 60) # Well beyond
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - both beyond grace period, should be included
assert set(result) == {"msg1", "msg2"}
def test_empty_messages_returns_empty(self):
"""Test that empty messages returns empty list."""
# Arrange
messages: list[SimpleMessage] = []
app_to_tenant = {"app1": "tenant1"}
plan_provider = make_plan_provider({"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}})
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_plan_provider_called_with_correct_tenant_ids(self):
"""Test that plan_provider is called with correct tenant_ids."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant1"} # tenant1 appears twice
plan_provider = make_plan_provider({})
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
policy.filter_message_ids(messages, app_to_tenant)
# Assert - plan_provider should be called once with unique tenant_ids
plan_provider.assert_called_once()
called_tenant_ids = set(plan_provider.call_args[0][0])
assert called_tenant_ids == {"tenant1", "tenant2"}
def test_complex_mixed_scenario(self):
"""Test complex scenario with mixed plans, expirations, whitelist, and missing mappings."""
# Arrange
now = self.CURRENT_TIMESTAMP
sandbox_expired_old = now - (15 * 24 * 60 * 60) # Beyond grace
sandbox_expired_recent = now - (3 * 24 * 60 * 60) # Within grace
future_expiration = now + (30 * 24 * 60 * 60)
messages = [
make_simple_message("msg1", "app1"), # Sandbox, no subscription - included
make_simple_message("msg2", "app2"), # Sandbox, expired old - included
make_simple_message("msg3", "app3"), # Sandbox, within grace - excluded
make_simple_message("msg4", "app4"), # Team plan, active - excluded
make_simple_message("msg5", "app5"), # No tenant mapping - excluded
make_simple_message("msg6", "app6"), # No plan info - excluded
make_simple_message("msg7", "app7"), # Sandbox, expired old, whitelisted - excluded
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
"app6": "tenant6", # Has mapping but no plan
"app7": "tenant7",
# app5 has no mapping
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent},
"tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
"tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
# tenant6 has no plan
}
plan_provider = make_plan_provider(tenant_plans)
tenant_whitelist = ["tenant7"]
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
tenant_whitelist=tenant_whitelist,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg1 and msg2 should be included
assert set(result) == {"msg1", "msg2"}
class TestBillingDisabledPolicyFilterMessageIds:
"""Unit tests for BillingDisabledPolicy.filter_message_ids method."""
def test_returns_all_message_ids(self):
"""Test that all message IDs are returned (order-preserving)."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all message IDs returned in order
assert list(result) == ["msg1", "msg2", "msg3"]
def test_ignores_app_to_tenant(self):
"""Test that app_to_tenant mapping is ignored."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant: dict[str, str] = {} # Empty - should be ignored
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all message IDs still returned
assert list(result) == ["msg1", "msg2"]
def test_empty_messages_returns_empty(self):
"""Test that empty messages returns empty list."""
# Arrange
messages: list[SimpleMessage] = []
app_to_tenant = {"app1": "tenant1"}
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
class TestCreateMessageCleanPolicy:
"""Unit tests for create_message_clean_policy factory function."""
@patch("services.retention.conversation.messages_clean_policy.dify_config")
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
# Arrange
mock_config.BILLING_ENABLED = False
# Act
policy = create_message_clean_policy(graceful_period_days=21)
# Assert
assert isinstance(policy, BillingDisabledPolicy)
@patch("services.retention.conversation.messages_clean_policy.BillingService")
@patch("services.retention.conversation.messages_clean_policy.dify_config")
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
"""Test that BillingSandboxPolicy is created with correct internal values."""
# Arrange
mock_config.BILLING_ENABLED = True
whitelist = ["tenant1", "tenant2"]
mock_billing_service.get_expired_subscription_cleanup_whitelist.return_value = whitelist
mock_plan_provider = MagicMock()
mock_billing_service.get_plan_bulk_with_cache = mock_plan_provider
# Act
policy = create_message_clean_policy(graceful_period_days=14, current_timestamp=1234567)
# Assert
mock_billing_service.get_expired_subscription_cleanup_whitelist.assert_called_once()
assert isinstance(policy, BillingSandboxPolicy)
assert policy._graceful_period_days == 14
assert list(policy._tenant_whitelist) == whitelist
assert policy._plan_provider == mock_plan_provider
assert policy._current_timestamp == 1234567
class TestMessagesCleanServiceFromTimeRange:
"""Unit tests for MessagesCleanService.from_time_range factory method."""
def test_start_from_end_before_raises_value_error(self):
"""Test that start_from == end_before raises ValueError."""
policy = BillingDisabledPolicy()
# Arrange
same_time = datetime.datetime(2024, 1, 1, 12, 0, 0)
# Act & Assert
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=same_time,
end_before=same_time,
)
# Arrange
start_from = datetime.datetime(2024, 12, 31)
end_before = datetime.datetime(2024, 1, 1)
# Act & Assert
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
)
def test_batch_size_raises_value_error(self):
"""Test that batch_size=0 raises ValueError."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=0,
)
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=-100,
)
def test_valid_params_creates_instance(self):
"""Test that valid parameters create a correctly configured instance."""
# Arrange
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 12, 31, 23, 59, 59)
policy = BillingDisabledPolicy()
batch_size = 500
dry_run = True
# Act
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert
assert isinstance(service, MessagesCleanService)
assert service._policy is policy
assert service._start_from == start_from
assert service._end_before == end_before
assert service._batch_size == batch_size
assert service._dry_run == dry_run
def test_default_params(self):
"""Test that default parameters are applied correctly."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
)
# Assert
assert service._batch_size == 1000 # default
assert service._dry_run is False # default
class TestMessagesCleanServiceFromDays:
"""Unit tests for MessagesCleanService.from_days factory method."""
def test_days_raises_value_error(self):
"""Test that days < 0 raises ValueError."""
# Arrange
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"):
MessagesCleanService.from_days(policy=policy, days=-1)
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy, days=0)
# Assert
assert service._end_before == fixed_now
def test_batch_size_raises_value_error(self):
"""Test that batch_size=0 raises ValueError."""
# Arrange
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_days(policy=policy, days=30, batch_size=0)
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_days(policy=policy, days=30, batch_size=-500)
def test_valid_params_creates_instance(self):
"""Test that valid parameters create a correctly configured instance."""
# Arrange
policy = BillingDisabledPolicy()
days = 90
batch_size = 500
dry_run = True
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(
policy=policy,
days=days,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=days)
assert isinstance(service, MessagesCleanService)
assert service._policy is policy
assert service._start_from is None
assert service._end_before == expected_end_before
assert service._batch_size == batch_size
assert service._dry_run == dry_run
def test_default_params(self):
"""Test that default parameters are applied correctly."""
# Arrange
policy = BillingDisabledPolicy()
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30
assert service._end_before == expected_end_before
assert service._batch_size == 1000 # default
assert service._dry_run is False # default

View File

@@ -31,8 +31,6 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false
# The timeout for the text generation in millisecond
NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
# Used by web/docker/entrypoint.sh to overwrite/export NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS at container startup (Docker only)
TEXT_GENERATION_TIMEOUT_MS=60000
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST=

View File

@@ -54,7 +54,7 @@ const pageNameEnrichmentPlugin = (): amplitude.Types.EnrichmentPlugin => {
}
const AmplitudeProvider: FC<IAmplitudeProps> = ({
sessionReplaySampleRate = 0.5,
sessionReplaySampleRate = 1,
}) => {
useEffect(() => {
// Only enable in Saas edition with valid API key

View File

@@ -1,201 +0,0 @@
'use client'
import type { FC } from 'react'
import type { Item } from '@/app/components/base/select'
import type { BuiltInMetadataItem, MetadataItemWithValueLength } from '@/app/components/datasets/metadata/types'
import type { SortType } from '@/service/datasets'
import { PlusIcon } from '@heroicons/react/24/solid'
import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Chip from '@/app/components/base/chip'
import Input from '@/app/components/base/input'
import Sort from '@/app/components/base/sort'
import AutoDisabledDocument from '@/app/components/datasets/common/document-status-with-action/auto-disabled-document'
import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed'
import StatusWithAction from '@/app/components/datasets/common/document-status-with-action/status-with-action'
import DatasetMetadataDrawer from '@/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer'
import { useDocLink } from '@/context/i18n'
import { DataSourceType } from '@/models/datasets'
import { useIndexStatus } from '../status-item/hooks'
type DocumentsHeaderProps = {
// Dataset info
datasetId: string
dataSourceType?: DataSourceType
embeddingAvailable: boolean
isFreePlan: boolean
// Filter & sort
statusFilterValue: string
sortValue: SortType
inputValue: string
onStatusFilterChange: (value: string) => void
onStatusFilterClear: () => void
onSortChange: (value: string) => void
onInputChange: (value: string) => void
// Metadata modal
isShowEditMetadataModal: boolean
showEditMetadataModal: () => void
hideEditMetadataModal: () => void
datasetMetaData?: MetadataItemWithValueLength[]
builtInMetaData?: BuiltInMetadataItem[]
builtInEnabled: boolean
onAddMetaData: (payload: BuiltInMetadataItem) => Promise<void>
onRenameMetaData: (payload: MetadataItemWithValueLength) => Promise<void>
onDeleteMetaData: (metaDataId: string) => Promise<void>
onBuiltInEnabledChange: (enabled: boolean) => void
// Actions
onAddDocument: () => void
}
const DocumentsHeader: FC<DocumentsHeaderProps> = ({
datasetId,
dataSourceType,
embeddingAvailable,
isFreePlan,
statusFilterValue,
sortValue,
inputValue,
onStatusFilterChange,
onStatusFilterClear,
onSortChange,
onInputChange,
isShowEditMetadataModal,
showEditMetadataModal,
hideEditMetadataModal,
datasetMetaData,
builtInMetaData,
builtInEnabled,
onAddMetaData,
onRenameMetaData,
onDeleteMetaData,
onBuiltInEnabledChange,
onAddDocument,
}) => {
const { t } = useTranslation()
const docLink = useDocLink()
const DOC_INDEX_STATUS_MAP = useIndexStatus()
const isDataSourceNotion = dataSourceType === DataSourceType.NOTION
const isDataSourceWeb = dataSourceType === DataSourceType.WEB
const statusFilterItems: Item[] = useMemo(() => [
{ value: 'all', name: t('list.index.all', { ns: 'datasetDocuments' }) as string },
{ value: 'queuing', name: DOC_INDEX_STATUS_MAP.queuing.text },
{ value: 'indexing', name: DOC_INDEX_STATUS_MAP.indexing.text },
{ value: 'paused', name: DOC_INDEX_STATUS_MAP.paused.text },
{ value: 'error', name: DOC_INDEX_STATUS_MAP.error.text },
{ value: 'available', name: DOC_INDEX_STATUS_MAP.available.text },
{ value: 'enabled', name: DOC_INDEX_STATUS_MAP.enabled.text },
{ value: 'disabled', name: DOC_INDEX_STATUS_MAP.disabled.text },
{ value: 'archived', name: DOC_INDEX_STATUS_MAP.archived.text },
], [DOC_INDEX_STATUS_MAP, t])
const sortItems: Item[] = useMemo(() => [
{ value: 'created_at', name: t('list.sort.uploadTime', { ns: 'datasetDocuments' }) as string },
{ value: 'hit_count', name: t('list.sort.hitCount', { ns: 'datasetDocuments' }) as string },
], [t])
// Determine add button text based on data source type
const addButtonText = useMemo(() => {
if (isDataSourceNotion)
return t('list.addPages', { ns: 'datasetDocuments' })
if (isDataSourceWeb)
return t('list.addUrl', { ns: 'datasetDocuments' })
return t('list.addFile', { ns: 'datasetDocuments' })
}, [isDataSourceNotion, isDataSourceWeb, t])
return (
<>
{/* Title section */}
<div className="flex flex-col justify-center gap-1 px-6 pt-4">
<h1 className="text-base font-semibold text-text-primary">
{t('list.title', { ns: 'datasetDocuments' })}
</h1>
<div className="flex items-center space-x-0.5 text-sm font-normal text-text-tertiary">
<span>{t('list.desc', { ns: 'datasetDocuments' })}</span>
<a
className="flex items-center text-text-accent"
target="_blank"
rel="noopener noreferrer"
href={docLink('/guides/knowledge-base/integrate-knowledge-within-application')}
>
<span>{t('list.learnMore', { ns: 'datasetDocuments' })}</span>
<RiExternalLinkLine className="h-3 w-3" />
</a>
</div>
</div>
{/* Toolbar section */}
<div className="flex flex-wrap items-center justify-between px-6 pt-4">
{/* Left: Filters */}
<div className="flex items-center gap-2">
<Chip
className="w-[160px]"
showLeftIcon={false}
value={statusFilterValue}
items={statusFilterItems}
onSelect={item => onStatusFilterChange(item?.value ? String(item.value) : '')}
onClear={onStatusFilterClear}
/>
<Input
showLeftIcon
showClearIcon
wrapperClassName="!w-[200px]"
value={inputValue}
onChange={e => onInputChange(e.target.value)}
onClear={() => onInputChange('')}
/>
<div className="h-3.5 w-px bg-divider-regular"></div>
<Sort
order={sortValue.startsWith('-') ? '-' : ''}
value={sortValue.replace('-', '')}
items={sortItems}
onSelect={value => onSortChange(String(value))}
/>
</div>
{/* Right: Actions */}
<div className="flex !h-8 items-center justify-center gap-2">
{!isFreePlan && <AutoDisabledDocument datasetId={datasetId} />}
<IndexFailed datasetId={datasetId} />
{!embeddingAvailable && (
<StatusWithAction
type="warning"
description={t('embeddingModelNotAvailable', { ns: 'dataset' })}
/>
)}
{embeddingAvailable && (
<Button variant="secondary" className="shrink-0" onClick={showEditMetadataModal}>
<RiDraftLine className="mr-1 size-4" />
{t('metadata.metadata', { ns: 'dataset' })}
</Button>
)}
{isShowEditMetadataModal && (
<DatasetMetadataDrawer
userMetadata={datasetMetaData ?? []}
onClose={hideEditMetadataModal}
onAdd={onAddMetaData}
onRename={onRenameMetaData}
onRemove={onDeleteMetaData}
builtInMetadata={builtInMetaData ?? []}
isBuiltInEnabled={builtInEnabled}
onIsBuiltInEnabledChange={onBuiltInEnabledChange}
/>
)}
{embeddingAvailable && (
<Button variant="primary" onClick={onAddDocument} className="shrink-0">
<PlusIcon className="mr-2 h-4 w-4 stroke-current" />
{addButtonText}
</Button>
)}
</div>
</div>
</>
)
}
export default DocumentsHeader

View File

@@ -1,41 +0,0 @@
'use client'
import type { FC } from 'react'
import { PlusIcon } from '@heroicons/react/24/solid'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import s from '../style.module.css'
import { FolderPlusIcon, NotionIcon, ThreeDotsIcon } from './icons'
type EmptyElementProps = {
canAdd: boolean
onClick: () => void
type?: 'upload' | 'sync'
}
const EmptyElement: FC<EmptyElementProps> = ({ canAdd = true, onClick, type = 'upload' }) => {
const { t } = useTranslation()
return (
<div className={s.emptyWrapper}>
<div className={s.emptyElement}>
<div className={s.emptySymbolIconWrapper}>
{type === 'upload' ? <FolderPlusIcon /> : <NotionIcon />}
</div>
<span className={s.emptyTitle}>
{t('list.empty.title', { ns: 'datasetDocuments' })}
<ThreeDotsIcon className="relative -left-1.5 -top-3 inline" />
</span>
<div className={s.emptyTip}>
{t(`list.empty.${type}.tip`, { ns: 'datasetDocuments' })}
</div>
{type === 'upload' && canAdd && (
<Button onClick={onClick} className={s.addFileBtn} variant="secondary-accent">
<PlusIcon className={s.plusIcon} />
{t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
)
}
export default EmptyElement

View File

@@ -1,34 +0,0 @@
import type * as React from 'react'
export const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M10.8332 5.83333L9.90355 3.9741C9.63601 3.439 9.50222 3.17144 9.30265 2.97597C9.12615 2.80311 8.91344 2.67164 8.6799 2.59109C8.41581 2.5 8.11668 2.5 7.51841 2.5H4.33317C3.39975 2.5 2.93304 2.5 2.57652 2.68166C2.26292 2.84144 2.00795 3.09641 1.84816 3.41002C1.6665 3.76654 1.6665 4.23325 1.6665 5.16667V5.83333M1.6665 5.83333H14.3332C15.7333 5.83333 16.4334 5.83333 16.9681 6.10582C17.4386 6.3455 17.821 6.72795 18.0607 7.19836C18.3332 7.73314 18.3332 8.4332 18.3332 9.83333V13.5C18.3332 14.9001 18.3332 15.6002 18.0607 16.135C17.821 16.6054 17.4386 16.9878 16.9681 17.2275C16.4334 17.5 15.7333 17.5 14.3332 17.5H5.6665C4.26637 17.5 3.56631 17.5 3.03153 17.2275C2.56112 16.9878 2.17867 16.6054 1.93899 16.135C1.6665 15.6002 1.6665 14.9001 1.6665 13.5V5.83333ZM9.99984 14.1667V9.16667M7.49984 11.6667H12.4998" stroke="#667085" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
export const ThreeDotsIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M5 6.5V5M8.93934 7.56066L10 6.5M10.0103 11.5H11.5103" stroke="#374151" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
export const NotionIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_2164_11263)">
<path fillRule="evenodd" clipRule="evenodd" d="M3.5725 18.2611L1.4229 15.5832C0.905706 14.9389 0.625 14.1466 0.625 13.3312V3.63437C0.625 2.4129 1.60224 1.39936 2.86295 1.31328L12.8326 0.632614C13.5569 0.583164 14.2768 0.775682 14.8717 1.17794L18.3745 3.5462C19.0015 3.97012 19.375 4.66312 19.375 5.40266V16.427C19.375 17.6223 18.4141 18.6121 17.1798 18.688L6.11458 19.3692C5.12958 19.4298 4.17749 19.0148 3.5725 18.2611Z" fill="white" />
<path d="M7.03006 8.48669V8.35974C7.03006 8.03794 7.28779 7.77104 7.61997 7.74886L10.0396 7.58733L13.3857 12.5147V8.19009L12.5244 8.07528V8.01498C12.5244 7.68939 12.788 7.42074 13.1244 7.4035L15.326 7.29073V7.60755C15.326 7.75628 15.2154 7.88349 15.0638 7.90913L14.534 7.99874V15.0023L13.8691 15.231C13.3136 15.422 12.6952 15.2175 12.3772 14.7377L9.12879 9.83574V14.5144L10.1287 14.7057L10.1147 14.7985C10.0711 15.089 9.82028 15.3087 9.51687 15.3222L7.03006 15.4329C6.99718 15.1205 7.23132 14.841 7.55431 14.807L7.88143 14.7727V8.53453L7.03006 8.48669Z" fill="black" />
<path fillRule="evenodd" clipRule="evenodd" d="M12.9218 1.85424L2.95217 2.53491C2.35499 2.57568 1.89209 3.05578 1.89209 3.63437V13.3312C1.89209 13.8748 2.07923 14.403 2.42402 14.8325L4.57362 17.5104C4.92117 17.9434 5.46812 18.1818 6.03397 18.147L17.0991 17.4658C17.6663 17.4309 18.1078 16.9762 18.1078 16.427V5.40266C18.1078 5.06287 17.9362 4.74447 17.6481 4.54969L14.1453 2.18143C13.7883 1.94008 13.3564 1.82457 12.9218 1.85424ZM3.44654 3.78562C3.30788 3.68296 3.37387 3.46909 3.54806 3.4566L12.9889 2.77944C13.2897 2.75787 13.5886 2.8407 13.8318 3.01305L15.7261 4.35508C15.798 4.40603 15.7642 4.51602 15.6752 4.52086L5.67742 5.0646C5.37485 5.08106 5.0762 4.99217 4.83563 4.81406L3.44654 3.78562ZM5.20848 6.76919C5.20848 6.4444 5.47088 6.1761 5.80642 6.15783L16.3769 5.58216C16.7039 5.56435 16.9792 5.81583 16.9792 6.13239V15.6783C16.9792 16.0025 16.7177 16.2705 16.3829 16.2896L5.8793 16.8872C5.51537 16.9079 5.20848 16.6283 5.20848 16.2759V6.76919Z" fill="black" />
</g>
<defs>
<clipPath id="clip0_2164_11263">
<rect width="20" height="20" fill="white" />
</clipPath>
</defs>
</svg>
)
}

View File

@@ -18,7 +18,7 @@ import { useDocumentDetail, useDocumentMetadata, useInvalidDocumentList } from '
import { useCheckSegmentBatchImportProgress, useChildSegmentListKey, useSegmentBatchImport, useSegmentListKey } from '@/service/knowledge/use-segment'
import { useInvalid } from '@/service/use-base'
import { cn } from '@/utils/classnames'
import Operations from '../components/operations'
import Operations from '../operations'
import StatusItem from '../status-item'
import BatchModal from './batch-modal'
import Completed from './completed'

View File

@@ -1,197 +0,0 @@
import type { DocumentListResponse } from '@/models/datasets'
import type { SortType } from '@/service/datasets'
import { useDebounce, useDebounceFn } from 'ahooks'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { normalizeStatusForQuery, sanitizeStatusValue } from '../status-filter'
import useDocumentListQueryState from './use-document-list-query-state'
/**
* Custom hook to manage documents page state including:
* - Search state (input value, debounced search value)
* - Filter state (status filter, sort value)
* - Pagination state (current page, limit)
* - Selection state (selected document ids)
* - Polling state (timer control for auto-refresh)
*/
export function useDocumentsPageState() {
const { query, updateQuery } = useDocumentListQueryState()
// Search state
const [inputValue, setInputValue] = useState<string>('')
const [searchValue, setSearchValue] = useState<string>('')
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
// Filter & sort state
const [statusFilterValue, setStatusFilterValue] = useState<string>(() => sanitizeStatusValue(query.status))
const [sortValue, setSortValue] = useState<SortType>(query.sort)
const normalizedStatusFilterValue = useMemo(
() => normalizeStatusForQuery(statusFilterValue),
[statusFilterValue],
)
// Pagination state
const [currPage, setCurrPage] = useState<number>(query.page - 1)
const [limit, setLimit] = useState<number>(query.limit)
// Selection state
const [selectedIds, setSelectedIds] = useState<string[]>([])
// Polling state
const [timerCanRun, setTimerCanRun] = useState(true)
// Initialize search value from URL on mount
useEffect(() => {
if (query.keyword) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
}, []) // Only run on mount
// Sync local state with URL query changes
useEffect(() => {
setCurrPage(query.page - 1)
setLimit(query.limit)
if (query.keyword !== searchValue) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
setStatusFilterValue((prev) => {
const nextValue = sanitizeStatusValue(query.status)
return prev === nextValue ? prev : nextValue
})
setSortValue(query.sort)
}, [query])
// Update URL when search changes
useEffect(() => {
if (debouncedSearchValue !== query.keyword) {
setCurrPage(0)
updateQuery({ keyword: debouncedSearchValue, page: 1 })
}
}, [debouncedSearchValue, query.keyword, updateQuery])
// Clear selection when search changes
useEffect(() => {
if (searchValue !== query.keyword)
setSelectedIds([])
}, [searchValue, query.keyword])
// Clear selection when status filter changes
useEffect(() => {
setSelectedIds([])
}, [normalizedStatusFilterValue])
// Page change handler
const handlePageChange = useCallback((newPage: number) => {
setCurrPage(newPage)
updateQuery({ page: newPage + 1 })
}, [updateQuery])
// Limit change handler
const handleLimitChange = useCallback((newLimit: number) => {
setLimit(newLimit)
setCurrPage(0)
updateQuery({ limit: newLimit, page: 1 })
}, [updateQuery])
// Debounced search handler
const { run: handleSearch } = useDebounceFn(() => {
setSearchValue(inputValue)
}, { wait: 500 })
// Input change handler
const handleInputChange = useCallback((value: string) => {
setInputValue(value)
handleSearch()
}, [handleSearch])
// Status filter change handler
const handleStatusFilterChange = useCallback((value: string) => {
const selectedValue = sanitizeStatusValue(value)
setStatusFilterValue(selectedValue)
setCurrPage(0)
updateQuery({ status: selectedValue, page: 1 })
}, [updateQuery])
// Status filter clear handler
const handleStatusFilterClear = useCallback(() => {
if (statusFilterValue === 'all')
return
setStatusFilterValue('all')
setCurrPage(0)
updateQuery({ status: 'all', page: 1 })
}, [statusFilterValue, updateQuery])
// Sort change handler
const handleSortChange = useCallback((value: string) => {
const next = value as SortType
if (next === sortValue)
return
setSortValue(next)
setCurrPage(0)
updateQuery({ sort: next, page: 1 })
}, [sortValue, updateQuery])
// Update polling state based on documents response
const updatePollingState = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes?.data)
return
let completedNum = 0
documentsRes.data.forEach((documentItem) => {
const { indexing_status } = documentItem
const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error'
if (isEmbedded)
completedNum++
})
const hasIncompleteDocuments = completedNum !== documentsRes.data.length
const transientStatuses = ['queuing', 'indexing', 'paused']
const shouldForcePolling = normalizedStatusFilterValue === 'all'
? false
: transientStatuses.includes(normalizedStatusFilterValue)
setTimerCanRun(shouldForcePolling || hasIncompleteDocuments)
}, [normalizedStatusFilterValue])
// Adjust page when total pages change
const adjustPageForTotal = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes)
return
const totalPages = Math.ceil(documentsRes.total / limit)
if (currPage > 0 && currPage + 1 > totalPages)
handlePageChange(totalPages > 0 ? totalPages - 1 : 0)
}, [limit, currPage, handlePageChange])
return {
// Search state
inputValue,
searchValue,
debouncedSearchValue,
handleInputChange,
// Filter & sort state
statusFilterValue,
sortValue,
normalizedStatusFilterValue,
handleStatusFilterChange,
handleStatusFilterClear,
handleSortChange,
// Pagination state
currPage,
limit,
handlePageChange,
handleLimitChange,
// Selection state
selectedIds,
setSelectedIds,
// Polling state
timerCanRun,
updatePollingState,
adjustPageForTotal,
}
}
export default useDocumentsPageState

View File

@@ -1,55 +1,185 @@
'use client'
import type { FC } from 'react'
import type { Item } from '@/app/components/base/select'
import type { SortType } from '@/service/datasets'
import { PlusIcon } from '@heroicons/react/24/solid'
import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react'
import { useDebounce, useDebounceFn } from 'ahooks'
import { useRouter } from 'next/navigation'
import { useCallback, useEffect } from 'react'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Loading from '@/app/components/base/loading'
import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n'
import { useProviderContext } from '@/context/provider-context'
import { DataSourceType } from '@/models/datasets'
import { useDocumentList, useInvalidDocumentDetail, useInvalidDocumentList } from '@/service/knowledge/use-document'
import { useChildSegmentListKey, useSegmentListKey } from '@/service/knowledge/use-segment'
import { useInvalid } from '@/service/use-base'
import { cn } from '@/utils/classnames'
import Chip from '../../base/chip'
import Sort from '../../base/sort'
import AutoDisabledDocument from '../common/document-status-with-action/auto-disabled-document'
import StatusWithAction from '../common/document-status-with-action/status-with-action'
import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata'
import DocumentsHeader from './components/documents-header'
import EmptyElement from './components/empty-element'
import List from './components/list'
import useDocumentsPageState from './hooks/use-documents-page-state'
import DatasetMetadataDrawer from '../metadata/metadata-dataset/dataset-metadata-drawer'
import useDocumentListQueryState from './hooks/use-document-list-query-state'
import List from './list'
import { normalizeStatusForQuery, sanitizeStatusValue } from './status-filter'
import { useIndexStatus } from './status-item/hooks'
import s from './style.module.css'
const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M10.8332 5.83333L9.90355 3.9741C9.63601 3.439 9.50222 3.17144 9.30265 2.97597C9.12615 2.80311 8.91344 2.67164 8.6799 2.59109C8.41581 2.5 8.11668 2.5 7.51841 2.5H4.33317C3.39975 2.5 2.93304 2.5 2.57652 2.68166C2.26292 2.84144 2.00795 3.09641 1.84816 3.41002C1.6665 3.76654 1.6665 4.23325 1.6665 5.16667V5.83333M1.6665 5.83333H14.3332C15.7333 5.83333 16.4334 5.83333 16.9681 6.10582C17.4386 6.3455 17.821 6.72795 18.0607 7.19836C18.3332 7.73314 18.3332 8.4332 18.3332 9.83333V13.5C18.3332 14.9001 18.3332 15.6002 18.0607 16.135C17.821 16.6054 17.4386 16.9878 16.9681 17.2275C16.4334 17.5 15.7333 17.5 14.3332 17.5H5.6665C4.26637 17.5 3.56631 17.5 3.03153 17.2275C2.56112 16.9878 2.17867 16.6054 1.93899 16.135C1.6665 15.6002 1.6665 14.9001 1.6665 13.5V5.83333ZM9.99984 14.1667V9.16667M7.49984 11.6667H12.4998" stroke="#667085" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
const ThreeDotsIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M5 6.5V5M8.93934 7.56066L10 6.5M10.0103 11.5H11.5103" stroke="#374151" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
const NotionIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_2164_11263)">
<path fillRule="evenodd" clipRule="evenodd" d="M3.5725 18.2611L1.4229 15.5832C0.905706 14.9389 0.625 14.1466 0.625 13.3312V3.63437C0.625 2.4129 1.60224 1.39936 2.86295 1.31328L12.8326 0.632614C13.5569 0.583164 14.2768 0.775682 14.8717 1.17794L18.3745 3.5462C19.0015 3.97012 19.375 4.66312 19.375 5.40266V16.427C19.375 17.6223 18.4141 18.6121 17.1798 18.688L6.11458 19.3692C5.12958 19.4298 4.17749 19.0148 3.5725 18.2611Z" fill="white" />
<path d="M7.03006 8.48669V8.35974C7.03006 8.03794 7.28779 7.77104 7.61997 7.74886L10.0396 7.58733L13.3857 12.5147V8.19009L12.5244 8.07528V8.01498C12.5244 7.68939 12.788 7.42074 13.1244 7.4035L15.326 7.29073V7.60755C15.326 7.75628 15.2154 7.88349 15.0638 7.90913L14.534 7.99874V15.0023L13.8691 15.231C13.3136 15.422 12.6952 15.2175 12.3772 14.7377L9.12879 9.83574V14.5144L10.1287 14.7057L10.1147 14.7985C10.0711 15.089 9.82028 15.3087 9.51687 15.3222L7.03006 15.4329C6.99718 15.1205 7.23132 14.841 7.55431 14.807L7.88143 14.7727V8.53453L7.03006 8.48669Z" fill="black" />
<path fillRule="evenodd" clipRule="evenodd" d="M12.9218 1.85424L2.95217 2.53491C2.35499 2.57568 1.89209 3.05578 1.89209 3.63437V13.3312C1.89209 13.8748 2.07923 14.403 2.42402 14.8325L4.57362 17.5104C4.92117 17.9434 5.46812 18.1818 6.03397 18.147L17.0991 17.4658C17.6663 17.4309 18.1078 16.9762 18.1078 16.427V5.40266C18.1078 5.06287 17.9362 4.74447 17.6481 4.54969L14.1453 2.18143C13.7883 1.94008 13.3564 1.82457 12.9218 1.85424ZM3.44654 3.78562C3.30788 3.68296 3.37387 3.46909 3.54806 3.4566L12.9889 2.77944C13.2897 2.75787 13.5886 2.8407 13.8318 3.01305L15.7261 4.35508C15.798 4.40603 15.7642 4.51602 15.6752 4.52086L5.67742 5.0646C5.37485 5.08106 5.0762 4.99217 4.83563 4.81406L3.44654 3.78562ZM5.20848 6.76919C5.20848 6.4444 5.47088 6.1761 5.80642 6.15783L16.3769 5.58216C16.7039 5.56435 16.9792 5.81583 16.9792 6.13239V15.6783C16.9792 16.0025 16.7177 16.2705 16.3829 16.2896L5.8793 16.8872C5.51537 16.9079 5.20848 16.6283 5.20848 16.2759V6.76919Z" fill="black" />
</g>
<defs>
<clipPath id="clip0_2164_11263">
<rect width="20" height="20" fill="white" />
</clipPath>
</defs>
</svg>
)
}
const EmptyElement: FC<{ canAdd: boolean, onClick: () => void, type?: 'upload' | 'sync' }> = ({ canAdd = true, onClick, type = 'upload' }) => {
const { t } = useTranslation()
return (
<div className={s.emptyWrapper}>
<div className={s.emptyElement}>
<div className={s.emptySymbolIconWrapper}>
{type === 'upload' ? <FolderPlusIcon /> : <NotionIcon />}
</div>
<span className={s.emptyTitle}>
{t('list.empty.title', { ns: 'datasetDocuments' })}
<ThreeDotsIcon className="relative -left-1.5 -top-3 inline" />
</span>
<div className={s.emptyTip}>
{t(`list.empty.${type}.tip`, { ns: 'datasetDocuments' })}
</div>
{type === 'upload' && canAdd && (
<Button onClick={onClick} className={s.addFileBtn} variant="secondary-accent">
<PlusIcon className={s.plusIcon} />
{t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
)
}
type IDocumentsProps = {
datasetId: string
}
const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
const router = useRouter()
const { t } = useTranslation()
const docLink = useDocLink()
const { plan } = useProviderContext()
const isFreePlan = plan.type === 'sandbox'
const { query, updateQuery } = useDocumentListQueryState()
const [inputValue, setInputValue] = useState<string>('') // the input value
const [searchValue, setSearchValue] = useState<string>('')
const [statusFilterValue, setStatusFilterValue] = useState<string>(() => sanitizeStatusValue(query.status))
const [sortValue, setSortValue] = useState<SortType>(query.sort)
const DOC_INDEX_STATUS_MAP = useIndexStatus()
const [currPage, setCurrPage] = React.useState<number>(query.page - 1) // Convert to 0-based index
const [limit, setLimit] = useState<number>(query.limit)
const router = useRouter()
const dataset = useDatasetDetailContextWithSelector(s => s.dataset)
const [timerCanRun, setTimerCanRun] = useState(true)
const isDataSourceNotion = dataset?.data_source_type === DataSourceType.NOTION
const isDataSourceWeb = dataset?.data_source_type === DataSourceType.WEB
const isDataSourceFile = dataset?.data_source_type === DataSourceType.FILE
const embeddingAvailable = !!dataset?.embedding_available
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
// Use custom hook for page state management
const {
inputValue,
debouncedSearchValue,
handleInputChange,
statusFilterValue,
sortValue,
normalizedStatusFilterValue,
handleStatusFilterChange,
handleStatusFilterClear,
handleSortChange,
currPage,
limit,
handlePageChange,
handleLimitChange,
selectedIds,
setSelectedIds,
timerCanRun,
updatePollingState,
adjustPageForTotal,
} = useDocumentsPageState()
const statusFilterItems: Item[] = useMemo(() => [
{ value: 'all', name: t('list.index.all', { ns: 'datasetDocuments' }) as string },
{ value: 'queuing', name: DOC_INDEX_STATUS_MAP.queuing.text },
{ value: 'indexing', name: DOC_INDEX_STATUS_MAP.indexing.text },
{ value: 'paused', name: DOC_INDEX_STATUS_MAP.paused.text },
{ value: 'error', name: DOC_INDEX_STATUS_MAP.error.text },
{ value: 'available', name: DOC_INDEX_STATUS_MAP.available.text },
{ value: 'enabled', name: DOC_INDEX_STATUS_MAP.enabled.text },
{ value: 'disabled', name: DOC_INDEX_STATUS_MAP.disabled.text },
{ value: 'archived', name: DOC_INDEX_STATUS_MAP.archived.text },
], [DOC_INDEX_STATUS_MAP, t])
const normalizedStatusFilterValue = useMemo(() => normalizeStatusForQuery(statusFilterValue), [statusFilterValue])
const sortItems: Item[] = useMemo(() => [
{ value: 'created_at', name: t('list.sort.uploadTime', { ns: 'datasetDocuments' }) as string },
{ value: 'hit_count', name: t('list.sort.hitCount', { ns: 'datasetDocuments' }) as string },
], [t])
// Initialize search value from URL on mount
useEffect(() => {
if (query.keyword) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
}, []) // Only run on mount
// Sync local state with URL query changes
useEffect(() => {
setCurrPage(query.page - 1)
setLimit(query.limit)
if (query.keyword !== searchValue) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
setStatusFilterValue((prev) => {
const nextValue = sanitizeStatusValue(query.status)
return prev === nextValue ? prev : nextValue
})
setSortValue(query.sort)
}, [query])
// Update URL when pagination changes
const handlePageChange = (newPage: number) => {
setCurrPage(newPage)
updateQuery({ page: newPage + 1 }) // Pagination emits 0-based page, convert to 1-based for URL
}
// Update URL when limit changes
const handleLimitChange = (newLimit: number) => {
setLimit(newLimit)
setCurrPage(0) // Reset to first page when limit changes
updateQuery({ limit: newLimit, page: 1 })
}
// Update URL when search changes
useEffect(() => {
if (debouncedSearchValue !== query.keyword) {
setCurrPage(0) // Reset to first page when search changes
updateQuery({ keyword: debouncedSearchValue, page: 1 })
}
}, [debouncedSearchValue, query.keyword, updateQuery])
// Fetch document list
const { data: documentsRes, isLoading: isListLoading } = useDocumentList({
datasetId,
query: {
@@ -62,18 +192,16 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
refetchInterval: timerCanRun ? 2500 : 0,
})
// Update polling state when documents change
useEffect(() => {
updatePollingState(documentsRes)
}, [documentsRes, updatePollingState])
// Adjust page when total changes
useEffect(() => {
adjustPageForTotal(documentsRes)
}, [documentsRes, adjustPageForTotal])
// Invalidation hooks
const invalidDocumentList = useInvalidDocumentList(datasetId)
useEffect(() => {
if (documentsRes) {
const totalPages = Math.ceil(documentsRes.total / limit)
if (totalPages < currPage + 1)
setCurrPage(totalPages === 0 ? 0 : totalPages - 1)
}
}, [documentsRes])
const invalidDocumentDetail = useInvalidDocumentDetail()
const invalidChunkList = useInvalid(useSegmentListKey)
const invalidChildChunkList = useInvalid(useChildSegmentListKey)
@@ -85,9 +213,73 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
invalidChunkList()
invalidChildChunkList()
}, 5000)
}, [invalidDocumentList, invalidDocumentDetail, invalidChunkList, invalidChildChunkList])
}, [])
useEffect(() => {
let completedNum = 0
let percent = 0
documentsRes?.data?.forEach((documentItem) => {
const { indexing_status, completed_segments, total_segments } = documentItem
const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error'
if (isEmbedded)
completedNum++
const completedCount = completed_segments || 0
const totalCount = total_segments || 0
if (totalCount === 0 && completedCount === 0) {
percent = isEmbedded ? 100 : 0
}
else {
const per = Math.round(completedCount * 100 / totalCount)
percent = per > 100 ? 100 : per
}
return {
...documentItem,
percent,
}
})
const hasIncompleteDocuments = completedNum !== documentsRes?.data?.length
const transientStatuses = ['queuing', 'indexing', 'paused']
const shouldForcePolling = normalizedStatusFilterValue === 'all'
? false
: transientStatuses.includes(normalizedStatusFilterValue)
setTimerCanRun(shouldForcePolling || hasIncompleteDocuments)
}, [documentsRes, normalizedStatusFilterValue])
const total = documentsRes?.total || 0
const routeToDocCreate = () => {
// if dataset is created from pipeline, go to create from pipeline page
if (dataset?.runtime_mode === 'rag_pipeline') {
router.push(`/datasets/${datasetId}/documents/create-from-pipeline`)
return
}
router.push(`/datasets/${datasetId}/documents/create`)
}
const documentsList = documentsRes?.data
const [selectedIds, setSelectedIds] = useState<string[]>([])
// Clear selection when search changes to avoid confusion
useEffect(() => {
if (searchValue !== query.keyword)
setSelectedIds([])
}, [searchValue, query.keyword])
useEffect(() => {
setSelectedIds([])
}, [normalizedStatusFilterValue])
const { run: handleSearch } = useDebounceFn(() => {
setSearchValue(inputValue)
}, { wait: 500 })
const handleInputChange = (value: string) => {
setInputValue(value)
handleSearch()
}
// Metadata editing hook
const {
isShowEditModal: isShowEditMetadataModal,
showEditModal: showEditMetadataModal,
@@ -105,84 +297,130 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
onUpdateDocList: invalidDocumentList,
})
// Route to document creation page
const routeToDocCreate = useCallback(() => {
if (dataset?.runtime_mode === 'rag_pipeline') {
router.push(`/datasets/${datasetId}/documents/create-from-pipeline`)
return
}
router.push(`/datasets/${datasetId}/documents/create`)
}, [dataset?.runtime_mode, datasetId, router])
const total = documentsRes?.total || 0
const documentsList = documentsRes?.data
// Render content based on loading and data state
const renderContent = () => {
if (isListLoading)
return <Loading type="app" />
if (total > 0) {
return (
<List
embeddingAvailable={embeddingAvailable}
documents={documentsList || []}
datasetId={datasetId}
onUpdate={handleUpdate}
selectedIds={selectedIds}
onSelectedIdChange={setSelectedIds}
statusFilterValue={normalizedStatusFilterValue}
remoteSortValue={sortValue}
pagination={{
total,
limit,
onLimitChange: handleLimitChange,
current: currPage,
onChange: handlePageChange,
}}
onManageMetadata={showEditMetadataModal}
/>
)
}
const isDataSourceNotion = dataset?.data_source_type === DataSourceType.NOTION
return (
<EmptyElement
canAdd={embeddingAvailable}
onClick={routeToDocCreate}
type={isDataSourceNotion ? 'sync' : 'upload'}
/>
)
}
return (
<div className="flex h-full flex-col">
<DocumentsHeader
datasetId={datasetId}
dataSourceType={dataset?.data_source_type}
embeddingAvailable={embeddingAvailable}
isFreePlan={isFreePlan}
statusFilterValue={statusFilterValue}
sortValue={sortValue}
inputValue={inputValue}
onStatusFilterChange={handleStatusFilterChange}
onStatusFilterClear={handleStatusFilterClear}
onSortChange={handleSortChange}
onInputChange={handleInputChange}
isShowEditMetadataModal={isShowEditMetadataModal}
showEditMetadataModal={showEditMetadataModal}
hideEditMetadataModal={hideEditMetadataModal}
datasetMetaData={datasetMetaData}
builtInMetaData={builtInMetaData}
builtInEnabled={!!builtInEnabled}
onAddMetaData={handleAddMetaData}
onRenameMetaData={handleRename}
onDeleteMetaData={handleDeleteMetaData}
onBuiltInEnabledChange={setBuiltInEnabled}
onAddDocument={routeToDocCreate}
/>
<div className="flex flex-col justify-center gap-1 px-6 pt-4">
<h1 className="text-base font-semibold text-text-primary">{t('list.title', { ns: 'datasetDocuments' })}</h1>
<div className="flex items-center space-x-0.5 text-sm font-normal text-text-tertiary">
<span>{t('list.desc', { ns: 'datasetDocuments' })}</span>
<a
className="flex items-center text-text-accent"
target="_blank"
href={docLink('/guides/knowledge-base/integrate-knowledge-within-application')}
>
<span>{t('list.learnMore', { ns: 'datasetDocuments' })}</span>
<RiExternalLinkLine className="h-3 w-3" />
</a>
</div>
</div>
<div className="flex h-0 grow flex-col px-6 pt-4">
{renderContent()}
<div className="flex flex-wrap items-center justify-between">
<div className="flex items-center gap-2">
<Chip
className="w-[160px]"
showLeftIcon={false}
value={statusFilterValue}
items={statusFilterItems}
onSelect={(item) => {
const selectedValue = sanitizeStatusValue(item?.value ? String(item.value) : '')
setStatusFilterValue(selectedValue)
setCurrPage(0)
updateQuery({ status: selectedValue, page: 1 })
}}
onClear={() => {
if (statusFilterValue === 'all')
return
setStatusFilterValue('all')
setCurrPage(0)
updateQuery({ status: 'all', page: 1 })
}}
/>
<Input
showLeftIcon
showClearIcon
wrapperClassName="!w-[200px]"
value={inputValue}
onChange={e => handleInputChange(e.target.value)}
onClear={() => handleInputChange('')}
/>
<div className="h-3.5 w-px bg-divider-regular"></div>
<Sort
order={sortValue.startsWith('-') ? '-' : ''}
value={sortValue.replace('-', '')}
items={sortItems}
onSelect={(value) => {
const next = String(value) as SortType
if (next === sortValue)
return
setSortValue(next)
setCurrPage(0)
updateQuery({ sort: next, page: 1 })
}}
/>
</div>
<div className="flex !h-8 items-center justify-center gap-2">
{!isFreePlan && <AutoDisabledDocument datasetId={datasetId} />}
<IndexFailed datasetId={datasetId} />
{!embeddingAvailable && <StatusWithAction type="warning" description={t('embeddingModelNotAvailable', { ns: 'dataset' })} />}
{embeddingAvailable && (
<Button variant="secondary" className="shrink-0" onClick={showEditMetadataModal}>
<RiDraftLine className="mr-1 size-4" />
{t('metadata.metadata', { ns: 'dataset' })}
</Button>
)}
{isShowEditMetadataModal && (
<DatasetMetadataDrawer
userMetadata={datasetMetaData || []}
onClose={hideEditMetadataModal}
onAdd={handleAddMetaData}
onRename={handleRename}
onRemove={handleDeleteMetaData}
builtInMetadata={builtInMetaData || []}
isBuiltInEnabled={!!builtInEnabled}
onIsBuiltInEnabledChange={setBuiltInEnabled}
/>
)}
{embeddingAvailable && (
<Button variant="primary" onClick={routeToDocCreate} className="shrink-0">
<PlusIcon className={cn('mr-2 h-4 w-4 stroke-current')} />
{isDataSourceNotion && t('list.addPages', { ns: 'datasetDocuments' })}
{isDataSourceWeb && t('list.addUrl', { ns: 'datasetDocuments' })}
{(!dataset?.data_source_type || isDataSourceFile) && t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
{isListLoading
? <Loading type="app" />
// eslint-disable-next-line sonarjs/no-nested-conditional
: total > 0
? (
<List
embeddingAvailable={embeddingAvailable}
documents={documentsList || []}
datasetId={datasetId}
onUpdate={handleUpdate}
selectedIds={selectedIds}
onSelectedIdChange={setSelectedIds}
statusFilterValue={normalizedStatusFilterValue}
remoteSortValue={sortValue}
pagination={{
total,
limit,
onLimitChange: handleLimitChange,
current: currPage,
onChange: handlePageChange,
}}
onManageMetadata={showEditMetadataModal}
/>
)
: (
<EmptyElement
canAdd={embeddingAvailable}
onClick={routeToDocCreate}
type={isDataSourceNotion ? 'sync' : 'upload'}
/>
)}
</div>
</div>
)

View File

@@ -16,16 +16,13 @@ import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Checkbox from '@/app/components/base/checkbox'
import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon'
import NotionIcon from '@/app/components/base/notion-icon'
import Pagination from '@/app/components/base/pagination'
import Toast from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label'
import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter'
import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type'
import EditMetadataBatchModal from '@/app/components/datasets/metadata/edit-metadata-batch/modal'
import useBatchEditDocumentMetadata from '@/app/components/datasets/metadata/hooks/use-batch-edit-document-metadata'
import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '@/context/dataset-detail'
import useTimestamp from '@/hooks/use-timestamp'
import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets'
@@ -34,11 +31,14 @@ import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useD
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import { formatNumber } from '@/utils/format'
import BatchAction from '../detail/completed/common/batch-action'
import StatusItem from '../status-item'
import s from '../style.module.css'
import FileTypeIcon from '../../base/file-uploader/file-type-icon'
import ChunkingModeLabel from '../common/chunking-mode-label'
import useBatchEditDocumentMetadata from '../metadata/hooks/use-batch-edit-document-metadata'
import BatchAction from './detail/completed/common/batch-action'
import Operations from './operations'
import RenameModal from './rename-modal'
import StatusItem from './status-item'
import s from './style.module.css'
export const renderTdValue = (value: string | number | null, isEmptyStyle = false) => {
return (

View File

@@ -1,4 +1,4 @@
import type { OperationName } from '../types'
import type { OperationName } from './types'
import type { CommonResponse } from '@/models/common'
import {
RiArchive2Line,
@@ -17,12 +17,6 @@ import * as React from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider'
import CustomPopover from '@/app/components/base/popover'
import Switch from '@/app/components/base/switch'
import { ToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { DataSourceType, DocumentActionType } from '@/models/datasets'
import {
useDocumentArchive,
@@ -37,8 +31,14 @@ import {
} from '@/service/knowledge/use-document'
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import s from '../style.module.css'
import Confirm from '../../base/confirm'
import Divider from '../../base/divider'
import CustomPopover from '../../base/popover'
import Switch from '../../base/switch'
import { ToastContext } from '../../base/toast'
import Tooltip from '../../base/tooltip'
import RenameModal from './rename-modal'
import s from './style.module.css'
type OperationsProps = {
embeddingAvailable: boolean

View File

@@ -7,8 +7,8 @@ import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
import Toast from '@/app/components/base/toast'
import { renameDocumentName } from '@/service/datasets'
import Toast from '../../base/toast'
type Props = {
datasetId: string

View File

@@ -19,28 +19,6 @@ describe('formatNumber', () => {
it('should correctly handle empty input', () => {
expect(formatNumber('')).toBe('')
})
it('should format very small numbers without scientific notation', () => {
expect(formatNumber(0.0000008)).toBe('0.0000008')
expect(formatNumber(0.0000001)).toBe('0.0000001')
expect(formatNumber(0.000001)).toBe('0.000001')
expect(formatNumber(0.00001)).toBe('0.00001')
})
it('should format negative small numbers without scientific notation', () => {
expect(formatNumber(-0.0000008)).toBe('-0.0000008')
expect(formatNumber(-0.0000001)).toBe('-0.0000001')
})
it('should handle small numbers from string input', () => {
expect(formatNumber('0.0000008')).toBe('0.0000008')
expect(formatNumber('8E-7')).toBe('0.0000008')
expect(formatNumber('1e-7')).toBe('0.0000001')
})
it('should handle small numbers with multi-digit mantissa in scientific notation', () => {
expect(formatNumber(1.23e-7)).toBe('0.000000123')
expect(formatNumber(1.234e-7)).toBe('0.0000001234')
expect(formatNumber(12.34e-7)).toBe('0.000001234')
expect(formatNumber(0.0001234)).toBe('0.0001234')
expect(formatNumber('1.23e-7')).toBe('0.000000123')
})
})
describe('formatFileSize', () => {
it('should return the input if it is falsy', () => {

View File

@@ -26,39 +26,11 @@ import 'dayjs/locale/zh-tw'
* Formats a number with comma separators.
* @example formatNumber(1234567) will return '1,234,567'
* @example formatNumber(1234567.89) will return '1,234,567.89'
* @example formatNumber(0.0000008) will return '0.0000008'
*/
export const formatNumber = (num: number | string) => {
if (!num)
return num
const n = typeof num === 'string' ? Number(num) : num
let numStr: string
// Force fixed decimal for small numbers to avoid scientific notation
if (Math.abs(n) < 0.001 && n !== 0) {
const str = n.toString()
const match = str.match(/e-(\d+)$/)
let precision: number
if (match) {
// Scientific notation: precision is exponent + decimal digits in mantissa
const exponent = Number.parseInt(match[1], 10)
const mantissa = str.split('e')[0]
const mantissaDecimalPart = mantissa.split('.')[1]
precision = exponent + (mantissaDecimalPart?.length || 0)
}
else {
// Decimal notation: count decimal places
const decimalPart = str.split('.')[1]
precision = decimalPart?.length || 0
}
numStr = n.toFixed(precision)
}
else {
numStr = n.toString()
}
const parts = numStr.split('.')
const parts = num.toString().split('.')
parts[0] = parts[0].replace(/\B(?=(\d{3})+(?!\d))/g, ',')
return parts.join('.')
}