Compare commits

..

71 Commits

Author SHA1 Message Date
yyh
0c08c4016d Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-06 14:57:48 +08:00
yyh
948efa129f Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-06 14:47:56 +08:00
yyh
6d612c0909 test: improve Jotai atom test quality and add model-provider atoms tests
Replace dynamic imports with static imports in marketplace atom tests.
Convert type-only and not-toThrow assertions into proper state-change
verifications. Add comprehensive test suite for model-provider-page
atoms covering all four hooks, cross-hook interaction, selectAtom
granularity, and Provider isolation.
2026-03-05 22:49:09 +08:00
yyh
56e0dc0ae6 trigger ci
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
2026-03-05 21:22:03 +08:00
yyh
975eca00c3 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 20:25:53 +08:00
yyh
f049bafcc3 refactor: simplify Jotai atoms by removing redundant write-only atoms
Replace 2 write-only derived atoms with primitive atom's built-in
updater functions. The selectAtom on the read side already prevents
unnecessary re-renders, making the manual guard logic redundant.
2026-03-05 20:25:29 +08:00
yyh
922dc71e36 fix 2026-03-05 16:17:38 +08:00
yyh
f03ec7f671 Merge branch 'main' into feat/model-provider-refactor 2026-03-05 16:14:36 +08:00
yyh
29f275442d Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx
#	web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx
2026-03-05 16:13:40 +08:00
yyh
c9532ffd43 add stories 2026-03-05 15:55:21 +08:00
yyh
840dc33b8b Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 15:12:32 +08:00
yyh
cae58a0649 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 15:08:13 +08:00
yyh
1752edc047 refactor(web): optimize model provider re-render and remove useEffect state sync
- Replace useEffect state sync with derived state pattern in useSystemDefaultModelAndModelList
- Use useCallback instead of useMemo for function memoization in useProviderCredentialsAndLoadBalancing
- Add memo() to ProviderAddedCard and CredentialPanel to prevent unnecessary re-renders
- Switch to useProviderContextSelector for precise context subscription in ProviderAddedCard
- Stabilize activate callback ref in useActivateCredential via supportedModelTypes ref
- Add usage priority tooltip with i18n support
2026-03-05 15:07:53 +08:00
yyh
7471c32612 Revert "temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing"
This reverts commit ab87ac333a.
2026-03-05 14:33:48 +08:00
yyh
2d333bbbe5 refactor(web): extract credential activation into hook and migrate credential-item overlays
Extract credential switching logic from dropdown-content into a dedicated
useActivateCredential hook with optimistic updates and proper data flow
separation. Credential items now stay visible in the popover after clicking
(no auto-close), show cursor-pointer, and disable during activation.

Migrate credential-item from legacy Tooltip and remixicon imports to
base-ui Tooltip and CSS icon classes, pruning stale ESLint suppressions.
2026-03-05 14:22:39 +08:00
yyh
4af6788ce0 fix(web): wrap Header test in Dialog context for base-ui compatibility 2026-03-05 14:20:35 +08:00
yyh
24b072def9 fix: lint 2026-03-05 14:08:20 +08:00
yyh
909c8c3350 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 13:58:51 +08:00
yyh
80e9c8bee0 refactor(web): make account setting fully controlled with action props 2026-03-05 13:39:36 +08:00
yyh
15b7b304d2 refactor(web): migrate model-modal overlays to base-ui Dialog and AlertDialog
Replace legacy PortalToFollowElem and Confirm with Dialog/AlertDialog
primitives. Remove manual ESC handler and backdrop div — now handled
natively by base-ui. Add backdropProps={{ forceRender: true }} for
correct nested overlay rendering.
2026-03-05 13:33:53 +08:00
yyh
61e2672b59 refactor(web): make provider reset event-driven and scope model invalidation
- remove provider-page lifecycle reset effect and handle reset in explicit tab/close actions
- switch account setting tab state to controlled/uncontrolled pattern without sync effect
- use provider-scoped model list queryKey with exact invalidation in credential and model toggle mutations
- update related tests and mocks for new behavior
2026-03-05 13:28:30 +08:00
yyh
5f4ed4c6f6 refactor(web): replace model provider emitter refresh with jotai state
- add atom-based provider expansion state with reset/prune helpers
- remove event-emitter dependency from model provider refresh flow
- invalidate exact provider model-list query key on refresh
- reset expansion state on model provider page mount/unmount
- update and extend tests for external expansion and query invalidation
- update eslint suppressions to match current code
2026-03-05 13:20:58 +08:00
yyh
4a1032c628 fix(web): remove redundant hover text swap on show models button
Merge the two hover-toggling divs into a single always-visible element
and remove the unused showModelsNum i18n key from all locales.
2026-03-05 13:16:04 +08:00
yyh
423c97a47e code style 2026-03-05 13:09:33 +08:00
yyh
a7e3fb2e33 fix(web): use triangle Warning icon instead of circle error icon
Replace i-ri-error-warning-fill (circle exclamation) with the
Warning component (triangle) for api-fallback and credits-fallback
variants to match Figma design.
2026-03-05 13:07:20 +08:00
yyh
ce34937a1c feat(web): add credits-fallback variant for API Key priority with available credits
When API Key is selected but unavailable/unconfigured and credits are
available, the card now shows "AI credits in use" with a warning icon
instead of "API key required". When both credits are exhausted and no
API key exists, it shows "No available usage" (destructive).

New deriveVariant logic for priority=apiKey:
- !exhausted + !authorized → credits-fallback (was api-required-*)
- exhausted + no credential → no-usage (was api-required-add)
- exhausted + named unauthorized → api-unavailable (unchanged)
2026-03-05 13:02:40 +08:00
yyh
ad9ac6978e fix(web): align alert card width with API key section in dropdown
Change mx-1 (4px) to mx-2 (8px) on CreditsFallbackAlert and
CreditsExhaustedAlert to match ApiKeySection's p-2 (8px) padding,
consistent with Figma design where both sections are 8px from the
dropdown edge.
2026-03-05 12:56:55 +08:00
yyh
57c1ba3543 fix(web): hide divider above empty API keys state in dropdown
Move the border from UsagePrioritySection (always visible) to
ApiKeySection's list variant (only when credentials exist). This
removes the unwanted divider line above the "No API Keys" empty
state card when on the AI Credits tab with no keys configured.
2026-03-05 12:25:11 +08:00
yyh
d7a5af2b9a Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/index.tsx
2026-03-05 10:46:24 +08:00
yyh
d45edffaa3 fix(web): wire upgrade link to pricing modal and add credits-coin icon
Replace broken HTML string interpolation with Trans component and
useModalContextSelector so "upgrade your plan" opens the pricing modal.
Add custom credits-coin SVG icon to replace the generic ri-coin-line.
2026-03-05 10:39:31 +08:00
yyh
530515b6ef fix(web): prevent model list from expanding on priority switch
Remove UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST event emission from
changePriority onSuccess. This event was designed for custom model
add/edit/delete scenarios where the card should expand, but firing
it on priority switch caused ProviderAddedCard to unexpectedly
expand via refreshModelList → setCollapsed(false).
2026-03-05 10:35:03 +08:00
yyh
f13f0d1f9a fix(web): align dropdown alerts with Figma design and fix hardcoded credits total
- Expose totalCredits from useTrialCredits hook instead of hardcoding 10,000
- Align CreditsExhaustedAlert with Figma: dynamic progress bar, correct
  design tokens (components-progress-error-bg/progress), sm-medium/xs-regular
  typography
- Align CreditsFallbackAlert typography to sm-medium/xs-regular
- Fix ApiKeySection empty state: horizontal gradient, sm-medium title,
  Figma-aligned padding (pl-7 for API KEYS label)
- Hoist empty credentials array constant to stabilize memo (rerender-memo-with-default-value)
- Remove redundant useCallback wrapper in ApiKeySection
- Replace nested ternary with Record lookup in TextLabel
- Remove dead || 0 guard in useTrialCredits
- Update all test mocks with totalCredits field
2026-03-05 10:09:51 +08:00
yyh
b597d52c11 refactor(web): remove dialog description from system model selector
Remove the DialogDescription and its i18n key (modelProvider.systemModelSettingsLink)
from the system model settings dialog across all 23 locales.
2026-03-05 10:05:01 +08:00
yyh
34c42fe666 Revert "temp: remove cloud condition"
This reverts commit 29e344ac8b.
2026-03-05 09:44:19 +08:00
yyh
dc109c99f0 test(web): expand credential panel and dropdown test coverage for all 8 card variants
Add comprehensive behavioral tests covering all discriminated union variants,
destructive/default styling, warning icons, CreditsFallbackAlert conditions,
credential CRUD interactions, AlertDialog delete confirmation, and Popover behavior.
2026-03-05 09:41:48 +08:00
yyh
223b9d89c1 refactor(web): migrate priority change to oRPC contract with useMutation
- Add changePreferredProviderType contract in model-providers.ts
- Register in consoleRouterContract
- Replace raw async changeModelProviderPriority with useMutation
- Use Toast.notify (static API) instead of useToastContext hook
- Pass isPending as isChangingPriority to disable buttons during switch
- Add disabled prop to UsagePrioritySection
- Fix pre-existing test assertions for api-unavailable variant
- Update all specs with isChangingPriority prop and oRPC mock pattern
2026-03-05 09:30:38 +08:00
yyh
dd119eb44f fix(web): align UsagePrioritySection with Figma design and fix i18n key ordering
- Single-row layout for icon, label, and option cards
- Icon: arrow-up-double-line matching design spec
- Buttons: flexible width with whitespace-nowrap instead of fixed w-[72px]
- Add min-w-0 + truncate for text overflow, focus-visible ring for a11y
- Sort modelProvider.card.* i18n keys alphabetically
2026-03-05 09:15:16 +08:00
yyh
970493fa85 test(web): update tests for credential panel refactoring and new ModelAuthDropdown components
Rewrite credential-panel.spec.tsx to match the new discriminated union
state model and variant-driven rendering. Add new test files for
useCredentialPanelState hook, SystemQuotaCard Label enhancement,
and all ModelAuthDropdown sub-components.
2026-03-05 08:41:17 +08:00
yyh
ab87ac333a temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing 2026-03-05 08:34:10 +08:00
yyh
b8b70da9ad refactor(web): rewrite CredentialPanel with declarative variant-driven state and new ModelAuthDropdown
- Extract useCredentialPanelState hook with discriminated union CardVariant type replacing scattered boolean conditions
- Create ModelAuthDropdown compound component (Popover-based) with UsagePrioritySection, CreditsExhaustedAlert, and ApiKeySection
- Enhance SystemQuotaCard.Label to accept className override for flexible styling
- Add i18n keys for new card states and dropdown content (en-US, zh-Hans)
2026-03-05 08:33:04 +08:00
yyh
77d81aebe8 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-04 23:35:20 +08:00
yyh
deb4cd3ece fix: i18n 2026-03-04 23:35:13 +08:00
yyh
648d9ef1f9 refactor(web): extract SystemQuotaCard compound component and shared useTrialCredits hook
Extract trial credits calculation into a shared useTrialCredits hook to prevent
logic drift between QuotaPanel and CredentialPanel. Add SystemQuotaCard compound
component with explicit default/destructive variants for the system quota UI
state in provider cards, replacing inline conditional styling with composable
Label and Actions slots. Remove unnecessary useMemo for simple derived values.
2026-03-04 23:30:25 +08:00
yyh
5ed4797078 fix 2026-03-04 22:53:29 +08:00
yyh
62631658e9 fix(web): update tests for AlertDialog migration and component API changes
- Replace deprecated Confirm mock with real AlertDialog role-based queries
- Add useInvalidateCheckInstalled mock for QueryClient dependency
- Wrap model-list-item renders in QueryClientProvider
- Migrate PluginVersionPicker from PortalToFollowElem to Popover
- Migrate UpdatePluginModal from Modal to Dialog
- Update version picker offset props (sideOffset/alignOffset)
2026-03-04 22:52:21 +08:00
yyh
22a4100dd7 fix(web): invalidate plugin checkInstalled cache after version updates 2026-03-04 22:33:17 +08:00
yyh
0f7ed6f67e refactor(web): align provider badges with figma and remove dead add-model-button 2026-03-04 22:29:51 +08:00
yyh
4d9fcbec57 refactor(web): migrate remove-plugin dialog to base UI AlertDialog and improve UX
- Replace deprecated Confirm component with AlertDialog primitives
- Add forceRender backdrop for proper overlay rendering
- Add success Toast notification after plugin removal
- Update "View Detail" text to "View on Marketplace" (en/zh-Hans)
- Add i18n keys for delete success message
- Prune stale eslint suppression for header-modals
2026-03-04 22:14:19 +08:00
yyh
4d7a9bc798 fix(web): align model provider cache invalidation with oRPC keys 2026-03-04 22:06:27 +08:00
yyh
d6d04ed657 fix 2026-03-04 22:03:06 +08:00
yyh
f594a71dae fix: icon 2026-03-04 22:02:36 +08:00
yyh
04e0ab7eda refactor(web): migrate provider-added-card model list to oRPC query-driven state 2026-03-04 21:55:34 +08:00
yyh
784bda9c86 refactor(web): migrate operation-dropdown to base UI and align provider card styles with Figma
- Migrate OperationDropdown from legacy portal-to-follow-elem to base UI DropdownMenu primitives
- Add placement, sideOffset, alignOffset, popupClassName props for flexible positioning
- Fix version badge font size: system-2xs-medium-uppercase (10px) → system-xs-medium-uppercase (12px)
- Set provider card dropdown to bottom-start placement with 192px width per Figma spec
- Fix PluginVersionPicker toggle: clicking badge now opens and closes the picker
- Add max-h-[224px] overflow scroll to version list
- Replace Remix icon imports with Tailwind CSS icon classes
- Prune stale eslint suppressions for migrated files
2026-03-04 21:55:23 +08:00
yyh
1af1fb6913 feat(web): add version badge and actions menu to provider cards
Integrate plugin version management into model provider cards by
reusing existing plugin detail panel hooks and components. Batch
query installed plugins at list level to avoid N+1 requests.
2026-03-04 21:29:52 +08:00
yyh
1f0c36e9f7 fix: style 2026-03-04 21:07:42 +08:00
yyh
455ae65025 fix: style 2026-03-04 20:58:14 +08:00
yyh
d44682e957 refactor(web): align quota panel with Figma design and migrate to base UI tooltip
- Rename title from "Quota" to "AI Credits" and update tooltip copy
  (Message Credits → AI Credits, free → Trial)
- Show "Credits exhausted" in destructive text when credits reach zero
  instead of displaying the number "0"
- Migrate from deprecated Tooltip to base UI Tooltip compound component
- Add 4px grid background with radial fade mask via CSS module
- Simplify provider icon tooltip text for uninstalled state
- Update i18n keys for both en-US and zh-Hans
2026-03-04 20:52:30 +08:00
yyh
8c4afc0c18 fix(model-selector): align empty trigger with default trigger style 2026-03-04 20:14:49 +08:00
yyh
539cbcae6a fix(account-settings): render nested system model backdrop via base ui 2026-03-04 19:57:53 +08:00
yyh
8d257fea7c chore(web): commit dialog overlay follow-up changes 2026-03-04 19:37:10 +08:00
yyh
c3364ac350 refactor(web): align account settings dialogs with base UI 2026-03-04 19:31:14 +08:00
yyh
f991644989 refactor(pricing): migrate to base ui dialog and extract category types 2026-03-04 19:26:54 +08:00
yyh
29e344ac8b temp: remove cloud condition 2026-03-04 18:50:38 +08:00
yyh
1ad9305732 fix(web): avoid quota panel flicker on account-setting tab switch
- remove mount-time workspace invalidate in model provider page

- read quota with useCurrentWorkspace and keep loading only for initial empty fetch

- reuse existing useSystemFeaturesQuery for marketplace and trial models

- update model provider and quota panel tests for new query/loading behavior
2026-03-04 18:43:01 +08:00
yyh
17f38f171d lint 2026-03-04 18:21:59 +08:00
yyh
802088c8eb test(web): fix trivial assertion and add useInvalidateDefaultModel tests
Replace the no-provider test assertion from checking a nonexistent i18n
key to verifying actual warning keys are absent. Add unit tests for
useInvalidateDefaultModel following the useUpdateModelList pattern.
2026-03-04 17:51:20 +08:00
yyh
cad6d94491 refactor(web): replace remixicon imports with Tailwind CSS icons in system-model-selector 2026-03-04 17:45:41 +08:00
yyh
621d0fb2c9 fix 2026-03-04 17:42:34 +08:00
yyh
a92fb3244b fix(web): skip top warning for no-provider state and remove unused i18n key
The empty state card below already prompts users to install a provider,
so the top warning bar is redundant for the no-provider case. Remove
the unused noProviderInstalled i18n key and replace the lookup map with
a ternary to preserve i18n literal types without assertions.
2026-03-04 17:39:49 +08:00
yyh
97508f8d7b fix(web): invalidate default model cache after saving system model settings
After saving system models, only the model list cache was invalidated
but not the default model cache, causing stale config status in the UI.
Add useInvalidateDefaultModel hook and call it for all 5 model types
after a successful save.
2026-03-04 17:26:24 +08:00
yyh
70e677a6ac feat(web): refine system model settings to 4 distinct config states
Replace the single `defaultModelNotConfigured` boolean with a derived
`systemModelConfigStatus` that distinguishes between no-provider,
none-configured, partially-configured, and fully-configured states,
each showing a context-appropriate warning message. Also updates the
button label from "System Model Settings" to "Default Model Settings"
and migrates remixicon imports to Tailwind CSS icon classes.
2026-03-04 16:58:46 +08:00
281 changed files with 7403 additions and 24231 deletions

View File

@@ -25,10 +25,6 @@ updates:
interval: "weekly"
open-pull-requests-limit: 2
groups:
lexical:
patterns:
- "lexical"
- "@lexical/*"
storybook:
patterns:
- "storybook"
@@ -37,7 +33,5 @@ updates:
patterns:
- "*"
exclude-patterns:
- "lexical"
- "@lexical/*"
- "storybook"
- "@storybook/*"

View File

@@ -62,22 +62,6 @@ This is the default standard for backend code in this repo. Follow it for new co
- Code should usually include type annotations that match the repos current Python version (avoid untyped public APIs and “mystery” values).
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless theres a strong reason.
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
```python
from datetime import datetime
from typing import NotRequired, TypedDict
class UserProfile(TypedDict):
user_id: str
email: str
created_at: datetime
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python

View File

@@ -30,7 +30,6 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from libs.db_migration_lock import DbMigrationAutoRenewLock
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
@@ -937,12 +936,6 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
is_flag=True,
help="Preview cleanup results without deleting any workflow run data.",
)
@click.option(
"--task-label",
default="daily",
show_default=True,
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
)
def clean_workflow_runs(
before_days: int,
batch_size: int,
@@ -951,13 +944,10 @@ def clean_workflow_runs(
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
dry_run: bool,
task_label: str,
):
"""
Clean workflow runs and related workflow data for free tenants.
"""
from extensions.otel.runtime import flush_telemetry
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
@@ -977,17 +967,13 @@ def clean_workflow_runs(
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
try:
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
task_label=task_label,
).run()
finally:
flush_telemetry()
WorkflowRunCleanup(
days=before_days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
@@ -2612,29 +2598,15 @@ def migrate_oss(
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=False,
default=None,
required=True,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=False,
default=None,
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--from-days-ago",
type=int,
default=None,
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
)
@click.option(
"--before-days",
type=int,
default=None,
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
@@ -2643,99 +2615,33 @@ def migrate_oss(
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
@click.option(
"--task-label",
default="daily",
show_default=True,
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
)
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
from_days_ago: int | None,
before_days: int | None,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
task_label: str,
):
"""
Clean expired messages and related data for tenants based on clean policy.
"""
from extensions.otel.runtime import flush_telemetry
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
abs_mode = start_from is not None and end_before is not None
rel_mode = before_days is not None
if abs_mode and rel_mode:
raise click.UsageError(
"Options are mutually exclusive: use either (--start-from,--end-before) "
"or (--from-days-ago,--before-days)."
)
if from_days_ago is not None and before_days is None:
raise click.UsageError("--from-days-ago must be used together with --before-days.")
if (start_from is None) ^ (end_before is None):
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
if not abs_mode and not rel_mode:
raise click.UsageError(
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
)
if rel_mode:
assert before_days is not None
if before_days < 0:
raise click.UsageError("--before-days must be >= 0.")
if from_days_ago is not None:
if from_days_ago < 0:
raise click.UsageError("--from-days-ago must be >= 0.")
if from_days_ago <= before_days:
raise click.UsageError("--from-days-ago must be greater than --before-days.")
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
if abs_mode:
assert start_from is not None
assert end_before is not None
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
elif from_days_ago is None:
assert before_days is not None
service = MessagesCleanService.from_days(
policy=policy,
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
else:
assert before_days is not None
assert from_days_ago is not None
now = naive_utc_now()
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=now - datetime.timedelta(days=from_days_ago),
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()
@@ -2760,81 +2666,5 @@ def clean_expired_messages(
)
)
raise
finally:
flush_telemetry()
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--filename",
required=True,
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
)
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
try:
validated_filename = AppMessageExportService.validate_export_filename(filename)
except ValueError as e:
raise click.BadParameter(str(e), param_hint="--filename") from e
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=validated_filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -44,13 +44,14 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.app.task_pipeline.message_file_utils import prepare_file_dict
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
@@ -459,40 +460,91 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
# Fetch files associated with this message
files = None
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if message_files:
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
upload_file_ids = list(
dict.fromkeys(
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
)
)
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
files_list = []
for message_file in message_files:
file_dict = prepare_file_dict(message_file, upload_files_map)
files_list.append(file_dict)
files = files_list or None
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=metadata_dict,
files=files,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@@ -1,6 +1,7 @@
import hashlib
import logging
from threading import Thread, Timer
import time
from threading import Thread
from typing import Union
from flask import Flask, current_app
@@ -95,9 +96,9 @@ class MessageCycleManager:
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
thread = Timer(
1,
self._generate_conversation_name_worker,
time.sleep(1)
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation_id,

View File

@@ -1,76 +0,0 @@
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
"""
Prepare file dictionary for message end stream response.
:param message_file: MessageFile instance
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
:return: Dictionary containing file information
"""
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
return {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}

View File

@@ -194,13 +194,6 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Create a new database session
with self._session_factory() as session:
existing_model = session.get(WorkflowRun, db_model.id)
if existing_model:
if existing_model.tenant_id != self._tenant_id:
raise ValueError("Unauthorized access to workflow run")
# Preserve the original start time for pause/resume flows.
db_model.created_at = existing_model.created_at
# SQLAlchemy merge intelligently handles both insert and update operations
# based on the presence of the primary key
session.merge(db_model)

View File

@@ -37,7 +37,6 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT,
}

View File

@@ -4,7 +4,6 @@ import json
import logging
import os
import tempfile
import zipfile
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
@@ -83,18 +82,8 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
value = variable.value
inputs = {"variable_selector": variable_selector}
if isinstance(value, list):
value = list(filter(lambda x: x, value))
process_data = {"documents": value if isinstance(value, list) else [value]}
if not value:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": ArrayStringSegment(value=[])},
)
try:
if isinstance(value, list):
extracted_text_list = [
@@ -122,7 +111,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
@@ -397,32 +385,6 @@ def parser_docx_part(block, doc: Document, content_items, i):
content_items.append((i, "table", Table(block, doc)))
def _normalize_docx_zip(file_content: bytes) -> bytes:
"""
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
ZIP entry names use backslash (\\) as path separator instead of the forward
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
"word\\document.xml" is never found when python-docx looks for
"word/document.xml", which triggers a KeyError about a missing relationship.
This function rewrites the ZIP in-memory, normalizing all entry names to
use forward slashes without touching any actual document content.
"""
try:
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
out_buf = io.BytesIO()
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
for item in zin.infolist():
data = zin.read(item.filename)
# Normalize backslash path separators to forward slash
item.filename = item.filename.replace("\\", "/")
zout.writestr(item, data)
return out_buf.getvalue()
except zipfile.BadZipFile:
# Not a valid zip — return as-is and let python-docx report the real error
return file_content
def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
@@ -430,15 +392,7 @@ def _extract_text_from_docx(file_content: bytes) -> str:
"""
try:
doc_file = io.BytesIO(file_content)
try:
doc = docx.Document(doc_file)
except Exception as e:
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
# Some DOCX files exported by tools like Evernote on Windows use
# backslash path separators in ZIP entries and/or single-quoted XML
# attributes, both of which break python-docx on Linux. Normalize and retry.
file_content = _normalize_docx_zip(file_content)
doc = docx.Document(io.BytesIO(file_content))
doc = docx.Document(doc_file)
text = []
# Keep track of paragraph and table positions

View File

@@ -23,11 +23,7 @@ from dify_graph.variables import (
)
from dify_graph.variables.segments import ArrayObjectSegment
from .entities import (
Condition,
KnowledgeRetrievalNodeData,
MetadataFilteringCondition,
)
from .entities import KnowledgeRetrievalNodeData
from .exc import (
KnowledgeRetrievalNodeError,
RateLimitExceededError,
@@ -175,12 +171,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
resolved_metadata_conditions = (
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
if node_data.metadata_filtering_conditions
else None
)
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
@@ -199,7 +189,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
)
@@ -257,7 +247,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
@@ -266,48 +256,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
def _resolve_metadata_filtering_conditions(
self, conditions: MetadataFilteringCondition
) -> MetadataFilteringCondition:
if conditions.conditions is None:
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator,
conditions=None,
)
variable_pool = self.graph_runtime_state.variable_pool
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
else:
resolved_value = value
resolved_conditions.append(
Condition(
name=cond.name,
comparison_operator=cond.comparison_operator,
value=resolved_value,
)
)
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator or "and",
conditions=resolved_conditions,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -13,7 +13,6 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -67,7 +66,6 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -5,7 +5,7 @@ from typing import Union
from celery.signals import worker_init
from flask_login import user_loaded_from_request, user_logged_in
from opentelemetry import metrics, trace
from opentelemetry import trace
from opentelemetry.propagate import set_global_textmap
from opentelemetry.propagators.b3 import B3Format
from opentelemetry.propagators.composite import CompositePropagator
@@ -31,29 +31,9 @@ def setup_context_propagation() -> None:
def shutdown_tracer() -> None:
flush_telemetry()
def flush_telemetry() -> None:
"""
Best-effort flush for telemetry providers.
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
so counters/histograms are exported before the process exits.
"""
provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"):
try:
provider.force_flush()
except Exception:
logger.exception("otel: failed to flush trace provider")
metric_provider = metrics.get_meter_provider()
if hasattr(metric_provider, "force_flush"):
try:
metric_provider.force_flush()
except Exception:
logger.exception("otel: failed to flush metric provider")
provider.force_flush()
def is_celery_worker():

View File

@@ -66,7 +66,6 @@ def run_migrations_offline():
context.configure(
url=url, target_metadata=get_metadata(), literal_binds=True
)
logger.info("Generating offline migration SQL with url: %s", url)
with context.begin_transaction():
context.run_migrations()

View File

@@ -1,6 +1,5 @@
[pytest]
pythonpath = .
addopts = --cov=./api --cov-report=json --import-mode=importlib
addopts = --cov=./api --cov-report=json
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
@@ -20,7 +19,7 @@ env =
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a
MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa
MOCK_SWITCH = true

View File

@@ -63,12 +63,7 @@ class RagPipelineTransformService:
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if dataset.tenant_id != current_user.current_tenant_id:
raise ValueError("Unauthorized")
node = self._deal_knowledge_index(
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
)
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
new_nodes.append(node)
if new_nodes:
graph["nodes"] = new_nodes
@@ -160,13 +155,14 @@ class RagPipelineTransformService:
def _deal_knowledge_index(
self,
knowledge_configuration: KnowledgeConfiguration,
dataset: Dataset,
doc_form: str,
indexing_technique: str | None,
retrieval_model: RetrievalSetting | None,
node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model

View File

@@ -1,304 +0,0 @@
"""
Export app messages to JSONL.GZ format.
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
retriever_resources (from message_metadata), feedback (user feedbacks array).
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
Does NOT touch Message.inputs / Message.user_feedback properties.
"""
import datetime
import gzip
import json
import logging
import tempfile
from collections import defaultdict
from collections.abc import Generator, Iterable
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, cast
import orjson
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, tuple_
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message, MessageFeedback
logger = logging.getLogger(__name__)
MAX_FILENAME_BASE_LENGTH = 1024
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
class AppMessageExportFeedback(BaseModel):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: str
updated_at: str
model_config = ConfigDict(extra="forbid")
class AppMessageExportRecord(BaseModel):
conversation_id: str
message_id: str
query: str
answer: str
inputs: dict[str, Any]
retriever_resources: list[Any] = Field(default_factory=list)
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class AppMessageExportStats(BaseModel):
batches: int = 0
total_messages: int = 0
messages_with_feedback: int = 0
total_feedbacks: int = 0
model_config = ConfigDict(extra="forbid")
class AppMessageExportService:
@staticmethod
def validate_export_filename(filename: str) -> str:
normalized = filename.strip()
if not normalized:
raise ValueError("--filename must not be empty.")
normalized_lower = normalized.lower()
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
if normalized.startswith("/"):
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
if "\\" in normalized:
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
if "//" in normalized:
raise ValueError("--filename must not contain empty path segments ('//').")
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
for ch in normalized:
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
raise ValueError("--filename must not contain control characters or NUL.")
parts = PurePosixPath(normalized).parts
if not parts:
raise ValueError("--filename must include a file name.")
if any(part in (".", "..") for part in parts):
raise ValueError("--filename must not contain '.' or '..' path segments.")
return normalized
@property
def output_gz_name(self) -> str:
return f"{self._filename_base}.jsonl.gz"
@property
def output_jsonl_name(self) -> str:
return f"{self._filename_base}.jsonl"
def __init__(
self,
app_id: str,
end_before: datetime.datetime,
filename: str,
*,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
use_cloud_storage: bool = False,
dry_run: bool = False,
) -> None:
if start_from and start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
self._app_id = app_id
self._end_before = end_before
self._start_from = start_from
self._filename_base = self.validate_export_filename(filename)
self._batch_size = batch_size
self._use_cloud_storage = use_cloud_storage
self._dry_run = dry_run
def run(self) -> AppMessageExportStats:
stats = AppMessageExportStats()
logger.info(
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
self._app_id,
self._start_from,
self._end_before,
self._dry_run,
self._use_cloud_storage,
self.output_gz_name,
)
if self._dry_run:
for _ in self._iter_records_with_stats(stats):
pass
self._finalize_stats(stats)
return stats
if self._use_cloud_storage:
self._export_to_cloud(stats)
else:
self._export_to_local(stats)
self._finalize_stats(stats)
return stats
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
for batch in self._iter_record_batches():
yield from batch
@staticmethod
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
for record in records:
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
def _export_to_local(self, stats: AppMessageExportStats) -> None:
output_path = Path.cwd() / self.output_gz_name
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as output_file:
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
tmp.seek(0)
data = tmp.read()
storage.save(self.output_gz_name, data)
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
for record in self.iter_records():
self._update_stats(stats, record)
yield record
@staticmethod
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
stats.total_messages += 1
if record.feedback:
stats.messages_with_feedback += 1
stats.total_feedbacks += len(record.feedback)
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
if stats.total_messages == 0:
stats.batches = 0
return
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
cursor: tuple[datetime.datetime, str] | None = None
while True:
rows, cursor = self._fetch_batch(cursor)
if not rows:
break
message_ids = [str(row.id) for row in rows]
feedbacks_map = self._fetch_feedbacks(message_ids)
yield [self._build_record(row, feedbacks_map) for row in rows]
def _fetch_batch(
self, cursor: tuple[datetime.datetime, str] | None
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(
Message.id,
Message.conversation_id,
Message.query,
Message.answer,
Message._inputs, # pyright: ignore[reportPrivateUsage]
Message.message_metadata,
Message.created_at,
)
.where(
Message.app_id == self._app_id,
Message.created_at < self._end_before,
)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
stmt = stmt.where(Message.created_at >= self._start_from)
if cursor:
stmt = stmt.where(
tuple_(Message.created_at, Message.id)
> tuple_(
sa.literal(cursor[0], type_=sa.DateTime()),
sa.literal(cursor[1], type_=Message.id.type),
)
)
rows = list(session.execute(stmt).all())
if not rows:
return [], cursor
last = rows[-1]
return rows, (last.created_at, last.id)
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
if not message_ids:
return {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(MessageFeedback)
.where(
MessageFeedback.message_id.in_(message_ids),
MessageFeedback.from_source == "user",
)
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
)
feedbacks = list(session.scalars(stmt).all())
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
for feedback in feedbacks:
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
return result
@staticmethod
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
retriever_resources: list[Any] = []
if row.message_metadata:
try:
metadata = json.loads(row.message_metadata)
value = metadata.get("retriever_resources", [])
if isinstance(value, list):
retriever_resources = value
except (json.JSONDecodeError, TypeError):
pass
message_id = str(row.id)
return AppMessageExportRecord(
conversation_id=str(row.conversation_id),
message_id=message_id,
query=row.query,
answer=row.answer,
inputs=row._inputs if isinstance(row._inputs, dict) else {},
retriever_resources=retriever_resources,
feedback=feedbacks_map.get(message_id, []),
)

View File

@@ -1,18 +1,17 @@
import datetime
import logging
import os
import random
import time
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import cast
import sqlalchemy as sa
from sqlalchemy import delete, select, tuple_
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import (
App,
AppAnnotationHitHistory,
@@ -33,128 +32,6 @@ from services.retention.conversation.messages_clean_policy import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from opentelemetry.metrics import Counter, Histogram
class MessagesCleanupMetrics:
"""
Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs.
We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain
dashboard-friendly for long-running CronJob executions.
"""
_job_runs_total: "Counter | None"
_batches_total: "Counter | None"
_messages_scanned_total: "Counter | None"
_messages_filtered_total: "Counter | None"
_messages_deleted_total: "Counter | None"
_job_duration_seconds: "Histogram | None"
_batch_duration_seconds: "Histogram | None"
_base_attributes: dict[str, str]
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
self._job_runs_total = None
self._batches_total = None
self._messages_scanned_total = None
self._messages_filtered_total = None
self._messages_deleted_total = None
self._job_duration_seconds = None
self._batch_duration_seconds = None
self._base_attributes = {
"job_name": "messages_cleanup",
"dry_run": str(dry_run).lower(),
"window_mode": "between" if has_window else "before_cutoff",
"task_label": task_label,
}
self._init_instruments()
def _init_instruments(self) -> None:
try:
from opentelemetry.metrics import get_meter
meter = get_meter("messages_cleanup", version=dify_config.project.version)
self._job_runs_total = meter.create_counter(
"messages_cleanup_jobs_total",
description="Total number of expired message cleanup jobs by status.",
unit="{job}",
)
self._batches_total = meter.create_counter(
"messages_cleanup_batches_total",
description="Total number of message cleanup batches processed.",
unit="{batch}",
)
self._messages_scanned_total = meter.create_counter(
"messages_cleanup_scanned_messages_total",
description="Total messages scanned by cleanup jobs.",
unit="{message}",
)
self._messages_filtered_total = meter.create_counter(
"messages_cleanup_filtered_messages_total",
description="Total messages selected by cleanup policy.",
unit="{message}",
)
self._messages_deleted_total = meter.create_counter(
"messages_cleanup_deleted_messages_total",
description="Total messages deleted by cleanup jobs.",
unit="{message}",
)
self._job_duration_seconds = meter.create_histogram(
"messages_cleanup_job_duration_seconds",
description="Duration of expired message cleanup jobs in seconds.",
unit="s",
)
self._batch_duration_seconds = meter.create_histogram(
"messages_cleanup_batch_duration_seconds",
description="Duration of expired message cleanup batch processing in seconds.",
unit="s",
)
except Exception:
logger.exception("messages_cleanup_metrics: failed to initialize instruments")
def _attrs(self, **extra: str) -> dict[str, str]:
return {**self._base_attributes, **extra}
@staticmethod
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
if not counter or value <= 0:
return
try:
counter.add(value, attributes)
except Exception:
logger.exception("messages_cleanup_metrics: failed to add counter value")
@staticmethod
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
if not histogram:
return
try:
histogram.record(value, attributes)
except Exception:
logger.exception("messages_cleanup_metrics: failed to record histogram value")
def record_batch(
self,
*,
scanned_messages: int,
filtered_messages: int,
deleted_messages: int,
batch_duration_seconds: float,
) -> None:
attributes = self._attrs()
self._add(self._batches_total, 1, attributes)
self._add(self._messages_scanned_total, scanned_messages, attributes)
self._add(self._messages_filtered_total, filtered_messages, attributes)
self._add(self._messages_deleted_total, deleted_messages, attributes)
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
attributes = self._attrs(status=status)
self._add(self._job_runs_total, 1, attributes)
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
class MessagesCleanService:
"""
Service for cleaning expired messages based on retention policies.
@@ -170,7 +47,6 @@ class MessagesCleanService:
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
dry_run: bool = False,
task_label: str = "daily",
) -> None:
"""
Initialize the service with cleanup parameters.
@@ -181,20 +57,12 @@ class MessagesCleanService:
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
task_label: Stable task label to distinguish multiple cleanup CronJobs
"""
self._policy = policy
self._end_before = end_before
self._start_from = start_from
self._batch_size = batch_size
self._dry_run = dry_run
normalized_task_label = task_label.strip()
self._task_label = normalized_task_label or "daily"
self._metrics = MessagesCleanupMetrics(
dry_run=dry_run,
has_window=bool(start_from),
task_label=self._task_label,
)
@classmethod
def from_time_range(
@@ -204,7 +72,6 @@ class MessagesCleanService:
end_before: datetime.datetime,
batch_size: int = 1000,
dry_run: bool = False,
task_label: str = "daily",
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages within a specific time range.
@@ -217,7 +84,6 @@ class MessagesCleanService:
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
task_label: Stable task label to distinguish multiple cleanup CronJobs
Returns:
MessagesCleanService instance
@@ -245,7 +111,6 @@ class MessagesCleanService:
start_from=start_from,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
@classmethod
@@ -255,7 +120,6 @@ class MessagesCleanService:
days: int = 30,
batch_size: int = 1000,
dry_run: bool = False,
task_label: str = "daily",
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages older than specified days.
@@ -265,7 +129,6 @@ class MessagesCleanService:
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
task_label: Stable task label to distinguish multiple cleanup CronJobs
Returns:
MessagesCleanService instance
@@ -279,7 +142,7 @@ class MessagesCleanService:
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
end_before = naive_utc_now() - datetime.timedelta(days=days)
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
logger.info(
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
@@ -289,14 +152,7 @@ class MessagesCleanService:
policy.__class__.__name__,
)
return cls(
policy=policy,
end_before=end_before,
start_from=None,
batch_size=batch_size,
dry_run=dry_run,
task_label=task_label,
)
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
def run(self) -> dict[str, int]:
"""
@@ -305,18 +161,7 @@ class MessagesCleanService:
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
status = "success"
run_start = time.monotonic()
try:
return self._clean_messages_by_time_range()
except Exception:
status = "failed"
raise
finally:
self._metrics.record_completion(
status=status,
job_duration_seconds=time.monotonic() - run_start,
)
return self._clean_messages_by_time_range()
def _clean_messages_by_time_range(self) -> dict[str, int]:
"""
@@ -351,14 +196,11 @@ class MessagesCleanService:
self._end_before,
)
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
while True:
stats["batches"] += 1
batch_start = time.monotonic()
batch_scanned_messages = 0
batch_filtered_messages = 0
batch_deleted_messages = 0
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
@@ -397,16 +239,9 @@ class MessagesCleanService:
# Track total messages fetched across all batches
stats["total_messages"] += len(messages)
batch_scanned_messages = len(messages)
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
break
# Update cursor to the last message's (created_at, id)
@@ -432,12 +267,6 @@ class MessagesCleanService:
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
# Build app_id -> tenant_id mapping
@@ -456,16 +285,9 @@ class MessagesCleanService:
if not message_ids_to_delete:
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
stats["filtered_messages"] += len(message_ids_to_delete)
batch_filtered_messages = len(message_ids_to_delete)
# Step 4: Batch delete messages and their relations
if not self._dry_run:
@@ -486,7 +308,6 @@ class MessagesCleanService:
commit_ms = int((time.monotonic() - commit_start) * 1000)
stats["total_deleted"] += messages_deleted
batch_deleted_messages = messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s messages",
@@ -521,13 +342,6 @@ class MessagesCleanService:
for msg_id in sampled_ids:
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
self._metrics.record_batch(
scanned_messages=batch_scanned_messages,
filtered_messages=batch_filtered_messages,
deleted_messages=batch_deleted_messages,
batch_duration_seconds=time.monotonic() - batch_start,
)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
stats["batches"],

View File

@@ -1,9 +1,9 @@
import datetime
import logging
import os
import random
import time
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
import click
from sqlalchemy.orm import Session, sessionmaker
@@ -20,156 +20,6 @@ from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from opentelemetry.metrics import Counter, Histogram
class WorkflowRunCleanupMetrics:
"""
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status)
to keep dashboard and alert cardinality predictable in production clusters.
"""
_job_runs_total: "Counter | None"
_batches_total: "Counter | None"
_runs_scanned_total: "Counter | None"
_runs_targeted_total: "Counter | None"
_runs_deleted_total: "Counter | None"
_runs_skipped_total: "Counter | None"
_related_records_total: "Counter | None"
_job_duration_seconds: "Histogram | None"
_batch_duration_seconds: "Histogram | None"
_base_attributes: dict[str, str]
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
self._job_runs_total = None
self._batches_total = None
self._runs_scanned_total = None
self._runs_targeted_total = None
self._runs_deleted_total = None
self._runs_skipped_total = None
self._related_records_total = None
self._job_duration_seconds = None
self._batch_duration_seconds = None
self._base_attributes = {
"job_name": "workflow_run_cleanup",
"dry_run": str(dry_run).lower(),
"window_mode": "between" if has_window else "before_cutoff",
"task_label": task_label,
}
self._init_instruments()
def _init_instruments(self) -> None:
try:
from opentelemetry.metrics import get_meter
meter = get_meter("workflow_run_cleanup", version=dify_config.project.version)
self._job_runs_total = meter.create_counter(
"workflow_run_cleanup_jobs_total",
description="Total number of workflow run cleanup jobs by status.",
unit="{job}",
)
self._batches_total = meter.create_counter(
"workflow_run_cleanup_batches_total",
description="Total number of processed cleanup batches.",
unit="{batch}",
)
self._runs_scanned_total = meter.create_counter(
"workflow_run_cleanup_scanned_runs_total",
description="Total workflow runs scanned by cleanup jobs.",
unit="{run}",
)
self._runs_targeted_total = meter.create_counter(
"workflow_run_cleanup_targeted_runs_total",
description="Total workflow runs targeted by cleanup policy.",
unit="{run}",
)
self._runs_deleted_total = meter.create_counter(
"workflow_run_cleanup_deleted_runs_total",
description="Total workflow runs deleted by cleanup jobs.",
unit="{run}",
)
self._runs_skipped_total = meter.create_counter(
"workflow_run_cleanup_skipped_runs_total",
description="Total workflow runs skipped because tenant is paid/unknown.",
unit="{run}",
)
self._related_records_total = meter.create_counter(
"workflow_run_cleanup_related_records_total",
description="Total related records processed by cleanup jobs.",
unit="{record}",
)
self._job_duration_seconds = meter.create_histogram(
"workflow_run_cleanup_job_duration_seconds",
description="Duration of workflow run cleanup jobs in seconds.",
unit="s",
)
self._batch_duration_seconds = meter.create_histogram(
"workflow_run_cleanup_batch_duration_seconds",
description="Duration of workflow run cleanup batch processing in seconds.",
unit="s",
)
except Exception:
logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments")
def _attrs(self, **extra: str) -> dict[str, str]:
return {**self._base_attributes, **extra}
@staticmethod
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
if not counter or value <= 0:
return
try:
counter.add(value, attributes)
except Exception:
logger.exception("workflow_run_cleanup_metrics: failed to add counter value")
@staticmethod
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
if not histogram:
return
try:
histogram.record(value, attributes)
except Exception:
logger.exception("workflow_run_cleanup_metrics: failed to record histogram value")
def record_batch(
self,
*,
batch_rows: int,
targeted_runs: int,
skipped_runs: int,
deleted_runs: int,
related_counts: dict[str, int] | None,
related_action: str | None,
batch_duration_seconds: float,
) -> None:
attributes = self._attrs()
self._add(self._batches_total, 1, attributes)
self._add(self._runs_scanned_total, batch_rows, attributes)
self._add(self._runs_targeted_total, targeted_runs, attributes)
self._add(self._runs_skipped_total, skipped_runs, attributes)
self._add(self._runs_deleted_total, deleted_runs, attributes)
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
if not related_counts or not related_action:
return
for record_type, count in related_counts.items():
self._add(
self._related_records_total,
count,
self._attrs(action=related_action, record_type=record_type),
)
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
attributes = self._attrs(status=status)
self._add(self._job_runs_total, 1, attributes)
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
class WorkflowRunCleanup:
def __init__(
self,
@@ -179,7 +29,6 @@ class WorkflowRunCleanup:
end_before: datetime.datetime | None = None,
workflow_run_repo: APIWorkflowRunRepository | None = None,
dry_run: bool = False,
task_label: str = "daily",
):
if (start_from is None) ^ (end_before is None):
raise ValueError("start_from and end_before must be both set or both omitted.")
@@ -197,13 +46,6 @@ class WorkflowRunCleanup:
self.batch_size = batch_size
self._cleanup_whitelist: set[str] | None = None
self.dry_run = dry_run
normalized_task_label = task_label.strip()
self.task_label = normalized_task_label or "daily"
self._metrics = WorkflowRunCleanupMetrics(
dry_run=dry_run,
has_window=bool(start_from),
task_label=self.task_label,
)
self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
self.workflow_run_repo: APIWorkflowRunRepository
if workflow_run_repo:
@@ -232,193 +74,153 @@ class WorkflowRunCleanup:
related_totals = self._empty_related_counts() if self.dry_run else None
batch_index = 0
last_seen: tuple[datetime.datetime, str] | None = None
status = "success"
run_start = time.monotonic()
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
try:
while True:
batch_start = time.monotonic()
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
fetch_start = time.monotonic()
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
start_from=self.window_start,
end_before=self.window_end,
last_seen=last_seen,
batch_size=self.batch_size,
while True:
batch_start = time.monotonic()
fetch_start = time.monotonic()
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
start_from=self.window_start,
end_before=self.window_end,
last_seen=last_seen,
batch_size=self.batch_size,
)
if not run_rows:
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
break
batch_index += 1
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
logger.info(
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
batch_index,
len(run_rows),
int((time.monotonic() - fetch_start) * 1000),
)
tenant_ids = {row.tenant_id for row in run_rows}
filter_start = time.monotonic()
free_tenants = self._filter_free_tenants(tenant_ids)
logger.info(
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
batch_index,
len(free_tenants),
len(tenant_ids),
int((time.monotonic() - filter_start) * 1000),
)
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
paid_or_skipped = len(run_rows) - len(free_runs)
if not free_runs:
skipped_message = (
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
)
if not run_rows:
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
break
batch_index += 1
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
logger.info(
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
batch_index,
len(run_rows),
int((time.monotonic() - fetch_start) * 1000),
)
tenant_ids = {row.tenant_id for row in run_rows}
filter_start = time.monotonic()
free_tenants = self._filter_free_tenants(tenant_ids)
logger.info(
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
batch_index,
len(free_tenants),
len(tenant_ids),
int((time.monotonic() - filter_start) * 1000),
)
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
paid_or_skipped = len(run_rows) - len(free_runs)
if not free_runs:
skipped_message = (
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
)
click.echo(
click.style(
skipped_message,
fg="yellow",
)
)
self._metrics.record_batch(
batch_rows=len(run_rows),
targeted_runs=0,
skipped_runs=paid_or_skipped,
deleted_runs=0,
related_counts=None,
related_action=None,
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
total_runs_targeted += len(free_runs)
if self.dry_run:
count_start = time.monotonic()
batch_counts = self.workflow_run_repo.count_runs_with_related(
free_runs,
count_node_executions=self._count_node_executions,
count_trigger_logs=self._count_trigger_logs,
)
logger.info(
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
batch_index,
int((time.monotonic() - count_start) * 1000),
)
if related_totals is not None:
for key in related_totals:
related_totals[key] += batch_counts.get(key, 0)
sample_ids = ", ".join(run.id for run in free_runs[:5])
click.echo(
click.style(
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
fg="yellow",
)
)
logger.info(
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
batch_index,
int((time.monotonic() - batch_start) * 1000),
)
self._metrics.record_batch(
batch_rows=len(run_rows),
targeted_runs=len(free_runs),
skipped_runs=paid_or_skipped,
deleted_runs=0,
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
related_action="would_delete",
batch_duration_seconds=time.monotonic() - batch_start,
)
continue
try:
delete_start = time.monotonic()
counts = self.workflow_run_repo.delete_runs_with_related(
free_runs,
delete_node_executions=self._delete_node_executions,
delete_trigger_logs=self._delete_trigger_logs,
)
delete_ms = int((time.monotonic() - delete_start) * 1000)
except Exception:
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
raise
total_runs_deleted += counts["runs"]
click.echo(
click.style(
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
f"skipped {paid_or_skipped} paid/unknown",
fg="green",
skipped_message,
fg="yellow",
)
)
continue
total_runs_targeted += len(free_runs)
if self.dry_run:
count_start = time.monotonic()
batch_counts = self.workflow_run_repo.count_runs_with_related(
free_runs,
count_node_executions=self._count_node_executions,
count_trigger_logs=self._count_trigger_logs,
)
logger.info(
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
batch_index,
int((time.monotonic() - count_start) * 1000),
)
if related_totals is not None:
for key in related_totals:
related_totals[key] += batch_counts.get(key, 0)
sample_ids = ", ".join(run.id for run in free_runs[:5])
click.echo(
click.style(
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
fg="yellow",
)
)
logger.info(
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
batch_index,
delete_ms,
int((time.monotonic() - batch_start) * 1000),
)
self._metrics.record_batch(
batch_rows=len(run_rows),
targeted_runs=len(free_runs),
skipped_runs=paid_or_skipped,
deleted_runs=counts["runs"],
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
related_action="deleted",
batch_duration_seconds=time.monotonic() - batch_start,
continue
try:
delete_start = time.monotonic()
counts = self.workflow_run_repo.delete_runs_with_related(
free_runs,
delete_node_executions=self._delete_node_executions,
delete_trigger_logs=self._delete_trigger_logs,
)
delete_ms = int((time.monotonic() - delete_start) * 1000)
except Exception:
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
raise
# Random sleep between batches to avoid overwhelming the database
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
time.sleep(sleep_ms / 1000)
if self.dry_run:
if self.window_start:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
else:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"before {self.window_end.isoformat()}"
)
if related_totals is not None:
summary_message = (
f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
)
summary_color = "yellow"
else:
if self.window_start:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
else:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
f"before {self.window_end.isoformat()}"
)
summary_color = "white"
click.echo(click.style(summary_message, fg=summary_color))
except Exception:
status = "failed"
raise
finally:
self._metrics.record_completion(
status=status,
job_duration_seconds=time.monotonic() - run_start,
total_runs_deleted += counts["runs"]
click.echo(
click.style(
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
f"skipped {paid_or_skipped} paid/unknown",
fg="green",
)
)
logger.info(
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
batch_index,
delete_ms,
int((time.monotonic() - batch_start) * 1000),
)
# Random sleep between batches to avoid overwhelming the database
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
time.sleep(sleep_ms / 1000)
if self.dry_run:
if self.window_start:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
else:
summary_message = (
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
f"before {self.window_end.isoformat()}"
)
if related_totals is not None:
summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
summary_color = "yellow"
else:
if self.window_start:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
)
else:
summary_message = (
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
)
summary_color = "white"
click.echo(click.style(summary_message, fg=summary_color))
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
tenant_id_list = list(tenant_ids)

View File

@@ -6,6 +6,7 @@ import typing
import click
from celery import shared_task
from core.helper.marketplace import record_install_plugin_event
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
@@ -165,6 +166,7 @@ def process_tenant_plugin_autoupgrade_check_task(
# execute upgrade
new_unique_identifier = manifest.latest_package_identifier
record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",

View File

@@ -5,10 +5,14 @@ This test module validates the 400-character limit enforcement
for App descriptions across all creation and editing endpoints.
"""
import os
import sys
import pytest
# Add the API root to Python path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
class TestAppDescriptionValidationUnit:
"""Unit tests for description validation function"""

View File

@@ -10,11 +10,8 @@ more reliable and realistic test scenarios.
import logging
import os
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Protocol, TypeVar
import psycopg2
import pytest
from flask import Flask
from flask.testing import FlaskClient
@@ -34,25 +31,6 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level
logger = logging.getLogger(__name__)
class _CloserProtocol(Protocol):
"""_Closer is any type which implement the close() method."""
def close(self):
"""close the current object, release any external resouece (file, transaction, connection etc.)
associated with it.
"""
pass
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
@contextmanager
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]:
yield closer
closer.close()
class DifyTestContainers:
"""
Manages all test containers required for Dify integration tests.
@@ -119,28 +97,45 @@ class DifyTestContainers:
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
logger.info("PostgreSQL container is ready and accepting connections")
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
with _auto_close(conn):
with conn.cursor() as cursor:
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
logger.info("uuid-ossp extension installed successfully")
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
try:
import psycopg2
# NOTE: We cannot use `with conn.cursor() as cursor:` as it will wrap the statement
# inside a transaction. However, the `CREATE DATABASE` statement cannot run inside a transaction block.
with _auto_close(conn.cursor()) as cursor:
# Create plugin database for dify-plugin-daemon
logger.info("Creating plugin database...")
cursor.execute("CREATE DATABASE dify_plugin;")
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
cursor.close()
conn.close()
logger.info("uuid-ossp extension installed successfully")
except Exception as e:
logger.warning("Failed to install uuid-ossp extension: %s", e)
# Create plugin database for dify-plugin-daemon
logger.info("Creating plugin database...")
try:
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute("CREATE DATABASE dify_plugin;")
cursor.close()
conn.close()
logger.info("Plugin database created successfully")
except Exception as e:
logger.warning("Failed to create plugin database: %s", e)
# Set up storage environment variables
os.environ.setdefault("STORAGE_TYPE", "opendal")
@@ -263,16 +258,23 @@ class DifyTestContainers:
containers = [self.redis, self.postgres, self.dify_sandbox, self.dify_plugin_daemon]
for container in containers:
if container:
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
try:
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
except Exception as e:
# Log error but don't fail the test cleanup
logger.warning("Failed to stop container %s: %s", container, e)
# Stop and remove the network
if self.network:
logger.info("Removing Docker network...")
self.network.remove()
logger.info("Successfully removed Docker network")
try:
logger.info("Removing Docker network...")
self.network.remove()
logger.info("Successfully removed Docker network")
except Exception as e:
logger.warning("Failed to remove Docker network: %s", e)
self._containers_started = False
logger.info("All test containers stopped and cleaned up successfully")

View File

@@ -1,233 +0,0 @@
import datetime
import json
import uuid
from decimal import Decimal
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
class TestAppMessageExportServiceIntegration:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers: Session):
yield
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@staticmethod
def _create_app_context(session: Session) -> tuple[App, Conversation]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="tester",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
session.add(tenant)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(join)
session.flush()
app = App(
tenant_id=tenant.id,
name="export-app",
description="integration test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-4o-mini",
mode="chat",
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
session.commit()
return app, conversation
@staticmethod
def _create_message(
session: Session,
app: App,
conversation: Conversation,
created_at: datetime.datetime,
*,
query: str,
answer: str,
inputs: dict,
message_metadata: str | None,
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-4o-mini",
inputs=inputs,
query=query,
answer=answer,
message=[{"role": "assistant", "content": answer}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)
session.add(message)
session.flush()
return message
def test_iter_records_with_stats(self, db_session_with_containers: Session):
app, conversation = self._create_app_context(db_session_with_containers)
first_inputs = {
"plain": "v1",
"nested": {"a": 1, "b": [1, {"x": True}]},
"list": ["x", 2, {"y": "z"}],
}
second_inputs = {"other": "value", "items": [1, 2, 3]}
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
first_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time,
query="q1",
answer="a1",
inputs=first_inputs,
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
)
second_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time + datetime.timedelta(minutes=1),
query="q2",
answer="a2",
inputs=second_inputs,
message_metadata=None,
)
user_feedback_1 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
content="first",
from_end_user_id=conversation.from_end_user_id,
)
user_feedback_2 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
content="second",
from_end_user_id=conversation.from_end_user_id,
)
admin_feedback = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
db_session_with_containers.commit()
service = AppMessageExportService(
app_id=app.id,
start_from=base_time - datetime.timedelta(minutes=1),
end_before=base_time + datetime.timedelta(minutes=10),
filename="unused",
batch_size=1,
dry_run=True,
)
stats = AppMessageExportStats()
records = list(service._iter_records_with_stats(stats))
service._finalize_stats(stats)
assert len(records) == 2
assert records[0].message_id == first_message.id
assert records[1].message_id == second_message.id
assert records[0].inputs == first_inputs
assert records[1].inputs == second_inputs
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
assert records[1].retriever_resources == []
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
assert records[1].feedback == []
assert stats.batches == 2
assert stats.total_messages == 2
assert stats.messages_with_feedback == 1
assert stats.total_feedbacks == 2

View File

@@ -1,188 +0,0 @@
import datetime
import re
from unittest.mock import MagicMock, patch
import click
import pytest
from commands import clean_expired_messages
def _mock_service() -> MagicMock:
service = MagicMock()
service.run.return_value = {
"batches": 1,
"total_messages": 10,
"filtered_messages": 5,
"total_deleted": 5,
}
return service
def test_absolute_mode_calls_from_time_range():
policy = object()
service = _mock_service()
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
):
clean_expired_messages.callback(
batch_size=200,
graceful_period=21,
start_from=start_from,
end_before=end_before,
from_days_ago=None,
before_days=None,
dry_run=True,
task_label="daily",
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=200,
dry_run=True,
task_label="daily",
)
mock_from_days.assert_not_called()
def test_relative_mode_before_days_only_calls_from_days():
policy = object()
service = _mock_service()
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
):
clean_expired_messages.callback(
batch_size=500,
graceful_period=14,
start_from=None,
end_before=None,
from_days_ago=None,
before_days=30,
dry_run=False,
task_label="daily",
)
mock_from_days.assert_called_once_with(
policy=policy,
days=30,
batch_size=500,
dry_run=False,
task_label="daily",
)
mock_from_time_range.assert_not_called()
def test_relative_mode_with_from_days_ago_calls_from_time_range():
policy = object()
service = _mock_service()
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
patch("commands.naive_utc_now", return_value=fixed_now),
):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=None,
end_before=None,
from_days_ago=60,
before_days=30,
dry_run=False,
task_label="daily",
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=fixed_now - datetime.timedelta(days=60),
end_before=fixed_now - datetime.timedelta(days=30),
batch_size=1000,
dry_run=False,
task_label="daily",
)
mock_from_days.assert_not_called()
@pytest.mark.parametrize(
("kwargs", "message"),
[
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": datetime.datetime(2024, 2, 1),
"from_days_ago": None,
"before_days": 30,
},
"mutually exclusive",
),
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"Both --start-from and --end-before are required",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 10,
"before_days": None,
},
"--from-days-ago must be used together with --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": -1,
},
"--before-days must be >= 0",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 30,
"before_days": 30,
},
"--from-days-ago must be greater than --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])",
),
],
)
def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
with pytest.raises(click.UsageError, match=re.escape(message)):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=kwargs["start_from"],
end_before=kwargs["end_before"],
from_days_ago=kwargs["from_days_ago"],
before_days=kwargs["before_days"],
dry_run=False,
task_label="daily",
)

View File

@@ -32,6 +32,11 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs")
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
os.environ.setdefault("STORAGE_TYPE", "opendal")
# Add the API directory to Python path to ensure proper imports
import sys
sys.path.insert(0, PROJECT_DIR)
from core.db.session_factory import configure_session_factory, session_factory
from extensions import ext_redis

View File

@@ -1,70 +0,0 @@
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
RemoteFileUploadError,
TooManyFilesError,
UnsupportedFileTypeError,
)
class TestFilenameNotExistsError:
def test_defaults(self):
error = FilenameNotExistsError()
assert error.code == 400
assert error.description == "The specified filename does not exist."
class TestRemoteFileUploadError:
def test_defaults(self):
error = RemoteFileUploadError()
assert error.code == 400
assert error.description == "Error uploading remote file."
class TestFileTooLargeError:
def test_defaults(self):
error = FileTooLargeError()
assert error.code == 413
assert error.error_code == "file_too_large"
assert error.description == "File size exceeded. {message}"
class TestUnsupportedFileTypeError:
def test_defaults(self):
error = UnsupportedFileTypeError()
assert error.code == 415
assert error.error_code == "unsupported_file_type"
assert error.description == "File type not allowed."
class TestBlockedFileExtensionError:
def test_defaults(self):
error = BlockedFileExtensionError()
assert error.code == 400
assert error.error_code == "file_extension_blocked"
assert error.description == "The file extension is blocked for security reasons."
class TestTooManyFilesError:
def test_defaults(self):
error = TooManyFilesError()
assert error.code == 400
assert error.error_code == "too_many_files"
assert error.description == "Only one file is allowed."
class TestNoFileUploadedError:
def test_defaults(self):
error = NoFileUploadedError()
assert error.code == 400
assert error.error_code == "no_file_uploaded"
assert error.description == "Please upload your file."

View File

@@ -1,95 +1,22 @@
from flask import Response
from controllers.common.file_response import (
_normalize_mime_type,
enforce_download_for_html,
is_html_content,
)
from controllers.common.file_response import enforce_download_for_html, is_html_content
class TestNormalizeMimeType:
def test_returns_empty_string_for_none(self):
assert _normalize_mime_type(None) == ""
def test_returns_empty_string_for_empty_string(self):
assert _normalize_mime_type("") == ""
def test_normalizes_mime_type(self):
assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html"
class TestIsHtmlContent:
def test_detects_html_via_mime_type(self):
class TestFileResponseHelpers:
def test_is_html_content_detects_mime_type(self):
mime_type = "text/html; charset=UTF-8"
result = is_html_content(
mime_type=mime_type,
filename="file.txt",
extension="txt",
)
result = is_html_content(mime_type, filename="file.txt", extension="txt")
assert result is True
def test_detects_html_via_extension_argument(self):
result = is_html_content(
mime_type="text/plain",
filename=None,
extension="html",
)
def test_is_html_content_detects_extension(self):
result = is_html_content("text/plain", filename="report.html", extension=None)
assert result is True
def test_detects_html_via_filename_extension(self):
result = is_html_content(
mime_type="text/plain",
filename="report.html",
extension=None,
)
assert result is True
def test_returns_false_when_no_html_detected_anywhere(self):
"""
Missing negative test:
- MIME type is not HTML
- filename has no HTML extension
- extension argument is not HTML
"""
result = is_html_content(
mime_type="application/json",
filename="data.json",
extension="json",
)
assert result is False
def test_returns_false_when_all_inputs_are_none(self):
result = is_html_content(
mime_type=None,
filename=None,
extension=None,
)
assert result is False
class TestEnforceDownloadForHtml:
def test_sets_attachment_when_filename_missing(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
response,
mime_type="text/html",
filename=None,
extension="html",
)
assert updated is True
assert response.headers["Content-Disposition"] == "attachment"
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_sets_headers_when_filename_present(self):
def test_enforce_download_for_html_sets_headers(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
@@ -100,12 +27,11 @@ class TestEnforceDownloadForHtml:
)
assert updated is True
assert response.headers["Content-Disposition"].startswith("attachment")
assert "unsafe.html" in response.headers["Content-Disposition"]
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_does_not_modify_response_for_non_html_content(self):
def test_enforce_download_for_html_no_change_for_non_html(self):
response = Response("payload", mimetype="text/plain")
updated = enforce_download_for_html(

View File

@@ -1,188 +0,0 @@
from uuid import UUID
import httpx
import pytest
from controllers.common import helpers
from controllers.common.helpers import FileInfo, guess_file_info_from_response
def make_response(
url="https://example.com/file.txt",
headers=None,
content=None,
):
return httpx.Response(
200,
request=httpx.Request("GET", url),
headers=headers or {},
content=content or b"",
)
class TestGuessFileInfoFromResponse:
def test_filename_from_url(self):
response = make_response(
url="https://example.com/test.pdf",
content=b"Hello World",
)
info = guess_file_info_from_response(response)
assert info.filename == "test.pdf"
assert info.extension == ".pdf"
assert info.mimetype == "application/pdf"
def test_filename_from_content_disposition(self):
headers = {
"Content-Disposition": "attachment; filename=myfile.csv",
"Content-Type": "text/csv",
}
response = make_response(
url="https://example.com/",
headers=headers,
content=b"Hello World",
)
info = guess_file_info_from_response(response)
assert info.filename == "myfile.csv"
assert info.extension == ".csv"
assert info.mimetype == "text/csv"
@pytest.mark.parametrize(
("magic_available", "expected_ext"),
[
(True, "txt"),
(False, "bin"),
],
)
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
if magic_available:
if helpers.magic is None:
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
else:
monkeypatch.setattr(helpers, "magic", None)
response = make_response(
url="https://example.com/",
content=b"Hello World",
)
info = guess_file_info_from_response(response)
name, ext = info.filename.split(".")
UUID(name)
assert ext == expected_ext
def test_mimetype_from_header_when_unknown(self):
headers = {"Content-Type": "application/json"}
response = make_response(
url="https://example.com/file.unknown",
headers=headers,
content=b'{"a": 1}',
)
info = guess_file_info_from_response(response)
assert info.mimetype == "application/json"
def test_extension_added_when_missing(self):
headers = {"Content-Type": "image/png"}
response = make_response(
url="https://example.com/image",
headers=headers,
content=b"fakepngdata",
)
info = guess_file_info_from_response(response)
assert info.extension == ".png"
assert info.filename.endswith(".png")
def test_content_length_used_as_size(self):
headers = {
"Content-Length": "1234",
"Content-Type": "text/plain",
}
response = make_response(
url="https://example.com/a.txt",
headers=headers,
content=b"a" * 1234,
)
info = guess_file_info_from_response(response)
assert info.size == 1234
def test_size_minus_one_when_header_missing(self):
response = make_response(url="https://example.com/a.txt")
info = guess_file_info_from_response(response)
assert info.size == -1
def test_fallback_to_bin_extension(self):
headers = {"Content-Type": "application/octet-stream"}
response = make_response(
url="https://example.com/download",
headers=headers,
content=b"\x00\x01\x02\x03",
)
info = guess_file_info_from_response(response)
assert info.extension == ".bin"
assert info.filename.endswith(".bin")
def test_return_type(self):
response = make_response()
info = guess_file_info_from_response(response)
assert isinstance(info, FileInfo)
class TestMagicImportWarnings:
@pytest.mark.parametrize(
("platform_name", "expected_message"),
[
("Windows", "pip install python-magic-bin"),
("Darwin", "brew install libmagic"),
("Linux", "sudo apt-get install libmagic1"),
("Other", "install `libmagic`"),
],
)
def test_magic_import_warning_per_platform(
self,
monkeypatch,
platform_name,
expected_message,
):
import builtins
import importlib
# Force ImportError when "magic" is imported
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "magic":
raise ImportError("No module named magic")
return real_import(name, *args, **kwargs)
monkeypatch.setattr(builtins, "__import__", fake_import)
monkeypatch.setattr("platform.system", lambda: platform_name)
# Remove helpers so it imports fresh
import sys
original_helpers = sys.modules.get(helpers.__name__)
sys.modules.pop(helpers.__name__, None)
try:
with pytest.warns(UserWarning, match="To use python-magic") as warning:
imported_helpers = importlib.import_module(helpers.__name__)
assert expected_message in str(warning[0].message)
finally:
if original_helpers is not None:
sys.modules[helpers.__name__] = original_helpers

View File

@@ -1,189 +0,0 @@
import sys
from enum import StrEnum
from unittest.mock import MagicMock, patch
import pytest
from flask_restx import Namespace
from pydantic import BaseModel
class UserModel(BaseModel):
id: int
name: str
class ProductModel(BaseModel):
id: int
price: float
@pytest.fixture(autouse=True)
def mock_console_ns():
"""Mock the console_ns to avoid circular imports during test collection."""
mock_ns = MagicMock(spec=Namespace)
mock_ns.models = {}
# Inject mock before importing schema module
with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}):
yield mock_ns
def test_default_ref_template_value():
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0
assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}"
def test_register_schema_model_calls_namespace_schema_model():
from controllers.common.schema import register_schema_model
namespace = MagicMock(spec=Namespace)
register_schema_model(namespace, UserModel)
namespace.schema_model.assert_called_once()
model_name, schema = namespace.schema_model.call_args.args
assert model_name == "UserModel"
assert isinstance(schema, dict)
assert "properties" in schema
def test_register_schema_model_passes_schema_from_pydantic():
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model
namespace = MagicMock(spec=Namespace)
register_schema_model(namespace, UserModel)
schema = namespace.schema_model.call_args.args[1]
expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
assert schema == expected_schema
def test_register_schema_models_registers_multiple_models():
from controllers.common.schema import register_schema_models
namespace = MagicMock(spec=Namespace)
register_schema_models(namespace, UserModel, ProductModel)
assert namespace.schema_model.call_count == 2
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
assert called_names == ["UserModel", "ProductModel"]
def test_register_schema_models_calls_register_schema_model(monkeypatch):
from controllers.common.schema import register_schema_models
namespace = MagicMock(spec=Namespace)
calls = []
def fake_register(ns, model):
calls.append((ns, model))
monkeypatch.setattr(
"controllers.common.schema.register_schema_model",
fake_register,
)
register_schema_models(namespace, UserModel, ProductModel)
assert calls == [
(namespace, UserModel),
(namespace, ProductModel),
]
class StatusEnum(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
class PriorityEnum(StrEnum):
HIGH = "high"
LOW = "low"
def test_get_or_create_model_returns_existing_model(mock_console_ns):
from controllers.common.schema import get_or_create_model
existing_model = MagicMock()
mock_console_ns.models = {"TestModel": existing_model}
result = get_or_create_model("TestModel", {"key": "value"})
assert result == existing_model
mock_console_ns.model.assert_not_called()
def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns):
from controllers.common.schema import get_or_create_model
mock_console_ns.models = {}
new_model = MagicMock()
mock_console_ns.model.return_value = new_model
field_def = {"name": {"type": "string"}}
result = get_or_create_model("NewModel", field_def)
assert result == new_model
mock_console_ns.model.assert_called_once_with("NewModel", field_def)
def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns):
from controllers.common.schema import get_or_create_model
existing_model = MagicMock()
mock_console_ns.models = {"ExistingModel": existing_model}
result = get_or_create_model("ExistingModel", {"key": "value"})
assert result == existing_model
mock_console_ns.model.assert_not_called()
def test_register_enum_models_registers_single_enum():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum)
namespace.schema_model.assert_called_once()
model_name, schema = namespace.schema_model.call_args.args
assert model_name == "StatusEnum"
assert isinstance(schema, dict)
def test_register_enum_models_registers_multiple_enums():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum, PriorityEnum)
assert namespace.schema_model.call_count == 2
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
assert called_names == ["StatusEnum", "PriorityEnum"]
def test_register_enum_models_uses_correct_ref_template():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum)
schema = namespace.schema_model.call_args.args[1]
# Verify the schema contains enum values
assert "enum" in schema or "anyOf" in schema

View File

@@ -124,12 +124,12 @@ def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
def start(self):
self.started = True
def fake_thread(*args, **kwargs):
def fake_thread(**kwargs):
thread = DummyThread(**kwargs)
captured["thread"] = thread
return thread
monkeypatch.setattr(message_cycle_manager, "Timer", fake_thread)
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")

View File

@@ -1,8 +1,13 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import dify_graph.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module

View File

@@ -1,425 +0,0 @@
"""
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
This test suite ensures that the files array is correctly populated in the message_end
SSE event, which is critical for vision/image chat responses to render correctly.
Test Coverage:
- Files array populated when MessageFile records exist
- Files array is None when no MessageFile records exist
- Correct signed URL generation for LOCAL_FILE transfer method
- Correct URL handling for REMOTE_URL transfer method
- Correct URL handling for TOOL_FILE transfer method
- Proper file metadata formatting (filename, mime_type, size, extension)
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.task_entities import MessageEndStreamResponse
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
class TestMessageEndStreamResponseFiles:
"""Test suite for files array population in message_end SSE event."""
@pytest.fixture
def mock_pipeline(self):
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
pipeline._message_id = str(uuid.uuid4())
pipeline._task_state = Mock()
pipeline._task_state.metadata = Mock()
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
pipeline._task_state.llm_result = Mock()
pipeline._task_state.llm_result.usage = Mock()
pipeline._application_generate_entity = Mock()
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
return pipeline
@pytest.fixture
def mock_message_file_local(self):
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
message_file.upload_file_id = str(uuid.uuid4())
message_file.url = None
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_remote(self):
"""Create a mock MessageFile with REMOTE_URL transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.REMOTE_URL
message_file.upload_file_id = None
message_file.url = "https://example.com/image.jpg"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_tool(self):
"""Create a mock MessageFile with TOOL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.TOOL_FILE
message_file.upload_file_id = None
message_file.url = "tool_file_123.png"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_upload_file(self, mock_message_file_local):
"""Create a mock UploadFile."""
upload_file = Mock(spec=UploadFile)
upload_file.id = mock_message_file_local.upload_file_id
upload_file.name = "test_image.png"
upload_file.mime_type = "image/png"
upload_file.size = 1024
upload_file.extension = "png"
return upload_file
def test_message_end_with_no_files(self, mock_pipeline):
"""Test that files array is None when no MessageFile records exist."""
# Arrange
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = []
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is None
assert result.id == mock_pipeline._message_id
assert result.metadata == {"test": "metadata"}
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_local.id
assert file_dict["filename"] == "test_image.png"
assert file_dict["mime_type"] == "image/png"
assert file_dict["size"] == 1024
assert file_dict["extension"] == ".png"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
assert "https://example.com/signed-url" in file_dict["url"]
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
assert file_dict["remote_url"] == ""
# Verify database queries
# Should be called twice: once for MessageFile, once for UploadFile
assert mock_session.scalars.call_count == 2
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
# Arrange
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_remote]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_remote.id
assert file_dict["filename"] == "image.jpg"
assert file_dict["url"] == "https://example.com/image.jpg"
assert file_dict["extension"] == ".jpg"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
assert file_dict["remote_url"] == "https://example.com/image.jpg"
assert file_dict["upload_file_id"] == mock_message_file_remote.id
# Verify only one query for message_files is made
mock_session.scalars.assert_called_once()
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "https://example.com/tool_file.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["url"] == "https://example.com/tool_file.png"
assert file_dict["filename"] == "tool_file.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with local path."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_123.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
assert file_dict["filename"] == "tool_file_123.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
# Verify tool file signing was called
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_abc.verylongextension"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed.bin"
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
assert result.files is not None
file_dict = result.files[0]
assert file_dict["extension"] == ".bin"
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
def test_message_end_with_multiple_files(
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
):
"""Test that files array contains all MessageFile records when multiple exist."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 2
# Verify both files are present
file_ids = [f["related_id"] for f in result.files]
assert mock_message_file_local.id in file_ids
assert mock_message_file_remote.id in file_ids
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query) - returns empty list (not found)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [] # UploadFile not found
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/fallback-url" in file_dict["url"]
# Verify fallback URL was generated using upload_file_id from message_file
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))

View File

@@ -1,84 +0,0 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@@ -2,7 +2,15 @@
Simple test to verify MockNodeFactory works with iteration nodes.
"""
import sys
from pathlib import Path
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
# Add api directory to path
api_dir = Path(__file__).parent.parent.parent.parent.parent.parent
sys.path.insert(0, str(api_dir))
from dify_graph.enums import NodeType
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory

View File

@@ -3,8 +3,14 @@ Simple test to validate the auto-mock system without external dependencies.
"""
import sys
from pathlib import Path
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
# Add api directory to path
api_dir = Path(__file__).parent.parent.parent.parent.parent.parent
sys.path.insert(0, str(api_dir))
from dify_graph.enums import NodeType
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory

View File

@@ -8,9 +8,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.knowledge_retrieval.entities import (
Condition,
KnowledgeRetrievalNodeData,
MetadataFilteringCondition,
MultipleRetrievalConfig,
RerankingModelConfig,
SingleRetrievalConfig,
@@ -595,106 +593,3 @@ class TestFetchDatasetRetriever:
# Assert
assert version == "1"
def test_resolve_metadata_filtering_conditions_templates(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""_resolve_metadata_filtering_conditions should expand {{#...#}} and keep numbers/None unchanged."""
# Arrange
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": {
"title": "Knowledge Retrieval",
"type": "knowledge-retrieval",
"dataset_ids": [str(uuid.uuid4())],
"retrieval_mode": "multiple",
},
}
# Variable in pool used by template
mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme"))
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
conditions = MetadataFilteringCondition(
logical_operator="and",
conditions=[
Condition(name="document_name", comparison_operator="is", value="{{#start.query#}}"),
Condition(name="tags", comparison_operator="in", value=["x", "{{#start.query#}}"]),
Condition(name="year", comparison_operator="=", value=2025),
],
)
# Act
resolved = node._resolve_metadata_filtering_conditions(conditions)
# Assert
assert resolved.logical_operator == "and"
assert resolved.conditions[0].value == "readme"
assert isinstance(resolved.conditions[1].value, list)
assert resolved.conditions[1].value[1] == "readme"
assert resolved.conditions[2].value == 2025
def test_fetch_passes_resolved_metadata_conditions(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""_fetch_dataset_retriever should pass resolved metadata conditions into request."""
# Arrange
query = "hi"
variables = {"query": query}
mock_graph_runtime_state.variable_pool.add(["start", "q"], StringSegment(value="readme"))
node_data = KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="multiple",
multiple_retrieval_config=MultipleRetrievalConfig(
top_k=4,
score_threshold=0.0,
reranking_mode="reranking_model",
reranking_enable=True,
reranking_model=RerankingModelConfig(provider="cohere", model="rerank-v2"),
),
metadata_filtering_mode="manual",
metadata_filtering_conditions=MetadataFilteringCondition(
logical_operator="and",
conditions=[
Condition(name="document_name", comparison_operator="is", value="{{#start.q#}}"),
],
),
)
node_id = str(uuid.uuid4())
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
mock_rag_retrieval.knowledge_retrieval.return_value = []
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
# Act
node._fetch_dataset_retriever(node_data=node_data, variables=variables)
# Assert the passed request has resolved value
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
request = call_args[1]["request"]
assert request.metadata_filtering_conditions is not None
assert request.metadata_filtering_conditions.conditions[0].value == "readme"

View File

@@ -16,7 +16,6 @@ from dify_graph.nodes.document_extractor.node import (
_extract_text_from_excel,
_extract_text_from_pdf,
_extract_text_from_plain_text,
_normalize_docx_zip,
)
from dify_graph.variables import ArrayFileSegment
from dify_graph.variables.segments import ArrayStringSegment
@@ -87,38 +86,6 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
assert "is not an ArrayFileSegment" in result.error
def test_run_empty_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state):
"""Empty file list should return SUCCEEDED with empty documents and ArrayStringSegment([])."""
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
# Provide an actual ArrayFileSegment with an empty list
mock_graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(value=[])
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
assert result.process_data.get("documents") == []
assert result.outputs["text"] == ArrayStringSegment(value=[])
def test_run_none_only_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state):
"""A file list containing only None (e.g., [None]) should be filtered to [] and succeed."""
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
# Use a Mock to bypass type validation for None entries in the list
afs = Mock(spec=ArrayFileSegment)
afs.value = [None]
mock_graph_runtime_state.variable_pool.get.return_value = afs
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
assert result.process_data.get("documents") == []
assert result.outputs["text"] == ArrayStringSegment(value=[])
@pytest.mark.parametrize(
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
[
@@ -418,58 +385,3 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n"
assert expected_manual == result
def _make_docx_zip(use_backslash: bool) -> bytes:
"""Helper to build a minimal in-memory DOCX zip.
When use_backslash=True the ZIP entry names use backslash separators
(as produced by Evernote on Windows), otherwise forward slashes are used.
"""
import zipfile
sep = "\\" if use_backslash else "/"
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
zf.writestr("[Content_Types].xml", b"<Types/>")
zf.writestr(f"_rels{sep}.rels", b"<Relationships/>")
zf.writestr(f"word{sep}document.xml", b"<w:document/>")
zf.writestr(f"word{sep}_rels{sep}document.xml.rels", b"<Relationships/>")
return buf.getvalue()
def test_normalize_docx_zip_replaces_backslashes():
"""ZIP entries with backslash separators must be rewritten to forward slashes."""
import zipfile
malformed = _make_docx_zip(use_backslash=True)
fixed = _normalize_docx_zip(malformed)
with zipfile.ZipFile(io.BytesIO(fixed)) as zf:
names = zf.namelist()
assert "word/document.xml" in names
assert "word/_rels/document.xml.rels" in names
# No entry should contain a backslash after normalization
assert all("\\" not in name for name in names)
def test_normalize_docx_zip_leaves_forward_slash_unchanged():
"""ZIP entries that already use forward slashes must not be modified."""
import zipfile
normal = _make_docx_zip(use_backslash=False)
fixed = _normalize_docx_zip(normal)
with zipfile.ZipFile(io.BytesIO(fixed)) as zf:
names = zf.namelist()
assert "word/document.xml" in names
assert "word/_rels/document.xml.rels" in names
def test_normalize_docx_zip_returns_original_on_bad_zip():
"""Non-zip bytes must be returned as-is without raising."""
garbage = b"not a zip file at all"
result = _normalize_docx_zip(garbage)
assert result == garbage

View File

@@ -265,61 +265,6 @@ def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup.run()
def test_run_records_metrics_on_success(monkeypatch: pytest.MonkeyPatch) -> None:
cutoff = datetime.datetime.now()
repo = FakeRepo(
batches=[[FakeRun("run-free", "t_free", cutoff)]],
delete_result={
"runs": 0,
"node_executions": 2,
"offloads": 1,
"app_logs": 3,
"trigger_logs": 4,
"pauses": 5,
"pause_reasons": 6,
},
)
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
batch_calls: list[dict[str, object]] = []
completion_calls: list[dict[str, object]] = []
monkeypatch.setattr(cleanup._metrics, "record_batch", lambda **kwargs: batch_calls.append(kwargs))
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
cleanup.run()
assert len(batch_calls) == 1
assert batch_calls[0]["batch_rows"] == 1
assert batch_calls[0]["targeted_runs"] == 1
assert batch_calls[0]["deleted_runs"] == 1
assert batch_calls[0]["related_action"] == "deleted"
assert len(completion_calls) == 1
assert completion_calls[0]["status"] == "success"
def test_run_records_failed_metrics(monkeypatch: pytest.MonkeyPatch) -> None:
class FailingRepo(FakeRepo):
def delete_runs_with_related(
self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None
) -> dict[str, int]:
raise RuntimeError("delete failed")
cutoff = datetime.datetime.now()
repo = FailingRepo(batches=[[FakeRun("run-free", "t_free", cutoff)]])
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
completion_calls: list[dict[str, object]] = []
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
with pytest.raises(RuntimeError, match="delete failed"):
cleanup.run()
assert len(completion_calls) == 1
assert completion_calls[0]["status"] == "failed"
def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
cutoff = datetime.datetime.now()
repo = FakeRepo(

View File

@@ -1,43 +0,0 @@
import datetime
import pytest
from services.retention.conversation.message_export_service import AppMessageExportService
def test_validate_export_filename_accepts_relative_path():
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
@pytest.mark.parametrize(
"filename",
[
"test01.jsonl.gz",
"test01.jsonl",
"test01.gz",
"/tmp/test01",
"exports/../test01",
"bad\x00name",
"bad\tname",
"a" * 1025,
],
)
def test_validate_export_filename_rejects_invalid_values(filename: str):
with pytest.raises(ValueError):
AppMessageExportService.validate_export_filename(filename)
def test_service_derives_output_names_from_filename_base():
service = AppMessageExportService(
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
start_from=None,
end_before=datetime.datetime(2026, 3, 1),
filename="exports/2026/test01",
batch_size=1000,
use_cloud_storage=True,
dry_run=True,
)
assert service._filename_base == "exports/2026/test01"
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
assert service.output_jsonl_name == "exports/2026/test01.jsonl"

View File

@@ -554,9 +554,11 @@ class TestMessagesCleanServiceFromDays:
MessagesCleanService.from_days(policy=policy, days=-1)
# Act
with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_now.return_value = fixed_now
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy, days=0)
# Assert
@@ -584,9 +586,11 @@ class TestMessagesCleanServiceFromDays:
dry_run = True
# Act
with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_now.return_value = fixed_now
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(
policy=policy,
days=days,
@@ -609,9 +613,11 @@ class TestMessagesCleanServiceFromDays:
policy = BillingDisabledPolicy()
# Act
with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_now.return_value = fixed_now
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy)
# Assert
@@ -619,53 +625,3 @@ class TestMessagesCleanServiceFromDays:
assert service._end_before == expected_end_before
assert service._batch_size == 1000 # default
assert service._dry_run is False # default
class TestMessagesCleanServiceRun:
"""Unit tests for MessagesCleanService.run instrumentation behavior."""
def test_run_records_completion_metrics_on_success(self):
# Arrange
service = MessagesCleanService(
policy=BillingDisabledPolicy(),
start_from=datetime.datetime(2024, 1, 1),
end_before=datetime.datetime(2024, 1, 2),
batch_size=100,
dry_run=False,
)
expected_stats = {
"batches": 1,
"total_messages": 10,
"filtered_messages": 5,
"total_deleted": 5,
}
service._clean_messages_by_time_range = MagicMock(return_value=expected_stats) # type: ignore[method-assign]
completion_calls: list[dict[str, object]] = []
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
# Act
result = service.run()
# Assert
assert result == expected_stats
assert len(completion_calls) == 1
assert completion_calls[0]["status"] == "success"
def test_run_records_completion_metrics_on_failure(self):
# Arrange
service = MessagesCleanService(
policy=BillingDisabledPolicy(),
start_from=datetime.datetime(2024, 1, 1),
end_before=datetime.datetime(2024, 1, 2),
batch_size=100,
dry_run=False,
)
service._clean_messages_by_time_range = MagicMock(side_effect=RuntimeError("clean failed")) # type: ignore[method-assign]
completion_calls: list[dict[str, object]] = []
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
# Act & Assert
with pytest.raises(RuntimeError, match="clean failed"):
service.run()
assert len(completion_calls) == 1
assert completion_calls[0]["status"] == "failed"

View File

@@ -6,13 +6,6 @@ from typing import Any
class ConfigHelper:
_LEGACY_SECTION_MAP = {
"admin_config": "admin",
"token_config": "auth",
"app_config": "app",
"api_key_config": "api_key",
}
"""Helper class for reading and writing configuration files."""
def __init__(self, base_dir: Path | None = None):
@@ -57,8 +50,14 @@ class ConfigHelper:
Dictionary containing config data, or None if file doesn't exist
"""
# Provide backward compatibility for old config names
if filename in self._LEGACY_SECTION_MAP:
return self.get_state_section(self._LEGACY_SECTION_MAP[filename])
if filename in ["admin_config", "token_config", "app_config", "api_key_config"]:
section_map = {
"admin_config": "admin",
"token_config": "auth",
"app_config": "app",
"api_key_config": "api_key",
}
return self.get_state_section(section_map[filename])
config_path = self.get_config_path(filename)
@@ -86,11 +85,14 @@ class ConfigHelper:
True if successful, False otherwise
"""
# Provide backward compatibility for old config names
if filename in self._LEGACY_SECTION_MAP:
return self.update_state_section(
self._LEGACY_SECTION_MAP[filename],
data,
)
if filename in ["admin_config", "token_config", "app_config", "api_key_config"]:
section_map = {
"admin_config": "admin",
"token_config": "auth",
"app_config": "app",
"api_key_config": "api_key",
}
return self.update_state_section(section_map[filename], data)
self.ensure_config_dir()
config_path = self.get_config_path(filename)

View File

@@ -2,12 +2,6 @@
- Refer to the `./docs/test.md` and `./docs/lint.md` for detailed frontend workflow instructions.
## Overlay Components (Mandatory)
- `./docs/overlay-migration.md` is the source of truth for overlay-related work.
- In new or modified code, use only overlay primitives from `@/app/components/base/ui/*`.
- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them and keep the allowlist shrinking (never expanding).
## Automated Test Generation
- Use `./docs/test.md` as the canonical instruction set for generating frontend automated tests.

View File

@@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => {
})
})
// ─── 6. Close Handling ───────────────────────────────────────────────────
describe('Close handling', () => {
it('should call onCancel when pressing ESC key', () => {
render(<Pricing onCancel={onCancel} />)
// ahooks useKeyPress listens on document for keydown events
document.dispatchEvent(new KeyboardEvent('keydown', {
key: 'Escape',
code: 'Escape',
keyCode: 27,
bubbles: true,
}))
expect(onCancel).toHaveBeenCalledTimes(1)
})
})
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
describe('Pricing page URL', () => {
it('should render pricing link with correct URL', () => {
render(<Pricing onCancel={onCancel} />)

View File

@@ -1,139 +0,0 @@
import * as amplitude from '@amplitude/analytics-browser'
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
import { render } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider'
const mockConfig = vi.hoisted(() => ({
AMPLITUDE_API_KEY: 'test-api-key',
IS_CLOUD_EDITION: true,
}))
vi.mock('@/config', () => mockConfig)
vi.mock('@amplitude/analytics-browser', () => ({
init: vi.fn(),
add: vi.fn(),
}))
vi.mock('@amplitude/plugin-session-replay-browser', () => ({
sessionReplayPlugin: vi.fn(() => ({ name: 'session-replay' })),
}))
describe('AmplitudeProvider', () => {
beforeEach(() => {
vi.clearAllMocks()
mockConfig.AMPLITUDE_API_KEY = 'test-api-key'
mockConfig.IS_CLOUD_EDITION = true
})
describe('isAmplitudeEnabled', () => {
it('returns true when cloud edition and api key present', () => {
expect(isAmplitudeEnabled()).toBe(true)
})
it('returns false when cloud edition but no api key', () => {
mockConfig.AMPLITUDE_API_KEY = ''
expect(isAmplitudeEnabled()).toBe(false)
})
it('returns false when not cloud edition', () => {
mockConfig.IS_CLOUD_EDITION = false
expect(isAmplitudeEnabled()).toBe(false)
})
})
describe('Component', () => {
it('initializes amplitude when enabled', () => {
render(<AmplitudeProvider sessionReplaySampleRate={0.8} />)
expect(amplitude.init).toHaveBeenCalledWith('test-api-key', expect.any(Object))
expect(sessionReplayPlugin).toHaveBeenCalledWith({ sampleRate: 0.8 })
expect(amplitude.add).toHaveBeenCalledTimes(2)
})
it('does not initialize amplitude when disabled', () => {
mockConfig.AMPLITUDE_API_KEY = ''
render(<AmplitudeProvider />)
expect(amplitude.init).not.toHaveBeenCalled()
expect(amplitude.add).not.toHaveBeenCalled()
})
it('pageNameEnrichmentPlugin logic works as expected', async () => {
render(<AmplitudeProvider />)
const plugin = vi.mocked(amplitude.add).mock.calls[0]?.[0] as amplitude.Types.EnrichmentPlugin | undefined
expect(plugin).toBeDefined()
if (!plugin?.execute || !plugin.setup)
throw new Error('Expected page-name-enrichment plugin with setup/execute')
expect(plugin.name).toBe('page-name-enrichment')
const execute = plugin.execute
const setup = plugin.setup
type SetupFn = NonNullable<amplitude.Types.EnrichmentPlugin['setup']>
const getPageTitle = (evt: amplitude.Types.Event | null | undefined) =>
(evt?.event_properties as Record<string, unknown> | undefined)?.['[Amplitude] Page Title']
await setup(
{} as Parameters<SetupFn>[0],
{} as Parameters<SetupFn>[1],
)
const originalWindowLocation = window.location
try {
Object.defineProperty(window, 'location', {
value: { pathname: '/datasets' },
writable: true,
})
const event: amplitude.Types.Event = {
event_type: '[Amplitude] Page Viewed',
event_properties: {},
}
const result = await execute(event)
expect(getPageTitle(result)).toBe('Knowledge')
window.location.pathname = '/'
await execute(event)
expect(getPageTitle(event)).toBe('Home')
window.location.pathname = '/apps'
await execute(event)
expect(getPageTitle(event)).toBe('Studio')
window.location.pathname = '/explore'
await execute(event)
expect(getPageTitle(event)).toBe('Explore')
window.location.pathname = '/tools'
await execute(event)
expect(getPageTitle(event)).toBe('Tools')
window.location.pathname = '/account'
await execute(event)
expect(getPageTitle(event)).toBe('Account')
window.location.pathname = '/signin'
await execute(event)
expect(getPageTitle(event)).toBe('Sign In')
window.location.pathname = '/signup'
await execute(event)
expect(getPageTitle(event)).toBe('Sign Up')
window.location.pathname = '/unknown'
await execute(event)
expect(getPageTitle(event)).toBe('Unknown')
const otherEvent = {
event_type: 'Button Clicked',
event_properties: {},
} as amplitude.Types.Event
const otherResult = await execute(otherEvent)
expect(getPageTitle(otherResult)).toBeUndefined()
const noPropsEvent = {
event_type: '[Amplitude] Page Viewed',
} as amplitude.Types.Event
const noPropsResult = await execute(noPropsEvent)
expect(noPropsResult?.event_properties).toBeUndefined()
}
finally {
Object.defineProperty(window, 'location', {
value: originalWindowLocation,
writable: true,
})
}
})
})
})

View File

@@ -1,32 +0,0 @@
import { describe, expect, it } from 'vitest'
import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider'
import indexDefault, {
isAmplitudeEnabled as indexIsAmplitudeEnabled,
resetUser,
setUserId,
setUserProperties,
trackEvent,
} from './index'
import {
resetUser as utilsResetUser,
setUserId as utilsSetUserId,
setUserProperties as utilsSetUserProperties,
trackEvent as utilsTrackEvent,
} from './utils'
describe('Amplitude index exports', () => {
it('exports AmplitudeProvider as default', () => {
expect(indexDefault).toBe(AmplitudeProvider)
})
it('exports isAmplitudeEnabled', () => {
expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled)
})
it('exports utils', () => {
expect(resetUser).toBe(utilsResetUser)
expect(setUserId).toBe(utilsSetUserId)
expect(setUserProperties).toBe(utilsSetUserProperties)
expect(trackEvent).toBe(utilsTrackEvent)
})
})

View File

@@ -1,119 +0,0 @@
import { resetUser, setUserId, setUserProperties, trackEvent } from './utils'
const mockState = vi.hoisted(() => ({
enabled: true,
}))
const mockTrack = vi.hoisted(() => vi.fn())
const mockSetUserId = vi.hoisted(() => vi.fn())
const mockIdentify = vi.hoisted(() => vi.fn())
const mockReset = vi.hoisted(() => vi.fn())
const MockIdentify = vi.hoisted(() =>
class {
setCalls: Array<[string, unknown]> = []
set(key: string, value: unknown) {
this.setCalls.push([key, value])
return this
}
},
)
vi.mock('./AmplitudeProvider', () => ({
isAmplitudeEnabled: () => mockState.enabled,
}))
vi.mock('@amplitude/analytics-browser', () => ({
track: (...args: unknown[]) => mockTrack(...args),
setUserId: (...args: unknown[]) => mockSetUserId(...args),
identify: (...args: unknown[]) => mockIdentify(...args),
reset: (...args: unknown[]) => mockReset(...args),
Identify: MockIdentify,
}))
describe('amplitude utils', () => {
beforeEach(() => {
vi.clearAllMocks()
mockState.enabled = true
})
describe('trackEvent', () => {
it('should call amplitude.track when amplitude is enabled', () => {
trackEvent('dataset_created', { source: 'wizard' })
expect(mockTrack).toHaveBeenCalledTimes(1)
expect(mockTrack).toHaveBeenCalledWith('dataset_created', { source: 'wizard' })
})
it('should not call amplitude.track when amplitude is disabled', () => {
mockState.enabled = false
trackEvent('dataset_created', { source: 'wizard' })
expect(mockTrack).not.toHaveBeenCalled()
})
})
describe('setUserId', () => {
it('should call amplitude.setUserId when amplitude is enabled', () => {
setUserId('user-123')
expect(mockSetUserId).toHaveBeenCalledTimes(1)
expect(mockSetUserId).toHaveBeenCalledWith('user-123')
})
it('should not call amplitude.setUserId when amplitude is disabled', () => {
mockState.enabled = false
setUserId('user-123')
expect(mockSetUserId).not.toHaveBeenCalled()
})
})
describe('setUserProperties', () => {
it('should build identify event and call amplitude.identify when amplitude is enabled', () => {
const properties: Record<string, unknown> = {
role: 'owner',
seats: 3,
verified: true,
}
setUserProperties(properties)
expect(mockIdentify).toHaveBeenCalledTimes(1)
const identifyArg = mockIdentify.mock.calls[0][0] as InstanceType<typeof MockIdentify>
expect(identifyArg).toBeInstanceOf(MockIdentify)
expect(identifyArg.setCalls).toEqual([
['role', 'owner'],
['seats', 3],
['verified', true],
])
})
it('should not call amplitude.identify when amplitude is disabled', () => {
mockState.enabled = false
setUserProperties({ role: 'owner' })
expect(mockIdentify).not.toHaveBeenCalled()
})
})
describe('resetUser', () => {
it('should call amplitude.reset when amplitude is enabled', () => {
resetUser()
expect(mockReset).toHaveBeenCalledTimes(1)
})
it('should not call amplitude.reset when amplitude is disabled', () => {
mockState.enabled = false
resetUser()
expect(mockReset).not.toHaveBeenCalled()
})
})
})

View File

@@ -1,148 +0,0 @@
import { AudioPlayerManager } from '../audio.player.manager'
type AudioCallback = ((event: string) => void) | null
type AudioPlayerCtorArgs = [
string,
boolean,
string | undefined,
string | null | undefined,
string | undefined,
AudioCallback,
]
type MockAudioPlayerInstance = {
setCallback: ReturnType<typeof vi.fn>
pauseAudio: ReturnType<typeof vi.fn>
resetMsgId: ReturnType<typeof vi.fn>
cacheBuffers: Array<ArrayBuffer>
sourceBuffer: {
abort: ReturnType<typeof vi.fn>
} | undefined
}
const mockState = vi.hoisted(() => ({
instances: [] as MockAudioPlayerInstance[],
}))
const mockAudioPlayerConstructor = vi.hoisted(() => vi.fn())
const MockAudioPlayer = vi.hoisted(() => {
return class MockAudioPlayerClass {
setCallback = vi.fn()
pauseAudio = vi.fn()
resetMsgId = vi.fn()
cacheBuffers = [new ArrayBuffer(1)]
sourceBuffer = { abort: vi.fn() }
constructor(...args: AudioPlayerCtorArgs) {
mockAudioPlayerConstructor(...args)
mockState.instances.push(this as unknown as MockAudioPlayerInstance)
}
}
})
vi.mock('@/app/components/base/audio-btn/audio', () => ({
default: MockAudioPlayer,
}))
describe('AudioPlayerManager', () => {
beforeEach(() => {
vi.clearAllMocks()
mockState.instances = []
Reflect.set(AudioPlayerManager, 'instance', undefined)
})
describe('getInstance', () => {
it('should return the same singleton instance across calls', () => {
const first = AudioPlayerManager.getInstance()
const second = AudioPlayerManager.getInstance()
expect(first).toBe(second)
})
})
describe('getAudioPlayer', () => {
it('should create a new audio player when no existing player is cached', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
const result = manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(1)
expect(mockAudioPlayerConstructor).toHaveBeenCalledWith(
'/text-to-audio',
false,
'msg-1',
'hello',
'en-US',
callback,
)
expect(result).toBe(mockState.instances[0])
})
it('should reuse existing player and update callback when msg id is unchanged', () => {
const manager = AudioPlayerManager.getInstance()
const firstCallback = vi.fn()
const secondCallback = vi.fn()
const first = manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', firstCallback)
const second = manager.getAudioPlayer('/ignored', true, 'msg-1', 'ignored', 'fr-FR', secondCallback)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(1)
expect(first).toBe(second)
expect(mockState.instances[0].setCallback).toHaveBeenCalledTimes(1)
expect(mockState.instances[0].setCallback).toHaveBeenCalledWith(secondCallback)
})
it('should cleanup existing player and create a new one when msg id changes', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
const previous = mockState.instances[0]
const next = manager.getAudioPlayer('/apps/1/text-to-audio', false, 'msg-2', 'world', 'en-US', callback)
expect(previous.pauseAudio).toHaveBeenCalledTimes(1)
expect(previous.cacheBuffers).toEqual([])
expect(previous.sourceBuffer?.abort).toHaveBeenCalledTimes(1)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(2)
expect(next).toBe(mockState.instances[1])
})
it('should swallow cleanup errors and still create a new player', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
const previous = mockState.instances[0]
previous.pauseAudio.mockImplementation(() => {
throw new Error('cleanup failure')
})
expect(() => {
manager.getAudioPlayer('/apps/1/text-to-audio', false, 'msg-2', 'world', 'en-US', callback)
}).not.toThrow()
expect(previous.pauseAudio).toHaveBeenCalledTimes(1)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(2)
})
})
describe('resetMsgId', () => {
it('should forward reset message id to the cached audio player when present', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
manager.resetMsgId('msg-updated')
expect(mockState.instances[0].resetMsgId).toHaveBeenCalledTimes(1)
expect(mockState.instances[0].resetMsgId).toHaveBeenCalledWith('msg-updated')
})
it('should not throw when resetting message id without an audio player', () => {
const manager = AudioPlayerManager.getInstance()
expect(() => manager.resetMsgId('msg-updated')).not.toThrow()
})
})
})

View File

@@ -1,610 +0,0 @@
import { Buffer } from 'node:buffer'
import { waitFor } from '@testing-library/react'
import { AppSourceType } from '@/service/share'
import AudioPlayer from '../audio'
const mockToastNotify = vi.hoisted(() => vi.fn())
const mockTextToAudioStream = vi.hoisted(() => vi.fn())
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (...args: unknown[]) => mockToastNotify(...args),
},
}))
vi.mock('@/service/share', () => ({
AppSourceType: {
webApp: 'webApp',
installedApp: 'installedApp',
},
textToAudioStream: (...args: unknown[]) => mockTextToAudioStream(...args),
}))
type AudioEventName = 'ended' | 'paused' | 'loaded' | 'play' | 'timeupdate' | 'loadeddate' | 'canplay' | 'error' | 'sourceopen'
type AudioEventListener = () => void
type ReaderResult = {
value: Uint8Array | undefined
done: boolean
}
type Reader = {
read: () => Promise<ReaderResult>
}
type AudioResponse = {
status: number
body: {
getReader: () => Reader
}
}
class MockSourceBuffer {
updating = false
appendBuffer = vi.fn((_buffer: ArrayBuffer) => undefined)
abort = vi.fn(() => undefined)
}
class MockMediaSource {
readyState: 'open' | 'closed' = 'open'
sourceBuffer = new MockSourceBuffer()
private listeners: Partial<Record<AudioEventName, AudioEventListener[]>> = {}
addEventListener = vi.fn((event: AudioEventName, listener: AudioEventListener) => {
const listeners = this.listeners[event] || []
listeners.push(listener)
this.listeners[event] = listeners
})
addSourceBuffer = vi.fn((_contentType: string) => this.sourceBuffer)
endOfStream = vi.fn(() => undefined)
emit(event: AudioEventName) {
const listeners = this.listeners[event] || []
listeners.forEach((listener) => {
listener()
})
}
}
class MockAudio {
src = ''
autoplay = false
disableRemotePlayback = false
controls = false
paused = true
ended = false
played: unknown = null
private listeners: Partial<Record<AudioEventName, AudioEventListener[]>> = {}
addEventListener = vi.fn((event: AudioEventName, listener: AudioEventListener) => {
const listeners = this.listeners[event] || []
listeners.push(listener)
this.listeners[event] = listeners
})
play = vi.fn(async () => {
this.paused = false
})
pause = vi.fn(() => {
this.paused = true
})
emit(event: AudioEventName) {
const listeners = this.listeners[event] || []
listeners.forEach((listener) => {
listener()
})
}
}
class MockAudioContext {
state: 'running' | 'suspended' = 'running'
destination = {}
connect = vi.fn(() => undefined)
createMediaElementSource = vi.fn((_audio: MockAudio) => ({
connect: this.connect,
}))
resume = vi.fn(async () => {
this.state = 'running'
})
suspend = vi.fn(() => {
this.state = 'suspended'
})
}
const testState = {
mediaSources: [] as MockMediaSource[],
audios: [] as MockAudio[],
audioContexts: [] as MockAudioContext[],
}
class MockMediaSourceCtor extends MockMediaSource {
constructor() {
super()
testState.mediaSources.push(this)
}
}
class MockAudioCtor extends MockAudio {
constructor() {
super()
testState.audios.push(this)
}
}
class MockAudioContextCtor extends MockAudioContext {
constructor() {
super()
testState.audioContexts.push(this)
}
}
const originalAudio = globalThis.Audio
const originalAudioContext = globalThis.AudioContext
const originalCreateObjectURL = globalThis.URL.createObjectURL
const originalMediaSource = window.MediaSource
const originalManagedMediaSource = window.ManagedMediaSource
const setMediaSourceSupport = (options: { mediaSource: boolean, managedMediaSource: boolean }) => {
Object.defineProperty(window, 'MediaSource', {
configurable: true,
writable: true,
value: options.mediaSource ? MockMediaSourceCtor : undefined,
})
Object.defineProperty(window, 'ManagedMediaSource', {
configurable: true,
writable: true,
value: options.managedMediaSource ? MockMediaSourceCtor : undefined,
})
}
const makeAudioResponse = (status: number, reads: ReaderResult[]): AudioResponse => {
const read = vi.fn<() => Promise<ReaderResult>>()
reads.forEach((result) => {
read.mockResolvedValueOnce(result)
})
return {
status,
body: {
getReader: () => ({ read }),
},
}
}
describe('AudioPlayer', () => {
beforeEach(() => {
vi.clearAllMocks()
testState.mediaSources = []
testState.audios = []
testState.audioContexts = []
Object.defineProperty(globalThis, 'Audio', {
configurable: true,
writable: true,
value: MockAudioCtor,
})
Object.defineProperty(globalThis, 'AudioContext', {
configurable: true,
writable: true,
value: MockAudioContextCtor,
})
Object.defineProperty(globalThis.URL, 'createObjectURL', {
configurable: true,
writable: true,
value: vi.fn(() => 'blob:mock-url'),
})
setMediaSourceSupport({ mediaSource: true, managedMediaSource: false })
})
afterAll(() => {
Object.defineProperty(globalThis, 'Audio', {
configurable: true,
writable: true,
value: originalAudio,
})
Object.defineProperty(globalThis, 'AudioContext', {
configurable: true,
writable: true,
value: originalAudioContext,
})
Object.defineProperty(globalThis.URL, 'createObjectURL', {
configurable: true,
writable: true,
value: originalCreateObjectURL,
})
Object.defineProperty(window, 'MediaSource', {
configurable: true,
writable: true,
value: originalMediaSource,
})
Object.defineProperty(window, 'ManagedMediaSource', {
configurable: true,
writable: true,
value: originalManagedMediaSource,
})
})
describe('constructor behavior', () => {
it('should initialize media source, audio, and media element source when MediaSource exists', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
const mediaSource = testState.mediaSources[0]
expect(player.mediaSource).toBe(mediaSource as unknown as MediaSource)
expect(globalThis.URL.createObjectURL).toHaveBeenCalledTimes(1)
expect(audio.src).toBe('blob:mock-url')
expect(audio.autoplay).toBe(true)
expect(audioContext.createMediaElementSource).toHaveBeenCalledWith(audio)
expect(audioContext.connect).toHaveBeenCalledTimes(1)
})
it('should notify unsupported browser when no MediaSource implementation exists', () => {
setMediaSourceSupport({ mediaSource: false, managedMediaSource: false })
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const audio = testState.audios[0]
expect(player.mediaSource).toBeNull()
expect(audio.src).toBe('')
expect(mockToastNotify).toHaveBeenCalledTimes(1)
expect(mockToastNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
)
})
it('should configure fallback audio controls when ManagedMediaSource is used', () => {
setMediaSourceSupport({ mediaSource: false, managedMediaSource: true })
// Create with callback to ensure constructor path completes with fallback source.
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, vi.fn())
const audio = testState.audios[0]
expect(player.mediaSource).not.toBeNull()
expect(audio.disableRemotePlayback).toBe(true)
expect(audio.controls).toBe(true)
})
})
describe('event wiring', () => {
it('should forward registered audio events to callback', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.emit('play')
audio.emit('ended')
audio.emit('error')
audio.emit('paused')
audio.emit('loaded')
audio.emit('timeupdate')
audio.emit('loadeddate')
audio.emit('canplay')
expect(player.callback).toBe(callback)
expect(callback).toHaveBeenCalledWith('play')
expect(callback).toHaveBeenCalledWith('ended')
expect(callback).toHaveBeenCalledWith('error')
expect(callback).toHaveBeenCalledWith('paused')
expect(callback).toHaveBeenCalledWith('loaded')
expect(callback).toHaveBeenCalledWith('timeupdate')
expect(callback).toHaveBeenCalledWith('loadeddate')
expect(callback).toHaveBeenCalledWith('canplay')
})
it('should initialize source buffer only once when sourceopen fires multiple times', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', vi.fn())
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
mediaSource.emit('sourceopen')
expect(mediaSource.addSourceBuffer).toHaveBeenCalledTimes(1)
expect(player.sourceBuffer).toBe(mediaSource.sourceBuffer)
})
})
describe('playback control', () => {
it('should request streaming audio when playAudio is called before loading', async () => {
mockTextToAudioStream.mockResolvedValue(
makeAudioResponse(200, [
{ value: new Uint8Array([4, 5]), done: false },
{ value: new Uint8Array([1, 2, 3]), done: true },
]),
)
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', vi.fn())
player.playAudio()
await waitFor(() => {
expect(mockTextToAudioStream).toHaveBeenCalledTimes(1)
})
expect(mockTextToAudioStream).toHaveBeenCalledWith(
'/text-to-audio',
AppSourceType.webApp,
{ content_type: 'audio/mpeg' },
{
message_id: 'msg-1',
streaming: true,
voice: 'en-US',
text: 'hello',
},
)
expect(player.isLoadData).toBe(true)
})
it('should emit error callback and reset load flag when stream response status is not 200', async () => {
const callback = vi.fn()
mockTextToAudioStream.mockResolvedValue(
makeAudioResponse(500, [{ value: new Uint8Array([1]), done: true }]),
)
const player = new AudioPlayer('/text-to-audio', false, 'msg-2', 'world', undefined, callback)
player.playAudio()
await waitFor(() => {
expect(callback).toHaveBeenCalledWith('error')
})
expect(player.isLoadData).toBe(false)
})
it('should resume and play immediately when playAudio is called in suspended loaded state', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.isLoadData = true
audioContext.state = 'suspended'
player.playAudio()
await Promise.resolve()
expect(audioContext.resume).toHaveBeenCalledTimes(1)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should play ended audio when data is already loaded', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.isLoadData = true
audioContext.state = 'running'
audio.ended = true
player.playAudio()
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should only emit play callback without replaying when loaded audio is already playing', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.isLoadData = true
audioContext.state = 'running'
audio.ended = false
player.playAudio()
expect(audio.play).not.toHaveBeenCalled()
expect(callback).toHaveBeenCalledWith('play')
})
it('should emit error callback when stream request throws', async () => {
const callback = vi.fn()
mockTextToAudioStream.mockRejectedValue(new Error('network failed'))
const player = new AudioPlayer('/text-to-audio', false, 'msg-2', 'world', undefined, callback)
player.playAudio()
await waitFor(() => {
expect(callback).toHaveBeenCalledWith('error')
})
expect(player.isLoadData).toBe(false)
})
it('should call pause flow and notify paused event when pauseAudio is invoked', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.pauseAudio()
expect(callback).toHaveBeenCalledWith('paused')
expect(audio.pause).toHaveBeenCalledTimes(1)
expect(audioContext.suspend).toHaveBeenCalledTimes(1)
})
})
describe('message and direct-audio helpers', () => {
it('should update message id through resetMsgId', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
player.resetMsgId('msg-2')
expect(player.msgId).toBe('msg-2')
})
it('should end stream without playback when playAudioWithAudio receives empty content', async () => {
vi.useFakeTimers()
try {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const mediaSource = testState.mediaSources[0]
await player.playAudioWithAudio('', true)
await vi.advanceTimersByTimeAsync(40)
expect(player.isLoadData).toBe(false)
expect(player.cacheBuffers).toHaveLength(0)
expect(mediaSource.endOfStream).toHaveBeenCalledTimes(1)
expect(callback).not.toHaveBeenCalledWith('play')
}
finally {
vi.useRealTimers()
}
})
it('should decode base64 and start playback when playAudioWithAudio is called with playable content', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
const mediaSource = testState.mediaSources[0]
const audioBase64 = Buffer.from('hello').toString('base64')
mediaSource.emit('sourceopen')
audio.paused = true
await player.playAudioWithAudio(audioBase64, true)
await Promise.resolve()
expect(player.isLoadData).toBe(true)
expect(player.cacheBuffers).toHaveLength(0)
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1)
const appendedAudioData = mediaSource.sourceBuffer.appendBuffer.mock.calls[0][0]
expect(appendedAudioData).toBeInstanceOf(ArrayBuffer)
expect(appendedAudioData.byteLength).toBeGreaterThan(0)
expect(audioContext.resume).toHaveBeenCalledTimes(1)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should skip playback when playAudioWithAudio is called with play=false', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), false)
expect(player.isLoadData).toBe(false)
expect(audioContext.resume).not.toHaveBeenCalled()
expect(audio.play).not.toHaveBeenCalled()
expect(callback).not.toHaveBeenCalledWith('play')
})
it('should play immediately for ended audio in playAudioWithAudio', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.paused = false
audio.ended = true
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should not replay when played list exists in playAudioWithAudio', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.paused = false
audio.ended = false
audio.played = {}
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true)
expect(audio.play).not.toHaveBeenCalled()
expect(callback).not.toHaveBeenCalledWith('play')
})
it('should replay when paused is false and played list is empty in playAudioWithAudio', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.paused = false
audio.ended = false
audio.played = null
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
})
describe('buffering internals', () => {
it('should finish stream when receiveAudioData gets an undefined chunk', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const finishStream = vi
.spyOn(player as unknown as { finishStream: () => void }, 'finishStream')
.mockImplementation(() => { })
; (player as unknown as { receiveAudioData: (data: Uint8Array | undefined) => void }).receiveAudioData(undefined)
expect(finishStream).toHaveBeenCalledTimes(1)
})
it('should finish stream when receiveAudioData gets empty bytes while source is open', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const finishStream = vi
.spyOn(player as unknown as { finishStream: () => void }, 'finishStream')
.mockImplementation(() => { })
; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array(0))
expect(finishStream).toHaveBeenCalledTimes(1)
})
it('should queue incoming buffer when source buffer is updating', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
mediaSource.sourceBuffer.updating = true
; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array([1, 2, 3]))
expect(player.cacheBuffers.length).toBe(1)
})
it('should append previously queued buffer before new one when source buffer is idle', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
const existingBuffer = new ArrayBuffer(2)
player.cacheBuffers = [existingBuffer]
mediaSource.sourceBuffer.updating = false
; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array([9]))
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1)
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledWith(existingBuffer)
expect(player.cacheBuffers.length).toBe(1)
})
it('should append cache chunks and end stream when finishStream drains buffers', () => {
vi.useFakeTimers()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
mediaSource.sourceBuffer.updating = false
player.cacheBuffers = [new ArrayBuffer(3)]
; (player as unknown as { finishStream: () => void }).finishStream()
vi.advanceTimersByTime(50)
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1)
expect(mediaSource.endOfStream).toHaveBeenCalledTimes(1)
vi.useRealTimers()
})
})
})

View File

@@ -26,7 +26,6 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
useEffect(() => {
const audio = audioRef.current
/* v8 ignore next 2 - @preserve */
if (!audio)
return
@@ -218,7 +217,6 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
const drawWaveform = useCallback(() => {
const canvas = canvasRef.current
/* v8 ignore next 2 - @preserve */
if (!canvas)
return
@@ -270,20 +268,14 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
drawWaveform()
}, [drawWaveform, bufferedTime, hasStartedPlaying])
const handleMouseMove = useCallback((e: React.MouseEvent<HTMLCanvasElement> | React.TouchEvent<HTMLCanvasElement>) => {
const handleMouseMove = useCallback((e: React.MouseEvent) => {
const canvas = canvasRef.current
const audio = audioRef.current
if (!canvas || !audio)
return
const clientX = 'touches' in e
? e.touches[0]?.clientX ?? e.changedTouches[0]?.clientX
: e.clientX
if (clientX === undefined)
return
const rect = canvas.getBoundingClientRect()
const percent = Math.min(Math.max(0, clientX - rect.left), rect.width) / rect.width
const percent = Math.min(Math.max(0, e.clientX - rect.left), rect.width) / rect.width
const time = percent * duration
// Check if the hovered position is within a buffered range before updating hoverTime
@@ -297,7 +289,7 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
return (
<div className="flex h-9 min-w-[240px] max-w-[420px] items-center gap-2 rounded-[10px] border border-components-panel-border-subtle bg-components-chat-input-audio-bg-alt p-2 shadow-xs backdrop-blur-sm">
<audio ref={audioRef} src={src} preload="auto" data-testid="audio-player">
<audio ref={audioRef} src={src} preload="auto">
{/* If srcs array is provided, render multiple source elements */}
{srcs && srcs.map((srcUrl, index) => (
<source key={index} src={srcUrl} />
@@ -305,8 +297,12 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
</audio>
<button type="button" data-testid="play-pause-btn" className="inline-flex shrink-0 cursor-pointer items-center justify-center border-none text-text-accent transition-all hover:text-text-accent-secondary disabled:text-components-button-primary-bg-disabled" onClick={togglePlay} disabled={!isAudioAvailable}>
{isPlaying
? (<div className="i-ri-pause-circle-fill h-5 w-5" />)
: (<div className="i-ri-play-large-fill h-5 w-5" />)}
? (
<div className="i-ri-pause-circle-fill h-5 w-5" />
)
: (
<div className="i-ri-play-large-fill h-5 w-5" />
)}
</button>
<div className={cn(isAudioAvailable && 'grow')} hidden={!isAudioAvailable}>
<div className="flex h-8 items-center justify-center">
@@ -317,8 +313,6 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
onClick={handleCanvasInteraction}
onMouseMove={handleMouseMove}
onMouseDown={handleCanvasInteraction}
onTouchMove={handleMouseMove}
onTouchStart={handleCanvasInteraction}
/>
<div className="inline-flex min-w-[50px] items-center justify-center text-text-accent-secondary system-xs-medium">
<span className="rounded-[10px] px-0.5 py-1">{formatTime(duration)}</span>

View File

@@ -1,7 +1,8 @@
import type { ToastHandle } from '@/app/components/base/toast'
import { act, fireEvent, render, screen } from '@testing-library/react'
import Toast from '@/app/components/base/toast'
import * as React from 'react'
import { vi } from 'vitest'
import useThemeMock from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import AudioPlayer from '../AudioPlayer'
@@ -44,13 +45,6 @@ async function advanceWaveformTimer() {
})
}
// eslint-disable-next-line ts/no-explicit-any
type ReactEventHandler = ((...args: any[]) => void) | undefined
function getReactProps<T extends Element>(el: T): Record<string, ReactEventHandler> {
const key = Object.keys(el).find(k => k.startsWith('__reactProps$'))
return key ? (el as unknown as Record<string, Record<string, ReactEventHandler>>)[key] : {}
}
// ─── Setup / teardown ─────────────────────────────────────────────────────────
beforeEach(() => {
@@ -62,12 +56,8 @@ beforeEach(() => {
HTMLMediaElement.prototype.load = vi.fn()
})
afterEach(async () => {
await act(async () => {
vi.runOnlyPendingTimers()
await Promise.resolve()
await Promise.resolve()
})
afterEach(() => {
vi.runOnlyPendingTimers()
vi.useRealTimers()
vi.unstubAllGlobals()
})
@@ -310,47 +300,36 @@ describe('AudioPlayer — waveform generation', () => {
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
})
it('should use webkitAudioContext when AudioContext is unavailable', async () => {
vi.stubGlobal('AudioContext', undefined)
vi.stubGlobal('webkitAudioContext', buildAudioContext(320))
stubFetchOk(256)
render(<AudioPlayer src="https://cdn.example/audio.mp3" />)
await advanceWaveformTimer()
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
})
})
// ─── Canvas interactions ──────────────────────────────────────────────────────
async function renderWithDuration(src = 'https://example.com/audio.mp3', durationVal = 120) {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src={src} />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: durationVal, configurable: true })
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => durationVal },
configurable: true,
})
await act(async () => {
audio.dispatchEvent(new Event('loadedmetadata'))
})
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
canvas.getBoundingClientRect = () =>
({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
return { audio, canvas }
}
describe('AudioPlayer — canvas seek interactions', () => {
async function renderWithDuration(src = 'https://example.com/audio.mp3', durationVal = 120) {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src={src} />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: durationVal, configurable: true })
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => durationVal },
configurable: true,
})
await act(async () => {
audio.dispatchEvent(new Event('loadedmetadata'))
})
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
canvas.getBoundingClientRect = () =>
({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
return { audio, canvas }
}
it('should seek to clicked position and start playback', async () => {
const { audio, canvas } = await renderWithDuration()
@@ -413,309 +392,3 @@ describe('AudioPlayer — canvas seek interactions', () => {
})
})
})
// ─── Missing coverage tests ───────────────────────────────────────────────────
describe('AudioPlayer — missing coverage', () => {
it('should handle unmounting without crashing (clears timeout)', () => {
const { unmount } = render(<AudioPlayer src="https://example.com/a.mp3" />)
unmount()
// Timer is cleared, no state update should happen after unmount
})
it('should handle getContext returning null safely', () => {
const originalGetContext = HTMLCanvasElement.prototype.getContext
HTMLCanvasElement.prototype.getContext = vi.fn().mockReturnValue(null)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
HTMLCanvasElement.prototype.getContext = originalGetContext
})
it('should fallback to fillRect when roundRect is missing in drawWaveform', async () => {
// Note: React 18 / testing-library wraps updates automatically, but we still wait for advanceWaveformTimer
const originalGetContext = HTMLCanvasElement.prototype.getContext
let fillRectCalled = false
HTMLCanvasElement.prototype.getContext = function (this: HTMLCanvasElement, ...args: Parameters<typeof HTMLCanvasElement.prototype.getContext>) {
const ctx = originalGetContext.apply(this, args) as CanvasRenderingContext2D | null
if (ctx) {
Object.defineProperty(ctx, 'roundRect', { value: undefined, configurable: true })
const origFillRect = ctx.fillRect
ctx.fillRect = function (...fArgs: Parameters<CanvasRenderingContext2D['fillRect']>) {
fillRectCalled = true
return origFillRect.apply(this, fArgs)
}
}
return ctx as CanvasRenderingContext2D
} as typeof HTMLCanvasElement.prototype.getContext
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
await advanceWaveformTimer()
expect(fillRectCalled).toBe(true)
HTMLCanvasElement.prototype.getContext = originalGetContext
})
it('should handle play error gracefully when togglePlay is clicked', async () => {
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
vi.spyOn(HTMLMediaElement.prototype, 'play').mockRejectedValue(new Error('play failed'))
render(<AudioPlayer src="https://example.com/audio.mp3" />)
const btn = screen.getByTestId('play-pause-btn')
await act(async () => {
fireEvent.click(btn)
})
expect(errorSpy).toHaveBeenCalled()
errorSpy.mockRestore()
})
it('should notify error when audio.play() fails during canvas seek', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: 120, configurable: true })
canvas.getBoundingClientRect = () => ({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
vi.spyOn(HTMLMediaElement.prototype, 'play').mockRejectedValue(new Error('play failed'))
await act(async () => {
fireEvent.click(canvas, { clientX: 100 })
})
// We can observe the error by checking document body for toast if Toast acts synchronously
// Or we just ensure the execution branched into catch naturally.
expect(HTMLMediaElement.prototype.play).toHaveBeenCalled()
})
it('should support touch events on canvas', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: 120, configurable: true })
canvas.getBoundingClientRect = () => ({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
await act(async () => {
// Use touch events
fireEvent.touchStart(canvas, {
touches: [{ clientX: 50 }],
})
})
expect(HTMLMediaElement.prototype.play).toHaveBeenCalled()
})
it('should gracefully handle interaction when canvas/audio refs are null', async () => {
const { unmount } = render(<AudioPlayer src="https://example.com/audio.mp3" />)
const canvas = screen.getByTestId('waveform-canvas')
unmount()
expect(canvas).toBeTruthy()
})
it('should keep play button disabled when source is unavailable', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
const toastSpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({} as unknown as ToastHandle))
render(<AudioPlayer src="blob:https://example.com" />)
await advanceWaveformTimer() // sets isAudioAvailable to false (invalid protocol)
const btn = screen.getByTestId('play-pause-btn')
await act(async () => {
fireEvent.click(btn)
})
expect(btn).toBeDisabled()
expect(HTMLMediaElement.prototype.play).not.toHaveBeenCalled()
expect(toastSpy).not.toHaveBeenCalled()
toastSpy.mockRestore()
})
it('should notify when toggle is invoked while audio is unavailable', async () => {
const toastSpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({} as unknown as ToastHandle))
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
await act(async () => {
audio.dispatchEvent(new Event('error'))
})
const btn = screen.getByTestId('play-pause-btn')
const props = getReactProps(btn)
await act(async () => {
props.onClick?.()
})
expect(toastSpy).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'Audio element not found',
}))
toastSpy.mockRestore()
})
})
describe('AudioPlayer — additional branch coverage', () => {
it('should render multiple source elements when srcs is provided', () => {
render(<AudioPlayer srcs={['a.mp3', 'b.ogg']} />)
const audio = screen.getByTestId('audio-player')
const sources = audio.querySelectorAll('source')
expect(sources).toHaveLength(2)
})
it('should handle handleMouseMove with empty touch list', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/a.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas')
await act(async () => {
fireEvent.touchMove(canvas, {
touches: [],
changedTouches: [{ clientX: 50 }],
})
})
})
it('should handle handleMouseMove with missing clientX', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/a.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas')
await act(async () => {
fireEvent.touchMove(canvas, {
touches: [{}] as unknown as TouchList,
})
})
})
it('should render "Audio source unavailable" when isAudioAvailable is false', async () => {
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
await act(async () => {
audio.dispatchEvent(new Event('error'))
})
expect(screen.queryByTestId('play-pause-btn')).toBeDisabled()
})
it('should update current time on timeupdate event', async () => {
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'currentTime', { value: 10, configurable: true })
await act(async () => {
audio.dispatchEvent(new Event('timeupdate'))
})
})
it('should ignore toggle click after audio error marks source unavailable', async () => {
const toastSpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({} as unknown as ToastHandle))
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
await act(async () => {
audio.dispatchEvent(new Event('error'))
})
const btn = screen.getByTestId('play-pause-btn')
await act(async () => {
fireEvent.click(btn)
})
expect(btn).toBeDisabled()
expect(HTMLMediaElement.prototype.play).not.toHaveBeenCalled()
expect(toastSpy).not.toHaveBeenCalled()
toastSpy.mockRestore()
})
it('should cover Dark theme waveform states', async () => {
; (useThemeMock as ReturnType<typeof vi.fn>).mockReturnValue({ theme: Theme.dark })
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: 100, configurable: true })
Object.defineProperty(audio, 'currentTime', { value: 50, configurable: true })
await act(async () => {
audio.dispatchEvent(new Event('loadedmetadata'))
audio.dispatchEvent(new Event('timeupdate'))
})
await advanceWaveformTimer()
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
})
it('should handle missing canvas/audio in handleCanvasInteraction/handleMouseMove', async () => {
const { unmount } = render(<AudioPlayer src="https://example.com/a.mp3" />)
const canvas = screen.getByTestId('waveform-canvas')
unmount()
fireEvent.click(canvas)
fireEvent.mouseMove(canvas)
})
it('should cover waveform branches for hover and played states', async () => {
const { audio, canvas } = await renderWithDuration('https://example.com/a.mp3', 100)
// Set some progress
Object.defineProperty(audio, 'currentTime', { value: 20, configurable: true })
// Trigger hover on a buffered range
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => 100 },
configurable: true,
})
await act(async () => {
fireEvent.mouseMove(canvas, { clientX: 50 }) // 50s hover
audio.dispatchEvent(new Event('timeupdate'))
})
expect(canvas).toBeInTheDocument()
})
it('should hit null-ref guards in canvas handlers after unmount', async () => {
const { unmount } = render(<AudioPlayer src="https://example.com/a.mp3" />)
const canvas = screen.getByTestId('waveform-canvas')
const props = getReactProps(canvas)
unmount()
await act(async () => {
props.onClick?.({ preventDefault: vi.fn(), clientX: 10 })
props.onMouseMove?.({ clientX: 10 })
})
})
it('should execute non-matching buffered branch in hover loop', async () => {
const { audio, canvas } = await renderWithDuration('https://example.com/a.mp3', 100)
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => 10 },
configurable: true,
})
await act(async () => {
fireEvent.mouseMove(canvas, { clientX: 180 }) // time near 90, outside 0-10
})
expect(canvas).toBeInTheDocument()
})
})

View File

@@ -1,9 +1,24 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
// AudioGallery.spec.tsx
import { describe, expect, it, vi } from 'vitest'
import AudioGallery from '../index'
// Mock AudioPlayer so we only assert prop forwarding
const audioPlayerMock = vi.fn()
vi.mock('../AudioPlayer', () => ({
default: (props: { srcs: string[] }) => {
audioPlayerMock(props)
return <div data-testid="audio-player" />
},
}))
describe('AudioGallery', () => {
beforeEach(() => {
vi.spyOn(HTMLMediaElement.prototype, 'load').mockImplementation(() => { })
afterEach(() => {
audioPlayerMock.mockClear()
vi.resetModules()
})
it('returns null when srcs array is empty', () => {
@@ -18,15 +33,11 @@ describe('AudioGallery', () => {
expect(screen.queryByTestId('audio-player')).toBeNull()
})
it('filters out falsy srcs and renders only valid sources in AudioPlayer', () => {
it('filters out falsy srcs and passes valid srcs to AudioPlayer', () => {
render(<AudioGallery srcs={['a.mp3', '', 'b.mp3']} />)
const audio = screen.getByTestId('audio-player')
const sources = audio.querySelectorAll('source')
expect(audio).toBeInTheDocument()
expect(sources).toHaveLength(2)
expect(sources[0]?.getAttribute('src')).toBe('a.mp3')
expect(sources[1]?.getAttribute('src')).toBe('b.mp3')
expect(screen.getByTestId('audio-player')).toBeInTheDocument()
expect(audioPlayerMock).toHaveBeenCalledTimes(1)
expect(audioPlayerMock).toHaveBeenCalledWith({ srcs: ['a.mp3', 'b.mp3'] })
})
it('wraps AudioPlayer inside container with expected class', () => {
@@ -34,6 +45,5 @@ describe('AudioGallery', () => {
const root = container.firstChild as HTMLElement
expect(root).toBeTruthy()
expect(root.className).toContain('my-3')
expect(screen.getByTestId('audio-player')).toBeInTheDocument()
})
})

View File

@@ -1,18 +1,6 @@
import type { IChatItem } from '../chat/type'
import type { ChatItem, ChatItemInTree } from '../types'
import type { ChatItemInTree } from '../types'
import { get } from 'es-toolkit/compat'
import { UUID_NIL } from '../constants'
import {
buildChatItemTree,
getLastAnswer,
getProcessedInputsFromUrlParams,
getProcessedSystemVariablesFromUrlParams,
getProcessedUserVariablesFromUrlParams,
getRawInputsFromUrlParams,
getRawUserVariablesFromUrlParams,
getThreadMessages,
isValidGeneratedAnswer,
} from '../utils'
import { buildChatItemTree, getThreadMessages } from '../utils'
import branchedTestMessages from './branchedTestMessages.json'
import legacyTestMessages from './legacyTestMessages.json'
import mixedTestMessages from './mixedTestMessages.json'
@@ -25,15 +13,6 @@ function visitNode(tree: ChatItemInTree | ChatItemInTree[], path: string): ChatI
return get(tree, path)
}
class MockDecompressionStream {
readable: unknown
writable: unknown
constructor() {
this.readable = {}
this.writable = {}
}
}
describe('build chat item tree and get thread messages', () => {
const tree1 = buildChatItemTree(branchedTestMessages as ChatItemInTree[])
@@ -268,12 +247,12 @@ describe('build chat item tree and get thread messages', () => {
expect(tree6).toMatchSnapshot()
})
it('should get thread messages from tree6, using the last message as target', () => {
it ('should get thread messages from tree6, using the last message as target', () => {
const threadMessages6_1 = getThreadMessages(tree6)
expect(threadMessages6_1).toMatchSnapshot()
})
it('should get thread messages from tree6, using specified message as target', () => {
it ('should get thread messages from tree6, using specified message as target', () => {
const threadMessages6_2 = getThreadMessages(tree6, 'ff4c2b43-48a5-47ad-9dc5-08b34ddba61b')
expect(threadMessages6_2).toMatchSnapshot()
})
@@ -290,285 +269,3 @@ describe('build chat item tree and get thread messages', () => {
expect(tree8).toMatchSnapshot()
})
})
describe('chat utils - url params and answer helpers', () => {
const setSearch = (search: string) => {
window.history.replaceState({}, '', `${window.location.pathname}${search}`)
}
beforeEach(() => {
vi.clearAllMocks()
vi.stubGlobal('DecompressionStream', MockDecompressionStream)
vi.stubGlobal('TextDecoder', class {
decode() { return 'decompressed_text' }
})
const mockPipeThrough = vi.fn().mockReturnValue({})
vi.stubGlobal('Response', class {
body = { pipeThrough: mockPipeThrough }
arrayBuffer = vi.fn().mockResolvedValue(new ArrayBuffer(8))
})
setSearch('')
})
afterEach(() => {
vi.unstubAllGlobals()
})
describe('URL Parameter Extractors', () => {
it('getRawInputsFromUrlParams extracts inputs except sys. and user.', async () => {
setSearch('?custom=123&sys.param=456&user.param=789&encoded=a%20b')
const res = await getRawInputsFromUrlParams()
expect(res).toEqual({ custom: '123', encoded: 'a b' })
})
it('getRawUserVariablesFromUrlParams extracts only user. prefixed params', async () => {
setSearch('?custom=123&sys.param=456&user.param=789&user.encoded=a%20b')
const res = await getRawUserVariablesFromUrlParams()
expect(res).toEqual({ param: '789', encoded: 'a b' })
})
it('getProcessedInputsFromUrlParams decompresses base64 inputs', async () => {
setSearch('?custom=123&sys.param=456&user.param=789')
const res = await getProcessedInputsFromUrlParams()
expect(res).toEqual({ custom: 'decompressed_text' })
})
it('getProcessedSystemVariablesFromUrlParams decompresses sys. prefixed params', async () => {
setSearch('?custom=123&sys.param=456&user.param=789')
const res = await getProcessedSystemVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text' })
})
it('getProcessedSystemVariablesFromUrlParams parses redirect_url without query string', async () => {
setSearch(`?redirect_url=${encodeURIComponent('http://example.com')}&sys.param=456`)
const res = await getProcessedSystemVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text' })
})
it('getProcessedSystemVariablesFromUrlParams parses redirect_url', async () => {
setSearch(`?redirect_url=${encodeURIComponent('http://example.com?sys.redirected=abc')}&sys.param=456`)
const res = await getProcessedSystemVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text', redirected: 'decompressed_text' })
})
it('getProcessedUserVariablesFromUrlParams decompresses user. prefixed params', async () => {
setSearch('?custom=123&sys.param=456&user.param=789')
const res = await getProcessedUserVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text' })
})
it('decodeBase64AndDecompress failure returns undefined softly', async () => {
vi.stubGlobal('atob', () => {
throw new Error('invalid')
})
setSearch('?custom=invalid_base64')
const res = await getProcessedInputsFromUrlParams()
expect(res).toEqual({ custom: undefined })
})
})
describe('Answer Validation', () => {
it('isValidGeneratedAnswer returns true for typical answers', () => {
expect(isValidGeneratedAnswer({ isAnswer: true, id: '123', isOpeningStatement: false } as ChatItem)).toBe(true)
})
it('isValidGeneratedAnswer returns false for placeholders', () => {
expect(isValidGeneratedAnswer({ isAnswer: true, id: 'answer-placeholder-123', isOpeningStatement: false } as ChatItem)).toBe(false)
})
it('isValidGeneratedAnswer returns false for opening statements', () => {
expect(isValidGeneratedAnswer({ isAnswer: true, id: '123', isOpeningStatement: true } as ChatItem)).toBe(false)
})
it('isValidGeneratedAnswer returns false for questions', () => {
expect(isValidGeneratedAnswer({ isAnswer: false, id: '123', isOpeningStatement: false } as ChatItem)).toBe(false)
})
it('isValidGeneratedAnswer returns false for falsy items', () => {
expect(isValidGeneratedAnswer(undefined)).toBe(false)
})
it('getLastAnswer returns the last valid answer from a list', () => {
const list = [
{ isAnswer: false, id: 'q1', isOpeningStatement: false },
{ isAnswer: true, id: 'a1', isOpeningStatement: false },
{ isAnswer: false, id: 'q2', isOpeningStatement: false },
{ isAnswer: true, id: 'answer-placeholder-2', isOpeningStatement: false },
] as ChatItem[]
expect(getLastAnswer(list)?.id).toBe('a1')
})
it('getLastAnswer returns null if no valid answer', () => {
const list = [
{ isAnswer: false, id: 'q1', isOpeningStatement: false },
{ isAnswer: true, id: 'answer-placeholder-2', isOpeningStatement: false },
] as ChatItem[]
expect(getLastAnswer(list)).toBeNull()
})
})
describe('ChatItem Tree Builders', () => {
it('buildChatItemTree builds a flat tree for legacy messages (parentMessageId = UUID_NIL)', () => {
const list: IChatItem[] = [
{ id: 'q1', isAnswer: false, parentMessageId: UUID_NIL } as IChatItem,
{ id: 'a1', isAnswer: true, parentMessageId: UUID_NIL } as IChatItem,
{ id: 'q2', isAnswer: false, parentMessageId: UUID_NIL } as IChatItem,
{ id: 'a2', isAnswer: true, parentMessageId: UUID_NIL } as IChatItem,
]
const tree = buildChatItemTree(list)
expect(tree.length).toBe(1)
expect(tree[0].id).toBe('q1')
expect(tree[0].children?.[0].id).toBe('a1')
expect(tree[0].children?.[0].children?.[0].id).toBe('q2')
expect(tree[0].children?.[0].children?.[0].children?.[0].id).toBe('a2')
expect(tree[0].children?.[0].children?.[0].children?.[0].siblingIndex).toBe(0)
})
it('buildChatItemTree builds nested tree based on parentMessageId', () => {
const list: IChatItem[] = [
{ id: 'q1', isAnswer: false, parentMessageId: null } as IChatItem,
{ id: 'a1', isAnswer: true } as IChatItem,
{ id: 'q2', isAnswer: false, parentMessageId: 'a1' } as IChatItem,
{ id: 'a2', isAnswer: true } as IChatItem,
{ id: 'q3', isAnswer: false, parentMessageId: 'a1' } as IChatItem,
{ id: 'a3', isAnswer: true } as IChatItem,
{ id: 'q4', isAnswer: false, parentMessageId: 'missing-parent' } as IChatItem,
{ id: 'a4', isAnswer: true } as IChatItem,
]
const tree = buildChatItemTree(list)
expect(tree.length).toBe(2)
expect(tree[0].id).toBe('q1')
expect(tree[1].id).toBe('q4')
const a1 = tree[0].children![0]
expect(a1.id).toBe('a1')
expect(a1.children?.length).toBe(2)
expect(a1.children![0].id).toBe('q2')
expect(a1.children![1].id).toBe('q3')
expect(a1.children![0].children![0].siblingIndex).toBe(0)
expect(a1.children![1].children![0].siblingIndex).toBe(1)
})
it('getThreadMessages node without children', () => {
const tree = [{ id: 'q1', isAnswer: false }]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'q1')
expect(thread.length).toBe(1)
expect(thread[0].id).toBe('q1')
})
it('getThreadMessages target not found', () => {
const tree = [{ id: 'q1', isAnswer: false, children: [] }]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'missing')
expect(thread.length).toBe(0)
})
it('getThreadMessages target not found with undefined children', () => {
const tree = [{ id: 'q1', isAnswer: false }]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'missing')
expect(thread.length).toBe(0)
})
it('getThreadMessages flat path logic', () => {
const tree = [{
id: 'q1',
isAnswer: false,
children: [{
id: 'a1',
isAnswer: true,
siblingIndex: 0,
children: [{
id: 'q2',
isAnswer: false,
children: [{
id: 'a2',
isAnswer: true,
siblingIndex: 0,
children: [],
}],
}],
}],
}]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[])
expect(thread.length).toBe(4)
expect(thread.map(t => t.id)).toEqual(['q1', 'a1', 'q2', 'a2'])
expect(thread[1].siblingCount).toBe(1)
expect(thread[3].siblingCount).toBe(1)
})
it('getThreadMessages to specific target', () => {
const tree = [{
id: 'q1',
isAnswer: false,
children: [{
id: 'a1',
isAnswer: true,
siblingIndex: 0,
children: [{
id: 'q2',
isAnswer: false,
children: [{
id: 'a2',
isAnswer: true,
siblingIndex: 0,
children: [],
}],
}, {
id: 'q3',
isAnswer: false,
children: [{
id: 'a3',
isAnswer: true,
siblingIndex: 1,
children: [],
}],
}],
}],
}]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'a3')
expect(thread.length).toBe(4)
expect(thread.map(t => t.id)).toEqual(['q1', 'a1', 'q3', 'a3'])
expect(thread[3].prevSibling).toBe('a2')
expect(thread[3].nextSibling).toBeUndefined()
})
it('getThreadMessages targetNode has descendants', () => {
const tree = [{
id: 'q1',
isAnswer: false,
children: [{
id: 'a1',
isAnswer: true,
siblingIndex: 0,
children: [{
id: 'q2',
isAnswer: false,
children: [{
id: 'a2',
isAnswer: true,
siblingIndex: 0,
children: [],
}],
}, {
id: 'q3',
isAnswer: false,
children: [{
id: 'a3',
isAnswer: true,
siblingIndex: 1,
children: [],
}],
}],
}],
}]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'a1')
expect(thread.length).toBe(4)
expect(thread.map(t => t.id)).toEqual(['q1', 'a1', 'q3', 'a3'])
expect(thread[3].prevSibling).toBe('a2')
})
})
})

View File

@@ -4,11 +4,12 @@ import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import type { HumanInputFormData } from '@/types/workflow'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { InputVarType } from '@/app/components/workflow/types'
import {
fetchSuggestedQuestions,
stopChatMessageResponding,
submitHumanInputForm,
} from '@/service/share'
import { TransferMethod } from '@/types/app'
import { useChat } from '../../chat/hooks'
@@ -500,34 +501,6 @@ describe('ChatWrapper', () => {
expect(handleSwitchSibling).toHaveBeenCalledWith('1', expect.any(Object))
})
it('should call fetchSuggestedQuestions from workflow resumption options callback', () => {
const handleSwitchSibling = vi.fn()
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [],
handleSwitchSibling,
} as unknown as ChatHookReturn)
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appPrevChatTree: [{
id: 'resume-node',
content: 'Paused answer',
isAnswer: true,
workflow_run_id: 'workflow-1',
humanInputFormDataList: [{ label: 'resume' }] as unknown as HumanInputFormData[],
children: [],
}],
})
render(<ChatWrapper />)
expect(handleSwitchSibling).toHaveBeenCalledWith('resume-node', expect.any(Object))
const resumeOptions = handleSwitchSibling.mock.calls[0][1]
resumeOptions.onGetSuggestedQuestions('response-from-resume')
expect(fetchSuggestedQuestions).toHaveBeenCalledWith('response-from-resume', 'webApp', 'test-app-id')
})
it('should handle workflow resumption with nested children (DFS)', () => {
const handleSwitchSibling = vi.fn()
vi.mocked(useChat).mockReturnValue({
@@ -787,47 +760,6 @@ describe('ChatWrapper', () => {
})
})
it('should handle human input form submission for web app', async () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
isInstalledApp: false,
})
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [
{ id: 'q1', content: 'Question' },
{
id: 'a1',
isAnswer: true,
content: '',
humanInputFormDataList: [{
id: 'node1',
form_id: 'form1',
form_token: 'token-web-1',
node_id: 'node1',
node_title: 'Node Web 1',
display_in_ui: true,
form_content: '{{#$output.test#}}',
inputs: [{ variable: 'test', label: 'Test', type: 'paragraph', required: true, output_variable_name: 'test', default: { type: 'text', value: '' } }],
actions: [{ id: 'run', title: 'Run', button_style: 'primary' }],
}] as unknown as HumanInputFormData[],
},
],
} as unknown as ChatHookReturn)
render(<ChatWrapper />)
expect(await screen.findByText('Node Web 1')).toBeInTheDocument()
const input = screen.getAllByRole('textbox').find(el => el.closest('.chat-answer-container')) || screen.getAllByRole('textbox')[0]
fireEvent.change(input, { target: { value: 'web-test' } })
fireEvent.click(screen.getByText('Run'))
await waitFor(() => {
expect(submitHumanInputForm).toHaveBeenCalledWith('token-web-1', expect.any(Object))
})
})
it('should filter opening statement in new conversation with single item', () => {
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
@@ -956,16 +888,8 @@ describe('ChatWrapper', () => {
})
it('should render answer icon when configured', () => {
const appDataWithAnswerIcon = {
site: {
...mockAppData.site,
use_icon_as_answer_icon: true,
},
} as unknown as AppData
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appData: appDataWithAnswerIcon,
} as ChatWithHistoryContextValue)
vi.mocked(useChat).mockReturnValue({
@@ -975,7 +899,6 @@ describe('ChatWrapper', () => {
render(<ChatWrapper />)
expect(screen.getByText('Answer')).toBeInTheDocument()
expect(screen.getByAltText('answer icon')).toBeInTheDocument()
})
it('should render question icon when user avatar is available', () => {
@@ -997,26 +920,6 @@ describe('ChatWrapper', () => {
expect(avatar).toBeInTheDocument()
})
it('should use fallback values for nullable appData, appMeta and user name', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appData: null as unknown as AppData,
appMeta: null as unknown as AppMeta,
initUserVariables: {
avatar_url: 'https://example.com/avatar-fallback.png',
},
})
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [{ id: 'q1', content: 'Question with fallback avatar name' }],
} as unknown as ChatHookReturn)
render(<ChatWrapper />)
expect(screen.getByText('Question with fallback avatar name')).toBeInTheDocument()
expect(screen.getByAltText('user')).toBeInTheDocument()
})
it('should set handleStop on currentChatInstanceRef', () => {
const handleStop = vi.fn()
const currentChatInstanceRef = { current: { handleStop: vi.fn() } } as ChatWithHistoryContextValue['currentChatInstanceRef']
@@ -1309,45 +1212,20 @@ describe('ChatWrapper', () => {
it('should handle doRegenerate with editedQuestion', async () => {
const handleSend = vi.fn()
const mockFiles = [
{
id: 'file-q1',
name: 'q1.txt',
type: 'text/plain',
size: 100,
url: 'https://example.com/q1.txt',
extension: 'txt',
mime_type: 'text/plain',
} as unknown as FileEntity,
] as FileEntity[]
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [
{ id: 'q1', content: 'Original question', message_files: mockFiles },
{ id: 'q1', content: 'Original question', message_files: [] },
{ id: 'a1', isAnswer: true, content: 'Answer', parentMessageId: 'q1' },
],
handleSend,
} as unknown as ChatHookReturn)
render(<ChatWrapper />)
const { container } = render(<ChatWrapper />)
fireEvent.click(await screen.findByTestId('edit-btn'))
const editedTextarea = await screen.findByDisplayValue('Original question')
fireEvent.change(editedTextarea, { target: { value: 'Edited question text' } })
fireEvent.click(screen.getByTestId('save-edit-btn'))
await waitFor(() => {
expect(handleSend).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
query: 'Edited question text',
files: mockFiles,
}),
expect.any(Object),
)
})
// This would test line 198-200 - the editedQuestion path
// The actual regenerate with edited question happens through the UI
expect(container).toBeInTheDocument()
})
it('should handle doRegenerate when parentAnswer is not a valid generated answer', async () => {
@@ -1814,31 +1692,4 @@ describe('ChatWrapper', () => {
// Should not be disabled because it's not required
expect(container).not.toBeInTheDocument()
})
it('should handle fallback branches for appParams, appId and empty chat instance ref', async () => {
const handleSend = vi.fn()
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appParams: undefined as unknown as ChatConfig,
appId: '',
currentConversationId: '',
currentChatInstanceRef: { current: null } as unknown as ChatWithHistoryContextValue['currentChatInstanceRef'],
})
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
handleSend,
} as unknown as ChatHookReturn)
render(<ChatWrapper />)
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'trigger fallback path' } })
fireEvent.keyDown(textarea, { key: 'Enter', code: 'Enter', keyCode: 13 })
await waitFor(() => {
expect(handleSend).toHaveBeenCalled()
})
})
})

View File

@@ -1,9 +1,9 @@
import type { i18n } from 'i18next'
import type { ChatConfig } from '../../types'
import type { ChatWithHistoryContextValue } from '../context'
import type { AppData, AppMeta } from '@/models/share'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as ReactI18next from 'react-i18next'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useChatWithHistoryContext } from '../context'
import HeaderInMobile from '../header-in-mobile'
@@ -80,14 +80,7 @@ vi.mock('@/app/components/base/modal', () => ({
// Sidebar mock removed to use real component
const mockAppData: AppData = {
app_id: 'test-app',
custom_config: null,
site: {
title: 'Test Chat',
chat_color_theme: 'blue',
},
}
const mockAppData = { site: { title: 'Test Chat', chat_color_theme: 'blue' } } as unknown as AppData
const defaultContextValue: ChatWithHistoryContextValue = {
appData: mockAppData,
currentConversationId: '',
@@ -111,27 +104,18 @@ const defaultContextValue: ChatWithHistoryContextValue = {
currentChatInstanceRef: { current: { handleStop: vi.fn() } } as ChatWithHistoryContextValue['currentChatInstanceRef'],
setIsResponding: vi.fn(),
setClearChatList: vi.fn(),
appParams: {
system_parameters: {
audio_file_size_limit: 10,
file_size_limit: 10,
image_file_size_limit: 10,
video_file_size_limit: 10,
workflow_file_upload_limit: 10,
},
more_like_this: { enabled: false },
} as ChatConfig,
appMeta: { tool_icons: {} } as AppMeta,
appParams: { system_parameters: { vision_config: { enabled: false } } } as unknown as ChatConfig,
appMeta: {} as AppMeta,
appPrevChatTree: [],
newConversationInputs: {},
newConversationInputsRef: { current: {} },
newConversationInputsRef: { current: {} } as ChatWithHistoryContextValue['newConversationInputsRef'],
appChatListDataLoading: false,
chatShouldReloadKey: '',
isMobile: true,
currentConversationInputs: null,
setCurrentConversationInputs: vi.fn(),
allInputsHidden: false,
conversationRenaming: false,
conversationRenaming: false, // Added missing property
}
describe('HeaderInMobile', () => {
@@ -150,7 +134,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
})
render(<HeaderInMobile />)
@@ -286,7 +270,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handlePinConversation: handlePin,
pinnedConversationList: [],
})
@@ -308,9 +292,9 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handleUnpinConversation: handleUnpin,
pinnedConversationList: [{ id: '1', name: 'Conv 1', inputs: null, introduction: '' }],
pinnedConversationList: [{ id: '1' }] as unknown as ConversationItem[],
})
render(<HeaderInMobile />)
@@ -330,7 +314,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handleRenameConversation: handleRename,
pinnedConversationList: [],
})
@@ -358,7 +342,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handleRenameConversation: handleRename,
pinnedConversationList: [],
})
@@ -389,7 +373,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handleRenameConversation: vi.fn(),
conversationRenaming: true, // Loading state
pinnedConversationList: [],
@@ -412,7 +396,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
})
@@ -438,7 +422,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
})
@@ -470,7 +454,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: '', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: '' } as unknown as ConversationItem,
})
render(<HeaderInMobile />)
@@ -501,17 +485,16 @@ describe('HeaderInMobile', () => {
})
it('should render app icon and title correctly', () => {
const appDataWithIcon: AppData = {
app_id: 'test-app',
custom_config: null,
const appDataWithIcon = {
site: {
title: 'My App',
icon: 'emoji',
icon_type: 'emoji',
icon_url: '',
icon_background: '#FF0000',
chat_color_theme: 'blue',
},
}
} as unknown as AppData
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
@@ -529,7 +512,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
handleRenameConversation: handleRename,
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
@@ -541,59 +524,4 @@ describe('HeaderInMobile', () => {
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
expect(screen.queryByText('share.chat.deleteConversation.title')).not.toBeInTheDocument()
})
it('should use empty string fallback for delete content translation', async () => {
const handleDelete = vi.fn()
const useTranslationSpy = vi.spyOn(ReactI18next, 'useTranslation')
useTranslationSpy.mockReturnValue({
t: (key: string) => key === 'chat.deleteConversation.content' ? '' : key,
i18n: {} as unknown as i18n,
ready: true,
tReady: true,
} as unknown as ReturnType<typeof ReactI18next.useTranslation>)
try {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
})
render(<HeaderInMobile />)
fireEvent.click(await screen.findByText('Conv 1'))
fireEvent.click(await screen.findByText(/sidebar\.action\.delete/i))
expect(await screen.findByRole('button', { name: /common\.operation\.confirm|operation\.confirm/i })).toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.confirm|operation\.confirm/i }))
expect(handleDelete).toHaveBeenCalledWith('1', expect.any(Object))
}
finally {
useTranslationSpy.mockRestore()
}
})
it('should use empty string fallback for rename modal name', async () => {
const handleRename = vi.fn()
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: '', inputs: null, introduction: '' },
handleRenameConversation: handleRename,
pinnedConversationList: [],
})
const { container } = render(<HeaderInMobile />)
const operationTrigger = container.querySelector('.system-md-semibold')?.parentElement as HTMLElement
fireEvent.click(operationTrigger)
fireEvent.click(await screen.findByText(/explore\.sidebar\.action\.rename|sidebar\.action\.rename/i))
const input = await screen.findByRole('textbox')
expect(input).toHaveValue('')
fireEvent.change(input, { target: { value: 'Renamed from empty' } })
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i }))
expect(handleRename).toHaveBeenCalledWith('1', 'Renamed from empty', expect.any(Object))
})
})

View File

@@ -2,7 +2,9 @@ import type { RefObject } from 'react'
import type { ChatConfig } from '../../types'
import type { InstalledApp } from '@/models/explore'
import type { AppConversationData, AppData, AppMeta, ConversationItem } from '@/models/share'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import useDocumentTitle from '@/hooks/use-document-title'
import { useChatWithHistory } from '../hooks'
@@ -111,22 +113,81 @@ describe('ChatWithHistory', () => {
vi.mocked(useChatWithHistory).mockReturnValue(defaultHookReturn)
})
it('renders desktop view with expanded sidebar and builds theme', async () => {
it('renders desktop view with expanded sidebar and builds theme', () => {
vi.mocked(useBreakpoints).mockReturnValue(MediaType.pc)
render(<ChatWithHistory />)
// header-in-mobile renders 'Test Chat'.
// Checks if the desktop elements render correctly
// Checks if the desktop elements render correctly
// Sidebar real component doesn't have data-testid="sidebar", so we check for its presence via class or content.
// Sidebar usually has "New Chat" button or similar.
// However, looking at the Sidebar mock it was just a div.
// Real Sidebar -> web/app/components/base/chat/chat-with-history/sidebar/index.tsx
// It likely has some text or distinct element.
// ChatWrapper also removed mock.
// Header also removed mock.
// For now, let's verify some key elements that should be present in these components.
// Sidebar: "Explore" or "Chats" or verify navigation structure.
// Header: Title or similar.
// ChatWrapper: "Start a new chat" or similar.
// Given the complexity of real components and lack of testIds, we might need to rely on:
// 1. Adding testIds to real components (preferred but might be out of scope if I can't touch them? Guidelines say "don't mock base components", but adding testIds is fine).
// But I can't see those files right now.
// 2. Use getByText for known static content.
// Let's assume some content based on `mockAppData` title 'Test Chat'.
// Header should contain 'Test Chat'.
// Check for "Test Chat" - might appear multiple times (header, sidebar, document title etc)
const titles = screen.getAllByText('Test Chat')
expect(titles.length).toBeGreaterThan(0)
// Sidebar should be present.
// We can check for a specific element in sidebar, e.g. "New Chat" button if it exists.
// Or we can check for the sidebar container class if possible.
// Let's look at `index.tsx` logic.
// Sidebar is rendered.
// Let's try to query by something generic or update to use `container.querySelector`.
// But `screen` is better.
// ChatWrapper is rendered.
// It renders "ChatWrapper" text? No, it's the real component now.
// Real ChatWrapper renders "Welcome" or chat list.
// In `chat-wrapper.spec.tsx`, we saw it renders "Welcome" or "Q1".
// Here `defaultHookReturn` returns empty chat list/conversation.
// So it might render nothing or empty state?
// Let's wait and see what `chat-wrapper.spec.tsx` expectations were.
// It expects "Welcome" if `isOpeningStatement` is true.
// In `index.spec.tsx` mock hook return:
// `currentConversationItem` is undefined.
// `conversationList` is [].
// `appPrevChatTree` is [].
// So ChatWrapper might render empty or loading?
// This is an integration test now.
// We need to ensure the hook return makes sense for the child components.
// Let's just assert the document title since we know that works?
// And check if we can find *something*.
// For now, I'll comment out the specific testId checks and rely on visual/text checks that are likely to flourish.
// header-in-mobile renders 'Test Chat'.
// Sidebar?
// Actually, `ChatWithHistory` renders `Sidebar` in a div with width.
// We can check if that div exists?
// Let's update to checks that are likely to pass or allow us to debug.
// expect(document.title).toBe('Test Chat')
// Checks if the document title was set correctly
expect(useDocumentTitle).toHaveBeenCalledWith('Test Chat')
// Checks if the themeBuilder useEffect fired
await waitFor(() => {
expect(mockBuildTheme).toHaveBeenCalledWith('blue', false)
})
expect(mockBuildTheme).toHaveBeenCalledWith('blue', false)
})
it('renders desktop view with collapsed sidebar and tests hover effects', () => {

View File

@@ -46,7 +46,6 @@ const HeaderInMobile = () => {
setShowConfirm(null)
}, [])
const handleDelete = useCallback(() => {
/* v8 ignore next 2 -- @preserve */
if (showConfirm)
handleDeleteConversation(showConfirm.id, { onSuccess: handleCancelConfirm })
}, [showConfirm, handleDeleteConversation, handleCancelConfirm])
@@ -54,7 +53,6 @@ const HeaderInMobile = () => {
setShowRename(null)
}, [])
const handleRename = useCallback((newName: string) => {
/* v8 ignore next 2 -- @preserve */
if (showRename)
handleRenameConversation(showRename.id, newName, { onSuccess: handleCancelRename })
}, [showRename, handleRenameConversation, handleCancelRename])

View File

@@ -1,128 +0,0 @@
import type { InputForm } from '../type'
import { renderHook } from '@testing-library/react'
import { InputVarType } from '@/app/components/workflow/types'
import { TransferMethod } from '@/types/app'
import { useCheckInputsForms } from '../check-input-forms-hooks'
const mockNotify = vi.fn()
vi.mock('@/app/components/base/toast/context', () => ({
useToastContext: () => ({ notify: mockNotify }),
}))
describe('useCheckInputsForms', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should return true when no inputs required', () => {
const { result } = renderHook(() => useCheckInputsForms())
const isValid = result.current.checkInputsForm({}, [])
expect(isValid).toBe(true)
})
it('should return false and notify when a required input is missing', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_var', label: 'Test Variable', required: true, type: InputVarType.textInput as string }]
const isValid = result.current.checkInputsForm({}, inputsForm as InputForm[])
expect(isValid).toBe(false)
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
message: expect.stringContaining('appDebug.errorMessage.valueOfVarRequired'),
}),
)
})
it('should ignore missing but not required inputs', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_var', label: 'Test Variable', required: false, type: InputVarType.textInput as string }]
const isValid = result.current.checkInputsForm({}, inputsForm as InputForm[])
expect(isValid).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should notify and return undefined when a file is still uploading (singleFile)', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_file', label: 'Test File', required: true, type: InputVarType.singleFile as string }]
const inputs = {
test_file: { transferMethod: TransferMethod.local_file }, // no uploadedId means still uploading
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBeUndefined()
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'info',
message: 'appDebug.errorMessage.waitForFileUpload',
}))
})
it('should notify and return undefined when a file is still uploading (multiFiles)', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_files', label: 'Test Files', required: true, type: InputVarType.multiFiles as string }]
const inputs = {
test_files: [{ transferMethod: TransferMethod.local_file }], // no uploadedId
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBeUndefined()
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'info',
message: 'appDebug.errorMessage.waitForFileUpload',
}))
})
it('should return true when all files are uploaded and required variables are present', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_file', label: 'Test File', required: true, type: InputVarType.singleFile as string }]
const inputs = {
test_file: { transferMethod: TransferMethod.local_file, uploadedId: '123' }, // uploaded
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should short-circuit remaining fields after first required input is missing', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [
{ variable: 'missing_text', label: 'Missing Text', required: true, type: InputVarType.textInput as string },
{ variable: 'later_file', label: 'Later File', required: true, type: InputVarType.singleFile as string },
]
const inputs = {
later_file: { transferMethod: TransferMethod.local_file },
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBe(false)
expect(mockNotify).toHaveBeenCalledTimes(1)
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: expect.stringContaining('appDebug.errorMessage.valueOfVarRequired'),
}))
})
it('should short-circuit remaining fields after detecting file upload in progress', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [
{ variable: 'uploading_file', label: 'Uploading File', required: true, type: InputVarType.singleFile as string },
{ variable: 'later_required_text', label: 'Later Required Text', required: true, type: InputVarType.textInput as string },
]
const inputs = {
uploading_file: { transferMethod: TransferMethod.local_file }, // still uploading
later_required_text: '',
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBeUndefined()
expect(mockNotify).toHaveBeenCalledTimes(1)
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'info',
message: 'appDebug.errorMessage.waitForFileUpload',
}))
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { vi } from 'vitest'
import Toast from '../../../toast'
import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context'
@@ -168,8 +169,7 @@ describe('Question component', () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
const item = makeItem()
renderWithProvider(item, onRegenerate)
renderWithProvider(makeItem(), onRegenerate)
const editBtn = screen.getByTestId('edit-btn')
await user.click(editBtn)
@@ -184,7 +184,7 @@ describe('Question component', () => {
await user.click(resendBtn)
await waitFor(() => {
expect(onRegenerate).toHaveBeenCalledWith(item, { message: 'Edited question', files: [] })
expect(onRegenerate).toHaveBeenCalledWith(makeItem(), { message: 'Edited question', files: [] })
})
})
@@ -199,7 +199,7 @@ describe('Question component', () => {
await user.clear(textbox)
await user.type(textbox, 'Edited question')
const cancelBtn = await screen.findByTestId('cancel-edit-btn')
const cancelBtn = screen.getByRole('button', { name: /operation.cancel/i })
await user.click(cancelBtn)
await waitFor(() => {
@@ -349,120 +349,4 @@ describe('Question component', () => {
const contentContainer = screen.getByTestId('question-content')
expect(contentContainer.getAttribute('style')).not.toBeNull()
})
it('should cover composition lifecycle preventing enter submitting when composing', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
const item = makeItem()
renderWithProvider(item, onRegenerate)
const editBtn = screen.getByTestId('edit-btn')
await user.click(editBtn)
const textbox = await screen.findByRole('textbox')
await user.clear(textbox)
// Simulate composition start and typing
act(() => {
textbox.focus()
})
// Simulate composition start
fireEvent.compositionStart(textbox)
// Try to press Enter while composing
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' })
// Simulate composition end
fireEvent.compositionEnd(textbox)
// Expect onRegenerate not to be called because Enter was pressed during composition
expect(onRegenerate).not.toHaveBeenCalled()
// Let setTimeout finish its 50ms interval to clear isComposing
await new Promise(r => setTimeout(r, 60))
// Now press Enter after composition is fully cleared
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' })
expect(onRegenerate).toHaveBeenCalledWith(item, { message: '', files: [] })
})
it('should prevent Enter from submitting when shiftKey is pressed', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
const item = makeItem()
renderWithProvider(item, onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
// Press Shift+Enter
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter', shiftKey: true })
expect(onRegenerate).not.toHaveBeenCalled()
})
it('should ignore enter when nativeEvent.isComposing is true', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
renderWithProvider(makeItem(), onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
// Create an event with nativeEvent.isComposing = true
const event = new KeyboardEvent('keydown', { key: 'Enter', code: 'Enter' })
Object.defineProperty(event, 'isComposing', { value: true })
fireEvent(textbox, event)
expect(onRegenerate).not.toHaveBeenCalled()
})
it('should clear timer on cancel and on component unmount', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
const { unmount } = renderWithProvider(makeItem(), onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
fireEvent.compositionStart(textbox)
fireEvent.compositionEnd(textbox)
// Timer is now running, let's start another composition to clear it
fireEvent.compositionStart(textbox)
fireEvent.compositionEnd(textbox)
const cancelBtn = await screen.findByTestId('cancel-edit-btn')
await user.click(cancelBtn)
// Test unmount clearing timer
await user.click(screen.getByTestId('edit-btn'))
const textbox2 = await screen.findByRole('textbox')
fireEvent.compositionStart(textbox2)
fireEvent.compositionEnd(textbox2)
unmount()
expect(onRegenerate).not.toHaveBeenCalled()
})
it('should ignore enter when handleResend with active timer', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
renderWithProvider(makeItem(), onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
fireEvent.compositionStart(textbox)
fireEvent.compositionEnd(textbox) // starts timer
const saveBtn = screen.getByTestId('save-edit-btn')
await user.click(saveBtn) // handleResend clears timer
expect(onRegenerate).toHaveBeenCalled()
})
})

View File

@@ -1,121 +0,0 @@
import type { InputForm } from '../type'
import { InputVarType } from '@/app/components/workflow/types'
import { getProcessedInputs, processInputFileFromServer, processOpeningStatement } from '../utils'
vi.mock('@/app/components/base/file-uploader/utils', () => ({
getProcessedFiles: vi.fn((files: File[]) => files.map((f: File) => ({ ...f, processed: true }))),
}))
describe('chat/chat/utils.ts', () => {
describe('processOpeningStatement', () => {
it('returns empty string if openingStatement is falsy', () => {
expect(processOpeningStatement('', {}, [])).toBe('')
})
it('replaces variables with input values when available', () => {
const result = processOpeningStatement('Hello {{name}}', { name: 'Alice' }, [])
expect(result).toBe('Hello Alice')
})
it('replaces variables with labels when input value is not available but form has variable', () => {
const result = processOpeningStatement('Hello {{user_name}}', {}, [{ variable: 'user_name', label: 'Name Label', type: InputVarType.textInput }] as InputForm[])
expect(result).toBe('Hello {{Name Label}}')
})
it('keeps original match when input value and form are not available', () => {
const result = processOpeningStatement('Hello {{unknown}}', {}, [])
expect(result).toBe('Hello {{unknown}}')
})
})
describe('processInputFileFromServer', () => {
it('maps server file object to local schema', () => {
const result = processInputFileFromServer({
type: 'image',
transfer_method: 'local_file',
remote_url: 'http://example.com/img.png',
related_id: '123',
})
expect(result).toEqual({
type: 'image',
transfer_method: 'local_file',
url: 'http://example.com/img.png',
upload_file_id: '123',
})
})
})
describe('getProcessedInputs', () => {
it('processes checkbox input types to boolean', () => {
const inputs = { terms: 'true', conds: null }
const inputsForm = [
{ variable: 'terms', type: InputVarType.checkbox as string },
{ variable: 'conds', type: InputVarType.checkbox as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result).toEqual({ terms: true, conds: false })
})
it('ignores null values', () => {
const inputs = { test: null }
const inputsForm = [{ variable: 'test', type: InputVarType.textInput as string }]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result).toEqual({ test: null })
})
it('processes singleFile using transfer_method logic', () => {
const inputs = {
file1: { transfer_method: 'local_file', url: '1' },
file2: { id: 'file2' },
}
const inputsForm = [
{ variable: 'file1', type: InputVarType.singleFile as string },
{ variable: 'file2', type: InputVarType.singleFile as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.file1).toHaveProperty('transfer_method', 'local_file')
expect(result.file2).toHaveProperty('processed', true)
})
it('processes multiFiles using transfer_method logic', () => {
const inputs = {
files1: [{ transfer_method: 'local_file', url: '1' }],
files2: [{ id: 'file2' }],
}
const inputsForm = [
{ variable: 'files1', type: InputVarType.multiFiles as string },
{ variable: 'files2', type: InputVarType.multiFiles as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.files1[0]).toHaveProperty('transfer_method', 'local_file')
expect(result.files2[0]).toHaveProperty('processed', true)
})
it('processes jsonObject parsing correct json', () => {
const inputs = {
json1: '{"key": "value"}',
}
const inputsForm = [{ variable: 'json1', type: InputVarType.jsonObject as string }]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.json1).toEqual({ key: 'value' })
})
it('processes jsonObject falling back to original if json is array or plain string/invalid json', () => {
const inputs = {
jsonInvalid: 'invalid json',
jsonArray: '["a", "b"]',
jsonPlainObj: { key: 'value' },
}
const inputsForm = [
{ variable: 'jsonInvalid', type: InputVarType.jsonObject as string },
{ variable: 'jsonArray', type: InputVarType.jsonObject as string },
{ variable: 'jsonPlainObj', type: InputVarType.jsonObject as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.jsonInvalid).toBe('invalid json')
expect(result.jsonArray).toBe('["a", "b"]')
expect(result.jsonPlainObj).toEqual({ key: 'value' })
})
})
})

View File

@@ -1,437 +0,0 @@
import { act, renderHook } from '@testing-library/react'
import { useTextAreaHeight } from '../hooks'
describe('useTextAreaHeight', () => {
// Mock getBoundingClientRect for all ref elements
const mockGetBoundingClientRect = (
width: number = 0,
height: number = 0,
) => ({
width,
height,
top: 0,
left: 0,
bottom: height,
right: width,
x: 0,
y: 0,
toJSON: () => ({}),
})
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render without crashing', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current).toBeDefined()
})
it('should return all required properties', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current).toHaveProperty('wrapperRef')
expect(result.current).toHaveProperty('textareaRef')
expect(result.current).toHaveProperty('textValueRef')
expect(result.current).toHaveProperty('holdSpaceRef')
expect(result.current).toHaveProperty('handleTextareaResize')
expect(result.current).toHaveProperty('isMultipleLine')
})
})
describe('Initial State', () => {
it('should initialize with isMultipleLine as false', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current.isMultipleLine).toBe(false)
})
it('should initialize refs as null', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current.wrapperRef.current).toBeNull()
expect(result.current.textValueRef.current).toBeNull()
expect(result.current.holdSpaceRef.current).toBeNull()
})
it('should initialize textareaRef as undefined', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current.textareaRef.current).toBeUndefined()
})
})
describe('Height Computation Logic (via handleTextareaResize)', () => {
it('should not update state when any ref is missing', () => {
const { result } = renderHook(() => useTextAreaHeight())
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(false)
})
it('should set isMultipleLine to true when textarea height exceeds 32px', () => {
const { result } = renderHook(() => useTextAreaHeight())
// Set up refs with mock elements
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 64), // height > 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0),
)
// Assign elements to refs
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should set isMultipleLine to true when combined content width exceeds wrapper width', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(200, 0), // wrapperWidth = 200
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20), // height <= 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(120, 0), // textValueWidth = 120
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0), // holdSpaceWidth = 100, total = 220 > 200
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should set isMultipleLine to false when content fits in wrapper', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0), // wrapperWidth = 300
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20), // height <= 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0), // textValueWidth = 100
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0), // holdSpaceWidth = 50, total = 150 < 300
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(false)
})
it('should handle exact boundary when combined width equals wrapper width', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(200, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0), // total = 200, equals wrapperWidth
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should handle boundary case when textarea height equals 32px', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 32), // exactly 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
// height = 32 is not > 32, so should check width condition
expect(result.current.isMultipleLine).toBe(false)
})
})
describe('handleTextareaResize', () => {
it('should be a function', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(typeof result.current.handleTextareaResize).toBe('function')
})
it('should call handleComputeHeight when invoked', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 64),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should update state based on new dimensions', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
const wrapperRect = vi.spyOn(wrapperElement, 'getBoundingClientRect')
const textareaRect = vi.spyOn(textareaElement, 'getBoundingClientRect')
const textValueRect = vi.spyOn(textValueElement, 'getBoundingClientRect')
const holdSpaceRect = vi.spyOn(holdSpaceElement, 'getBoundingClientRect')
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
// First call - content fits
wrapperRect.mockReturnValue(mockGetBoundingClientRect(300, 0))
textareaRect.mockReturnValue(mockGetBoundingClientRect(300, 20))
textValueRect.mockReturnValue(mockGetBoundingClientRect(100, 0))
holdSpaceRect.mockReturnValue(mockGetBoundingClientRect(50, 0))
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(false)
// Second call - content overflows
textareaRect.mockReturnValue(mockGetBoundingClientRect(300, 64))
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
})
describe('Callback Stability', () => {
it('should maintain ref objects across rerenders', () => {
const { result, rerender } = renderHook(() => useTextAreaHeight())
const firstWrapperRef = result.current.wrapperRef
const firstTextareaRef = result.current.textareaRef
const firstTextValueRef = result.current.textValueRef
const firstHoldSpaceRef = result.current.holdSpaceRef
rerender()
expect(result.current.wrapperRef).toBe(firstWrapperRef)
expect(result.current.textareaRef).toBe(firstTextareaRef)
expect(result.current.textValueRef).toBe(firstTextValueRef)
expect(result.current.holdSpaceRef).toBe(firstHoldSpaceRef)
})
})
describe('Edge Cases', () => {
it('should handle zero dimensions', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
// When all dimensions are 0, 0 + 0 >= 0 is true, so isMultipleLine is true
expect(result.current.isMultipleLine).toBe(true)
})
it('should handle very large dimensions', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(10000, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(10000, 100),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(5000, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(5000, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should handle numeric precision edge cases', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(200.5, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100.2, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100.3, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
})
})

View File

@@ -1,7 +1,7 @@
import type { FileUpload } from '@/app/components/base/features/types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { TransferMethod } from '@/types/app'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { vi } from 'vitest'
@@ -52,8 +52,6 @@ vi.mock('@/app/components/base/file-uploader/store', () => ({
// ---------------------------------------------------------------------------
// File-uploader hooks provide stable drag/drop handlers
// ---------------------------------------------------------------------------
let mockIsDragActive = false
vi.mock('@/app/components/base/file-uploader/hooks', () => ({
useFile: () => ({
handleDragFileEnter: vi.fn(),
@@ -61,7 +59,7 @@ vi.mock('@/app/components/base/file-uploader/hooks', () => ({
handleDragFileOver: vi.fn(),
handleDropFile: vi.fn(),
handleClipboardPasteFile: vi.fn(),
isDragActive: mockIsDragActive,
isDragActive: false,
}),
}))
@@ -212,7 +210,6 @@ describe('ChatInputArea', () => {
beforeEach(() => {
vi.clearAllMocks()
mockFileStore.files = []
mockIsDragActive = false
mockIsMultipleLine = false
})
@@ -239,12 +236,6 @@ describe('ChatInputArea', () => {
expect(disabledWrapper).toBeInTheDocument()
})
it('should apply drag-active styles when a file is being dragged over the input', () => {
mockIsDragActive = true
const { container } = render(<ChatInputArea visionConfig={mockVisionConfig} />)
expect(container.querySelector('.border-dashed')).toBeInTheDocument()
})
it('should render the operation section inline when single-line', () => {
// mockIsMultipleLine is false by default
render(<ChatInputArea visionConfig={mockVisionConfig} />)
@@ -340,30 +331,6 @@ describe('ChatInputArea', () => {
expect(onSend).toHaveBeenCalledWith('With attachment', [uploadedFile])
})
it('should not send on Enter while IME composition is active, then send after composition ends', () => {
vi.useFakeTimers()
try {
const onSend = vi.fn()
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
const textarea = getTextarea()
fireEvent.change(textarea, { target: { value: 'Composed text' } })
fireEvent.compositionStart(textarea)
fireEvent.keyDown(textarea, { key: 'Enter' })
expect(onSend).not.toHaveBeenCalled()
fireEvent.compositionEnd(textarea)
vi.advanceTimersByTime(60)
fireEvent.keyDown(textarea, { key: 'Enter' })
expect(onSend).toHaveBeenCalledWith('Composed text', [])
}
finally {
vi.useRealTimers()
}
})
})
// -------------------------------------------------------------------------

View File

@@ -219,8 +219,8 @@ const Question: FC<QuestionProps> = ({
/>
</div>
<div className="flex items-center justify-end gap-2">
<Button className="min-w-24" onClick={handleCancelEditing} data-testid="cancel-edit-btn">{t('operation.cancel', { ns: 'common' })}</Button>
<Button className="min-w-24" variant="primary" onClick={handleResend} data-testid="save-edit-btn">{t('operation.save', { ns: 'common' })}</Button>
<Button className="min-w-24" onClick={handleCancelEditing}>{t('operation.cancel', { ns: 'common' })}</Button>
<Button className="min-w-24" variant="primary" onClick={handleResend}>{t('operation.save', { ns: 'common' })}</Button>
</div>
</div>
)}

View File

@@ -14,17 +14,6 @@ import { shareQueryKeys } from '@/service/use-share'
import { CONVERSATION_ID_INFO } from '../../constants'
import { useEmbeddedChatbot } from '../hooks'
type InputForm = {
variable: string
type: string
default?: unknown
required?: boolean
label?: string
max_length?: number
options?: string[]
hide?: boolean
}
vi.mock('@/i18n-config/client', () => ({
changeLanguage: vi.fn().mockResolvedValue(undefined),
}))
@@ -51,23 +40,13 @@ vi.mock('@/context/web-app-context', () => ({
useWebAppStore: (selector?: (state: typeof mockStoreState) => unknown) => useWebAppStoreMock(selector),
}))
const {
mockGetProcessedInputsFromUrlParams,
mockGetProcessedSystemVariablesFromUrlParams,
mockGetProcessedUserVariablesFromUrlParams,
} = vi.hoisted(() => ({
mockGetProcessedInputsFromUrlParams: vi.fn(),
mockGetProcessedSystemVariablesFromUrlParams: vi.fn(),
mockGetProcessedUserVariablesFromUrlParams: vi.fn(),
}))
vi.mock('../../utils', async () => {
const actual = await vi.importActual<typeof import('../../utils')>('../../utils')
return {
...actual,
getProcessedInputsFromUrlParams: mockGetProcessedInputsFromUrlParams,
getProcessedSystemVariablesFromUrlParams: mockGetProcessedSystemVariablesFromUrlParams,
getProcessedUserVariablesFromUrlParams: mockGetProcessedUserVariablesFromUrlParams,
getProcessedInputsFromUrlParams: vi.fn().mockResolvedValue({}),
getProcessedSystemVariablesFromUrlParams: vi.fn().mockResolvedValue({}),
getProcessedUserVariablesFromUrlParams: vi.fn().mockResolvedValue({}),
}
})
@@ -86,12 +65,6 @@ vi.mock('@/service/share', async (importOriginal) => {
}
})
const STABLE_MOCK_DATA = { data: {} }
vi.mock('@/service/use-try-app', () => ({
useGetTryAppInfo: vi.fn(() => STABLE_MOCK_DATA),
useGetTryAppParams: vi.fn(() => STABLE_MOCK_DATA),
}))
const mockFetchConversations = vi.mocked(fetchConversations)
const mockFetchChatList = vi.mocked(fetchChatList)
const mockGenerationConversationName = vi.mocked(generationConversationName)
@@ -112,20 +85,12 @@ const createWrapper = (queryClient: QueryClient) => {
)
}
const renderWithClient = async <T,>(hook: () => T) => {
const renderWithClient = <T,>(hook: () => T) => {
const queryClient = createQueryClient()
const wrapper = createWrapper(queryClient)
let result: ReturnType<typeof renderHook<T, unknown>> | undefined
act(() => {
result = renderHook(hook, { wrapper })
})
await waitFor(() => {
if (queryClient.isFetching() > 0)
throw new Error('Queries are still fetching')
}, { timeout: 2000 })
return {
queryClient,
...result!,
...renderHook(hook, { wrapper }),
}
}
@@ -148,10 +113,6 @@ const createConversationData = (overrides: Partial<AppConversationData> = {}): A
describe('useEmbeddedChatbot', () => {
beforeEach(() => {
vi.clearAllMocks()
// Re-establish default mock implementations after clearAllMocks
mockGetProcessedInputsFromUrlParams.mockResolvedValue({})
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({})
mockGetProcessedUserVariablesFromUrlParams.mockResolvedValue({})
localStorage.removeItem(CONVERSATION_ID_INFO)
mockStoreState.appInfo = {
app_id: 'app-1',
@@ -167,8 +128,6 @@ describe('useEmbeddedChatbot', () => {
mockStoreState.appParams = null
mockStoreState.embeddedConversationId = 'conversation-1'
mockStoreState.embeddedUserId = 'embedded-user-1'
mockFetchConversations.mockResolvedValue({ data: [], has_more: false, limit: 100 })
mockFetchChatList.mockResolvedValue({ data: [] })
})
afterEach(() => {
@@ -191,7 +150,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
// Act
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Assert
await waitFor(() => {
@@ -208,49 +167,6 @@ describe('useEmbeddedChatbot', () => {
expect(result.current.conversationList).toEqual(listData.data)
})
})
it('should format chat list history correctly into appPrevChatList', async () => {
// Provide a currentConversationId by rendering successfully
mockStoreState.embeddedConversationId = 'conversation-1'
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({ conversation_id: 'conversation-1' })
mockFetchChatList.mockResolvedValue({
data: [{
id: 'msg-1',
query: 'Hello',
answer: 'Hi there!',
message_files: [{ belongs_to: 'user', id: 'mf-1' }, { belongs_to: 'assistant', id: 'mf-2' }],
agent_thoughts: [{ id: 'at-1' }],
feedback: { rating: 'like' },
}],
})
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Wait for the mock to be called
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', AppSourceType.webApp, 'app-1')
})
// Wait for the chat list to be populated
await waitFor(() => {
expect(result.current.appPrevChatList.length).toBeGreaterThan(0)
})
// We expect the formatting logic to split the message into question and answer ChatItems
const chatList = result.current.appPrevChatList
const userMsg = chatList.find((msg: unknown) => (msg as Record<string, unknown>).id === 'question-msg-1')
expect(userMsg).toBeDefined()
expect((userMsg as Record<string, unknown>)?.content).toBe('Hello')
expect((userMsg as Record<string, unknown>)?.isAnswer).toBe(false)
const assistantMsg = ((userMsg as Record<string, unknown>)?.children as unknown[])?.[0]
expect(assistantMsg).toBeDefined()
expect((assistantMsg as Record<string, unknown>)?.id).toBe('msg-1')
expect((assistantMsg as Record<string, unknown>)?.content).toBe('Hi there!')
expect((assistantMsg as Record<string, unknown>)?.isAnswer).toBe(true)
expect(((assistantMsg as Record<string, unknown>)?.feedback as Record<string, unknown>)?.rating).toBe('like')
})
})
// Scenario: completion invalidates share caches and merges generated names.
@@ -268,7 +184,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
mockGenerationConversationName.mockResolvedValue(generatedConversation)
const { result, queryClient } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result, queryClient } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const invalidateSpy = vi.spyOn(queryClient, 'invalidateQueries')
// Act
@@ -298,7 +214,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
mockGenerationConversationName.mockResolvedValue(createConversationItem({ id: 'conversation-1' }))
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledTimes(1)
@@ -328,7 +244,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
mockGenerationConversationName.mockResolvedValue(createConversationItem({ id: 'conversation-new' }))
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Act
act(() => {
@@ -345,215 +261,4 @@ describe('useEmbeddedChatbot', () => {
})
})
})
// Scenario: TryApp mode initialization and logic.
describe('TryApp mode', () => {
it('should use tryApp source type and skip URL overrides and user fetch', async () => {
// Arrange
const { useGetTryAppInfo } = await import('@/service/use-try-app')
const mockTryAppInfo = { app_id: 'try-app-1', site: { title: 'Try App' } };
(useGetTryAppInfo as unknown as ReturnType<typeof vi.fn>).mockReturnValue({ data: mockTryAppInfo })
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({})
// Act
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.tryApp, 'try-app-1'))
// Assert
expect(result.current.isInstalledApp).toBe(false)
expect(result.current.appId).toBe('try-app-1')
expect(result.current.appData?.site.title).toBe('Try App')
// ensure URL fetching is skipped
expect(mockGetProcessedSystemVariablesFromUrlParams).not.toHaveBeenCalled()
})
})
// Language overrides tests were causing hang, removed for now.
// Scenario: Removing conversation id info
describe('removeConversationIdInfo', () => {
it('should successfully remove a stored conversation ID info by appId', async () => {
// Setup some initial info
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ 'app-1': { 'user-1': 'conv-id' } }))
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
act(() => {
result.current.removeConversationIdInfo('app-1')
})
await waitFor(() => {
const storedValue = localStorage.getItem(CONVERSATION_ID_INFO)
const parsed = storedValue ? JSON.parse(storedValue) : {}
expect(parsed['app-1']).toBeUndefined()
})
})
})
// Scenario: various form inputs configurations and default parsing
describe('inputsForms mapping and default parsing', () => {
const mockAppParamsWithInputs = {
user_input_form: [
{ paragraph: { variable: 'p1', default: 'para', max_length: 5 } },
{ number: { variable: 'n1', default: 42 } },
{ checkbox: { variable: 'c1', default: true } },
{ select: { variable: 's1', options: ['A', 'B'], default: 'A' } },
{ 'file-list': { variable: 'fl1' } },
{ file: { variable: 'f1' } },
{ json_object: { variable: 'j1' } },
{ 'text-input': { variable: 't1', default: 'txt', max_length: 3 } },
],
}
it('should map various types properly with max_length truncation when defaults supplied via URL', async () => {
mockGetProcessedInputsFromUrlParams.mockResolvedValue({
p1: 'toolongparagraph', // truncated to 5
n1: '99',
c1: true,
s1: 'B', // Matches options
t1: '1234', // truncated to 3
})
mockStoreState.appParams = mockAppParamsWithInputs as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Wait for the mock to be called
await waitFor(() => {
expect(mockGetProcessedInputsFromUrlParams).toHaveBeenCalled()
})
await waitFor(() => {
expect(result.current.inputsForms).toHaveLength(8)
})
const forms = result.current.inputsForms
expect(forms.find((f: InputForm) => f.variable === 'p1')?.default).toBe('toolo')
expect(forms.find((f: InputForm) => f.variable === 'n1')?.default).toBe(99)
expect(forms.find((f: InputForm) => f.variable === 'c1')?.default).toBe(true)
expect(forms.find((f: InputForm) => f.variable === 's1')?.default).toBe('B')
expect(forms.find((f: InputForm) => f.variable === 't1')?.default).toBe('123')
expect(forms.find((f: InputForm) => f.variable === 'fl1')?.type).toBe('file-list')
expect(forms.find((f: InputForm) => f.variable === 'f1')?.type).toBe('file')
expect(forms.find((f: InputForm) => f.variable === 'j1')?.type).toBe('json_object')
})
})
// Scenario: checkInputsRequired validates empty fields and pending multi-file uploads
describe('checkInputsRequired and handleStartChat', () => {
it('should return undefined and notify when file is still uploading', async () => {
mockStoreState.appParams = {
user_input_form: [
{ file: { variable: 'file_var', required: true } },
],
} as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Simulate a local file uploading
act(() => {
result.current.handleNewConversationInputsChange({
file_var: [{ transferMethod: 'local_file', uploadedId: null }],
})
})
const onStart = vi.fn()
let checkResult: boolean | undefined
act(() => {
checkResult = (result.current as unknown as { handleStartChat: (onStart?: () => void) => boolean }).handleStartChat(onStart)
})
expect(checkResult).toBeUndefined()
expect(onStart).not.toHaveBeenCalled()
})
it('should fail checkInputsRequired when required fields are missing', async () => {
mockStoreState.appParams = {
user_input_form: [
{ 'text-input': { variable: 't1', required: true, label: 'T1' } },
],
} as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
act(() => {
result.current.handleNewConversationInputsChange({
t1: '',
})
})
const onStart = vi.fn()
act(() => {
(result.current as unknown as { handleStartChat: (cb?: () => void) => void }).handleStartChat(onStart)
})
expect(onStart).not.toHaveBeenCalled()
})
it('should pass checkInputsRequired when allInputsHidden is true', async () => {
mockStoreState.appParams = {
user_input_form: [
{ 'text-input': { variable: 't1', required: true, label: 'T1', hide: true } },
],
} as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const callback = vi.fn()
act(() => {
(result.current as unknown as { handleStartChat: (cb?: () => void) => void }).handleStartChat(callback)
})
expect(callback).toHaveBeenCalled()
})
})
// Scenario: handlers (New Conversation, Change Conversation, Feedback)
describe('Event Handlers', () => {
it('handleNewConversation sets clearChatList to true for webApp', async () => {
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
await act(async () => {
await result.current.handleNewConversation()
})
expect(result.current.clearChatList).toBe(true)
})
it('handleNewConversation sets clearChatList to true for tryApp without complex parsing', async () => {
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.tryApp, 'app-try-1'))
await act(async () => {
await result.current.handleNewConversation()
})
expect(result.current.clearChatList).toBe(true)
})
it('handleChangeConversation updates current conversation and refetches chat list', async () => {
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
act(() => {
result.current.handleChangeConversation('another-convo')
})
await waitFor(() => {
expect(result.current.currentConversationId).toBe('another-convo')
})
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('another-convo', AppSourceType.webApp, 'app-1')
})
expect(result.current.newConversationId).toBe('')
expect(result.current.clearChatList).toBe(false)
})
it('handleFeedback invokes updateFeedback service successfully', async () => {
const { updateFeedback } = await import('@/service/share')
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
await act(async () => {
await result.current.handleFeedback('msg-123', { rating: 'like' })
})
expect(updateFeedback).toHaveBeenCalled()
})
})
})

View File

@@ -1,189 +0,0 @@
/**
* Tests for embedded-chatbot utility functions.
*/
import { isDify } from '../utils'
describe('isDify', () => {
const originalReferrer = document.referrer
afterEach(() => {
Object.defineProperty(document, 'referrer', {
value: originalReferrer,
writable: true,
})
})
it('should return true when referrer includes dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai/something',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes www.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://www.dify.ai/app/xyz',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return false when referrer does not include dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://example.com',
writable: true,
})
expect(isDify()).toBe(false)
})
it('should return false when referrer is empty', () => {
Object.defineProperty(document, 'referrer', {
value: '',
writable: true,
})
expect(isDify()).toBe(false)
})
it('should return false when referrer does not contain dify.ai domain', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://example-dify.com',
writable: true,
})
expect(isDify()).toBe(false)
})
it('should handle referrer without protocol', () => {
Object.defineProperty(document, 'referrer', {
value: 'dify.ai',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes api.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://api.dify.ai/v1/endpoint',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes app.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://app.dify.ai/chat',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes docs.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://docs.dify.ai/guide',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer has dify.ai with query parameters', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai/?ref=test&id=123',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer has dify.ai with hash fragment', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai/page#section',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer has dify.ai with port number', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai:8080/app',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when dify.ai appears after another domain', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://example.com/redirect?url=https://dify.ai',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when substring contains dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://notdify.ai',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when dify.ai is part of a different domain', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://fake-dify.ai.example.com',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true with multiple referrer variations', () => {
const variations = [
'https://dify.ai',
'http://www.dify.ai',
'http://dify.ai/',
'https://dify.ai/app?token=123#section',
'dify.ai/test',
'www.dify.ai/en',
]
variations.forEach((referrer) => {
Object.defineProperty(document, 'referrer', {
value: referrer,
writable: true,
})
expect(isDify()).toBe(true)
})
})
it('should return false with multiple non-dify referrer variations', () => {
const variations = [
'https://github.com',
'https://google.com',
'https://stackoverflow.com',
'https://example.dify',
'https://difyai.com',
'',
]
variations.forEach((referrer) => {
Object.defineProperty(document, 'referrer', {
value: referrer,
writable: true,
})
expect(isDify()).toBe(false)
})
})
})

View File

@@ -1,221 +0,0 @@
import { renderHook } from '@testing-library/react'
import { Theme, ThemeBuilder, useThemeContext } from '../theme-context'
// Scenario: Theme class configures colors from chatColorTheme and chatColorThemeInverted flags.
describe('Theme', () => {
describe('Default colors', () => {
it('should use default primary color when chatColorTheme is null', () => {
const theme = new Theme(null, false)
expect(theme.primaryColor).toBe('#1C64F2')
})
it('should use gradient background header when chatColorTheme is null', () => {
const theme = new Theme(null, false)
expect(theme.backgroundHeaderColorStyle).toBe(
'backgroundImage: linear-gradient(to right, #2563eb, #0ea5e9)',
)
})
it('should have empty chatBubbleColorStyle when chatColorTheme is null', () => {
const theme = new Theme(null, false)
expect(theme.chatBubbleColorStyle).toBe('')
})
it('should use default colors when chatColorTheme is empty string', () => {
const theme = new Theme('', false)
expect(theme.primaryColor).toBe('#1C64F2')
expect(theme.backgroundHeaderColorStyle).toBe(
'backgroundImage: linear-gradient(to right, #2563eb, #0ea5e9)',
)
})
})
describe('Custom color (configCustomColor)', () => {
it('should set primaryColor to chatColorTheme value', () => {
const theme = new Theme('#FF5733', false)
expect(theme.primaryColor).toBe('#FF5733')
})
it('should set backgroundHeaderColorStyle to solid custom color', () => {
const theme = new Theme('#FF5733', false)
expect(theme.backgroundHeaderColorStyle).toBe('backgroundColor: #FF5733')
})
it('should include primary color in backgroundButtonDefaultColorStyle', () => {
const theme = new Theme('#FF5733', false)
expect(theme.backgroundButtonDefaultColorStyle).toContain('#FF5733')
})
it('should set roundedBackgroundColorStyle with 5% opacity rgba', () => {
const theme = new Theme('#FF5733', false)
// #FF5733 → r=255 g=87 b=51
expect(theme.roundedBackgroundColorStyle).toBe('backgroundColor: rgba(255,87,51,0.05)')
})
it('should set chatBubbleColorStyle with 15% opacity rgba', () => {
const theme = new Theme('#FF5733', false)
expect(theme.chatBubbleColorStyle).toBe('backgroundColor: rgba(255,87,51,0.15)')
})
})
describe('Inverted color (configInvertedColor)', () => {
it('should use white background header when inverted with no custom color', () => {
const theme = new Theme(null, true)
expect(theme.backgroundHeaderColorStyle).toBe('backgroundColor: #ffffff')
})
it('should set colorFontOnHeaderStyle to default primaryColor when inverted with no custom color', () => {
const theme = new Theme(null, true)
expect(theme.colorFontOnHeaderStyle).toBe('color: #1C64F2')
})
it('should set headerBorderBottomStyle when inverted', () => {
const theme = new Theme(null, true)
expect(theme.headerBorderBottomStyle).toBe('borderBottom: 1px solid #ccc')
})
it('should set colorPathOnHeader to primaryColor when inverted', () => {
const theme = new Theme(null, true)
expect(theme.colorPathOnHeader).toBe('#1C64F2')
})
it('should have empty headerBorderBottomStyle when not inverted', () => {
const theme = new Theme(null, false)
expect(theme.headerBorderBottomStyle).toBe('')
})
})
describe('Custom color + inverted combined', () => {
it('should override background to white even when custom color is set', () => {
const theme = new Theme('#FF5733', true)
// configCustomColor runs first (solid bg), then configInvertedColor overrides to white
expect(theme.backgroundHeaderColorStyle).toBe('backgroundColor: #ffffff')
})
it('should use custom primaryColor for colorFontOnHeaderStyle when inverted', () => {
const theme = new Theme('#FF5733', true)
expect(theme.colorFontOnHeaderStyle).toBe('color: #FF5733')
})
it('should set colorPathOnHeader to custom primaryColor when inverted', () => {
const theme = new Theme('#FF5733', true)
expect(theme.colorPathOnHeader).toBe('#FF5733')
})
})
})
// Scenario: ThemeBuilder manages a lazily-created Theme instance and rebuilds on config change.
describe('ThemeBuilder', () => {
describe('theme getter', () => {
it('should create a default Theme when _theme is undefined (first access)', () => {
const builder = new ThemeBuilder()
const theme = builder.theme
expect(theme).toBeInstanceOf(Theme)
expect(theme.primaryColor).toBe('#1C64F2')
})
it('should return the same Theme instance on subsequent accesses', () => {
const builder = new ThemeBuilder()
const first = builder.theme
const second = builder.theme
expect(first).toBe(second)
})
})
describe('buildTheme', () => {
it('should create a Theme with the given color on first call', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
expect(builder.theme.primaryColor).toBe('#AABBCC')
})
it('should not rebuild the Theme when called again with the same config', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
const themeAfterFirstBuild = builder.theme
builder.buildTheme('#AABBCC', false)
// Same instance: no rebuild occurred
expect(builder.theme).toBe(themeAfterFirstBuild)
})
it('should rebuild the Theme when chatColorTheme changes', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
const originalTheme = builder.theme
builder.buildTheme('#FF0000', false)
expect(builder.theme).not.toBe(originalTheme)
expect(builder.theme.primaryColor).toBe('#FF0000')
})
it('should rebuild the Theme when chatColorThemeInverted changes', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
const originalTheme = builder.theme
builder.buildTheme('#AABBCC', true)
expect(builder.theme).not.toBe(originalTheme)
expect(builder.theme.chatColorThemeInverted).toBe(true)
})
it('should use default args (null, false) when called with no arguments', () => {
const builder = new ThemeBuilder()
builder.buildTheme()
expect(builder.theme.chatColorTheme).toBeNull()
expect(builder.theme.chatColorThemeInverted).toBe(false)
})
it('should store chatColorTheme and chatColorThemeInverted on the built Theme', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#123456', true)
expect(builder.theme.chatColorTheme).toBe('#123456')
expect(builder.theme.chatColorThemeInverted).toBe(true)
})
})
})
// Scenario: useThemeContext returns a ThemeBuilder from the nearest ThemeContext.
describe('useThemeContext', () => {
it('should return a ThemeBuilder instance from the default context', () => {
const { result } = renderHook(() => useThemeContext())
expect(result.current).toBeInstanceOf(ThemeBuilder)
})
it('should expose a valid theme on the returned ThemeBuilder', () => {
const { result } = renderHook(() => useThemeContext())
expect(result.current.theme).toBeInstanceOf(Theme)
})
})

View File

@@ -1,5 +1,6 @@
import type { Dayjs } from 'dayjs'
import type { DatePickerProps, Period } from '../types'
import { RiCalendarLine, RiCloseCircleFill } from '@remixicon/react'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -217,29 +218,38 @@ const DatePicker = ({
>
<PortalToFollowElemTrigger className={triggerWrapClassName}>
{renderTrigger
? (
renderTrigger({
value: normalizedValue,
selectedDate,
isOpen,
handleClear,
handleClickTrigger,
}))
? (renderTrigger({
value: normalizedValue,
selectedDate,
isOpen,
handleClear,
handleClickTrigger,
}))
: (
<div
className="group flex w-[252px] cursor-pointer items-center gap-x-0.5 rounded-lg bg-components-input-bg-normal px-2 py-1 hover:bg-state-base-hover-alt"
onClick={handleClickTrigger}
data-testid="date-picker-trigger"
>
<input
className="flex-1 cursor-pointer appearance-none truncate bg-transparent p-1 text-components-input-text-filled
outline-none system-xs-regular placeholder:text-components-input-text-placeholder"
className="system-xs-regular flex-1 cursor-pointer appearance-none truncate bg-transparent p-1
text-components-input-text-filled outline-none placeholder:text-components-input-text-placeholder"
readOnly
value={isOpen ? '' : displayValue}
placeholder={placeholderDate}
/>
<span className={cn('i-ri-calendar-line h-4 w-4 shrink-0 text-text-quaternary', isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary', (displayValue || (isOpen && selectedDate)) && 'group-hover:hidden')} />
<span className={cn('i-ri-close-circle-fill hidden h-4 w-4 shrink-0 text-text-quaternary', (displayValue || (isOpen && selectedDate)) && 'hover:text-text-secondary group-hover:inline-block')} onClick={handleClear} data-testid="date-picker-clear-button" />
<RiCalendarLine className={cn(
'h-4 w-4 shrink-0 text-text-quaternary',
isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary',
(displayValue || (isOpen && selectedDate)) && 'group-hover:hidden',
)}
/>
<RiCloseCircleFill
className={cn(
'hidden h-4 w-4 shrink-0 text-text-quaternary',
(displayValue || (isOpen && selectedDate)) && 'hover:text-text-secondary group-hover:inline-block',
)}
onClick={handleClear}
/>
</div>
)}
</PortalToFollowElemTrigger>

View File

@@ -503,7 +503,7 @@ describe('TimePicker', () => {
const emitted = onChange.mock.calls[0][0]
expect(isDayjsObject(emitted)).toBe(true)
// 10:30 UTC converted to America/New_York (UTC-5 in Jan) = 05:30
expect(emitted.utcOffset()).toBe(dayjs.tz('2024-01-01', 'America/New_York').utcOffset())
expect(emitted.utcOffset()).toBe(dayjs().tz('America/New_York').utcOffset())
expect(emitted.hour()).toBe(5)
expect(emitted.minute()).toBe(30)
})

View File

@@ -1,5 +1,6 @@
import type { Dayjs } from 'dayjs'
import type { TimePickerProps } from '../types'
import { RiCloseCircleFill, RiTimeLine } from '@remixicon/react'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -198,8 +199,8 @@ const TimePicker = ({
const inputElem = (
<input
className="flex-1 cursor-pointer select-none appearance-none truncate bg-transparent p-1 text-components-input-text-filled
outline-none system-xs-regular placeholder:text-components-input-text-placeholder"
className="system-xs-regular flex-1 cursor-pointer select-none appearance-none truncate bg-transparent p-1
text-components-input-text-filled outline-none placeholder:text-components-input-text-placeholder"
readOnly
value={isOpen ? '' : displayValue}
placeholder={placeholderDate}
@@ -225,14 +226,26 @@ const TimePicker = ({
triggerFullWidth ? 'w-full min-w-0' : 'w-[252px]',
)}
onClick={handleClickTrigger}
data-testid="time-picker-trigger"
>
{inputElem}
{showTimezone && timezone && (
<TimezoneLabel timezone={timezone} inline className="shrink-0 select-none text-xs" />
)}
<span className={cn('i-ri-time-line h-4 w-4 shrink-0 text-text-quaternary', isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary', (displayValue || (isOpen && selectedTime)) && !notClearable && 'group-hover:hidden')} />
<span className={cn('i-ri-close-circle-fill hidden h-4 w-4 shrink-0 text-text-quaternary', (displayValue || (isOpen && selectedTime)) && !notClearable && 'hover:text-text-secondary group-hover:inline-block')} role="button" aria-label={t('operation.clear', { ns: 'common' })} onClick={handleClear} />
<RiTimeLine className={cn(
'h-4 w-4 shrink-0 text-text-quaternary',
isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary',
(displayValue || (isOpen && selectedTime)) && !notClearable && 'group-hover:hidden',
)}
/>
<RiCloseCircleFill
className={cn(
'hidden h-4 w-4 shrink-0 text-text-quaternary',
(displayValue || (isOpen && selectedTime)) && !notClearable && 'hover:text-text-secondary group-hover:inline-block',
)}
role="button"
aria-label={t('operation.clear', { ns: 'common' })}
onClick={handleClear}
/>
</div>
)}
</PortalToFollowElemTrigger>

View File

@@ -20,7 +20,7 @@ describe('dayjs utilities', () => {
const result = toDayjs('07:15 PM', { timezone: tz })
expect(result).toBeDefined()
expect(result?.format('HH:mm')).toBe('19:15')
expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone: tz }).startOf('day').utcOffset())
expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone: tz }).utcOffset())
})
it('isDayjsObject detects dayjs instances', () => {

View File

@@ -1,105 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import DynamicPdfPreview from './dynamic-pdf-preview'
type DynamicPdfPreviewProps = {
url: string
onCancel: () => void
}
type DynamicLoader = () => Promise<unknown> | undefined
type DynamicOptions = {
ssr?: boolean
}
const mockState = vi.hoisted(() => ({
loader: undefined as DynamicLoader | undefined,
options: undefined as DynamicOptions | undefined,
}))
const mockDynamicRender = vi.hoisted(() => vi.fn())
const mockDynamic = vi.hoisted(() =>
vi.fn((loader: DynamicLoader, options: DynamicOptions) => {
mockState.loader = loader
mockState.options = options
const MockDynamicPdfPreview = ({ url, onCancel }: DynamicPdfPreviewProps) => {
mockDynamicRender({ url, onCancel })
return (
<button data-testid="dynamic-pdf-preview" data-url={url} onClick={onCancel}>
Dynamic PDF Preview
</button>
)
}
return MockDynamicPdfPreview
}),
)
const mockPdfPreview = vi.hoisted(() =>
vi.fn(() => null),
)
vi.mock('next/dynamic', () => ({
default: mockDynamic,
}))
vi.mock('./pdf-preview', () => ({
default: mockPdfPreview,
}))
describe('dynamic-pdf-preview', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should configure next/dynamic with ssr disabled', () => {
expect(mockState.loader).toEqual(expect.any(Function))
expect(mockState.options).toEqual({ ssr: false })
})
it('should render the dynamic component and forward props', () => {
const onCancel = vi.fn()
render(<DynamicPdfPreview url="https://example.com/test.pdf" onCancel={onCancel} />)
const trigger = screen.getByTestId('dynamic-pdf-preview')
expect(trigger).toHaveAttribute('data-url', 'https://example.com/test.pdf')
expect(mockDynamicRender).toHaveBeenCalledWith({
url: 'https://example.com/test.pdf',
onCancel,
})
fireEvent.click(trigger)
expect(onCancel).toHaveBeenCalledTimes(1)
})
it('should return pdf-preview module when loader is executed in browser-like environment', async () => {
const loaded = mockState.loader?.()
expect(loaded).toBeInstanceOf(Promise)
const loadedModule = (await loaded) as { default: unknown }
const pdfPreviewModule = await import('./pdf-preview')
expect(loadedModule.default).toBe(pdfPreviewModule.default)
})
it('should return undefined when loader runs without window', () => {
const originalWindow = globalThis.window
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: undefined,
})
try {
const loaded = mockState.loader?.()
expect(loaded).toBeUndefined()
}
finally {
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: originalWindow,
})
}
})
})

View File

@@ -44,16 +44,4 @@ describe('VariableOrConstantInputField', () => {
fireEvent.click(modeButtons[0])
expect(screen.getByRole('button', { name: 'Variable picker' })).toBeInTheDocument()
})
it('should handle variable picker changes', () => {
const logSpy = vi.spyOn(console, 'log').mockImplementation(() => { })
try {
render(<VariableOrConstantInputField label="Input source" />)
fireEvent.click(screen.getByRole('button', { name: 'Variable picker' }))
expect(logSpy).toHaveBeenCalledWith('Variable value changed')
}
finally {
logSpy.mockRestore()
}
})
})

View File

@@ -46,54 +46,4 @@ describe('base scenario schema generator', () => {
expect(schema.safeParse({}).success).toBe(true)
expect(schema.safeParse({ mode: null }).success).toBe(true)
})
it('should validate required checkbox values as booleans', () => {
const schema = generateZodSchema([{
type: BaseFieldType.checkbox,
variable: 'accepted',
label: 'Accepted',
required: true,
showConditions: [],
}])
expect(schema.safeParse({ accepted: true }).success).toBe(true)
expect(schema.safeParse({ accepted: false }).success).toBe(true)
expect(schema.safeParse({ accepted: 'yes' }).success).toBe(false)
expect(schema.safeParse({}).success).toBe(false)
})
it('should fallback to any schema for unsupported field types', () => {
const schema = generateZodSchema([{
type: BaseFieldType.file,
variable: 'attachment',
label: 'Attachment',
required: false,
showConditions: [],
allowedFileTypes: [],
allowedFileExtensions: [],
allowedFileUploadMethods: [],
}])
expect(schema.safeParse({ attachment: { id: 'file-1' } }).success).toBe(true)
expect(schema.safeParse({ attachment: 'raw-string' }).success).toBe(true)
expect(schema.safeParse({}).success).toBe(true)
expect(schema.safeParse({ attachment: null }).success).toBe(true)
})
it('should ignore numeric and text constraints for non-applicable field types', () => {
const schema = generateZodSchema([{
type: BaseFieldType.checkbox,
variable: 'toggle',
label: 'Toggle',
required: true,
showConditions: [],
maxLength: 1,
min: 10,
max: 20,
}])
expect(schema.safeParse({ toggle: true }).success).toBe(true)
expect(schema.safeParse({ toggle: false }).success).toBe(true)
expect(schema.safeParse({ toggle: 1 }).success).toBe(false)
})
})

View File

@@ -8,7 +8,7 @@ import * as utils from '../utils'
vi.mock('../utils', () => ({
generate: vi.fn((icon, key, props) => (
<svg
data-testid={key}
data-testid="mock-svg"
key={key}
{...props}
>
@@ -29,7 +29,7 @@ describe('IconBase Component', () => {
it('renders properly with required props', () => {
render(<IconBase data={mockData} />)
const svg = screen.getByTestId('svg-test-icon')
const svg = screen.getByTestId('mock-svg')
expect(svg).toBeInTheDocument()
expect(svg).toHaveAttribute('data-icon', mockData.name)
expect(svg).toHaveAttribute('aria-hidden', 'true')
@@ -37,7 +37,7 @@ describe('IconBase Component', () => {
it('passes className to the generated SVG', () => {
render(<IconBase data={mockData} className="custom-class" />)
const svg = screen.getByTestId('svg-test-icon')
const svg = screen.getByTestId('mock-svg')
expect(svg).toHaveAttribute('class', 'custom-class')
expect(utils.generate).toHaveBeenCalledWith(
mockData.icon,
@@ -49,7 +49,7 @@ describe('IconBase Component', () => {
it('handles onClick events', () => {
const handleClick = vi.fn()
render(<IconBase data={mockData} onClick={handleClick} />)
const svg = screen.getByTestId('svg-test-icon')
const svg = screen.getByTestId('mock-svg')
fireEvent.click(svg)
expect(handleClick).toHaveBeenCalledTimes(1)
})

View File

@@ -21,28 +21,6 @@ describe('generate icon base utils', () => {
const result = normalizeAttrs(attrs)
expect(result).toEqual({ dataTest: 'value', xlinkHref: 'url' })
})
it('should filter out editor metadata attributes', () => {
const attrs = {
'inkscape:version': '1.0',
'sodipodi:docname': 'icon.svg',
'xmlns:inkscape': 'http...',
'xmlns:sodipodi': 'http...',
'xmlns:svg': 'http...',
'data-name': 'Layer 1',
'xmlns-inkscape': 'http...',
'xmlns-sodipodi': 'http...',
'xmlns-svg': 'http...',
'dataName': 'Layer 1',
'valid': 'value',
}
expect(normalizeAttrs(attrs)).toEqual({ valid: 'value' })
})
it('should ignore undefined attribute values and handle default argument', () => {
expect(normalizeAttrs()).toEqual({})
expect(normalizeAttrs({ missing: undefined, valid: 'true' })).toEqual({ valid: 'true' })
})
})
describe('generate', () => {
@@ -80,19 +58,7 @@ describe('generate icon base utils', () => {
const node: AbstractNode = {
name: 'div',
attributes: { class: 'container' },
children: [{ name: 'span', attributes: {} }],
}
const rootProps = { id: 'root' }
const { container } = render(generate(node, 'key', rootProps))
expect(container.querySelector('div')).toHaveAttribute('id', 'root')
expect(container.querySelector('span')).toBeInTheDocument()
})
it('should handle undefined children with rootProps', () => {
const node: AbstractNode = {
name: 'div',
attributes: { class: 'container' },
children: [],
}
const rootProps = { id: 'root' }

View File

@@ -0,0 +1,4 @@
<svg width="10" height="10" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z" fill="#676F83"/>
<path d="M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z" fill="#676F83"/>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -0,0 +1,35 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "10",
"height": "10",
"viewBox": "0 0 10 10",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z",
"fill": "currentColor"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z",
"fill": "currentColor"
},
"children": []
}
]
},
"name": "CreditsCoin"
}

View File

@@ -0,0 +1,20 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import type { IconData } from '@/app/components/base/icons/IconBase'
import * as React from 'react'
import IconBase from '@/app/components/base/icons/IconBase'
import data from './CreditsCoin.json'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'CreditsCoin'
export default Icon

View File

@@ -1,5 +1,6 @@
export { default as Balance } from './Balance'
export { default as CoinsStacked01 } from './CoinsStacked01'
export { default as CreditsCoin } from './CreditsCoin'
export { default as GoldCoin } from './GoldCoin'
export { default as ReceiptList } from './ReceiptList'
export { default as Tag01 } from './Tag01'

View File

@@ -36,7 +36,7 @@ const ImageGallery: FC<Props> = ({
const imgNum = srcs.length
const imgStyle = getWidthStyle(imgNum)
return (
<div className={cn(s[`img-${imgNum}`], 'flex flex-wrap')} data-testid="image-gallery">
<div className={cn(s[`img-${imgNum}`], 'flex flex-wrap')}>
{srcs.map((src, index) => (
!src
? null

View File

@@ -1,6 +1,6 @@
import type { useLocalFileUploader } from '../hooks'
import type { ImageFile, VisionSettings } from '@/types/app'
import { fireEvent, render, screen } from '@testing-library/react'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { Resolution, TransferMethod } from '@/types/app'
import ChatImageUploader from '../chat-image-uploader'
@@ -193,23 +193,6 @@ describe('ChatImageUploader', () => {
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
})
it('should keep popover closed when trigger wrapper is clicked while disabled', async () => {
const user = userEvent.setup()
const settings = createSettings({
transfer_methods: [TransferMethod.remote_url],
})
render(<ChatImageUploader settings={settings} onUpload={defaultOnUpload} disabled />)
const button = screen.getByRole('button')
const triggerWrapper = button.parentElement
if (!triggerWrapper)
throw new Error('Expected trigger wrapper to exist')
await user.click(triggerWrapper)
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
})
it('should show OR separator and local uploader when both methods are available', async () => {
const user = userEvent.setup()
const settings = createSettings({
@@ -224,30 +207,6 @@ describe('ChatImageUploader', () => {
expect(queryFileInput()).toBeInTheDocument()
})
it('should toggle local-upload hover style in mixed transfer mode', async () => {
const user = userEvent.setup()
const settings = createSettings({
transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url],
})
render(<ChatImageUploader settings={settings} onUpload={defaultOnUpload} />)
await user.click(screen.getByRole('button'))
const uploadFromComputer = screen.getByText('common.imageUploader.uploadFromComputer')
expect(uploadFromComputer).not.toHaveClass('bg-primary-50')
const localInput = getFileInput()
const hoverWrapper = localInput.parentElement
if (!hoverWrapper)
throw new Error('Expected local uploader wrapper to exist')
fireEvent.mouseEnter(hoverWrapper)
expect(uploadFromComputer).toHaveClass('bg-primary-50')
fireEvent.mouseLeave(hoverWrapper)
expect(uploadFromComputer).not.toHaveClass('bg-primary-50')
})
it('should not show OR separator or local uploader when only remote_url method', async () => {
const user = userEvent.setup()
const settings = createSettings({

View File

@@ -140,11 +140,9 @@ describe('ImageLinkInput', () => {
const input = screen.getByRole('textbox')
await user.type(input, 'https://example.com/image.png')
const button = screen.getByRole('button')
expect(button).toBeDisabled()
await user.click(button)
await user.click(screen.getByRole('button'))
// Button is disabled, so click won't fire handleClick
expect(onUpload).not.toHaveBeenCalled()
})

View File

@@ -2,15 +2,22 @@ import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import ImagePreview from '../image-preview'
type _HotkeyHandler = () => void
type HotkeyHandler = () => void
const mocks = vi.hoisted(() => ({
hotkeys: {} as Record<string, HotkeyHandler>,
notify: vi.fn(),
downloadUrl: vi.fn(),
windowOpen: vi.fn<(...args: unknown[]) => Window | null>(),
clipboardWrite: vi.fn<(items: ClipboardItem[]) => Promise<void>>(),
}))
vi.mock('react-hotkeys-hook', () => ({
useHotkeys: (keys: string, handler: HotkeyHandler) => {
mocks.hotkeys[keys] = handler
},
}))
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (...args: Parameters<typeof mocks.notify>) => mocks.notify(...args),
@@ -37,6 +44,7 @@ describe('ImagePreview', () => {
beforeEach(() => {
vi.clearAllMocks()
mocks.hotkeys = {}
if (!navigator.clipboard) {
Object.defineProperty(globalThis.navigator, 'clipboard', {
@@ -101,8 +109,7 @@ describe('ImagePreview', () => {
})
describe('Hotkeys', () => {
it('should trigger esc/left/right handlers from keyboard', async () => {
const user = userEvent.setup()
it('should register hotkeys and invoke esc/left/right handlers', () => {
const onCancel = vi.fn()
const onPrev = vi.fn()
const onNext = vi.fn()
@@ -116,34 +123,18 @@ describe('ImagePreview', () => {
/>,
)
await user.keyboard('{Escape}{ArrowLeft}{ArrowRight}')
expect(mocks.hotkeys.esc).toBeInstanceOf(Function)
expect(mocks.hotkeys.left).toBeInstanceOf(Function)
expect(mocks.hotkeys.right).toBeInstanceOf(Function)
mocks.hotkeys.esc?.()
mocks.hotkeys.left?.()
mocks.hotkeys.right?.()
expect(onCancel).toHaveBeenCalledTimes(1)
expect(onPrev).toHaveBeenCalledTimes(1)
expect(onNext).toHaveBeenCalledTimes(1)
})
it('should zoom in and out from keyboard up/down hotkeys', async () => {
const user = userEvent.setup()
render(
<ImagePreview
url="https://example.com/image.png"
title="Preview Image"
onCancel={vi.fn()}
/>,
)
const image = screen.getByRole('img', { name: 'Preview Image' })
await user.keyboard('{ArrowUp}')
await waitFor(() => {
expect(image).toHaveStyle({ transform: 'scale(1.2) translate(0px, 0px)' })
})
await user.keyboard('{ArrowDown}')
await waitFor(() => {
expect(image).toHaveStyle({ transform: 'scale(1) translate(0px, 0px)' })
})
})
})
describe('User Interactions', () => {
@@ -234,18 +225,13 @@ describe('ImagePreview', () => {
act(() => {
overlay.dispatchEvent(new MouseEvent('mousedown', { bubbles: true, clientX: 10, clientY: 10 }))
overlay.dispatchEvent(new MouseEvent('mousemove', { bubbles: true, clientX: 40, clientY: 30 }))
})
await waitFor(() => {
expect(image.style.transition).toBe('none')
})
act(() => {
overlay.dispatchEvent(new MouseEvent('mousemove', { bubbles: true, clientX: 200, clientY: -100 }))
})
await waitFor(() => {
expect(image).toHaveStyle({ transform: 'scale(1.2) translate(70px, -22px)' })
})
expect(image.style.transform).toContain('translate(')
act(() => {
document.dispatchEvent(new MouseEvent('mouseup', { bubbles: true }))

View File

@@ -1,5 +1,4 @@
import { fireEvent, render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { InputNumber } from '../index'
describe('InputNumber Component', () => {
@@ -17,130 +16,70 @@ describe('InputNumber Component', () => {
expect(input).toBeInTheDocument()
})
it('handles increment button click', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} />)
it('handles increment button click', () => {
render(<InputNumber {...defaultProps} value={5} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
await user.click(incrementBtn)
expect(onChange).toHaveBeenCalledWith(6)
fireEvent.click(incrementBtn)
expect(defaultProps.onChange).toHaveBeenCalledWith(6)
})
it('handles decrement button click', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} />)
it('handles decrement button click', () => {
render(<InputNumber {...defaultProps} value={5} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).toHaveBeenCalledWith(4)
fireEvent.click(decrementBtn)
expect(defaultProps.onChange).toHaveBeenCalledWith(4)
})
it('respects max value constraint', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={10} max={10} />)
it('respects max value constraint', () => {
render(<InputNumber {...defaultProps} value={10} max={10} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
await user.click(incrementBtn)
expect(onChange).not.toHaveBeenCalled()
fireEvent.click(incrementBtn)
expect(defaultProps.onChange).not.toHaveBeenCalled()
})
it('respects min value constraint', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={0} min={0} />)
it('respects min value constraint', () => {
render(<InputNumber {...defaultProps} value={0} min={0} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
fireEvent.click(decrementBtn)
expect(defaultProps.onChange).not.toHaveBeenCalled()
})
it('handles direct input changes', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} />)
render(<InputNumber {...defaultProps} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '42' } })
expect(onChange).toHaveBeenCalledWith(42)
expect(defaultProps.onChange).toHaveBeenCalledWith(42)
})
it('handles empty input', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={1} />)
render(<InputNumber {...defaultProps} value={1} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '' } })
expect(onChange).toHaveBeenCalledWith(0)
expect(defaultProps.onChange).toHaveBeenCalledWith(0)
})
it('does not call onChange when parsed value is NaN', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} />)
it('handles invalid input', () => {
render(<InputNumber {...defaultProps} />)
const input = screen.getByRole('spinbutton')
const originalNumber = globalThis.Number
const numberSpy = vi.spyOn(globalThis, 'Number').mockImplementation((val: unknown) => {
if (val === '123') {
return Number.NaN
}
return originalNumber(val)
})
try {
fireEvent.change(input, { target: { value: '123' } })
expect(onChange).not.toHaveBeenCalled()
}
finally {
numberSpy.mockRestore()
}
})
it('does not call onChange when direct input exceeds range', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} max={10} min={0} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '11' } })
expect(onChange).not.toHaveBeenCalled()
})
it('uses default value when increment and decrement are clicked without value prop', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} defaultValue={7} />)
await user.click(screen.getByRole('button', { name: /increment/i }))
await user.click(screen.getByRole('button', { name: /decrement/i }))
expect(onChange).toHaveBeenNthCalledWith(1, 7)
expect(onChange).toHaveBeenNthCalledWith(2, 7)
})
it('falls back to zero when controls are used without value and defaultValue', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} />)
await user.click(screen.getByRole('button', { name: /increment/i }))
await user.click(screen.getByRole('button', { name: /decrement/i }))
expect(onChange).toHaveBeenNthCalledWith(1, 0)
expect(onChange).toHaveBeenNthCalledWith(2, 0)
fireEvent.change(input, { target: { value: 'abc' } })
expect(defaultProps.onChange).toHaveBeenCalledWith(0)
})
it('displays unit when provided', () => {
const onChange = vi.fn()
const unit = 'px'
render(<InputNumber onChange={onChange} unit={unit} />)
render(<InputNumber {...defaultProps} unit={unit} />)
expect(screen.getByText(unit)).toBeInTheDocument()
})
it('disables controls when disabled prop is true', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} disabled />)
render(<InputNumber {...defaultProps} disabled />)
const input = screen.getByRole('spinbutton')
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
@@ -149,205 +88,4 @@ describe('InputNumber Component', () => {
expect(incrementBtn).toBeDisabled()
expect(decrementBtn).toBeDisabled()
})
it('does not change value when disabled controls are clicked', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
const { getByRole } = render(<InputNumber onChange={onChange} disabled value={5} />)
const incrementBtn = getByRole('button', { name: /increment/i })
const decrementBtn = getByRole('button', { name: /decrement/i })
expect(incrementBtn).toBeDisabled()
expect(decrementBtn).toBeDisabled()
await user.click(incrementBtn)
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('keeps increment guard when disabled even if button is force-clickable', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} disabled value={5} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
// Remove native disabled to force event dispatch and hit component-level guard.
incrementBtn.removeAttribute('disabled')
fireEvent.click(incrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('keeps decrement guard when disabled even if button is force-clickable', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} disabled value={5} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
// Remove native disabled to force event dispatch and hit component-level guard.
decrementBtn.removeAttribute('disabled')
fireEvent.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('applies large-size classes for control buttons', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} size="large" />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
expect(incrementBtn).toHaveClass('pt-1.5')
expect(decrementBtn).toHaveClass('pb-1.5')
})
it('prevents increment beyond max with custom amount', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={8} max={10} amount={5} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
await user.click(incrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('prevents decrement below min with custom amount', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={2} min={0} amount={5} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('increments when value with custom amount stays within bounds', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} max={10} amount={3} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
await user.click(incrementBtn)
expect(onChange).toHaveBeenCalledWith(8)
})
it('decrements when value with custom amount stays within bounds', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} min={0} amount={3} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).toHaveBeenCalledWith(2)
})
it('validates input against max constraint', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} max={10} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '15' } })
expect(onChange).not.toHaveBeenCalled()
})
it('validates input against min constraint', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={5} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '2' } })
expect(onChange).not.toHaveBeenCalled()
})
it('accepts input within min and max constraints', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={0} max={100} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '50' } })
expect(onChange).toHaveBeenCalledWith(50)
})
it('handles negative min and max values', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={-10} max={10} value={0} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).toHaveBeenCalledWith(-1)
})
it('prevents decrement below negative min', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={-10} value={-10} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('applies wrapClassName to outer div', () => {
const onChange = vi.fn()
const wrapClassName = 'custom-wrap-class'
render(<InputNumber onChange={onChange} wrapClassName={wrapClassName} />)
const wrapper = screen.getByTestId('input-number-wrapper')
expect(wrapper).toHaveClass(wrapClassName)
})
it('applies controlWrapClassName to control buttons container', () => {
const onChange = vi.fn()
const controlWrapClassName = 'custom-control-wrap'
render(<InputNumber onChange={onChange} controlWrapClassName={controlWrapClassName} />)
const controlDiv = screen.getByTestId('input-number-controls')
expect(controlDiv).toHaveClass(controlWrapClassName)
})
it('applies controlClassName to individual control buttons', () => {
const onChange = vi.fn()
const controlClassName = 'custom-control'
render(<InputNumber onChange={onChange} controlClassName={controlClassName} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
expect(incrementBtn).toHaveClass(controlClassName)
expect(decrementBtn).toHaveClass(controlClassName)
})
it('applies regular-size classes for control buttons when size is regular', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} size="regular" />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
expect(incrementBtn).toHaveClass('pt-1')
expect(decrementBtn).toHaveClass('pb-1')
})
it('handles zero as a valid input', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={-5} max={5} value={1} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '0' } })
expect(onChange).toHaveBeenCalledWith(0)
})
it('prevents exact max boundary increment', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={10} max={10} />)
await user.click(screen.getByRole('button', { name: /increment/i }))
expect(onChange).not.toHaveBeenCalled()
})
it('prevents exact min boundary decrement', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={0} min={0} />)
await user.click(screen.getByRole('button', { name: /decrement/i }))
expect(onChange).not.toHaveBeenCalled()
})
})

View File

@@ -1,5 +1,6 @@
import type { FC } from 'react'
import type { InputProps } from '../input'
import { RiArrowDownSLine, RiArrowUpSLine } from '@remixicon/react'
import { useCallback } from 'react'
import { cn } from '@/utils/classnames'
import Input from '../input'
@@ -44,7 +45,6 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
}, [max, min])
const inc = () => {
/* v8 ignore next 2 - @preserve */
if (disabled)
return
@@ -58,7 +58,6 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
onChange(newValue)
}
const dec = () => {
/* v8 ignore next 2 - @preserve */
if (disabled)
return
@@ -87,12 +86,12 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
}, [isValidValue, onChange])
return (
<div data-testid="input-number-wrapper" className={cn('flex', wrapClassName)}>
<div className={cn('flex', wrapClassName)}>
<Input
{...rest}
// disable default controller
type="number"
className={cn('rounded-r-none no-spinner', className)}
className={cn('no-spinner rounded-r-none', className)}
value={value ?? 0}
max={max}
min={min}
@@ -101,10 +100,7 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
unit={unit}
size={size}
/>
<div
data-testid="input-number-controls"
className={cn('flex flex-col rounded-r-md border-l border-divider-subtle bg-components-input-bg-normal text-text-tertiary focus:shadow-xs', disabled && 'cursor-not-allowed opacity-50', controlWrapClassName)}
>
<div className={cn('flex flex-col rounded-r-md border-l border-divider-subtle bg-components-input-bg-normal text-text-tertiary focus:shadow-xs', disabled && 'cursor-not-allowed opacity-50', controlWrapClassName)}>
<button
type="button"
onClick={inc}
@@ -112,7 +108,7 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
aria-label="increment"
className={cn(size === 'regular' ? 'pt-1' : 'pt-1.5', 'px-1.5 hover:bg-components-input-bg-hover', disabled && 'cursor-not-allowed hover:bg-transparent', controlClassName)}
>
<span className="i-ri-arrow-up-s-line size-3" />
<RiArrowUpSLine className="size-3" />
</button>
<button
type="button"
@@ -121,7 +117,7 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
aria-label="decrement"
className={cn(size === 'regular' ? 'pb-1' : 'pb-1.5', 'px-1.5 hover:bg-components-input-bg-hover', disabled && 'cursor-not-allowed hover:bg-transparent', controlClassName)}
>
<span className="i-ri-arrow-down-s-line size-3" />
<RiArrowDownSLine className="size-3" />
</button>
</div>
</div>

View File

@@ -35,7 +35,7 @@ describe('Input component', () => {
it('renders correctly with default props', () => {
render(<Input />)
const input = screen.getByPlaceholderText(/input/i)
const input = screen.getByPlaceholderText('Please input')
expect(input).toBeInTheDocument()
expect(input).not.toBeDisabled()
expect(input).not.toHaveClass('cursor-not-allowed')
@@ -45,7 +45,7 @@ describe('Input component', () => {
render(<Input showLeftIcon />)
const searchIcon = document.querySelector('.i-ri-search-line')
expect(searchIcon).toBeInTheDocument()
const input = screen.getByPlaceholderText(/search/i)
const input = screen.getByPlaceholderText('Search')
expect(input).toHaveClass('pl-[26px]')
})
@@ -75,13 +75,13 @@ describe('Input component', () => {
render(<Input destructive />)
const warningIcon = document.querySelector('.i-ri-error-warning-line')
expect(warningIcon).toBeInTheDocument()
const input = screen.getByPlaceholderText(/input/i)
const input = screen.getByPlaceholderText('Please input')
expect(input).toHaveClass('border-components-input-border-destructive')
})
it('applies disabled styles when disabled', () => {
render(<Input disabled />)
const input = screen.getByPlaceholderText(/input/i)
const input = screen.getByPlaceholderText('Please input')
expect(input).toBeDisabled()
expect(input).toHaveClass('cursor-not-allowed')
expect(input).toHaveClass('bg-components-input-bg-disabled')
@@ -97,7 +97,7 @@ describe('Input component', () => {
const customClass = 'test-class'
const customStyle = { color: 'red' }
render(<Input className={customClass} styleCss={customStyle} />)
const input = screen.getByPlaceholderText(/input/i)
const input = screen.getByPlaceholderText('Please input')
expect(input).toHaveClass(customClass)
expect(input).toHaveStyle({ color: 'rgb(255, 0, 0)' })
})
@@ -114,61 +114,4 @@ describe('Input component', () => {
const input = screen.getByPlaceholderText(placeholder)
expect(input).toBeInTheDocument()
})
describe('Number Input Formatting', () => {
it('removes leading zeros on change when current value is zero', () => {
let changedValue = ''
const onChange = vi.fn((e: React.ChangeEvent<HTMLInputElement>) => {
changedValue = e.target.value
})
render(<Input type="number" value={0} onChange={onChange} />)
const input = screen.getByRole('spinbutton') as HTMLInputElement
fireEvent.change(input, { target: { value: '00042' } })
expect(onChange).toHaveBeenCalledTimes(1)
expect(changedValue).toBe('42')
})
it('keeps typed value on change when current value is not zero', () => {
let changedValue = ''
const onChange = vi.fn((e: React.ChangeEvent<HTMLInputElement>) => {
changedValue = e.target.value
})
render(<Input type="number" value={1} onChange={onChange} />)
const input = screen.getByRole('spinbutton') as HTMLInputElement
fireEvent.change(input, { target: { value: '00042' } })
expect(onChange).toHaveBeenCalledTimes(1)
expect(changedValue).toBe('00042')
})
it('normalizes value and triggers change on blur when leading zeros exist', () => {
const onChange = vi.fn()
const onBlur = vi.fn()
render(<Input type="number" defaultValue="0012" onChange={onChange} onBlur={onBlur} />)
const input = screen.getByRole('spinbutton')
fireEvent.blur(input)
expect(onChange).toHaveBeenCalledTimes(1)
expect(onChange.mock.calls[0][0].type).toBe('change')
expect(onChange.mock.calls[0][0].target.value).toBe('12')
expect(onBlur).toHaveBeenCalledTimes(1)
expect(onBlur.mock.calls[0][0].target.value).toBe('12')
})
it('does not trigger change on blur when value is already normalized', () => {
const onChange = vi.fn()
const onBlur = vi.fn()
render(<Input type="number" defaultValue="12" onChange={onChange} onBlur={onBlur} />)
const input = screen.getByRole('spinbutton')
fireEvent.blur(input)
expect(onChange).not.toHaveBeenCalled()
expect(onBlur).toHaveBeenCalledTimes(1)
expect(onBlur.mock.calls[0][0].target.value).toBe('12')
})
})
})

View File

@@ -1,6 +1,7 @@
import { createRequire } from 'node:module'
import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { Theme } from '@/types/app'
import CodeBlock from '../code-block'
@@ -153,12 +154,12 @@ describe('CodeBlock', () => {
expect(screen.getByText('Ruby')).toBeInTheDocument()
})
// it('should render mermaid controls when language is mermaid', async () => {
// render(<CodeBlock className="language-mermaid">graph TB; A--&gt;B;</CodeBlock>)
it('should render mermaid controls when language is mermaid', async () => {
render(<CodeBlock className="language-mermaid">graph TB; A--&gt;B;</CodeBlock>)
// expect(await screen.findByTestId('classic')).toBeInTheDocument()
// expect(screen.getByText('Mermaid')).toBeInTheDocument()
// })
expect(await screen.findByText('app.mermaid.classic')).toBeInTheDocument()
expect(screen.getByText('Mermaid')).toBeInTheDocument()
})
it('should render abc section header when language is abc', () => {
render(<CodeBlock className="language-abc">X:1\nT:test</CodeBlock>)

View File

@@ -200,7 +200,7 @@ describe('MarkdownForm', () => {
})
it('should handle invalid data-options string without crashing', () => {
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
const node = createRootNode([
createElementNode('input', {
'type': 'select',
@@ -317,174 +317,4 @@ describe('MarkdownForm', () => {
expect(mockOnSend).not.toHaveBeenCalled()
})
})
// DatePicker onChange and onClear callbacks should update form state.
describe('DatePicker interaction', () => {
it('should update form value when date is picked via onChange', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'date', name: 'startDate', value: '' }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
{ dataFormat: 'json' },
)
render(<MarkdownForm node={node} />)
// Click the DatePicker trigger to open the popup
const trigger = screen.getByTestId('date-picker-trigger')
await user.click(trigger)
// Click the "Now" button in the footer to select current date (calls onChange)
const nowButton = await screen.findByText('time.operation.now')
await user.click(nowButton)
// Submit the form
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
// onChange was called with a Dayjs object that has .format, so formatDateForOutput is called
expect(mockFormatDateForOutput).toHaveBeenCalledWith(expect.anything(), false)
expect(mockOnSend).toHaveBeenCalled()
})
})
it('should clear form value when date is cleared via onClear', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'date', name: 'startDate', value: dayjs('2026-01-10') }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
{ dataFormat: 'json' },
)
render(<MarkdownForm node={node} />)
const clearIcon = screen.getByTestId('date-picker-clear-button')
await user.click(clearIcon)
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
// onClear sets value to undefined, which JSON.stringify omits
expect(mockOnSend).toHaveBeenCalledWith('{}')
})
})
})
// TimePicker rendering, onChange, and onClear should work correctly.
describe('TimePicker interaction', () => {
it('should render TimePicker for time input type', () => {
const node = createRootNode([
createElementNode('input', { type: 'time', name: 'meetingTime', value: '09:00' }),
])
render(<MarkdownForm node={node} />)
// The real TimePicker renders a trigger with a readonly input showing the formatted time
const timeInput = screen.getByTestId('time-picker-trigger').querySelector('input[readonly]') as HTMLInputElement
expect(timeInput).not.toBeNull()
expect(timeInput.value).toBe('09:00 AM')
})
it('should update form value when time is picked via onChange', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'time', name: 'meetingTime', value: '' }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
)
render(<MarkdownForm node={node} />)
// Click the TimePicker trigger to open the popup
const trigger = screen.getByTestId('time-picker-trigger')
await user.click(trigger)
// Click the "Now" button in the footer to select current time (calls onChange)
const nowButtons = await screen.findAllByText('time.operation.now')
await user.click(nowButtons[0])
// Submit the form
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
expect(mockOnSend).toHaveBeenCalled()
})
})
it('should clear form value when time is cleared via onClear', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'time', name: 'meetingTime', value: '09:00' }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
{ dataFormat: 'json' },
)
render(<MarkdownForm node={node} />)
// The TimePicker's clear icon has role="button" and an aria-label
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
await user.click(clearButton)
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
// onClear sets value to undefined, which JSON.stringify omits
expect(mockOnSend).toHaveBeenCalledWith('{}')
})
})
})
// Fallback branches for edge cases in tag rendering.
describe('Fallback branches', () => {
it('should render label with empty text when children array is empty', () => {
const node = createRootNode([
createElementNode('label', { for: 'field' }, []),
])
render(<MarkdownForm node={node} />)
const label = screen.getByTestId('label-field')
expect(label).not.toBeNull()
expect(label?.textContent).toBe('')
})
it('should render checkbox without tip text when dataTip is missing', () => {
const node = createRootNode([
createElementNode('input', { type: 'checkbox', name: 'agree', value: false }),
])
render(<MarkdownForm node={node} />)
expect(screen.getByTestId('checkbox-agree')).toBeInTheDocument()
})
it('should render select with no options when dataOptions is missing', () => {
const node = createRootNode([
createElementNode('input', { type: 'select', name: 'color', value: '' }),
])
render(<MarkdownForm node={node} />)
// Select renders with empty items list
expect(screen.getByTestId('markdown-form')).toBeInTheDocument()
})
it('should render button with empty text when children array is empty', () => {
const node = createRootNode([
createElementNode('button', {}, []),
])
render(<MarkdownForm node={node} />)
const button = screen.getByRole('button')
expect(button.textContent).toBe('')
})
})
})

View File

@@ -1,86 +0,0 @@
import { render, screen } from '@testing-library/react'
import { Img } from '..'
describe('Img', () => {
describe('Rendering', () => {
it('should render with the correct wrapper class', () => {
const { container } = render(<Img src="https://example.com/image.png" />)
const wrapper = container.querySelector('.markdown-img-wrapper')
expect(wrapper).toBeInTheDocument()
})
it('should render ImageGallery with the src as an array', () => {
render(<Img src="https://example.com/image.png" />)
const gallery = screen.getByTestId('image-gallery')
expect(gallery).toBeInTheDocument()
const images = gallery.querySelectorAll('img')
expect(images).toHaveLength(1)
expect(images[0]).toHaveAttribute('src', 'https://example.com/image.png')
})
it('should pass src as single element array to ImageGallery', () => {
const testSrc = 'https://example.com/test-image.jpg'
render(<Img src={testSrc} />)
const gallery = screen.getByTestId('image-gallery')
const images = gallery.querySelectorAll('img')
expect(images[0]).toHaveAttribute('src', testSrc)
})
it('should render with different src values', () => {
const { rerender } = render(<Img src="https://example.com/first.png" />)
expect(screen.getByTestId('gallery-image')).toHaveAttribute('src', 'https://example.com/first.png')
rerender(<Img src="https://example.com/second.jpg" />)
expect(screen.getByTestId('gallery-image')).toHaveAttribute('src', 'https://example.com/second.jpg')
})
})
describe('Props', () => {
it('should accept src prop with various URL formats', () => {
// Test with HTTPS URL
const { container: container1 } = render(<Img src="https://example.com/image.png" />)
expect(container1.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
// Test with HTTP URL
const { container: container2 } = render(<Img src="http://example.com/image.png" />)
expect(container2.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
// Test with data URL
const { container: container3 } = render(<Img src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" />)
expect(container3.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
// Test with relative URL
const { container: container4 } = render(<Img src="/images/photo.jpg" />)
expect(container4.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
})
it('should handle empty string src', () => {
const { container } = render(<Img src="" />)
const wrapper = container.querySelector('.markdown-img-wrapper')
expect(wrapper).toBeInTheDocument()
})
})
describe('Structure', () => {
it('should have exactly one wrapper div', () => {
const { container } = render(<Img src="https://example.com/image.png" />)
const wrappers = container.querySelectorAll('.markdown-img-wrapper')
expect(wrappers).toHaveLength(1)
})
it('should contain ImageGallery component inside wrapper', () => {
const { container } = render(<Img src="https://example.com/image.png" />)
const wrapper = container.querySelector('.markdown-img-wrapper')
const gallery = wrapper?.querySelector('[data-testid="image-gallery"]')
expect(gallery).toBeInTheDocument()
})
})
})

View File

@@ -1,121 +0,0 @@
import { getMarkdownImageURL, isValidUrl } from '../utils'
vi.mock('@/config', () => ({
ALLOW_UNSAFE_DATA_SCHEME: false,
MARKETPLACE_API_PREFIX: '/api/marketplace',
}))
describe('utils', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('isValidUrl', () => {
it('should return true for http: URLs', () => {
expect(isValidUrl('http://example.com')).toBe(true)
})
it('should return true for https: URLs', () => {
expect(isValidUrl('https://example.com')).toBe(true)
})
it('should return true for protocol-relative URLs', () => {
expect(isValidUrl('//cdn.example.com/image.png')).toBe(true)
})
it('should return true for mailto: URLs', () => {
expect(isValidUrl('mailto:user@example.com')).toBe(true)
})
it('should return false for data: URLs when ALLOW_UNSAFE_DATA_SCHEME is false', () => {
expect(isValidUrl('data:image/png;base64,abc123')).toBe(false)
})
it('should return false for javascript: URLs', () => {
expect(isValidUrl('javascript:alert(1)')).toBe(false)
})
it('should return false for ftp: URLs', () => {
expect(isValidUrl('ftp://files.example.com')).toBe(false)
})
it('should return false for relative paths', () => {
expect(isValidUrl('/images/photo.png')).toBe(false)
})
it('should return false for empty string', () => {
expect(isValidUrl('')).toBe(false)
})
it('should return false for plain text', () => {
expect(isValidUrl('not a url')).toBe(false)
})
})
describe('isValidUrl with ALLOW_UNSAFE_DATA_SCHEME enabled', () => {
beforeEach(() => {
vi.resetModules()
vi.doMock('@/config', () => ({
ALLOW_UNSAFE_DATA_SCHEME: true,
MARKETPLACE_API_PREFIX: '/api/marketplace',
}))
})
it('should return true for data: URLs when ALLOW_UNSAFE_DATA_SCHEME is true', async () => {
const { isValidUrl: isValidUrlWithData } = await import('../utils')
expect(isValidUrlWithData('data:image/png;base64,abc123')).toBe(true)
})
})
describe('getMarkdownImageURL', () => {
it('should return the original URL when it does not match the asset regex', () => {
expect(getMarkdownImageURL('https://example.com/image.png')).toBe('https://example.com/image.png')
})
it('should transform ./_assets URL without pathname', () => {
const result = getMarkdownImageURL('./_assets/icon.png')
expect(result).toBe('/api/marketplace/plugins//_assets/icon.png')
})
it('should transform ./_assets URL with pathname', () => {
const result = getMarkdownImageURL('./_assets/icon.png', 'my-plugin/')
expect(result).toBe('/api/marketplace/plugins/my-plugin//_assets/icon.png')
})
it('should transform _assets URL without leading dot-slash', () => {
const result = getMarkdownImageURL('_assets/logo.svg')
expect(result).toBe('/api/marketplace/plugins//_assets/logo.svg')
})
it('should transform _assets URL with pathname', () => {
const result = getMarkdownImageURL('_assets/logo.svg', 'org/plugin/')
expect(result).toBe('/api/marketplace/plugins/org/plugin//_assets/logo.svg')
})
it('should not transform URLs that contain _assets in the middle', () => {
expect(getMarkdownImageURL('https://cdn.example.com/_assets/image.png'))
.toBe('https://cdn.example.com/_assets/image.png')
})
it('should use empty string for pathname when undefined', () => {
const result = getMarkdownImageURL('./_assets/test.png')
expect(result).toBe('/api/marketplace/plugins//_assets/test.png')
})
})
describe('getMarkdownImageURL with trailing slash prefix', () => {
beforeEach(() => {
vi.resetModules()
vi.doMock('@/config', () => ({
ALLOW_UNSAFE_DATA_SCHEME: false,
MARKETPLACE_API_PREFIX: '/api/marketplace/',
}))
})
it('should not add extra slash when prefix ends with slash', async () => {
const { getMarkdownImageURL: getURL } = await import('../utils')
const result = getURL('./_assets/icon.png', 'my-plugin/')
expect(result).toBe('/api/marketplace/plugins/my-plugin//_assets/icon.png')
})
})
})

View File

@@ -90,7 +90,6 @@ const MarkdownForm = ({ node }: any) => {
<form
autoComplete="off"
className="flex flex-col self-stretch"
data-testid="markdown-form"
onSubmit={(e: any) => {
e.preventDefault()
e.stopPropagation()
@@ -103,7 +102,6 @@ const MarkdownForm = ({ node }: any) => {
key={index}
htmlFor={child.properties.htmlFor || child.properties.name}
className="my-2 text-text-secondary system-md-semibold"
data-testid="label-field"
>
{child.children[0]?.value || ''}
</label>

View File

@@ -1,3 +1,6 @@
// app/components/base/markdown/preprocess.spec.ts
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
/**
* Helper to (re)load the module with a mocked config value.
* We need to reset modules because the tested module imports

View File

@@ -8,9 +8,9 @@ vi.mock('@/app/components/base/markdown-blocks', () => ({
Link: ({ children, href }: { children?: ReactNode, href?: string }) => <a href={href}>{children}</a>,
MarkdownButton: ({ children }: PropsWithChildren) => <button>{children}</button>,
MarkdownForm: ({ children }: PropsWithChildren) => <form>{children}</form>,
Paragraph: ({ children }: PropsWithChildren) => <p data-testid="paragraph">{children}</p>,
Paragraph: ({ children }: PropsWithChildren) => <p>{children}</p>,
PluginImg: ({ alt }: { alt?: string }) => <span data-testid="plugin-img">{alt}</span>,
PluginParagraph: ({ children }: PropsWithChildren) => <p data-testid="plugin-paragraph">{children}</p>,
PluginParagraph: ({ children }: PropsWithChildren) => <p>{children}</p>,
ScriptBlock: () => null,
ThinkBlock: ({ children }: PropsWithChildren) => <details>{children}</details>,
VideoBlock: ({ children }: PropsWithChildren) => <div data-testid="video-block">{children}</div>,
@@ -105,85 +105,5 @@ describe('ReactMarkdownWrapper', () => {
expect(screen.getByText('italic text')).toBeInTheDocument()
expect(document.querySelector('em')).not.toBeNull()
})
it('should render standard Image component when pluginInfo is not provided', () => {
// Act
render(<ReactMarkdownWrapper latexContent="![standard-img](https://example.com/img.png)" />)
// Assert
expect(screen.getByTestId('img')).toBeInTheDocument()
})
it('should render a CodeBlock component for code markdown', async () => {
// Arrange
const content = '```javascript\nconsole.log("hello")\n```'
// Act
render(<ReactMarkdownWrapper latexContent={content} />)
// Assert
// We mocked code block to return <code>{children}</code>
const codeElement = await screen.findByText('console.log("hello")')
expect(codeElement).toBeInTheDocument()
})
})
describe('Plugin Info behavior', () => {
it('should render PluginImg and PluginParagraph when pluginInfo is provided', () => {
// Arrange
const content = 'This is a plugin paragraph\n\n![plugin-img](https://example.com/plugin.png)'
const pluginInfo = { pluginUniqueIdentifier: 'test-plugin', pluginId: 'plugin-1' }
// Act
render(<ReactMarkdownWrapper latexContent={content} pluginInfo={pluginInfo} />)
// Assert
expect(screen.getByTestId('plugin-img')).toBeInTheDocument()
expect(screen.queryByTestId('img')).toBeNull()
expect(screen.getAllByTestId('plugin-paragraph').length).toBeGreaterThan(0)
expect(screen.queryByTestId('paragraph')).toBeNull()
})
})
describe('Custom elements configuration', () => {
it('should use customComponents if provided', () => {
// Arrange
const customComponents = {
a: ({ children }: PropsWithChildren) => <a data-testid="custom-link">{children}</a>,
}
// Act
render(<ReactMarkdownWrapper latexContent="[link](https://example.com)" customComponents={customComponents} />)
// Assert
expect(screen.getByTestId('custom-link')).toBeInTheDocument()
})
it('should disallow customDisallowedElements', () => {
// Act - disallow strong (which is usually **bold**)
render(<ReactMarkdownWrapper latexContent="**bold**" customDisallowedElements={['strong']} />)
// Assert - strong element shouldn't be rendered (it will be stripped out)
expect(document.querySelector('strong')).toBeNull()
})
})
describe('Rehype AST modification', () => {
it('should remove ref attributes from elements', () => {
// Act
render(<ReactMarkdownWrapper latexContent={'<div ref="someRef">content</div>'} />)
// Assert - If ref isn't stripped, it gets passed to React DOM causing warnings, but here we just ensure content renders
expect(screen.getByText('content')).toBeInTheDocument()
})
it('should convert invalid tag names to text nodes', () => {
// Act - <custom-element> is invalid because it contains a hyphen
render(<ReactMarkdownWrapper latexContent="<custom-element>content</custom-element>" />)
// Assert - The AST node is changed to text with value `<custom-element`
expect(screen.getByText(/<custom-element/)).toBeInTheDocument()
})
})
})

View File

@@ -27,11 +27,6 @@ describe('Mermaid Flowchart Component', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.mocked(mermaid.initialize).mockImplementation(() => { })
vi.mocked(mermaid.render).mockResolvedValue({ svg: '<svg id="mermaid-chart">test-svg</svg>', diagramType: 'flowchart' })
})
afterEach(() => {
vi.useRealTimers()
})
describe('Rendering', () => {
@@ -137,86 +132,6 @@ describe('Mermaid Flowchart Component', () => {
}, { timeout: 3000 })
})
it('should keep selected look unchanged when clicking an already-selected look button', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} />)
})
await waitFor(() => screen.getByText('test-svg'), { timeout: 3000 })
const initialRenderCalls = vi.mocked(mermaid.render).mock.calls.length
const initialApiRenderCalls = vi.mocked(mermaid.mermaidAPI.render).mock.calls.length
await act(async () => {
fireEvent.click(screen.getByText(/classic/i))
})
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(initialRenderCalls)
expect(vi.mocked(mermaid.mermaidAPI.render).mock.calls.length).toBe(initialApiRenderCalls)
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
})
await waitFor(() => {
expect(screen.getByText('test-svg-api')).toBeInTheDocument()
}, { timeout: 3000 })
const afterFirstHandDrawnApiCalls = vi.mocked(mermaid.mermaidAPI.render).mock.calls.length
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
})
expect(vi.mocked(mermaid.mermaidAPI.render).mock.calls.length).toBe(afterFirstHandDrawnApiCalls)
})
it('should toggle theme from light to dark and back to light', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} theme="light" />)
})
await waitFor(() => {
expect(screen.getByText('test-svg')).toBeInTheDocument()
}, { timeout: 3000 })
const toggleBtn = screen.getByRole('button')
await act(async () => {
fireEvent.click(toggleBtn)
})
await waitFor(() => {
expect(screen.getByRole('button')).toHaveAttribute('title', expect.stringMatching(/switchLight$/))
}, { timeout: 3000 })
await act(async () => {
fireEvent.click(screen.getByRole('button'))
})
await waitFor(() => {
expect(screen.getByRole('button')).toHaveAttribute('title', expect.stringMatching(/switchDark$/))
}, { timeout: 3000 })
})
it('should configure handDrawn mode for dark non-flowchart diagrams', async () => {
const sequenceCode = 'sequenceDiagram\n A->>B: Hi'
await act(async () => {
render(<Flowchart PrimitiveCode={sequenceCode} theme="dark" />)
})
await waitFor(() => {
expect(screen.getByText('test-svg')).toBeInTheDocument()
}, { timeout: 3000 })
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
})
await waitFor(() => {
expect(screen.getByText('test-svg-api')).toBeInTheDocument()
}, { timeout: 3000 })
expect(mermaid.initialize).toHaveBeenCalledWith(expect.objectContaining({
theme: 'default',
themeVariables: expect.objectContaining({
primaryBorderColor: '#60a5fa',
}),
}))
})
it('should open image preview when clicking the chart', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} />)
@@ -229,7 +144,7 @@ describe('Mermaid Flowchart Component', () => {
fireEvent.click(chartDiv!)
})
await waitFor(() => {
expect(screen.getByTestId('image-preview-container')).toBeInTheDocument()
expect(document.body.querySelector('.image-preview-container')).toBeInTheDocument()
}, { timeout: 3000 })
})
})
@@ -249,79 +164,35 @@ describe('Mermaid Flowchart Component', () => {
const errorMsg = 'Syntax error'
vi.mocked(mermaid.render).mockRejectedValue(new Error(errorMsg))
try {
const uniqueCode = 'graph TD\n X-->Y\n Y-->Z'
render(<Flowchart PrimitiveCode={uniqueCode} />)
// Use unique code to avoid hitting the module-level diagramCache from previous tests
const uniqueCode = 'graph TD\n X-->Y\n Y-->Z'
const { container } = render(<Flowchart PrimitiveCode={uniqueCode} />)
const errorMessage = await screen.findByText(/Rendering failed/i)
expect(errorMessage).toBeInTheDocument()
}
finally {
consoleSpy.mockRestore()
}
})
it('should show unknown-error fallback when render fails without an error message', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
vi.mocked(mermaid.render).mockRejectedValue({} as Error)
try {
render(<Flowchart PrimitiveCode={'graph TD\n P-->Q\n Q-->R'} />)
expect(await screen.findByText(/Unknown error\. Please check the console\./i)).toBeInTheDocument()
}
finally {
consoleSpy.mockRestore()
}
})
await waitFor(() => {
const errorSpan = container.querySelector('.text-red-500 span.ml-2')
expect(errorSpan).toBeInTheDocument()
expect(errorSpan?.textContent).toContain('Rendering failed')
}, { timeout: 5000 })
consoleSpy.mockRestore()
// Restore default mock to prevent leaking into subsequent tests
vi.mocked(mermaid.render).mockResolvedValue({ svg: '<svg id="mermaid-chart">test-svg</svg>', diagramType: 'flowchart' })
}, 10000)
it('should use cached diagram if available', async () => {
const { rerender } = render(<Flowchart PrimitiveCode={mockCode} />)
// Wait for initial render to complete
await waitFor(() => {
expect(vi.mocked(mermaid.render)).toHaveBeenCalled()
}, { timeout: 3000 })
const initialCallCount = vi.mocked(mermaid.render).mock.calls.length
await waitFor(() => screen.getByText('test-svg'), { timeout: 3000 })
vi.mocked(mermaid.render).mockClear()
// Rerender with same code
await act(async () => {
rerender(<Flowchart PrimitiveCode={mockCode} />)
})
await waitFor(() => {
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(initialCallCount)
}, { timeout: 3000 })
// Call count should not increase (cache was used)
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(initialCallCount)
})
it('should keep previous svg visible while next render is loading', async () => {
let resolveSecondRender: ((value: { svg: string, diagramType: string }) => void) | null = null
const secondRenderPromise = new Promise<{ svg: string, diagramType: string }>((resolve) => {
resolveSecondRender = resolve
})
vi.mocked(mermaid.render)
.mockResolvedValueOnce({ svg: '<svg id="mermaid-chart">initial-svg</svg>', diagramType: 'flowchart' })
.mockImplementationOnce(() => secondRenderPromise)
const { rerender } = render(<Flowchart PrimitiveCode="graph TD\n A-->B" />)
await waitFor(() => {
expect(screen.getByText('initial-svg')).toBeInTheDocument()
}, { timeout: 3000 })
await act(async () => {
rerender(<Flowchart PrimitiveCode="graph TD\n C-->D" />)
await new Promise(resolve => setTimeout(resolve, 500))
})
expect(screen.getByText('initial-svg')).toBeInTheDocument()
resolveSecondRender!({ svg: '<svg id="mermaid-chart">second-svg</svg>', diagramType: 'flowchart' })
await waitFor(() => {
expect(screen.getByText('second-svg')).toBeInTheDocument()
}, { timeout: 3000 })
expect(mermaid.render).not.toHaveBeenCalled()
})
it('should handle invalid mermaid code completion', async () => {
@@ -335,116 +206,6 @@ describe('Mermaid Flowchart Component', () => {
}, { timeout: 3000 })
})
it('should keep single "after" gantt dependency formatting unchanged', async () => {
const singleAfterGantt = [
'gantt',
'title One after dependency',
'Single task :after task1, 2024-01-01, 1d',
].join('\n')
await act(async () => {
render(<Flowchart PrimitiveCode={singleAfterGantt} />)
})
await waitFor(() => {
expect(mermaid.render).toHaveBeenCalled()
}, { timeout: 3000 })
const lastRenderArgs = vi.mocked(mermaid.render).mock.calls.at(-1)
expect(lastRenderArgs?.[1]).toContain('Single task :after task1, 2024-01-01, 1d')
})
it('should use cache without rendering again when PrimitiveCode changes back to previous', async () => {
const firstCode = 'graph TD\n CacheOne-->CacheTwo'
const secondCode = 'graph TD\n CacheThree-->CacheFour'
const { rerender } = render(<Flowchart PrimitiveCode={firstCode} />)
// Wait for initial render
await waitFor(() => {
expect(vi.mocked(mermaid.render)).toHaveBeenCalled()
}, { timeout: 3000 })
const firstRenderCallCount = vi.mocked(mermaid.render).mock.calls.length
// Change to different code
await act(async () => {
rerender(<Flowchart PrimitiveCode={secondCode} />)
})
// Wait for second render
await waitFor(() => {
expect(vi.mocked(mermaid.render).mock.calls.length).toBeGreaterThan(firstRenderCallCount)
}, { timeout: 3000 })
const afterSecondRenderCallCount = vi.mocked(mermaid.render).mock.calls.length
// Change back to first code - should use cache
await act(async () => {
rerender(<Flowchart PrimitiveCode={firstCode} />)
})
await waitFor(() => {
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(afterSecondRenderCallCount)
}, { timeout: 3000 })
// Call count should not increase (cache was used)
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(afterSecondRenderCallCount)
})
it('should close image preview when cancel is clicked', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} />)
})
// Wait for SVG to be rendered
await waitFor(() => {
const svgElement = screen.queryByText('test-svg')
expect(svgElement).toBeInTheDocument()
}, { timeout: 3000 })
const mermaidDiv = screen.getByText('test-svg').closest('.mermaid')
await act(async () => {
fireEvent.click(mermaidDiv!)
})
// Wait for image preview to appear
const cancelBtn = await screen.findByTestId('image-preview-close-button')
expect(cancelBtn).toBeInTheDocument()
await act(async () => {
fireEvent.click(cancelBtn)
})
await waitFor(() => {
expect(screen.queryByTestId('image-preview-container')).not.toBeInTheDocument()
expect(screen.queryByTestId('image-preview-close-button')).not.toBeInTheDocument()
})
})
it('should handle configuration failure during configureMermaid', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
const originalMock = vi.mocked(mermaid.initialize).getMockImplementation()
vi.mocked(mermaid.initialize).mockImplementation(() => {
throw new Error('Config fail')
})
try {
await act(async () => {
render(<Flowchart PrimitiveCode="graph TD\n G-->H" />)
})
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith('Config error:', expect.any(Error))
})
}
finally {
consoleSpy.mockRestore()
if (originalMock) {
vi.mocked(mermaid.initialize).mockImplementation(originalMock)
}
else {
vi.mocked(mermaid.initialize).mockImplementation(() => { })
}
}
})
it('should handle unmount cleanup', async () => {
const { unmount } = render(<Flowchart PrimitiveCode={mockCode} />)
await act(async () => {
@@ -458,20 +219,6 @@ describe('Mermaid Flowchart Component Module Isolation', () => {
const mockCode = 'graph TD\n A-->B'
let mermaidFresh: typeof mermaid
const setWindowUndefined = () => {
const descriptor = Object.getOwnPropertyDescriptor(globalThis, 'window')
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: undefined,
})
return descriptor
}
const restoreWindowDescriptor = (descriptor?: PropertyDescriptor) => {
if (descriptor)
Object.defineProperty(globalThis, 'window', descriptor)
}
beforeEach(async () => {
vi.resetModules()
@@ -548,212 +295,5 @@ describe('Mermaid Flowchart Component Module Isolation', () => {
})
consoleSpy.mockRestore()
})
it('should load module safely when window is undefined', async () => {
const descriptor = setWindowUndefined()
try {
vi.resetModules()
const { default: FlowchartFresh } = await import('../index')
expect(FlowchartFresh).toBeDefined()
}
finally {
restoreWindowDescriptor(descriptor)
}
})
it('should skip configuration when window is unavailable before debounce execution', async () => {
const { default: FlowchartFresh } = await import('../index')
const descriptor = Object.getOwnPropertyDescriptor(globalThis, 'window')
vi.useFakeTimers()
try {
await act(async () => {
render(<FlowchartFresh PrimitiveCode={mockCode} />)
})
await Promise.resolve()
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: undefined,
})
await vi.advanceTimersByTimeAsync(350)
expect(mermaidFresh.render).not.toHaveBeenCalled()
}
finally {
if (descriptor)
Object.defineProperty(globalThis, 'window', descriptor)
vi.useRealTimers()
}
})
it.skip('should show container-not-found error when container ref remains null', async () => {
vi.resetModules()
vi.doMock('react', async () => {
const reactActual = await vi.importActual<typeof import('react')>('react')
let pendingContainerRef: ReturnType<typeof reactActual.useRef> | null = null
let patchedContainerRef = false
const mockedUseRef = ((initialValue: unknown) => {
const ref = reactActual.useRef(initialValue as never)
if (!patchedContainerRef && initialValue === null)
pendingContainerRef = ref
if (!patchedContainerRef
&& pendingContainerRef
&& typeof initialValue === 'string'
&& initialValue.startsWith('mermaid-chart-')) {
Object.defineProperty(pendingContainerRef, 'current', {
configurable: true,
get() {
return null
},
set(_value: HTMLDivElement | null) { },
})
patchedContainerRef = true
pendingContainerRef = null
}
return ref
}) as typeof reactActual.useRef
return {
...reactActual,
useRef: mockedUseRef,
}
})
try {
const { default: FlowchartFresh } = await import('../index')
render(<FlowchartFresh PrimitiveCode={mockCode} />)
expect(await screen.findByText('Container element not found')).toBeInTheDocument()
}
finally {
vi.doUnmock('react')
}
})
it('should tolerate missing hidden container during classic render and cleanup', async () => {
vi.resetModules()
let pendingContainerRef: unknown | null = null
let patchedContainerRef = false
let patchedTimeoutRef = false
let containerReadCount = 0
const virtualContainer = { innerHTML: 'seed' } as HTMLDivElement
vi.doMock('react', async () => {
const reactActual = await vi.importActual<typeof import('react')>('react')
const mockedUseRef = ((initialValue: unknown) => {
const ref = reactActual.useRef(initialValue as never)
if (!patchedContainerRef && initialValue === null)
pendingContainerRef = ref
if (!patchedContainerRef
&& pendingContainerRef
&& typeof initialValue === 'string'
&& initialValue.startsWith('mermaid-chart-')) {
Object.defineProperty(pendingContainerRef as { current: unknown }, 'current', {
configurable: true,
get() {
containerReadCount += 1
if (containerReadCount === 1)
return virtualContainer
return null
},
set(_value: HTMLDivElement | null) { },
})
patchedContainerRef = true
pendingContainerRef = null
}
if (patchedContainerRef && !patchedTimeoutRef && initialValue === undefined) {
patchedTimeoutRef = true
Object.defineProperty(ref, 'current', {
configurable: true,
get() {
return undefined
},
set(_value: NodeJS.Timeout | undefined) { },
})
return ref
}
return ref
}) as typeof reactActual.useRef
return {
...reactActual,
useRef: mockedUseRef,
}
})
try {
const { default: FlowchartFresh } = await import('../index')
const { unmount } = render(<FlowchartFresh PrimitiveCode={mockCode} />)
await waitFor(() => {
expect(screen.getByText('test-svg')).toBeInTheDocument()
}, { timeout: 3000 })
unmount()
}
finally {
vi.doUnmock('react')
}
})
it('should tolerate missing hidden container during handDrawn render', async () => {
vi.resetModules()
let pendingContainerRef: unknown | null = null
let patchedContainerRef = false
let containerReadCount = 0
const virtualContainer = { innerHTML: 'seed' } as HTMLDivElement
vi.doMock('react', async () => {
const reactActual = await vi.importActual<typeof import('react')>('react')
const mockedUseRef = ((initialValue: unknown) => {
const ref = reactActual.useRef(initialValue as never)
if (!patchedContainerRef && initialValue === null)
pendingContainerRef = ref
if (!patchedContainerRef
&& pendingContainerRef
&& typeof initialValue === 'string'
&& initialValue.startsWith('mermaid-chart-')) {
Object.defineProperty(pendingContainerRef as { current: unknown }, 'current', {
configurable: true,
get() {
containerReadCount += 1
if (containerReadCount === 1)
return virtualContainer
return null
},
set(_value: HTMLDivElement | null) { },
})
patchedContainerRef = true
pendingContainerRef = null
}
return ref
}) as typeof reactActual.useRef
return {
...reactActual,
useRef: mockedUseRef,
}
})
vi.useFakeTimers()
try {
const { default: FlowchartFresh } = await import('../index')
const { rerender } = render(<FlowchartFresh PrimitiveCode="graph" />)
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
rerender(<FlowchartFresh PrimitiveCode={mockCode} />)
await vi.advanceTimersByTimeAsync(350)
})
await Promise.resolve()
expect(screen.getByText('test-svg-api')).toBeInTheDocument()
}
finally {
vi.useRealTimers()
vi.doUnmock('react')
}
})
})
})

View File

@@ -1,4 +1,6 @@
import type { MermaidConfig } from 'mermaid'
import { ExclamationTriangleIcon } from '@heroicons/react/24/outline'
import { MoonIcon, SunIcon } from '@heroicons/react/24/solid'
import mermaid from 'mermaid'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
@@ -20,7 +22,7 @@ import {
// Global flags and cache for mermaid
let isMermaidInitialized = false
const diagramCache = new Map<string, string>()
let mermaidAPI: typeof mermaid.mermaidAPI | null = null
let mermaidAPI: any = null
if (typeof window !== 'undefined')
mermaidAPI = mermaid.mermaidAPI
@@ -133,7 +135,6 @@ const Flowchart = (props: FlowchartProps) => {
const renderMermaidChart = async (code: string, style: 'classic' | 'handDrawn') => {
if (style === 'handDrawn') {
// Special handling for hand-drawn style
/* v8 ignore next */
if (containerRef.current)
containerRef.current.innerHTML = `<div id="${chartId}"></div>`
await new Promise(resolve => setTimeout(resolve, 30))
@@ -151,7 +152,6 @@ const Flowchart = (props: FlowchartProps) => {
else {
// Standard rendering for classic style - using the extracted waitForDOMElement function
const renderWithRetry = async () => {
/* v8 ignore next */
if (containerRef.current)
containerRef.current.innerHTML = `<div id="${chartId}"></div>`
await new Promise(resolve => setTimeout(resolve, 30))
@@ -207,16 +207,20 @@ const Flowchart = (props: FlowchartProps) => {
}, [props.theme])
const renderFlowchart = useCallback(async (primitiveCode: string) => {
/* v8 ignore next */
if (!isInitialized || !containerRef.current) {
/* v8 ignore next */
setIsLoading(false)
/* v8 ignore next */
setErrMsg(!isInitialized ? 'Mermaid initialization failed' : 'Container element not found')
return
}
// Return cached result if available
const cacheKey = `${primitiveCode}-${look}-${currentTheme}`
if (diagramCache.has(cacheKey)) {
setErrMsg('')
setSvgString(diagramCache.get(cacheKey) || null)
setIsLoading(false)
return
}
setIsLoading(true)
setErrMsg('')
@@ -244,7 +248,9 @@ const Flowchart = (props: FlowchartProps) => {
// Rule 1: Correct multiple "after" dependencies ONLY if they exist.
// This is a common mistake, e.g., "..., after task1, after task2, ..."
paramsStr = paramsStr.replace(/,\s*after\s+/g, ' ')
const afterCount = (paramsStr.match(/after /g) || []).length
if (afterCount > 1)
paramsStr = paramsStr.replace(/,\s*after\s+/g, ' ')
// Rule 2: Normalize spacing between parameters for consistency.
const finalParams = paramsStr.replace(/\s*,\s*/g, ', ').trim()
@@ -280,8 +286,10 @@ const Flowchart = (props: FlowchartProps) => {
// Step 4: Clean up SVG code
const cleanedSvg = cleanUpSvgCode(processedSvg)
diagramCache.set(cacheKey, cleanedSvg as string)
setSvgString(cleanedSvg as string)
if (cleanedSvg && typeof cleanedSvg === 'string') {
diagramCache.set(cacheKey, cleanedSvg)
setSvgString(cleanedSvg)
}
setIsLoading(false)
}
@@ -413,7 +421,7 @@ const Flowchart = (props: FlowchartProps) => {
const cacheKey = `${props.PrimitiveCode}-${look}-${currentTheme}`
if (diagramCache.has(cacheKey)) {
setErrMsg('')
setSvgString(diagramCache.get(cacheKey)!)
setSvgString(diagramCache.get(cacheKey) || null)
setIsLoading(false)
return
}
@@ -423,23 +431,26 @@ const Flowchart = (props: FlowchartProps) => {
}, 300) // 300ms debounce
return () => {
clearTimeout(renderTimeoutRef.current)
if (renderTimeoutRef.current)
clearTimeout(renderTimeoutRef.current)
}
}, [props.PrimitiveCode, look, currentTheme, isInitialized, configureMermaid, renderFlowchart])
// Cleanup on unmount
useEffect(() => {
return () => {
if (containerRef.current)
containerRef.current.innerHTML = ''
if (renderTimeoutRef.current)
clearTimeout(renderTimeoutRef.current)
}
}, [])
const handlePreviewClick = async () => {
if (!svgString)
return
const base64 = await svgToBase64(svgString)
setImagePreviewUrl(base64)
if (svgString) {
const base64 = await svgToBase64(svgString)
setImagePreviewUrl(base64)
}
}
const toggleTheme = () => {
@@ -473,24 +484,20 @@ const Flowchart = (props: FlowchartProps) => {
'text-gray-300': currentTheme === Theme.dark,
}),
themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', {
'border border-gray-200 bg-white/80 text-gray-700 hover:bg-white hover:shadow-lg': currentTheme === Theme.light,
'border border-slate-600 bg-slate-800/80 text-yellow-300 hover:bg-slate-700 hover:shadow-lg': currentTheme === Theme.dark,
'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light,
'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark,
}),
}
// Style classes for look options
const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => {
return cn(
'mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary system-sm-medium',
'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary',
look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary',
currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300',
look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white',
)
}
const themeToggleTitleByTheme = {
light: t('theme.switchDark', { ns: 'app' }),
dark: t('theme.switchLight', { ns: 'app' }),
} as const
return (
<div ref={props.ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
@@ -548,10 +555,10 @@ const Flowchart = (props: FlowchartProps) => {
toggleTheme()
}}
className={themeClasses.themeToggle}
title={themeToggleTitleByTheme[currentTheme] || ''}
title={(currentTheme === Theme.light ? t('theme.switchDark', { ns: 'app' }) : t('theme.switchLight', { ns: 'app' })) || ''}
style={{ transform: 'translate3d(0, 0, 0)' }}
>
{currentTheme === Theme.light ? <span className="i-heroicons-moon-solid h-5 w-5" /> : <span className="i-heroicons-sun-solid h-5 w-5" />}
{currentTheme === Theme.light ? <MoonIcon className="h-5 w-5" /> : <SunIcon className="h-5 w-5" />}
</button>
</div>
@@ -565,7 +572,7 @@ const Flowchart = (props: FlowchartProps) => {
{errMsg && (
<div className={themeClasses.errorMessage}>
<div className="flex items-center">
<span className={`i-heroicons-exclamation-triangle ${themeClasses.errorIcon}`} />
<ExclamationTriangleIcon className={themeClasses.errorIcon} />
<span className="ml-2">{errMsg}</span>
</div>
</div>

File diff suppressed because it is too large Load Diff

View File

@@ -1,209 +0,0 @@
import type { LexicalEditor } from 'lexical'
import { act, waitFor } from '@testing-library/react'
import {
$createParagraphNode,
$createTextNode,
$getRoot,
$getSelection,
$isRangeSelection,
ParagraphNode,
TextNode,
} from 'lexical'
import {
createLexicalTestEditor,
expectInlineWrapperDom,
getNodeCount,
getNodesByType,
readEditorStateValue,
readRootTextContent,
renderLexicalEditor,
selectRootEnd,
setEditorRootText,
waitForEditorReady,
} from '../test-helpers'
describe('test-helpers', () => {
describe('renderLexicalEditor & waitForEditorReady', () => {
it('should render the editor and wait for it', async () => {
const { getEditor } = renderLexicalEditor({
namespace: 'TestNamespace',
nodes: [ParagraphNode, TextNode],
children: null,
})
const editor = await waitForEditorReady(getEditor)
expect(editor).toBeDefined()
expect(editor).toBe(getEditor())
})
it('should throw if wait times out without editor', async () => {
await expect(waitForEditorReady(() => null)).rejects.toThrow()
})
it('should throw if editor is null after waitFor completes', async () => {
let callCount = 0
await expect(
waitForEditorReady(() => {
callCount++
// Return non-null on the last check of `waitFor` so it passes,
// then null when actually retrieving the editor
return callCount === 1 ? ({} as LexicalEditor) : null
}),
).rejects.toThrow('Editor is not available')
})
it('should surface errors through configured onError callback', async () => {
const { getEditor } = renderLexicalEditor({
namespace: 'TestNamespace',
nodes: [ParagraphNode, TextNode],
children: null,
})
const editor = await waitForEditorReady(getEditor)
expect(() => {
editor.update(() => {
throw new Error('test error')
}, { discrete: true })
}).toThrow('test error')
})
})
describe('selectRootEnd', () => {
it('should select the end of the root', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
selectRootEnd(editor)
await waitFor(() => {
let isRangeSelection = false
editor.getEditorState().read(() => {
const selection = $getSelection()
isRangeSelection = $isRangeSelection(selection)
})
expect(isRangeSelection).toBe(true)
})
})
})
describe('Content Reading/Writing Helpers', () => {
it('should read root text content', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
act(() => {
editor.update(() => {
const root = $getRoot()
root.clear()
const paragraph = $createParagraphNode()
paragraph.append($createTextNode('Hello World'))
root.append(paragraph)
}, { discrete: true })
})
let content = ''
act(() => {
content = readRootTextContent(editor)
})
expect(content).toBe('Hello World')
})
it('should set editor root text and select end', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
setEditorRootText(editor, 'New Text', $createTextNode)
await waitFor(() => {
let content = ''
editor.getEditorState().read(() => {
content = $getRoot().getTextContent()
})
expect(content).toBe('New Text')
})
})
})
describe('Node Selection Helpers', () => {
it('should get node count', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
act(() => {
editor.update(() => {
const root = $getRoot()
root.clear()
root.append($createParagraphNode())
root.append($createParagraphNode())
}, { discrete: true })
})
let count = 0
act(() => {
count = getNodeCount(editor, ParagraphNode)
})
expect(count).toBe(2)
})
it('should get nodes by type', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
act(() => {
editor.update(() => {
const root = $getRoot()
root.clear()
root.append($createParagraphNode())
}, { discrete: true })
})
let nodes: ParagraphNode[] = []
act(() => {
nodes = getNodesByType(editor, ParagraphNode)
})
expect(nodes).toHaveLength(1)
expect(nodes[0]).not.toBeUndefined()
})
})
describe('readEditorStateValue', () => {
it('should read primitive values from editor state', () => {
const editor = createLexicalTestEditor('test', [ParagraphNode, TextNode])
const val = readEditorStateValue(editor, () => {
return $getRoot().isEmpty()
})
expect(val).toBe(true)
})
it('should throw if value is undefined', () => {
const editor = createLexicalTestEditor('test', [ParagraphNode, TextNode])
expect(() => {
readEditorStateValue(editor, () => undefined)
}).toThrow('Failed to read editor state value')
})
})
describe('createLexicalTestEditor', () => {
it('should expose createLexicalTestEditor with onError throw', () => {
const editor = createLexicalTestEditor('custom-namespace', [ParagraphNode, TextNode])
expect(editor).toBeDefined()
expect(() => {
editor.update(() => {
throw new Error('test error')
}, { discrete: true })
}).toThrow('test error')
})
})
describe('expectInlineWrapperDom', () => {
it('should assert wrapper properties on a valid DOM element', () => {
const div = document.createElement('div')
div.classList.add('inline-flex', 'items-center', 'align-middle', 'extra1', 'extra2')
expectInlineWrapperDom(div, ['extra1', 'extra2']) // Does not throw
})
})
})

View File

@@ -1,300 +0,0 @@
import type { RootNode } from 'lexical'
import { $createParagraphNode, $createTextNode, $getRoot, ParagraphNode, TextNode } from 'lexical'
import { describe, expect, it, vi } from 'vitest'
import { createTestEditor, withEditorUpdate } from './utils'
describe('Prompt Editor Test Utils', () => {
describe('createTestEditor', () => {
it('should create an editor without crashing', () => {
const editor = createTestEditor()
expect(editor).toBeDefined()
})
it('should create an editor with no nodes by default', () => {
const editor = createTestEditor()
expect(editor).toBeDefined()
})
it('should create an editor with provided nodes', () => {
const nodes = [ParagraphNode, TextNode]
const editor = createTestEditor(nodes)
expect(editor).toBeDefined()
})
it('should set up root element for the editor', () => {
const editor = createTestEditor()
// The editor should be properly initialized with a root element
expect(editor).toBeDefined()
})
it('should throw errors when they occur', () => {
const nodes = [ParagraphNode, TextNode]
const editor = createTestEditor(nodes)
expect(() => {
editor.update(() => {
throw new Error('Test error')
}, { discrete: true })
}).toThrow('Test error')
})
it('should allow multiple editors to be created independently', () => {
const editor1 = createTestEditor()
const editor2 = createTestEditor()
expect(editor1).not.toBe(editor2)
})
it('should initialize with basic node types', () => {
const nodes = [ParagraphNode, TextNode]
const editor = createTestEditor(nodes)
let content: string = ''
editor.update(() => {
const root = $getRoot()
const paragraph = $createParagraphNode()
const text = $createTextNode('Hello World')
paragraph.append(text)
root.append(paragraph)
content = root.getTextContent()
}, { discrete: true })
expect(content).toBe('Hello World')
})
})
describe('withEditorUpdate', () => {
it('should execute update function without crashing', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
const updateFn = vi.fn()
withEditorUpdate(editor, updateFn)
expect(updateFn).toHaveBeenCalled()
})
it('should pass discrete: true option to editor.update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
const updateSpy = vi.spyOn(editor, 'update')
withEditorUpdate(editor, () => {
$getRoot()
})
expect(updateSpy).toHaveBeenCalledWith(expect.any(Function), { discrete: true })
})
it('should allow updating editor state', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let textContent: string = ''
withEditorUpdate(editor, () => {
const root = $getRoot()
const paragraph = $createParagraphNode()
const text = $createTextNode('Test Content')
paragraph.append(text)
root.append(paragraph)
})
withEditorUpdate(editor, () => {
textContent = $getRoot().getTextContent()
})
expect(textContent).toBe('Test Content')
})
it('should handle multiple consecutive updates', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p1 = $createParagraphNode()
p1.append($createTextNode('First'))
root.append(p1)
})
withEditorUpdate(editor, () => {
const root = $getRoot()
const p2 = $createParagraphNode()
p2.append($createTextNode('Second'))
root.append(p2)
})
let content: string = ''
withEditorUpdate(editor, () => {
content = $getRoot().getTextContent()
})
expect(content).toContain('First')
expect(content).toContain('Second')
})
it('should provide access to editor state within update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let capturedState: RootNode | null = null
withEditorUpdate(editor, () => {
const root = $getRoot()
capturedState = root
})
expect(capturedState).toBeDefined()
})
it('should execute update function immediately', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let executed = false
withEditorUpdate(editor, () => {
executed = true
})
// Update should be executed synchronously in discrete mode
expect(executed).toBe(true)
})
it('should handle complex editor operations within update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let nodeCount: number = 0
withEditorUpdate(editor, () => {
const root = $getRoot()
for (let i = 0; i < 3; i++) {
const paragraph = $createParagraphNode()
paragraph.append($createTextNode(`Paragraph ${i}`))
root.append(paragraph)
}
// Count child nodes
nodeCount = root.getChildrenSize()
})
expect(nodeCount).toBe(3)
})
it('should allow reading editor state after update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const paragraph = $createParagraphNode()
paragraph.append($createTextNode('Read Test'))
root.append(paragraph)
})
let readContent: string = ''
withEditorUpdate(editor, () => {
readContent = $getRoot().getTextContent()
})
expect(readContent).toBe('Read Test')
})
it('should handle error thrown within update function', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
expect(() => {
withEditorUpdate(editor, () => {
throw new Error('Update error')
})
}).toThrow('Update error')
})
it('should preserve editor state across multiple updates', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Persistent'))
root.append(p)
})
let persistedContent: string = ''
withEditorUpdate(editor, () => {
persistedContent = $getRoot().getTextContent()
})
expect(persistedContent).toBe('Persistent')
})
})
describe('Integration', () => {
it('should work together to create and update editor', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Integration Test'))
root.append(p)
})
let result: string = ''
withEditorUpdate(editor, () => {
result = $getRoot().getTextContent()
})
expect(result).toBe('Integration Test')
})
it('should support multiple editors with isolated state', () => {
const editor1 = createTestEditor([ParagraphNode, TextNode])
const editor2 = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor1, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Editor 1'))
root.append(p)
})
withEditorUpdate(editor2, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Editor 2'))
root.append(p)
})
let content1: string = ''
let content2: string = ''
withEditorUpdate(editor1, () => {
content1 = $getRoot().getTextContent()
})
withEditorUpdate(editor2, () => {
content2 = $getRoot().getTextContent()
})
expect(content1).toBe('Editor 1')
expect(content2).toBe('Editor 2')
})
it('should handle nested paragraph and text nodes', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p1 = $createParagraphNode()
const p2 = $createParagraphNode()
p1.append($createTextNode('First Para'))
p2.append($createTextNode('Second Para'))
root.append(p1)
root.append(p2)
})
let content: string = ''
withEditorUpdate(editor, () => {
content = $getRoot().getTextContent()
})
expect(content).toContain('First Para')
expect(content).toContain('Second Para')
})
})
})

View File

@@ -1,251 +1,112 @@
import type { LexicalEditor } from 'lexical'
import type { JSX, RefObject } from 'react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { act, render, screen } from '@testing-library/react'
import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import DraggableBlockPlugin from '..'
type DraggableExperimentalProps = {
anchorElem: HTMLElement
menuRef: RefObject<HTMLDivElement>
targetLineRef: RefObject<HTMLDivElement>
menuComponent: JSX.Element | null
targetLineComponent: JSX.Element
isOnMenu: (element: HTMLElement) => boolean
onElementChanged: (element: HTMLElement | null) => void
const CONTENT_EDITABLE_TEST_ID = 'draggable-content-editable'
let namespaceCounter = 0
function renderWithEditor(anchorElem?: HTMLElement) {
render(
<LexicalComposer
initialConfig={{
namespace: `draggable-plugin-test-${namespaceCounter++}`,
onError: (error: Error) => { throw error },
}}
>
<RichTextPlugin
contentEditable={<ContentEditable data-testid={CONTENT_EDITABLE_TEST_ID} />}
placeholder={null}
ErrorBoundary={LexicalErrorBoundary}
/>
<DraggableBlockPlugin anchorElem={anchorElem} />
</LexicalComposer>,
)
return screen.getByTestId(CONTENT_EDITABLE_TEST_ID)
}
type MouseMoveHandler = (event: MouseEvent) => void
const { draggableMockState } = vi.hoisted(() => ({
draggableMockState: {
latestProps: null as DraggableExperimentalProps | null,
},
}))
vi.mock('@lexical/react/LexicalComposerContext')
vi.mock('@lexical/react/LexicalDraggableBlockPlugin', () => ({
DraggableBlockPlugin_EXPERIMENTAL: (props: DraggableExperimentalProps) => {
draggableMockState.latestProps = props
return (
<div data-testid="draggable-plugin-experimental-mock">
{props.menuComponent}
{props.targetLineComponent}
</div>
)
},
}))
function createRootElementMock() {
let mouseMoveHandler: MouseMoveHandler | null = null
const addEventListener = vi.fn((eventName: string, handler: EventListenerOrEventListenerObject) => {
if (eventName === 'mousemove' && typeof handler === 'function')
mouseMoveHandler = handler as MouseMoveHandler
})
const removeEventListener = vi.fn()
return {
rootElement: {
addEventListener,
removeEventListener,
} as unknown as HTMLElement,
addEventListener,
removeEventListener,
getMouseMoveHandler: () => mouseMoveHandler,
}
}
function getRegisteredMouseMoveHandler(
rootMock: ReturnType<typeof createRootElementMock>,
): MouseMoveHandler {
const handler = rootMock.getMouseMoveHandler()
if (!handler)
throw new Error('Expected mousemove handler to be registered')
return handler
}
function setupEditorRoot(rootElement: HTMLElement | null) {
const editor = {
getRootElement: vi.fn(() => rootElement),
} as unknown as LexicalEditor
vi.mocked(useLexicalComposerContext).mockReturnValue([
editor,
{},
] as unknown as ReturnType<typeof useLexicalComposerContext>)
return editor
function appendChildToRoot(rootElement: HTMLElement, className = '') {
const element = document.createElement('div')
element.className = className
rootElement.appendChild(element)
return element
}
describe('DraggableBlockPlugin', () => {
beforeEach(() => {
vi.clearAllMocks()
draggableMockState.latestProps = null
})
describe('Rendering', () => {
it('should use body as default anchor and render target line', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
renderWithEditor()
render(<DraggableBlockPlugin />)
expect(draggableMockState.latestProps?.anchorElem).toBe(document.body)
expect(screen.getByTestId('draggable-target-line')).toBeInTheDocument()
const targetLine = screen.getByTestId('draggable-target-line')
expect(targetLine).toBeInTheDocument()
expect(document.body.contains(targetLine)).toBe(true)
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
})
it('should render with custom anchor when provided', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
const anchorElem = document.createElement('div')
it('should render inside custom anchor element when provided', () => {
const customAnchor = document.createElement('div')
document.body.appendChild(customAnchor)
render(<DraggableBlockPlugin anchorElem={anchorElem} />)
renderWithEditor(customAnchor)
expect(draggableMockState.latestProps?.anchorElem).toBe(anchorElem)
expect(screen.getByTestId('draggable-target-line')).toBeInTheDocument()
})
const targetLine = screen.getByTestId('draggable-target-line')
expect(customAnchor.contains(targetLine)).toBe(true)
it('should return early when editor root element is null', () => {
const editor = setupEditorRoot(null)
render(<DraggableBlockPlugin />)
expect(editor.getRootElement).toHaveBeenCalledTimes(1)
expect(screen.getByTestId('draggable-target-line')).toBeInTheDocument()
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
customAnchor.remove()
})
})
describe('Drag support detection', () => {
it('should show menu when target has support-drag class', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const target = document.createElement('div')
target.className = 'support-drag'
act(() => {
onMove({ target } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
it('should show menu when target contains a support-drag descendant', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const target = document.createElement('div')
target.appendChild(Object.assign(document.createElement('span'), { className: 'support-drag' }))
act(() => {
onMove({ target } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
it('should show menu when target is inside a support-drag ancestor', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const ancestor = document.createElement('div')
ancestor.className = 'support-drag'
const child = document.createElement('span')
ancestor.appendChild(child)
act(() => {
onMove({ target: child } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
it('should hide menu when target does not support drag', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const supportDragTarget = document.createElement('div')
supportDragTarget.className = 'support-drag'
act(() => {
onMove({ target: supportDragTarget } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
const plainTarget = document.createElement('div')
act(() => {
onMove({ target: plainTarget } as unknown as MouseEvent)
})
describe('Drag Support Detection', () => {
it('should render drag menu when mouse moves over a support-drag element', async () => {
const rootElement = renderWithEditor()
const supportDragTarget = appendChildToRoot(rootElement, 'support-drag')
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
fireEvent.mouseMove(supportDragTarget)
await waitFor(() => {
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
})
it('should keep menu hidden when event target becomes null', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
it('should hide drag menu when support-drag target is removed and mouse moves again', async () => {
const rootElement = renderWithEditor()
const supportDragTarget = appendChildToRoot(rootElement, 'support-drag')
const onMove = getRegisteredMouseMoveHandler(rootMock)
const supportDragTarget = document.createElement('div')
supportDragTarget.className = 'support-drag'
act(() => {
onMove({ target: supportDragTarget } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
act(() => {
onMove({ target: null } as unknown as MouseEvent)
fireEvent.mouseMove(supportDragTarget)
await waitFor(() => {
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
supportDragTarget.remove()
fireEvent.mouseMove(rootElement)
await waitFor(() => {
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
})
})
})
describe('Forwarded callbacks', () => {
it('should forward isOnMenu and detect menu membership correctly', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
describe('Menu Detection Contract', () => {
it('should render menu with draggable-block-menu class and keep non-menu elements outside it', async () => {
const rootElement = renderWithEditor()
const supportDragTarget = appendChildToRoot(rootElement, 'support-drag')
const onMove = getRegisteredMouseMoveHandler(rootMock)
const supportDragTarget = document.createElement('div')
supportDragTarget.className = 'support-drag'
act(() => {
onMove({ target: supportDragTarget } as unknown as MouseEvent)
})
fireEvent.mouseMove(supportDragTarget)
const renderedMenu = screen.getByTestId('draggable-menu')
const isOnMenu = draggableMockState.latestProps?.isOnMenu
if (!isOnMenu)
throw new Error('Expected isOnMenu callback')
const menuIcon = await screen.findByTestId('draggable-menu-icon')
expect(menuIcon.closest('.draggable-block-menu')).not.toBeNull()
const menuIcon = screen.getByTestId('draggable-menu-icon')
const outsideElement = document.createElement('div')
expect(isOnMenu(menuIcon)).toBe(true)
expect(isOnMenu(renderedMenu)).toBe(true)
expect(isOnMenu(outsideElement)).toBe(false)
})
it('should register and cleanup mousemove listener on mount and unmount', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
const { unmount } = render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
expect(rootMock.addEventListener).toHaveBeenCalledWith('mousemove', expect.any(Function))
unmount()
expect(rootMock.removeEventListener).toHaveBeenCalledWith('mousemove', onMove)
const normalElement = document.createElement('div')
document.body.appendChild(normalElement)
expect(normalElement.closest('.draggable-block-menu')).toBeNull()
normalElement.remove()
})
})
})

View File

@@ -1,10 +1,8 @@
import type { LexicalCommand } from 'lexical'
import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { createCommand } from 'lexical'
import * as React from 'react'
import { useState } from 'react'
import ShortcutsPopupPlugin, { SHORTCUTS_EMPTY_CONTENT } from '../index'
@@ -23,9 +21,6 @@ const mockDOMRect = {
toJSON: () => ({}),
}
const originalRangeGetClientRects = Range.prototype.getClientRects
const originalRangeGetBoundingClientRect = Range.prototype.getBoundingClientRect
beforeAll(() => {
// Mock getClientRects on Range prototype
Range.prototype.getClientRects = vi.fn(() => {
@@ -39,31 +34,12 @@ beforeAll(() => {
Range.prototype.getBoundingClientRect = vi.fn(() => mockDOMRect as DOMRect)
})
afterAll(() => {
Range.prototype.getClientRects = originalRangeGetClientRects
Range.prototype.getBoundingClientRect = originalRangeGetBoundingClientRect
})
const CONTAINER_ID = 'host'
const CONTENT_EDITABLE_ID = 'ce'
type MinimalEditorProps = {
const MinimalEditor: React.FC<{
withContainer?: boolean
hotkey?: string | string[] | string[][] | ((e: KeyboardEvent) => boolean)
children?: React.ReactNode | ((close: () => void, onInsert: (command: LexicalCommand<unknown>, params: unknown[]) => void) => React.ReactNode)
className?: string
onOpen?: () => void
onClose?: () => void
}
const MinimalEditor: React.FC<MinimalEditorProps> = ({
withContainer = true,
hotkey,
children,
className,
onOpen,
onClose,
}) => {
}> = ({ withContainer = true }) => {
const initialConfig = {
namespace: 'shortcuts-popup-plugin-test',
onError: (e: Error) => {
@@ -82,35 +58,25 @@ const MinimalEditor: React.FC<MinimalEditorProps> = ({
/>
<ShortcutsPopupPlugin
container={withContainer ? containerEl : undefined}
hotkey={hotkey}
className={className}
onOpen={onOpen}
onClose={onClose}
>
{children}
</ShortcutsPopupPlugin>
/>
</div>
</LexicalComposer>
)
}
/** Helper: focus the content editable and trigger a hotkey. */
function focusAndTriggerHotkey(key: string, modifiers: Partial<Record<'ctrlKey' | 'metaKey' | 'altKey' | 'shiftKey', boolean>> = { ctrlKey: true }) {
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key, ...modifiers })
}
describe('ShortcutsPopupPlugin', () => {
// ─── Basic open / close ───
it('opens on hotkey when editor is focused', async () => {
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true }) // 模拟 Ctrl+/
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not open when editor is not focused', async () => {
render(<MinimalEditor />)
// 未聚焦
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
@@ -119,7 +85,10 @@ describe('ShortcutsPopupPlugin', () => {
it('closes on Escape', async () => {
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
fireEvent.keyDown(document, { key: 'Escape' })
@@ -142,370 +111,24 @@ describe('ShortcutsPopupPlugin', () => {
})
})
// ─── Container / portal ───
it('portals into provided container when container is set', async () => {
render(<MinimalEditor withContainer />)
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
const host = screen.getByTestId(CONTAINER_ID)
focusAndTriggerHotkey('/')
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
const portalContent = await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
expect(host).toContainElement(portalContent)
})
it('falls back to document.body when container is not provided', async () => {
render(<MinimalEditor withContainer={false} />)
focusAndTriggerHotkey('/')
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
const portalContent = await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
expect(document.body).toContainElement(portalContent)
})
// ─── matchHotkey: string hotkey ───
it('matches a string hotkey like "mod+/"', async () => {
render(<MinimalEditor hotkey="mod+/" />)
focusAndTriggerHotkey('/', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches ctrl+/ when hotkey is "mod+/" (mod matches ctrl or meta)', async () => {
render(<MinimalEditor hotkey="mod+/" />)
focusAndTriggerHotkey('/', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── matchHotkey: string[] hotkey ───
it('matches when hotkey is a string array like ["mod", "/"]', async () => {
render(<MinimalEditor hotkey={['mod', '/']} />)
focusAndTriggerHotkey('/', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── matchHotkey: string[][] (nested) hotkey ───
it('matches when hotkey is a nested array (any combo matches)', async () => {
render(<MinimalEditor hotkey={[['ctrl', 'k'], ['meta', 'j']]} />)
focusAndTriggerHotkey('k', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches the second combo in a nested array', async () => {
render(<MinimalEditor hotkey={[['ctrl', 'k'], ['meta', 'j']]} />)
focusAndTriggerHotkey('j', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match nested array when no combo matches', async () => {
render(<MinimalEditor hotkey={[['ctrl', 'k'], ['meta', 'j']]} />)
focusAndTriggerHotkey('x', { ctrlKey: true })
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── matchHotkey: function hotkey ───
it('matches when hotkey is a custom function returning true', async () => {
const customMatcher = (e: KeyboardEvent) => e.key === 'F1'
render(<MinimalEditor hotkey={customMatcher} />)
focusAndTriggerHotkey('F1', {})
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match when custom function returns false', async () => {
const customMatcher = (e: KeyboardEvent) => e.key === 'F1'
render(<MinimalEditor hotkey={customMatcher} />)
focusAndTriggerHotkey('F2', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── matchHotkey: modifier aliases ───
it('matches meta/cmd/command aliases', async () => {
render(<MinimalEditor hotkey="cmd+k" />)
focusAndTriggerHotkey('k', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches "command" alias for meta', async () => {
render(<MinimalEditor hotkey="command+k" />)
focusAndTriggerHotkey('k', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match meta alias when meta is not pressed', async () => {
render(<MinimalEditor hotkey="cmd+k" />)
focusAndTriggerHotkey('k', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
it('matches alt/option alias', async () => {
render(<MinimalEditor hotkey="alt+a" />)
focusAndTriggerHotkey('a', { altKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match alt alias when alt is not pressed', async () => {
render(<MinimalEditor hotkey="alt+a" />)
focusAndTriggerHotkey('a', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
it('matches shift alias', async () => {
render(<MinimalEditor hotkey="shift+s" />)
focusAndTriggerHotkey('s', { shiftKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match shift alias when shift is not pressed', async () => {
render(<MinimalEditor hotkey="shift+s" />)
focusAndTriggerHotkey('s', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
it('matches ctrl alias', async () => {
render(<MinimalEditor hotkey="ctrl+b" />)
focusAndTriggerHotkey('b', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match ctrl alias when ctrl is not pressed', async () => {
render(<MinimalEditor hotkey="ctrl+b" />)
focusAndTriggerHotkey('b', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── matchHotkey: space key normalization ───
it('normalizes space key to "space" for matching', async () => {
render(<MinimalEditor hotkey="ctrl+space" />)
focusAndTriggerHotkey(' ', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── matchHotkey: key mismatch ───
it('does not match when expected key does not match pressed key', async () => {
render(<MinimalEditor hotkey="ctrl+z" />)
focusAndTriggerHotkey('x', { ctrlKey: true })
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── Children rendering ───
it('renders children as ReactNode when provided', async () => {
render(
<MinimalEditor>
<div data-testid="custom-content">My Content</div>
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
expect(await screen.findByTestId('custom-content')).toBeInTheDocument()
expect(screen.getByText('My Content')).toBeInTheDocument()
})
it('renders children as render function and provides close/onInsert', async () => {
const TEST_COMMAND = createCommand<unknown>('TEST_COMMAND')
const childrenFn = vi.fn((close: () => void, onInsert: (cmd: LexicalCommand<unknown>, params: unknown[]) => void) => (
<div>
<button type="button" data-testid="close-btn" onClick={close}>Close</button>
<button type="button" data-testid="insert-btn" onClick={() => onInsert(TEST_COMMAND, ['param1'])}>Insert</button>
</div>
))
render(
<MinimalEditor>
{childrenFn}
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
// Children render function should have been called
expect(await screen.findByTestId('close-btn')).toBeInTheDocument()
expect(screen.getByTestId('insert-btn')).toBeInTheDocument()
})
it('renders SHORTCUTS_EMPTY_CONTENT when children is undefined', async () => {
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── handleInsert callback ───
it('calls close after insert via children render function', async () => {
const TEST_COMMAND = createCommand<unknown>('TEST_INSERT_COMMAND')
render(
<MinimalEditor>
{(close: () => void, onInsert: (cmd: LexicalCommand<unknown>, params: unknown[]) => void) => (
<div>
<button type="button" data-testid="insert-btn" onClick={() => onInsert(TEST_COMMAND, ['value'])}>Insert</button>
</div>
)}
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
const insertBtn = await screen.findByTestId('insert-btn')
fireEvent.click(insertBtn)
// After insert, the popup should close
await waitFor(() => {
expect(screen.queryByTestId('insert-btn')).not.toBeInTheDocument()
})
})
it('calls close via children render function close callback', async () => {
render(
<MinimalEditor>
{(close: () => void) => (
<button type="button" data-testid="close-via-fn" onClick={close}>Close</button>
)}
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
const closeBtn = await screen.findByTestId('close-via-fn')
fireEvent.click(closeBtn)
await waitFor(() => {
expect(screen.queryByTestId('close-via-fn')).not.toBeInTheDocument()
})
})
// ─── onOpen / onClose callbacks ───
it('calls onOpen when popup opens', async () => {
const onOpen = vi.fn()
render(<MinimalEditor onOpen={onOpen} />)
focusAndTriggerHotkey('/')
await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
expect(onOpen).toHaveBeenCalledTimes(1)
})
it('calls onClose when popup closes', async () => {
const onClose = vi.fn()
render(<MinimalEditor onClose={onClose} />)
focusAndTriggerHotkey('/')
await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
fireEvent.keyDown(document, { key: 'Escape' })
await waitFor(() => {
expect(onClose).toHaveBeenCalledTimes(1)
})
})
// ─── className prop ───
it('applies custom className to floating popup', async () => {
render(<MinimalEditor className="custom-popup-class" />)
focusAndTriggerHotkey('/')
const content = await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
const floatingDiv = content.closest('div')
expect(floatingDiv).toHaveClass('custom-popup-class')
})
// ─── mousedown inside portal should not close ───
it('does not close on mousedown inside the portal', async () => {
render(
<MinimalEditor>
<div data-testid="portal-inner">Inner content</div>
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
const inner = await screen.findByTestId('portal-inner')
fireEvent.mouseDown(inner)
// Should still be open
await waitFor(() => {
expect(screen.getByTestId('portal-inner')).toBeInTheDocument()
})
})
it('prevents default and stops propagation on Escape when open', async () => {
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
const preventDefaultSpy = vi.fn()
const stopPropagationSpy = vi.fn()
// Use a custom event to capture preventDefault/stopPropagation calls
const escEvent = new KeyboardEvent('keydown', { key: 'Escape', bubbles: true, cancelable: true })
Object.defineProperty(escEvent, 'preventDefault', { value: preventDefaultSpy })
Object.defineProperty(escEvent, 'stopPropagation', { value: stopPropagationSpy })
document.dispatchEvent(escEvent)
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
expect(preventDefaultSpy).toHaveBeenCalledTimes(1)
expect(stopPropagationSpy).toHaveBeenCalledTimes(1)
})
// ─── Zero-rect fallback in openPortal ───
it('handles zero-size range rects by falling back to node bounding rect', async () => {
// Temporarily override getClientRects to return zero-size rect
const zeroRect = { x: 0, y: 0, width: 0, height: 0, top: 0, right: 0, bottom: 0, left: 0, toJSON: () => ({}) }
const originalGetClientRects = Range.prototype.getClientRects
const originalGetBoundingClientRect = Range.prototype.getBoundingClientRect
Range.prototype.getClientRects = vi.fn(() => {
const rectList = [zeroRect] as unknown as DOMRectList
Object.defineProperty(rectList, 'length', { value: 1 })
Object.defineProperty(rectList, 'item', { value: () => zeroRect })
return rectList
})
Range.prototype.getBoundingClientRect = vi.fn(() => zeroRect as DOMRect)
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
// Restore
Range.prototype.getClientRects = originalGetClientRects
Range.prototype.getBoundingClientRect = originalGetBoundingClientRect
})
it('handles empty getClientRects by using getBoundingClientRect fallback', async () => {
const originalGetClientRects = Range.prototype.getClientRects
const originalGetBoundingClientRect = Range.prototype.getBoundingClientRect
Range.prototype.getClientRects = vi.fn(() => {
const rectList = [] as unknown as DOMRectList
Object.defineProperty(rectList, 'length', { value: 0 })
Object.defineProperty(rectList, 'item', { value: () => null })
return rectList
})
Range.prototype.getBoundingClientRect = vi.fn(() => mockDOMRect as DOMRect)
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
Range.prototype.getClientRects = originalGetClientRects
Range.prototype.getBoundingClientRect = originalGetBoundingClientRect
})
// ─── Combined modifier hotkeys ───
it('matches hotkey with multiple modifiers: ctrl+shift+k', async () => {
render(<MinimalEditor hotkey="ctrl+shift+k" />)
focusAndTriggerHotkey('k', { ctrlKey: true, shiftKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches "option" alias for alt', async () => {
render(<MinimalEditor hotkey="option+o" />)
focusAndTriggerHotkey('o', { altKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match mod hotkey when neither ctrl nor meta is pressed', async () => {
render(<MinimalEditor hotkey="mod+k" />)
focusAndTriggerHotkey('k', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
})

Some files were not shown because too many files have changed in this diff Show More