Compare commits

...

88 Commits

Author SHA1 Message Date
yyh
00a79a3e26 Merge branch 'feat/model-provider-refactor' into deploy/dev
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
2026-03-05 14:34:45 +08:00
yyh
7471c32612 Revert "temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing"
This reverts commit ab87ac333a.
2026-03-05 14:33:48 +08:00
yyh
88d87d6053 Merge branch 'feat/model-provider-refactor' into deploy/dev 2026-03-05 14:24:55 +08:00
yyh
2d333bbbe5 refactor(web): extract credential activation into hook and migrate credential-item overlays
Extract credential switching logic from dropdown-content into a dedicated
useActivateCredential hook with optimistic updates and proper data flow
separation. Credential items now stay visible in the popover after clicking
(no auto-close), show cursor-pointer, and disable during activation.

Migrate credential-item from legacy Tooltip and remixicon imports to
base-ui Tooltip and CSS icon classes, pruning stale ESLint suppressions.
2026-03-05 14:22:39 +08:00
yyh
4af6788ce0 fix(web): wrap Header test in Dialog context for base-ui compatibility 2026-03-05 14:20:35 +08:00
yyh
24b072def9 fix: lint 2026-03-05 14:08:20 +08:00
yyh
909c8c3350 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 13:58:51 +08:00
Coding On Star
89a859ae32 refactor: simplify oauthNewUser state handling in AppInitializer component (#33019)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-05 13:53:20 +08:00
yyh
80e9c8bee0 refactor(web): make account setting fully controlled with action props 2026-03-05 13:39:36 +08:00
yyh
15b7b304d2 refactor(web): migrate model-modal overlays to base-ui Dialog and AlertDialog
Replace legacy PortalToFollowElem and Confirm with Dialog/AlertDialog
primitives. Remove manual ESC handler and backdrop div — now handled
natively by base-ui. Add backdropProps={{ forceRender: true }} for
correct nested overlay rendering.
2026-03-05 13:33:53 +08:00
yyh
61e2672b59 refactor(web): make provider reset event-driven and scope model invalidation
- remove provider-page lifecycle reset effect and handle reset in explicit tab/close actions
- switch account setting tab state to controlled/uncontrolled pattern without sync effect
- use provider-scoped model list queryKey with exact invalidation in credential and model toggle mutations
- update related tests and mocks for new behavior
2026-03-05 13:28:30 +08:00
yyh
5f4ed4c6f6 refactor(web): replace model provider emitter refresh with jotai state
- add atom-based provider expansion state with reset/prune helpers
- remove event-emitter dependency from model provider refresh flow
- invalidate exact provider model-list query key on refresh
- reset expansion state on model provider page mount/unmount
- update and extend tests for external expansion and query invalidation
- update eslint suppressions to match current code
2026-03-05 13:20:58 +08:00
yyh
4a1032c628 fix(web): remove redundant hover text swap on show models button
Merge the two hover-toggling divs into a single always-visible element
and remove the unused showModelsNum i18n key from all locales.
2026-03-05 13:16:04 +08:00
yyh
423c97a47e code style 2026-03-05 13:09:33 +08:00
yyh
a7e3fb2e33 fix(web): use triangle Warning icon instead of circle error icon
Replace i-ri-error-warning-fill (circle exclamation) with the
Warning component (triangle) for api-fallback and credits-fallback
variants to match Figma design.
2026-03-05 13:07:20 +08:00
yyh
ce34937a1c feat(web): add credits-fallback variant for API Key priority with available credits
When API Key is selected but unavailable/unconfigured and credits are
available, the card now shows "AI credits in use" with a warning icon
instead of "API key required". When both credits are exhausted and no
API key exists, it shows "No available usage" (destructive).

New deriveVariant logic for priority=apiKey:
- !exhausted + !authorized → credits-fallback (was api-required-*)
- exhausted + no credential → no-usage (was api-required-add)
- exhausted + named unauthorized → api-unavailable (unchanged)
2026-03-05 13:02:40 +08:00
yyh
ad9ac6978e fix(web): align alert card width with API key section in dropdown
Change mx-1 (4px) to mx-2 (8px) on CreditsFallbackAlert and
CreditsExhaustedAlert to match ApiKeySection's p-2 (8px) padding,
consistent with Figma design where both sections are 8px from the
dropdown edge.
2026-03-05 12:56:55 +08:00
yyh
e2b247d762 Merge branch 'feat/model-provider-refactor' into deploy/dev 2026-03-05 12:25:53 +08:00
yyh
57c1ba3543 fix(web): hide divider above empty API keys state in dropdown
Move the border from UsagePrioritySection (always visible) to
ApiKeySection's list variant (only when credentials exist). This
removes the unwanted divider line above the "No API Keys" empty
state card when on the AI Credits tab with no keys configured.
2026-03-05 12:25:11 +08:00
yyh
37b15acd0d Merge branch 'feat/model-provider-refactor' into deploy/dev 2026-03-05 10:48:48 +08:00
yyh
d7a5af2b9a Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/index.tsx
2026-03-05 10:46:24 +08:00
Stephen Zhou
a4373d8b7b chore: i18n hmr support, fix hmr for app context (#32997) 2026-03-05 10:45:16 +08:00
yyh
0ce9ebed63 Merge branch 'feat/model-provider-refactor' into deploy/dev 2026-03-05 10:40:55 +08:00
yyh
d45edffaa3 fix(web): wire upgrade link to pricing modal and add credits-coin icon
Replace broken HTML string interpolation with Trans component and
useModalContextSelector so "upgrade your plan" opens the pricing modal.
Add custom credits-coin SVG icon to replace the generic ri-coin-line.
2026-03-05 10:39:31 +08:00
yyh
530515b6ef fix(web): prevent model list from expanding on priority switch
Remove UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST event emission from
changePriority onSuccess. This event was designed for custom model
add/edit/delete scenarios where the card should expand, but firing
it on priority switch caused ProviderAddedCard to unexpectedly
expand via refreshModelList → setCollapsed(false).
2026-03-05 10:35:03 +08:00
yyh
1fa68d0863 Merge branch 'feat/model-provider-refactor' into deploy/dev 2026-03-05 10:20:02 +08:00
yyh
f13f0d1f9a fix(web): align dropdown alerts with Figma design and fix hardcoded credits total
- Expose totalCredits from useTrialCredits hook instead of hardcoding 10,000
- Align CreditsExhaustedAlert with Figma: dynamic progress bar, correct
  design tokens (components-progress-error-bg/progress), sm-medium/xs-regular
  typography
- Align CreditsFallbackAlert typography to sm-medium/xs-regular
- Fix ApiKeySection empty state: horizontal gradient, sm-medium title,
  Figma-aligned padding (pl-7 for API KEYS label)
- Hoist empty credentials array constant to stabilize memo (rerender-memo-with-default-value)
- Remove redundant useCallback wrapper in ApiKeySection
- Replace nested ternary with Record lookup in TextLabel
- Remove dead || 0 guard in useTrialCredits
- Update all test mocks with totalCredits field
2026-03-05 10:09:51 +08:00
yyh
b597d52c11 refactor(web): remove dialog description from system model selector
Remove the DialogDescription and its i18n key (modelProvider.systemModelSettingsLink)
from the system model settings dialog across all 23 locales.
2026-03-05 10:05:01 +08:00
yyh
34c42fe666 Revert "temp: remove cloud condition"
This reverts commit 29e344ac8b.
2026-03-05 09:44:19 +08:00
yyh
dc109c99f0 test(web): expand credential panel and dropdown test coverage for all 8 card variants
Add comprehensive behavioral tests covering all discriminated union variants,
destructive/default styling, warning icons, CreditsFallbackAlert conditions,
credential CRUD interactions, AlertDialog delete confirmation, and Popover behavior.
2026-03-05 09:41:48 +08:00
hj24
4a0770192e feat: add export app messages
fix: tests
2026-03-05 09:41:32 +08:00
yyh
223b9d89c1 refactor(web): migrate priority change to oRPC contract with useMutation
- Add changePreferredProviderType contract in model-providers.ts
- Register in consoleRouterContract
- Replace raw async changeModelProviderPriority with useMutation
- Use Toast.notify (static API) instead of useToastContext hook
- Pass isPending as isChangingPriority to disable buttons during switch
- Add disabled prop to UsagePrioritySection
- Fix pre-existing test assertions for api-unavailable variant
- Update all specs with isChangingPriority prop and oRPC mock pattern
2026-03-05 09:30:38 +08:00
yyh
dd119eb44f fix(web): align UsagePrioritySection with Figma design and fix i18n key ordering
- Single-row layout for icon, label, and option cards
- Icon: arrow-up-double-line matching design spec
- Buttons: flexible width with whitespace-nowrap instead of fixed w-[72px]
- Add min-w-0 + truncate for text overflow, focus-visible ring for a11y
- Sort modelProvider.card.* i18n keys alphabetically
2026-03-05 09:15:16 +08:00
Eric Cao
164ccb7c48 chore: add TypedDict related prompt to AGENTS.md (#32995)
Co-authored-by: caoergou <caogou123@163.com>
2026-03-05 10:08:18 +09:00
木之本澪
2b5ce196ad test: migrate test_document_service_rename_document SQL tests to testcontainers (#32542)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-03-05 10:07:29 +09:00
dependabot[bot]
2977a4d2a4 chore(deps): bump immutable from 5.1.4 to 5.1.5 in /web (#33000)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-05 09:50:52 +09:00
dependabot[bot]
a0331b8b45 chore(deps): bump authlib from 1.6.6 to 1.6.7 in /api (#32998)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-05 09:50:12 +09:00
dependabot[bot]
914bd4d00d chore(deps): bump fickling from 0.1.8 to 0.1.9 in /api (#32999)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-05 09:49:43 +09:00
yyh
9c9cb50981 chore: expand overlay migration ESLint rules to popover, dropdown, dialog (#33001) 2026-03-05 08:47:54 +08:00
yyh
970493fa85 test(web): update tests for credential panel refactoring and new ModelAuthDropdown components
Rewrite credential-panel.spec.tsx to match the new discriminated union
state model and variant-driven rendering. Add new test files for
useCredentialPanelState hook, SystemQuotaCard Label enhancement,
and all ModelAuthDropdown sub-components.
2026-03-05 08:41:17 +08:00
yyh
ab87ac333a temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing 2026-03-05 08:34:10 +08:00
yyh
b8b70da9ad refactor(web): rewrite CredentialPanel with declarative variant-driven state and new ModelAuthDropdown
- Extract useCredentialPanelState hook with discriminated union CardVariant type replacing scattered boolean conditions
- Create ModelAuthDropdown compound component (Popover-based) with UsagePrioritySection, CreditsExhaustedAlert, and ApiKeySection
- Enhance SystemQuotaCard.Label to accept className override for flexible styling
- Add i18n keys for new card states and dropdown content (en-US, zh-Hans)
2026-03-05 08:33:04 +08:00
木之本澪
df3c66a8ac test: migrate duplicate_document_indexing_task SQL tests to testcontainers (#32540)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-03-05 02:35:24 +09:00
Stephen Zhou
7252ce6f26 chore: add react refresh plugin for vinext (#32996) 2026-03-05 01:09:32 +08:00
yyh
26d96f97a7 Merge branch 'feat/model-provider-refactor' into deploy/dev
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2026-03-04 23:39:41 +08:00
yyh
77d81aebe8 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-04 23:35:20 +08:00
yyh
deb4cd3ece fix: i18n 2026-03-04 23:35:13 +08:00
yyh
648d9ef1f9 refactor(web): extract SystemQuotaCard compound component and shared useTrialCredits hook
Extract trial credits calculation into a shared useTrialCredits hook to prevent
logic drift between QuotaPanel and CredentialPanel. Add SystemQuotaCard compound
component with explicit default/destructive variants for the system quota UI
state in provider cards, replacing inline conditional styling with composable
Label and Actions slots. Remove unnecessary useMemo for simple derived values.
2026-03-04 23:30:25 +08:00
yyh
5ed4797078 fix 2026-03-04 22:53:29 +08:00
yyh
62631658e9 fix(web): update tests for AlertDialog migration and component API changes
- Replace deprecated Confirm mock with real AlertDialog role-based queries
- Add useInvalidateCheckInstalled mock for QueryClient dependency
- Wrap model-list-item renders in QueryClientProvider
- Migrate PluginVersionPicker from PortalToFollowElem to Popover
- Migrate UpdatePluginModal from Modal to Dialog
- Update version picker offset props (sideOffset/alignOffset)
2026-03-04 22:52:21 +08:00
yyh
22a4100dd7 fix(web): invalidate plugin checkInstalled cache after version updates 2026-03-04 22:33:17 +08:00
yyh
0f7ed6f67e refactor(web): align provider badges with figma and remove dead add-model-button 2026-03-04 22:29:51 +08:00
yyh
4d9fcbec57 refactor(web): migrate remove-plugin dialog to base UI AlertDialog and improve UX
- Replace deprecated Confirm component with AlertDialog primitives
- Add forceRender backdrop for proper overlay rendering
- Add success Toast notification after plugin removal
- Update "View Detail" text to "View on Marketplace" (en/zh-Hans)
- Add i18n keys for delete success message
- Prune stale eslint suppression for header-modals
2026-03-04 22:14:19 +08:00
yyh
4d7a9bc798 fix(web): align model provider cache invalidation with oRPC keys 2026-03-04 22:06:27 +08:00
yyh
d6d04ed657 fix 2026-03-04 22:03:06 +08:00
yyh
f594a71dae fix: icon 2026-03-04 22:02:36 +08:00
yyh
04e0ab7eda refactor(web): migrate provider-added-card model list to oRPC query-driven state 2026-03-04 21:55:34 +08:00
yyh
784bda9c86 refactor(web): migrate operation-dropdown to base UI and align provider card styles with Figma
- Migrate OperationDropdown from legacy portal-to-follow-elem to base UI DropdownMenu primitives
- Add placement, sideOffset, alignOffset, popupClassName props for flexible positioning
- Fix version badge font size: system-2xs-medium-uppercase (10px) → system-xs-medium-uppercase (12px)
- Set provider card dropdown to bottom-start placement with 192px width per Figma spec
- Fix PluginVersionPicker toggle: clicking badge now opens and closes the picker
- Add max-h-[224px] overflow scroll to version list
- Replace Remix icon imports with Tailwind CSS icon classes
- Prune stale eslint suppressions for migrated files
2026-03-04 21:55:23 +08:00
yyh
1af1fb6913 feat(web): add version badge and actions menu to provider cards
Integrate plugin version management into model provider cards by
reusing existing plugin detail panel hooks and components. Batch
query installed plugins at list level to avoid N+1 requests.
2026-03-04 21:29:52 +08:00
yyh
1f0c36e9f7 fix: style 2026-03-04 21:07:42 +08:00
yyh
455ae65025 fix: style 2026-03-04 20:58:14 +08:00
yyh
d44682e957 refactor(web): align quota panel with Figma design and migrate to base UI tooltip
- Rename title from "Quota" to "AI Credits" and update tooltip copy
  (Message Credits → AI Credits, free → Trial)
- Show "Credits exhausted" in destructive text when credits reach zero
  instead of displaying the number "0"
- Migrate from deprecated Tooltip to base UI Tooltip compound component
- Add 4px grid background with radial fade mask via CSS module
- Simplify provider icon tooltip text for uninstalled state
- Update i18n keys for both en-US and zh-Hans
2026-03-04 20:52:30 +08:00
yyh
8c4afc0c18 fix(model-selector): align empty trigger with default trigger style 2026-03-04 20:14:49 +08:00
yyh
539cbcae6a fix(account-settings): render nested system model backdrop via base ui 2026-03-04 19:57:53 +08:00
yyh
8d257fea7c chore(web): commit dialog overlay follow-up changes 2026-03-04 19:37:10 +08:00
yyh
c3364ac350 refactor(web): align account settings dialogs with base UI 2026-03-04 19:31:14 +08:00
yyh
f991644989 refactor(pricing): migrate to base ui dialog and extract category types 2026-03-04 19:26:54 +08:00
yyh
29e344ac8b temp: remove cloud condition 2026-03-04 18:50:38 +08:00
yyh
1ad9305732 fix(web): avoid quota panel flicker on account-setting tab switch
- remove mount-time workspace invalidate in model provider page

- read quota with useCurrentWorkspace and keep loading only for initial empty fetch

- reuse existing useSystemFeaturesQuery for marketplace and trial models

- update model provider and quota panel tests for new query/loading behavior
2026-03-04 18:43:01 +08:00
yyh
17f38f171d lint 2026-03-04 18:21:59 +08:00
yyh
802088c8eb test(web): fix trivial assertion and add useInvalidateDefaultModel tests
Replace the no-provider test assertion from checking a nonexistent i18n
key to verifying actual warning keys are absent. Add unit tests for
useInvalidateDefaultModel following the useUpdateModelList pattern.
2026-03-04 17:51:20 +08:00
yyh
cad6d94491 refactor(web): replace remixicon imports with Tailwind CSS icons in system-model-selector 2026-03-04 17:45:41 +08:00
yyh
621d0fb2c9 fix 2026-03-04 17:42:34 +08:00
yyh
efdd88f78a Merge branch 'feat/system-model-settings-4-states' into deploy/dev
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2026-03-04 17:40:55 +08:00
yyh
a92fb3244b fix(web): skip top warning for no-provider state and remove unused i18n key
The empty state card below already prompts users to install a provider,
so the top warning bar is redundant for the no-provider case. Remove
the unused noProviderInstalled i18n key and replace the lookup map with
a ternary to preserve i18n literal types without assertions.
2026-03-04 17:39:49 +08:00
yyh
97508f8d7b fix(web): invalidate default model cache after saving system model settings
After saving system models, only the model list cache was invalidated
but not the default model cache, causing stale config status in the UI.
Add useInvalidateDefaultModel hook and call it for all 5 model types
after a successful save.
2026-03-04 17:26:24 +08:00
yyh
b2f84bf081 Merge branch 'feat/system-model-settings-4-states' into deploy/dev 2026-03-04 16:59:42 +08:00
yyh
70e677a6ac feat(web): refine system model settings to 4 distinct config states
Replace the single `defaultModelNotConfigured` boolean with a derived
`systemModelConfigStatus` that distinguishes between no-provider,
none-configured, partially-configured, and fully-configured states,
each showing a context-appropriate warning message. Also updates the
button label from "System Model Settings" to "Default Model Settings"
and migrates remixicon imports to Tailwind CSS icon classes.
2026-03-04 16:58:46 +08:00
Yansong Zhang
2330aac623 fix comment
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2026-03-04 12:02:38 +08:00
Yansong Zhang
97769c5c7a Merge branch 'feat/notification' of github.com:langgenius/dify into feat/notification 2026-03-04 11:39:26 +08:00
Yansong Zhang
09ae3a9b52 return array 2026-03-04 11:38:45 +08:00
autofix-ci[bot]
9589bba713 [autofix.ci] apply automated fixes 2026-03-02 07:00:50 +00:00
Yansong Zhang
6473c1419b Merge remote-tracking branch 'origin/main' into feat/notification
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2026-03-02 14:58:00 +08:00
Yansong Zhang
d1a0b9695c fix linter 2026-03-02 14:55:24 +08:00
Yansong Zhang
3147e44a0b add notification use saas 2026-03-02 14:55:15 +08:00
autofix-ci[bot]
c243e91668 [autofix.ci] apply automated fixes 2026-02-10 08:30:13 +00:00
Yansong Zhang
004fbbe52b add notification logic for backend 2026-02-10 16:13:06 +08:00
Yansong Zhang
63fb0ddde5 add notification logic for backend 2026-02-10 16:12:59 +08:00
131 changed files with 5966 additions and 2390 deletions

View File

@@ -29,7 +29,7 @@ The codebase is split into:
## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
## General Practices

View File

@@ -2668,3 +2668,68 @@ def clean_expired_messages(
raise
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option("--filename", required=True, help="Output filename (local path or cloud storage key).")
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -39,6 +39,7 @@ from . import (
feature,
human_input_form,
init_validate,
notification,
ping,
setup,
spec,
@@ -184,6 +185,7 @@ __all__ = [
"model_config",
"model_providers",
"models",
"notification",
"oauth",
"oauth_server",
"ops_trace",

View File

@@ -1,3 +1,5 @@
import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
@@ -6,7 +8,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
@@ -16,6 +18,7 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
@@ -277,3 +280,168 @@ class DeleteExploreBannerApi(Resource):
db.session.commit()
return {"result": "success"}, 204
class LangContentPayload(BaseModel):
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
title: str = Field(...)
subtitle: str | None = Field(default=None)
body: str = Field(...)
title_pic_url: str | None = Field(default=None)
class UpsertNotificationPayload(BaseModel):
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
contents: list[LangContentPayload] = Field(..., min_length=1)
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
status: str = Field(default="active", description="'active' | 'inactive'")
class BatchAddNotificationAccountsPayload(BaseModel):
notification_id: str = Field(...)
user_email: list[str] = Field(..., description="List of account email addresses")
console_ns.schema_model(
UpsertNotificationPayload.__name__,
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
BatchAddNotificationAccountsPayload.__name__,
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/admin/upsert_notification")
class UpsertNotificationApi(Resource):
@console_ns.doc("upsert_notification")
@console_ns.doc(
description=(
"Create or update an in-product notification. "
"Supply notification_id to update an existing one; omit it to create a new one. "
"Pass at least one language variant in contents (zh / en / jp)."
)
)
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
@console_ns.response(200, "Notification upserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[c.model_dump() for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,
start_time=payload.start_time,
end_time=payload.end_time,
)
return {"result": "success", "notification_id": result.get("notificationId")}, 200
@console_ns.route("/admin/batch_add_notification_accounts")
class BatchAddNotificationAccountsApi(Resource):
@console_ns.doc("batch_add_notification_accounts")
@console_ns.doc(
description=(
"Register target accounts for a notification by email address. "
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
"plus a 'notification_id' field. "
"Emails that do not match any account are silently skipped."
)
)
@console_ns.response(200, "Accounts added successfully")
@only_edition_cloud
@admin_required
def post(self):
from models.account import Account
if "file" in request.files:
notification_id = request.form.get("notification_id", "").strip()
if not notification_id:
raise BadRequest("notification_id is required.")
emails = self._parse_emails_from_file()
else:
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
notification_id = payload.notification_id
emails = payload.user_email
if not emails:
raise BadRequest("No valid email addresses provided.")
# Resolve emails → account IDs in chunks to avoid large IN-clause
account_ids: list[str] = []
chunk_size = 500
for i in range(0, len(emails), chunk_size):
chunk = emails[i : i + chunk_size]
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
account_ids.extend(str(row.id) for row in rows)
if not account_ids:
raise BadRequest("None of the provided emails matched an existing account.")
# Send to dify-saas in batches of 1000
total_count = 0
batch_size = 1000
for i in range(0, len(account_ids), batch_size):
batch = account_ids[i : i + batch_size]
result = BillingService.batch_add_notification_accounts(
notification_id=notification_id,
account_ids=batch,
)
total_count += result.get("count", 0)
return {
"result": "success",
"emails_provided": len(emails),
"accounts_matched": len(account_ids),
"count": total_count,
}, 200
@staticmethod
def _parse_emails_from_file() -> list[str]:
"""Parse email addresses from an uploaded CSV or TXT file."""
file = request.files["file"]
if not file.filename:
raise BadRequest("Uploaded file has no filename.")
filename_lower = file.filename.lower()
if not filename_lower.endswith((".csv", ".txt")):
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
try:
content = file.read().decode("utf-8")
except UnicodeDecodeError:
try:
file.seek(0)
content = file.read().decode("gbk")
except UnicodeDecodeError:
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
emails: list[str] = []
if filename_lower.endswith(".csv"):
reader = csv.reader(io.StringIO(content))
for row in reader:
for cell in row:
cell = cell.strip()
if cell:
emails.append(cell)
else:
for line in content.splitlines():
line = line.strip()
if line:
emails.append(line)
# Deduplicate while preserving order
seen: set[str] = set()
unique_emails: list[str] = []
for email in emails:
if email.lower() not in seen:
seen.add(email.lower())
unique_emails.append(email)
return unique_emails

View File

@@ -0,0 +1,108 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
_FALLBACK_LANG = "en"
# Maps dify interface_language prefixes to notification lang tags.
# Any unrecognised prefix falls back to _FALLBACK_LANG.
_LANG_MAP: dict[str, str] = {
"zh": "zh",
"ja": "jp",
}
def _resolve_lang(interface_language: str | None) -> str:
"""Derive the notification lang tag from the user's interface_language.
e.g. "zh-Hans""zh", "ja-JP""jp", "en-US" / None → "en"
"""
if not interface_language:
return _FALLBACK_LANG
prefix = interface_language.split("-")[0].lower()
return _LANG_MAP.get(prefix, _FALLBACK_LANG)
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
class DismissNotificationPayload(BaseModel):
notification_id: str = Field(...)
@console_ns.route("/notification")
class NotificationApi(Resource):
@console_ns.doc("get_notification")
@console_ns.doc(
description=(
"Return the active in-product notification for the current user "
"in their interface language (falls back to English if unavailable). "
"The notification is NOT marked as seen here; call POST /notification/dismiss "
"when the user explicitly closes the modal."
),
responses={
200: "Success — inspect should_show to decide whether to render the modal",
401: "Unauthorized",
},
)
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, _ = current_account_with_tenant()
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
if not result.get("shouldShow"):
return {"should_show": False, "notifications": []}, 200
lang = _resolve_lang(current_user.interface_language)
notifications = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
notifications.append(
{
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
)
return {"should_show": bool(notifications), "notifications": notifications}, 200
@console_ns.route("/notification/dismiss")
class NotificationDismissApi(Resource):
@console_ns.doc("dismiss_notification")
@console_ns.doc(
description="Mark a notification as dismissed for the current user.",
responses={200: "Success", 401: "Unauthorized"},
)
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def post(self):
current_user, _ = current_account_with_tenant()
payload = DismissNotificationPayload.model_validate(request.get_json())
BillingService.dismiss_notification(
notification_id=payload.notification_id,
account_id=str(current_user.id),
)
return {"result": "success"}, 200

View File

@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -393,3 +393,78 @@ class BillingService:
for item in data:
tenant_whitelist.append(item["tenant_id"])
return tenant_whitelist
@classmethod
def get_account_notification(cls, account_id: str) -> dict:
"""Return the active in-product notification for account_id, if any.
Calling this endpoint also marks the notification as seen; subsequent
calls will return should_show=false when frequency='once'.
Response shape (mirrors GetAccountNotificationReply):
{
"should_show": bool,
"notification": { # present only when should_show=true
"notification_id": str,
"contents": { # lang -> LangContent
"en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...},
...
},
"frequency": "once" | "every_page_load"
}
}
"""
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
@classmethod
def upsert_notification(
cls,
contents: list[dict],
frequency: str = "once",
status: str = "active",
notification_id: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
) -> dict:
"""Create or update a notification.
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
Returns {"notification_id": str}.
"""
payload: dict = {
"contents": contents,
"frequency": frequency,
"status": status,
}
if notification_id:
payload["notification_id"] = notification_id
if start_time:
payload["start_time"] = start_time
if end_time:
payload["end_time"] = end_time
return cls._send_request("POST", "/notifications", json=payload)
@classmethod
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
"""Register target account IDs for a notification (max 1000 per call).
Returns {"count": int}.
"""
return cls._send_request(
"POST",
f"/notifications/{notification_id}/accounts",
json={"account_ids": account_ids},
)
@classmethod
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
"""Mark a notification as dismissed for an account.
Returns {"success": bool}.
"""
return cls._send_request(
"POST",
f"/notifications/{notification_id}/dismiss",
json={"account_id": account_id},
)

View File

@@ -0,0 +1,258 @@
"""
Export app messages to JSONL.GZ format.
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
retriever_resources (from message_metadata), feedback (user feedbacks array).
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
Does NOT touch Message.inputs / Message.user_feedback properties.
"""
import datetime
import gzip
import json
import logging
import tempfile
from collections import defaultdict
from collections.abc import Generator, Iterable
from typing import Any, BinaryIO, cast
import orjson
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, tuple_
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message, MessageFeedback
logger = logging.getLogger(__name__)
class AppMessageExportFeedback(BaseModel):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: str
updated_at: str
model_config = ConfigDict(extra="forbid")
class AppMessageExportRecord(BaseModel):
conversation_id: str
message_id: str
query: str
answer: str
inputs: dict[str, Any]
retriever_resources: list[Any] = Field(default_factory=list)
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class AppMessageExportStats(BaseModel):
batches: int = 0
total_messages: int = 0
messages_with_feedback: int = 0
total_feedbacks: int = 0
model_config = ConfigDict(extra="forbid")
class AppMessageExportService:
def __init__(
self,
app_id: str,
end_before: datetime.datetime,
filename: str,
*,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
use_cloud_storage: bool = False,
dry_run: bool = False,
) -> None:
if start_from and start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
self._app_id = app_id
self._end_before = end_before
self._start_from = start_from
self._filename = filename
self._batch_size = batch_size
self._use_cloud_storage = use_cloud_storage
self._dry_run = dry_run
def run(self) -> AppMessageExportStats:
stats = AppMessageExportStats()
logger.info(
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s",
self._app_id,
self._start_from,
self._end_before,
self._dry_run,
self._use_cloud_storage,
)
if self._dry_run:
for _ in self._iter_records_with_stats(stats):
pass
self._finalize_stats(stats)
return stats
if self._use_cloud_storage:
self._export_to_cloud(stats)
else:
self._export_to_local(stats)
self._finalize_stats(stats)
return stats
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
for batch in self._iter_record_batches():
yield from batch
@staticmethod
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
for record in records:
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
def _export_to_local(self, stats: AppMessageExportStats) -> None:
with open(self._filename, "wb") as output_file:
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
tmp.seek(0)
data = tmp.read()
storage.save(self._filename, data)
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self._filename)
def _iter_records_with_stats(
self, stats: AppMessageExportStats
) -> Generator[AppMessageExportRecord, None, None]:
for record in self.iter_records():
self._update_stats(stats, record)
yield record
@staticmethod
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
stats.total_messages += 1
if record.feedback:
stats.messages_with_feedback += 1
stats.total_feedbacks += len(record.feedback)
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
if stats.total_messages == 0:
stats.batches = 0
return
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
cursor: tuple[datetime.datetime, str] | None = None
while True:
rows, cursor = self._fetch_batch(cursor)
if not rows:
break
message_ids = [str(row.id) for row in rows]
feedbacks_map = self._fetch_feedbacks(message_ids)
yield [self._build_record(row, feedbacks_map) for row in rows]
def _fetch_batch(
self, cursor: tuple[datetime.datetime, str] | None
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(
Message.id,
Message.conversation_id,
Message.query,
Message.answer,
Message._inputs, # pyright: ignore[reportPrivateUsage]
Message.message_metadata,
Message.created_at,
)
.where(
Message.app_id == self._app_id,
Message.created_at < self._end_before,
)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
stmt = stmt.where(Message.created_at >= self._start_from)
if cursor:
stmt = stmt.where(
tuple_(Message.created_at, Message.id)
> tuple_(
sa.literal(cursor[0], type_=sa.DateTime()),
sa.literal(cursor[1], type_=Message.id.type),
)
)
rows = list(session.execute(stmt).all())
if not rows:
return [], cursor
last = rows[-1]
return rows, (last.created_at, last.id)
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
if not message_ids:
return {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(MessageFeedback)
.where(
MessageFeedback.message_id.in_(message_ids),
MessageFeedback.from_source == "user",
)
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
)
feedbacks = list(session.scalars(stmt).all())
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
for feedback in feedbacks:
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
return result
@staticmethod
def _build_record(
row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]
) -> AppMessageExportRecord:
retriever_resources: list[Any] = []
if row.message_metadata:
try:
metadata = json.loads(row.message_metadata)
value = metadata.get("retriever_resources", [])
if isinstance(value, list):
retriever_resources = value
except (json.JSONDecodeError, TypeError):
pass
message_id = str(row.id)
return AppMessageExportRecord(
conversation_id=str(row.conversation_id),
message_id=message_id,
query=row.query,
answer=row.answer,
inputs=row._inputs if isinstance(row._inputs, dict) else {},
retriever_resources=retriever_resources,
feedback=feedbacks_map.get(message_id, []),
)

View File

@@ -0,0 +1,252 @@
"""Container-backed integration tests for DocumentService.rename_document real SQL paths."""
import datetime
import json
from unittest.mock import create_autospec, patch
from uuid import uuid4
import pytest
from models import Account
from models.dataset import Dataset, Document
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.dataset_service import DocumentService
FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0)
@pytest.fixture
def mock_env():
"""Patch only non-SQL dependency used by rename_document: current_user context."""
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
current_user.current_tenant_id = str(uuid4())
current_user.id = str(uuid4())
yield {"current_user": current_user}
def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, built_in_field_enabled=False):
"""Persist a dataset row for rename_document integration scenarios."""
dataset_id = dataset_id or str(uuid4())
tenant_id = tenant_id or str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
created_by=str(uuid4()),
)
dataset.id = dataset_id
dataset.built_in_field_enabled = built_in_field_enabled
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def make_document(
db_session_with_containers,
document_id=None,
dataset_id=None,
tenant_id=None,
name="Old Name",
data_source_info=None,
doc_metadata=None,
):
"""Persist a document row used by rename_document integration scenarios."""
document_id = document_id or str(uuid4())
dataset_id = dataset_id or str(uuid4())
tenant_id = tenant_id or str(uuid4())
doc = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=1,
data_source_type="upload_file",
data_source_info=json.dumps(data_source_info or {}),
batch=f"batch-{uuid4()}",
name=name,
created_from="web",
created_by=str(uuid4()),
doc_form="text_model",
)
doc.id = document_id
doc.indexing_status = "completed"
doc.doc_metadata = dict(doc_metadata or {})
db_session_with_containers.add(doc)
db_session_with_containers.commit()
return doc
def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, name: str):
"""Persist an upload file row referenced by document.data_source_info."""
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=f"uploads/{uuid4()}",
name=name,
size=128,
extension="pdf",
mime_type="application/pdf",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=FIXED_UPLOAD_CREATED_AT,
used=False,
)
upload_file.id = file_id
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
def test_rename_document_success(db_session_with_containers, mock_env):
"""Rename succeeds and returns the renamed document identity by id."""
# Arrange
dataset_id = str(uuid4())
document_id = str(uuid4())
new_name = "New Document Name"
dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id)
document = make_document(
db_session_with_containers,
document_id=document_id,
dataset_id=dataset_id,
tenant_id=mock_env["current_user"].current_tenant_id,
)
# Act
result = DocumentService.rename_document(dataset.id, document_id, new_name)
# Assert
db_session_with_containers.refresh(document)
assert result.id == document.id
assert document.name == new_name
def test_rename_document_with_built_in_fields(db_session_with_containers, mock_env):
"""Built-in document_name metadata is updated while existing metadata keys are preserved."""
# Arrange
dataset_id = str(uuid4())
document_id = str(uuid4())
new_name = "Renamed"
dataset = make_dataset(
db_session_with_containers,
dataset_id,
mock_env["current_user"].current_tenant_id,
built_in_field_enabled=True,
)
document = make_document(
db_session_with_containers,
document_id=document_id,
dataset_id=dataset.id,
tenant_id=mock_env["current_user"].current_tenant_id,
doc_metadata={"foo": "bar"},
)
# Act
DocumentService.rename_document(dataset.id, document.id, new_name)
# Assert
db_session_with_containers.refresh(document)
assert document.name == new_name
assert document.doc_metadata["document_name"] == new_name
assert document.doc_metadata["foo"] == "bar"
def test_rename_document_updates_upload_file_when_present(db_session_with_containers, mock_env):
"""Rename propagates to UploadFile.name when upload_file_id is present in data_source_info."""
# Arrange
dataset_id = str(uuid4())
document_id = str(uuid4())
file_id = str(uuid4())
new_name = "Renamed"
dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id)
document = make_document(
db_session_with_containers,
document_id=document_id,
dataset_id=dataset.id,
tenant_id=mock_env["current_user"].current_tenant_id,
data_source_info={"upload_file_id": file_id},
)
upload_file = make_upload_file(
db_session_with_containers,
tenant_id=mock_env["current_user"].current_tenant_id,
file_id=file_id,
name="old.pdf",
)
# Act
DocumentService.rename_document(dataset.id, document.id, new_name)
# Assert
db_session_with_containers.refresh(document)
db_session_with_containers.refresh(upload_file)
assert document.name == new_name
assert upload_file.name == new_name
def test_rename_document_does_not_update_upload_file_when_missing_id(db_session_with_containers, mock_env):
"""Rename does not update UploadFile when data_source_info lacks upload_file_id."""
# Arrange
dataset_id = str(uuid4())
document_id = str(uuid4())
new_name = "Another Name"
dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id)
document = make_document(
db_session_with_containers,
document_id=document_id,
dataset_id=dataset.id,
tenant_id=mock_env["current_user"].current_tenant_id,
data_source_info={"url": "https://example.com"},
)
untouched_file = make_upload_file(
db_session_with_containers,
tenant_id=mock_env["current_user"].current_tenant_id,
file_id=str(uuid4()),
name="untouched.pdf",
)
# Act
DocumentService.rename_document(dataset.id, document.id, new_name)
# Assert
db_session_with_containers.refresh(document)
db_session_with_containers.refresh(untouched_file)
assert document.name == new_name
assert untouched_file.name == "untouched.pdf"
def test_rename_document_dataset_not_found(db_session_with_containers, mock_env):
"""Rename raises Dataset not found when dataset id does not exist."""
# Arrange
missing_dataset_id = str(uuid4())
# Act / Assert
with pytest.raises(ValueError, match="Dataset not found"):
DocumentService.rename_document(missing_dataset_id, str(uuid4()), "x")
def test_rename_document_not_found(db_session_with_containers, mock_env):
"""Rename raises Document not found when document id is absent in the dataset."""
# Arrange
dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id)
# Act / Assert
with pytest.raises(ValueError, match="Document not found"):
DocumentService.rename_document(dataset.id, str(uuid4()), "x")
def test_rename_document_permission_denied_when_tenant_mismatch(db_session_with_containers, mock_env):
"""Rename raises No permission when document tenant differs from current_user tenant."""
# Arrange
dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id)
document = make_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=str(uuid4()),
)
# Act / Assert
with pytest.raises(ValueError, match="No permission"):
DocumentService.rename_document(dataset.id, document.id, "x")

View File

@@ -0,0 +1,233 @@
import datetime
import json
import uuid
from decimal import Decimal
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
class TestAppMessageExportServiceIntegration:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers: Session):
yield
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@staticmethod
def _create_app_context(session: Session) -> tuple[App, Conversation]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="tester",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
session.add(tenant)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(join)
session.flush()
app = App(
tenant_id=tenant.id,
name="export-app",
description="integration test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-4o-mini",
mode="chat",
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
session.commit()
return app, conversation
@staticmethod
def _create_message(
session: Session,
app: App,
conversation: Conversation,
created_at: datetime.datetime,
*,
query: str,
answer: str,
inputs: dict,
message_metadata: str | None,
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-4o-mini",
inputs=inputs,
query=query,
answer=answer,
message=[{"role": "assistant", "content": answer}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)
session.add(message)
session.flush()
return message
def test_iter_records_with_stats(self, db_session_with_containers: Session):
app, conversation = self._create_app_context(db_session_with_containers)
first_inputs = {
"plain": "v1",
"nested": {"a": 1, "b": [1, {"x": True}]},
"list": ["x", 2, {"y": "z"}],
}
second_inputs = {"other": "value", "items": [1, 2, 3]}
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
first_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time,
query="q1",
answer="a1",
inputs=first_inputs,
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
)
second_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time + datetime.timedelta(minutes=1),
query="q2",
answer="a2",
inputs=second_inputs,
message_metadata=None,
)
user_feedback_1 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
content="first",
from_end_user_id=conversation.from_end_user_id,
)
user_feedback_2 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
content="second",
from_end_user_id=conversation.from_end_user_id,
)
admin_feedback = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
db_session_with_containers.commit()
service = AppMessageExportService(
app_id=app.id,
start_from=base_time - datetime.timedelta(minutes=1),
end_before=base_time + datetime.timedelta(minutes=10),
filename="unused.jsonl.gz",
batch_size=1,
dry_run=True,
)
stats = AppMessageExportStats()
records = list(service._iter_records_with_stats(stats))
service._finalize_stats(stats)
assert len(records) == 2
assert records[0].message_id == first_message.id
assert records[1].message_id == second_message.id
assert records[0].inputs == first_inputs
assert records[1].inputs == second_inputs
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
assert records[1].retriever_resources == []
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
assert records[1].feedback == []
assert stats.batches == 2
assert stats.total_messages == 2
assert stats.messages_with_feedback == 1
assert stats.total_feedbacks == 2

View File

@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.indexing_runner import DocumentIsPausedError
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@@ -282,7 +283,7 @@ class TestDuplicateDocumentIndexingTasks:
return dataset, documents
def test_duplicate_document_indexing_task_success(
def _test_duplicate_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
@@ -324,7 +325,7 @@ class TestDuplicateDocumentIndexingTasks:
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 3
def test_duplicate_document_indexing_task_with_segment_cleanup(
def _test_duplicate_document_indexing_task_with_segment_cleanup(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
@@ -374,7 +375,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_duplicate_document_indexing_task_dataset_not_found(
def _test_duplicate_document_indexing_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
@@ -445,7 +446,7 @@ class TestDuplicateDocumentIndexingTasks:
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 2 # Only existing documents
def test_duplicate_document_indexing_task_indexing_runner_exception(
def _test_duplicate_document_indexing_task_indexing_runner_exception(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
@@ -486,7 +487,7 @@ class TestDuplicateDocumentIndexingTasks:
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
@@ -549,7 +550,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify indexing runner was not called due to early validation error
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
@@ -783,3 +784,90 @@ class TestDuplicateDocumentIndexingTasks:
document_ids=document_ids,
)
mock_queue.delete_task_key.assert_not_called()
def test_successful_duplicate_document_indexing(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test successful duplicate document indexing flow."""
self._test_duplicate_document_indexing_task_success(
db_session_with_containers, mock_external_service_dependencies
)
def test_duplicate_document_indexing_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test duplicate document indexing when dataset is not found."""
self._test_duplicate_document_indexing_task_dataset_not_found(
db_session_with_containers, mock_external_service_dependencies
)
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test duplicate document indexing with billing enabled and sandbox plan."""
self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
db_session_with_containers, mock_external_service_dependencies
)
def test_duplicate_document_indexing_with_billing_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test duplicate document indexing when billing limit is exceeded."""
self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
db_session_with_containers, mock_external_service_dependencies
)
def test_duplicate_document_indexing_runner_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test duplicate document indexing when IndexingRunner raises an error."""
self._test_duplicate_document_indexing_task_indexing_runner_exception(
db_session_with_containers, mock_external_service_dependencies
)
def _test_duplicate_document_indexing_task_document_is_paused(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test duplicate document indexing when document is paused."""
# Arrange
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
for document in documents:
document.is_paused = True
db_session_with_containers.add(document)
db_session_with_containers.commit()
document_ids = [doc.id for doc in documents]
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError(
"Document paused"
)
# Act
_duplicate_document_indexing_task(dataset.id, document_ids)
db_session_with_containers.expire_all()
# Assert
for doc_id in document_ids:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.is_paused is True
assert updated_document.indexing_status == "parsing"
assert updated_document.display_status == "paused"
assert updated_document.processing_started_at is not None
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_duplicate_document_indexing_document_is_paused(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test duplicate document indexing when document is paused."""
self._test_duplicate_document_indexing_task_document_is_paused(
db_session_with_containers, mock_external_service_dependencies
)
def test_duplicate_document_indexing_cleans_old_segments(
self, db_session_with_containers, mock_external_service_dependencies
):
"""Test that duplicate document indexing cleans old segments."""
self._test_duplicate_document_indexing_task_with_segment_cleanup(
db_session_with_containers, mock_external_service_dependencies
)

View File

@@ -1,176 +0,0 @@
from types import SimpleNamespace
from unittest.mock import Mock, create_autospec, patch
import pytest
from models import Account
from services.dataset_service import DocumentService
@pytest.fixture
def mock_env():
"""Patch dependencies used by DocumentService.rename_document.
Mocks:
- DatasetService.get_dataset
- DocumentService.get_document
- current_user (with current_tenant_id)
- db.session
"""
with (
patch("services.dataset_service.DatasetService.get_dataset") as get_dataset,
patch("services.dataset_service.DocumentService.get_document") as get_document,
patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user,
patch("extensions.ext_database.db.session") as db_session,
):
current_user.current_tenant_id = "tenant-123"
yield {
"get_dataset": get_dataset,
"get_document": get_document,
"current_user": current_user,
"db_session": db_session,
}
def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False):
return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled)
def make_document(
document_id="document-123",
dataset_id="dataset-123",
tenant_id="tenant-123",
name="Old Name",
data_source_info=None,
doc_metadata=None,
):
doc = Mock()
doc.id = document_id
doc.dataset_id = dataset_id
doc.tenant_id = tenant_id
doc.name = name
doc.data_source_info = data_source_info or {}
# property-like usage in code relies on a dict
doc.data_source_info_dict = dict(doc.data_source_info)
doc.doc_metadata = dict(doc_metadata or {})
return doc
def test_rename_document_success(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "New Document Name"
dataset = make_dataset(dataset_id)
document = make_document(document_id=document_id, dataset_id=dataset_id)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
result = DocumentService.rename_document(dataset_id, document_id, new_name)
assert result is document
assert document.name == new_name
mock_env["db_session"].add.assert_called_once_with(document)
mock_env["db_session"].commit.assert_called_once()
def test_rename_document_with_built_in_fields(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Renamed"
dataset = make_dataset(dataset_id, built_in_field_enabled=True)
document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"})
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
# BuiltInField.document_name == "document_name" in service code
assert document.doc_metadata["document_name"] == new_name
assert document.doc_metadata["foo"] == "bar"
def test_rename_document_updates_upload_file_when_present(mock_env):
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Renamed"
file_id = "file-123"
dataset = make_dataset(dataset_id)
document = make_document(
document_id=document_id,
dataset_id=dataset_id,
data_source_info={"upload_file_id": file_id},
)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
# Intercept UploadFile rename UPDATE chain
mock_query = Mock()
mock_query.where.return_value = mock_query
mock_env["db_session"].query.return_value = mock_query
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
mock_env["db_session"].query.assert_called() # update executed
def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env):
"""
When data_source_info_dict exists but does not contain "upload_file_id",
UploadFile should not be updated.
"""
dataset_id = "dataset-123"
document_id = "document-123"
new_name = "Another Name"
dataset = make_dataset(dataset_id)
# Ensure data_source_info_dict is truthy but lacks the key
document = make_document(
document_id=document_id,
dataset_id=dataset_id,
data_source_info={"url": "https://example.com"},
)
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
DocumentService.rename_document(dataset_id, document_id, new_name)
assert document.name == new_name
# Should NOT attempt to update UploadFile
mock_env["db_session"].query.assert_not_called()
def test_rename_document_dataset_not_found(mock_env):
mock_env["get_dataset"].return_value = None
with pytest.raises(ValueError, match="Dataset not found"):
DocumentService.rename_document("missing", "doc", "x")
def test_rename_document_not_found(mock_env):
dataset = make_dataset("dataset-123")
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = None
with pytest.raises(ValueError, match="Document not found"):
DocumentService.rename_document(dataset.id, "missing", "x")
def test_rename_document_permission_denied_when_tenant_mismatch(mock_env):
dataset = make_dataset("dataset-123")
# different tenant than current_user.current_tenant_id
document = make_document(dataset_id=dataset.id, tenant_id="tenant-other")
mock_env["get_dataset"].return_value = dataset
mock_env["get_document"].return_value = document
with pytest.raises(ValueError, match="No permission"):
DocumentService.rename_document(dataset.id, document.id, "x")

View File

@@ -1,158 +1,38 @@
"""
Unit tests for duplicate document indexing tasks.
This module tests the duplicate document indexing task functionality including:
- Task enqueuing to different queues (normal, priority, tenant-isolated)
- Batch processing of multiple duplicate documents
- Progress tracking through task lifecycle
- Error handling and retry mechanisms
- Cleanup of old document data before re-indexing
"""
"""Unit tests for queue/wrapper behaviors in duplicate document indexing tasks (non-database logic)."""
import uuid
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from models.dataset import Dataset, Document, DocumentSegment
from tasks.duplicate_document_indexing_task import (
_duplicate_document_indexing_task,
_duplicate_document_indexing_task_with_tenant_queue,
duplicate_document_indexing_task,
normal_duplicate_document_indexing_task,
priority_duplicate_document_indexing_task,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def tenant_id():
"""Generate a unique tenant ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def dataset_id():
"""Generate a unique dataset ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def document_ids():
"""Generate a list of document IDs for testing."""
return [str(uuid.uuid4()) for _ in range(3)]
@pytest.fixture
def mock_dataset(dataset_id, tenant_id):
"""Create a mock Dataset object."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@pytest.fixture
def mock_documents(document_ids, dataset_id):
"""Create mock Document objects."""
documents = []
for doc_id in document_ids:
doc = Mock(spec=Document)
doc.id = doc_id
doc.dataset_id = dataset_id
doc.indexing_status = "waiting"
doc.error = None
doc.stopped_at = None
doc.processing_started_at = None
doc.doc_form = "text_model"
documents.append(doc)
return documents
@pytest.fixture
def mock_document_segments(document_ids):
"""Create mock DocumentSegment objects."""
segments = []
for doc_id in document_ids:
for i in range(3):
segment = Mock(spec=DocumentSegment)
segment.id = str(uuid.uuid4())
segment.document_id = doc_id
segment.index_node_id = f"node-{doc_id}-{i}"
segments.append(segment)
return segments
@pytest.fixture
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
with patch("tasks.duplicate_document_indexing_task.session_factory", autospec=True) as mock_sf:
session = MagicMock()
# Allow tests to observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
def mock_indexing_runner():
"""Mock IndexingRunner."""
with patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_runner_class:
mock_runner = MagicMock(spec=IndexingRunner)
mock_runner_class.return_value = mock_runner
yield mock_runner
@pytest.fixture
def mock_feature_service():
"""Mock FeatureService."""
with patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_service:
mock_features = Mock()
mock_features.billing = Mock()
mock_features.billing.enabled = False
mock_features.vector_space = Mock()
mock_features.vector_space.size = 0
mock_features.vector_space.limit = 1000
mock_service.get_features.return_value = mock_features
yield mock_service
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True) as mock_factory:
mock_processor = MagicMock()
mock_processor.clean = Mock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
yield mock_factory
@pytest.fixture
def mock_tenant_isolated_queue():
"""Mock TenantIsolatedTaskQueue."""
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) as mock_queue_class:
mock_queue = MagicMock(spec=TenantIsolatedTaskQueue)
mock_queue = Mock(spec=TenantIsolatedTaskQueue)
mock_queue.pull_tasks.return_value = []
mock_queue.delete_task_key = Mock()
mock_queue.set_task_waiting_time = Mock()
@@ -160,11 +40,6 @@ def mock_tenant_isolated_queue():
yield mock_queue
# ============================================================================
# Tests for deprecated duplicate_document_indexing_task
# ============================================================================
class TestDuplicateDocumentIndexingTask:
"""Tests for the deprecated duplicate_document_indexing_task function."""
@@ -190,258 +65,6 @@ class TestDuplicateDocumentIndexingTask:
mock_core_func.assert_called_once_with(dataset_id, document_ids)
# ============================================================================
# Tests for _duplicate_document_indexing_task core function
# ============================================================================
class TestDuplicateDocumentIndexingTaskCore:
"""Tests for the _duplicate_document_indexing_task core function."""
def test_successful_duplicate_document_indexing(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
mock_document_segments,
dataset_id,
document_ids,
):
"""Test successful duplicate document indexing flow."""
# Arrange
# Dataset via query.first()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# scalars() call sequence:
# 1) documents list
# 2..N) segments per document
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
# First call returns documents; subsequent calls return segments
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Verify IndexingRunner was called
mock_indexing_runner.run.assert_called_once()
# Verify all documents were set to parsing status
for doc in mock_documents:
assert doc.indexing_status == "parsing"
assert doc.processing_started_at is not None
# Verify session operations
assert mock_db_session.commit.called
assert mock_db_session.close.called
def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids):
"""Test duplicate document indexing when dataset is not found."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should close the session at least once
assert mock_db_session.close.called
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
self,
mock_db_session,
mock_feature_service,
mock_dataset,
dataset_id,
document_ids,
):
"""Test duplicate document indexing with billing enabled and sandbox plan."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
mock_features = mock_feature_service.get_features.return_value
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# For sandbox plan with multiple documents, should fail
mock_db_session.commit.assert_called()
def test_duplicate_document_indexing_with_billing_limit_exceeded(
self,
mock_db_session,
mock_feature_service,
mock_dataset,
mock_documents,
dataset_id,
document_ids,
):
"""Test duplicate document indexing when billing limit is exceeded."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# First scalars() -> documents; subsequent -> empty segments
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_features = mock_feature_service.get_features.return_value
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.TEAM
mock_features.vector_space.size = 990
mock_features.vector_space.limit = 1000
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should commit the session
assert mock_db_session.commit.called
# Should close the session
assert mock_db_session.close.called
def test_duplicate_document_indexing_runner_error(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
dataset_id,
document_ids,
):
"""Test duplicate document indexing when IndexingRunner raises an error."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should close the session even after error
mock_db_session.close.assert_called_once()
def test_duplicate_document_indexing_document_is_paused(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
dataset_id,
document_ids,
):
"""Test duplicate document indexing when document is paused."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Should handle DocumentIsPausedError gracefully
mock_db_session.close.assert_called_once()
def test_duplicate_document_indexing_cleans_old_segments(
self,
mock_db_session,
mock_indexing_runner,
mock_feature_service,
mock_index_processor_factory,
mock_dataset,
mock_documents,
mock_document_segments,
dataset_id,
document_ids,
):
"""Test that duplicate document indexing cleans old segments."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
# Assert
# Verify clean was called for each document
assert mock_processor.clean.call_count == len(mock_documents)
# Verify segments were deleted in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# ============================================================================
# Tests for tenant queue wrapper function
# ============================================================================
class TestDuplicateDocumentIndexingTaskWithTenantQueue:
"""Tests for _duplicate_document_indexing_task_with_tenant_queue function."""
@@ -536,11 +159,6 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
mock_tenant_isolated_queue.pull_tasks.assert_called_once()
# ============================================================================
# Tests for normal_duplicate_document_indexing_task
# ============================================================================
class TestNormalDuplicateDocumentIndexingTask:
"""Tests for normal_duplicate_document_indexing_task function."""
@@ -581,11 +199,6 @@ class TestNormalDuplicateDocumentIndexingTask:
)
# ============================================================================
# Tests for priority_duplicate_document_indexing_task
# ============================================================================
class TestPriorityDuplicateDocumentIndexingTask:
"""Tests for priority_duplicate_document_indexing_task function."""

12
api/uv.lock generated
View File

@@ -441,14 +441,14 @@ wheels = [
[[package]]
name = "authlib"
version = "1.6.6"
version = "1.6.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bb/9b/b1661026ff24bc641b76b78c5222d614776b0c085bcfdac9bd15a1cb4b35/authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e", size = 164894, upload-time = "2025-12-12T08:01:41.464Z" }
sdist = { url = "https://files.pythonhosted.org/packages/49/dc/ed1681bf1339dd6ea1ce56136bad4baabc6f7ad466e375810702b0237047/authlib-1.6.7.tar.gz", hash = "sha256:dbf10100011d1e1b34048c9d120e83f13b35d69a826ae762b93d2fb5aafc337b", size = 164950, upload-time = "2026-02-06T14:04:14.171Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" },
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
]
[[package]]
@@ -1989,11 +1989,11 @@ wheels = [
[[package]]
name = "fickling"
version = "0.1.8"
version = "0.1.9"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/88/be/cd91e3921f064230ac9462479e4647fb91a7b0d01677103fce89f52e3042/fickling-0.1.8.tar.gz", hash = "sha256:25a0bc7acda76176a9087b405b05f7f5021f76079aa26c6fe3270855ec57d9bf", size = 336756, upload-time = "2026-02-21T00:57:26.106Z" }
sdist = { url = "https://files.pythonhosted.org/packages/25/bd/ca7127df0201596b0b30f9ab3d36e565bb9d6f8f4da1560758b817e81b65/fickling-0.1.9.tar.gz", hash = "sha256:bb518c2fd833555183bc46b6903bb4022f3ae0436a69c3fb149cfc75eebaac33", size = 336940, upload-time = "2026-03-03T23:32:19.449Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/02/92/af72f783ac57fa2452f8f921c9441366c42ae1f03f5af41718445114c82f/fickling-0.1.8-py3-none-any.whl", hash = "sha256:97218785cfe00a93150808dcf9e3eb512371e0484e3ce0b05bc460b97240f292", size = 52613, upload-time = "2026-02-21T00:57:24.82Z" },
{ url = "https://files.pythonhosted.org/packages/92/49/c597bad508c74917901432b41ae5a8f036839a7fb8d0d29a89765f5d3643/fickling-0.1.9-py3-none-any.whl", hash = "sha256:ccc3ce3b84733406ade2fe749717f6e428047335157c6431eefd3e7e970a06d1", size = 52786, upload-time = "2026-03-03T23:32:17.533Z" },
]
[[package]]

View File

@@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => {
})
})
// ─── 6. Close Handling ───────────────────────────────────────────────────
describe('Close handling', () => {
it('should call onCancel when pressing ESC key', () => {
render(<Pricing onCancel={onCancel} />)
// ahooks useKeyPress listens on document for keydown events
document.dispatchEvent(new KeyboardEvent('keydown', {
key: 'Escape',
code: 'Escape',
keyCode: 27,
bubbles: true,
}))
expect(onCancel).toHaveBeenCalledTimes(1)
})
})
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
describe('Pricing page URL', () => {
it('should render pricing link with correct URL', () => {
render(<Pricing onCancel={onCancel} />)

View File

@@ -8,7 +8,7 @@ import GotoAnything from '@/app/components/goto-anything'
import Header from '@/app/components/header'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import ReadmePanel from '@/app/components/plugins/readme-panel'
import { AppContextProvider } from '@/context/app-context'
import { AppContextProvider } from '@/context/app-context-provider'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ModalContextProvider } from '@/context/modal-context'
import { ProviderContextProvider } from '@/context/provider-context'

View File

@@ -4,7 +4,7 @@ import { AppInitializer } from '@/app/components/app-initializer'
import AmplitudeProvider from '@/app/components/base/amplitude'
import GA, { GaType } from '@/app/components/base/ga'
import HeaderWrapper from '@/app/components/header/header-wrapper'
import { AppContextProvider } from '@/context/app-context'
import { AppContextProvider } from '@/context/app-context-provider'
import { EventEmitterContextProvider } from '@/context/event-emitter'
import { ModalContextProvider } from '@/context/modal-context'
import { ProviderContextProvider } from '@/context/provider-context'

View File

@@ -2,7 +2,7 @@
import Loading from '@/app/components/base/loading'
import Header from '@/app/signin/_header'
import { AppContextProvider } from '@/context/app-context'
import { AppContextProvider } from '@/context/app-context-provider'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import { useIsLogin } from '@/service/use-common'
@@ -38,7 +38,7 @@ export default function SignInLayout({ children }: any) {
</div>
</div>
{systemFeatures.branding.enabled === false && (
<div className="system-xs-regular px-8 py-6 text-text-tertiary">
<div className="px-8 py-6 text-text-tertiary system-xs-regular">
©
{' '}
{new Date().getFullYear()}

View File

@@ -26,11 +26,10 @@ export const AppInitializer = ({
// Tokens are now stored in cookies, no need to check localStorage
const pathname = usePathname()
const [init, setInit] = useState(false)
const [oauthNewUser, setOauthNewUser] = useQueryState(
const [oauthNewUser] = useQueryState(
'oauth_new_user',
parseAsBoolean.withOptions({ history: 'replace' }),
)
const isSetupFinished = useCallback(async () => {
try {
const setUpStatus = await fetchSetupStatusWithCache()
@@ -69,11 +68,12 @@ export const AppInitializer = ({
...utmInfo,
})
// Clean up: remove utm_info cookie and URL params
Cookies.remove('utm_info')
setOauthNewUser(null)
}
if (oauthNewUser !== null)
router.replace(pathname)
if (action === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION)
localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes')
@@ -96,7 +96,7 @@ export const AppInitializer = ({
router.replace('/signin')
}
})()
}, [isSetupFinished, router, pathname, searchParams, oauthNewUser, setOauthNewUser])
}, [isSetupFinished, router, pathname, searchParams, oauthNewUser])
return init ? children : null
}

View File

@@ -0,0 +1,4 @@
<svg width="10" height="10" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z" fill="#676F83"/>
<path d="M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z" fill="#676F83"/>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -0,0 +1,35 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "10",
"height": "10",
"viewBox": "0 0 10 10",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z",
"fill": "currentColor"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z",
"fill": "currentColor"
},
"children": []
}
]
},
"name": "CreditsCoin"
}

View File

@@ -0,0 +1,20 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import type { IconData } from '@/app/components/base/icons/IconBase'
import * as React from 'react'
import IconBase from '@/app/components/base/icons/IconBase'
import data from './CreditsCoin.json'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'CreditsCoin'
export default Icon

View File

@@ -1,5 +1,6 @@
export { default as Balance } from './Balance'
export { default as CoinsStacked01 } from './CoinsStacked01'
export { default as CreditsCoin } from './CreditsCoin'
export { default as GoldCoin } from './GoldCoin'
export { default as ReceiptList } from './ReceiptList'
export { default as Tag01 } from './Tag01'

View File

@@ -43,20 +43,24 @@ type DialogContentProps = {
children: React.ReactNode
className?: string
overlayClassName?: string
backdropProps?: React.ComponentPropsWithoutRef<typeof BaseDialog.Backdrop>
}
export function DialogContent({
children,
className,
overlayClassName,
backdropProps,
}: DialogContentProps) {
return (
<DialogPortal>
<BaseDialog.Backdrop
{...backdropProps}
className={cn(
'fixed inset-0 z-50 bg-background-overlay',
'transition-opacity duration-150 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none',
overlayClassName,
backdropProps?.className,
)}
/>
<BaseDialog.Popup

View File

@@ -1,7 +1,7 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import { CategoryEnum } from '..'
import Footer from '../footer'
import { CategoryEnum } from '../types'
vi.mock('next/link', () => ({
default: ({ children, href, className, target }: { children: React.ReactNode, href: string, className?: string, target?: string }) => (

View File

@@ -1,7 +1,16 @@
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { Dialog } from '@/app/components/base/ui/dialog'
import Header from '../header'
function renderHeader(onClose: () => void) {
return render(
<Dialog open>
<Header onClose={onClose} />
</Dialog>,
)
}
describe('Header', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -11,7 +20,7 @@ describe('Header', () => {
it('should render title and description translations', () => {
const handleClose = vi.fn()
render(<Header onClose={handleClose} />)
renderHeader(handleClose)
expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument()
expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument()
@@ -22,7 +31,7 @@ describe('Header', () => {
describe('Props', () => {
it('should invoke onClose when close button is clicked', () => {
const handleClose = vi.fn()
render(<Header onClose={handleClose} />)
renderHeader(handleClose)
fireEvent.click(screen.getByRole('button'))
@@ -32,7 +41,7 @@ describe('Header', () => {
describe('Edge Cases', () => {
it('should render structural elements with translation keys', () => {
const { container } = render(<Header onClose={vi.fn()} />)
const { container } = renderHeader(vi.fn())
expect(container.querySelector('span')).toBeInTheDocument()
expect(container.querySelector('p')).toBeInTheDocument()

View File

@@ -74,15 +74,11 @@ describe('Pricing', () => {
})
describe('Props', () => {
it('should allow switching categories and handle esc key', () => {
const handleCancel = vi.fn()
render(<Pricing onCancel={handleCancel} />)
it('should allow switching categories', () => {
render(<Pricing onCancel={vi.fn()} />)
fireEvent.click(screen.getByText('billing.plansCommon.self'))
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
fireEvent.keyDown(window, { key: 'Escape', keyCode: 27 })
expect(handleCancel).toHaveBeenCalled()
})
})

View File

@@ -1,10 +1,9 @@
import type { Category } from '.'
import { RiArrowRightUpLine } from '@remixicon/react'
import type { Category } from './types'
import Link from 'next/link'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'
import { CategoryEnum } from '.'
import { CategoryEnum } from './types'
type FooterProps = {
pricingPageURL: string
@@ -34,7 +33,7 @@ const Footer = ({
>
{t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })}
</Link>
<RiArrowRightUpLine className="size-4" />
<span aria-hidden="true" className="i-ri-arrow-right-up-line size-4" />
</span>
</div>
</div>

View File

@@ -1,6 +1,6 @@
import { RiCloseLine } from '@remixicon/react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog'
import Button from '../../base/button'
import DifyLogo from '../../base/logo/dify-logo'
@@ -20,19 +20,19 @@ const Header = ({
<div className="py-[5px]">
<DifyLogo className="h-[27px] w-[60px]" />
</div>
<span className="bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
<DialogTitle className="m-0 bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
{t('plansCommon.title.plans', { ns: 'billing' })}
</span>
</DialogTitle>
</div>
<p className="system-sm-regular text-text-tertiary">
<DialogDescription className="m-0 text-text-tertiary system-sm-regular">
{t('plansCommon.title.description', { ns: 'billing' })}
</p>
</DialogDescription>
<Button
variant="secondary"
className="absolute bottom-[40.5px] right-[-18px] z-10 size-9 rounded-full p-2"
onClick={onClose}
>
<RiCloseLine className="size-5" />
<span aria-hidden="true" className="i-ri-close-line size-5" />
</Button>
</div>
</div>

View File

@@ -1,9 +1,9 @@
'use client'
import type { FC } from 'react'
import { useKeyPress } from 'ahooks'
import type { Category } from './types'
import * as React from 'react'
import { useState } from 'react'
import { createPortal } from 'react-dom'
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
import { useAppContext } from '@/context/app-context'
import { useGetPricingPageLanguage } from '@/context/i18n'
import { useProviderContext } from '@/context/provider-context'
@@ -13,13 +13,7 @@ import Header from './header'
import PlanSwitcher from './plan-switcher'
import { PlanRange } from './plan-switcher/plan-range-switcher'
import Plans from './plans'
export enum CategoryEnum {
CLOUD = 'cloud',
SELF = 'self',
}
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF
import { CategoryEnum } from './types'
type PricingProps = {
onCancel: () => void
@@ -33,42 +27,47 @@ const Pricing: FC<PricingProps> = ({
const [planRange, setPlanRange] = React.useState<PlanRange>(PlanRange.monthly)
const [currentCategory, setCurrentCategory] = useState<Category>(CategoryEnum.CLOUD)
const canPay = isCurrentWorkspaceManager
useKeyPress(['esc'], onCancel)
const pricingPageLanguage = useGetPricingPageLanguage()
const pricingPageURL = pricingPageLanguage
? `https://dify.ai/${pricingPageLanguage}/pricing#plans-and-features`
: 'https://dify.ai/pricing#plans-and-features'
return createPortal(
<div
className="fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] overflow-auto bg-saas-background"
onClick={e => e.stopPropagation()}
return (
<Dialog
open
onOpenChange={(open) => {
if (!open)
onCancel()
}}
>
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
<div className="absolute -top-12 left-0 right-0 -z-10">
<NoiseTop />
<DialogContent
className="inset-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-auto rounded-none border-none bg-saas-background p-0 shadow-none"
>
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
<div className="absolute -top-12 left-0 right-0 -z-10">
<NoiseTop />
</div>
<Header onClose={onCancel} />
<PlanSwitcher
currentCategory={currentCategory}
onChangeCategory={setCurrentCategory}
currentPlanRange={planRange}
onChangePlanRange={setPlanRange}
/>
<Plans
plan={plan}
currentPlan={currentCategory}
planRange={planRange}
canPay={canPay}
/>
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
<div className="absolute -bottom-12 left-0 right-0 -z-10">
<NoiseBottom />
</div>
</div>
<Header onClose={onCancel} />
<PlanSwitcher
currentCategory={currentCategory}
onChangeCategory={setCurrentCategory}
currentPlanRange={planRange}
onChangePlanRange={setPlanRange}
/>
<Plans
plan={plan}
currentPlan={currentCategory}
planRange={planRange}
canPay={canPay}
/>
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
<div className="absolute -bottom-12 left-0 right-0 -z-10">
<NoiseBottom />
</div>
</div>
</div>,
document.body,
</DialogContent>
</Dialog>
)
}
export default React.memo(Pricing)

View File

@@ -1,6 +1,6 @@
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { CategoryEnum } from '../../index'
import { CategoryEnum } from '../../types'
import PlanSwitcher from '../index'
import { PlanRange } from '../plan-range-switcher'

View File

@@ -1,5 +1,5 @@
import type { FC } from 'react'
import type { Category } from '../index'
import type { Category } from '../types'
import type { PlanRange } from './plan-range-switcher'
import * as React from 'react'
import { useTranslation } from 'react-i18next'

View File

@@ -0,0 +1,6 @@
export enum CategoryEnum {
CLOUD = 'cloud',
SELF = 'self',
}
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF

View File

@@ -100,10 +100,10 @@ vi.mock('@/app/components/datasets/create/step-two', () => ({
}))
vi.mock('@/app/components/header/account-setting', () => ({
default: ({ activeTab, onCancel }: { activeTab?: string, onCancel?: () => void }) => (
default: ({ activeTab, onCancelAction }: { activeTab?: string, onCancelAction?: () => void }) => (
<div data-testid="account-setting">
<span data-testid="active-tab">{activeTab}</span>
<button onClick={onCancel} data-testid="close-setting">Close</button>
<button onClick={onCancelAction} data-testid="close-setting">Close</button>
</div>
),
}))

View File

@@ -1,3 +1,4 @@
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
import type { DataSourceProvider, NotionPage } from '@/models/common'
import type {
CrawlOptions,
@@ -19,6 +20,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable'
import Loading from '@/app/components/base/loading'
import StepTwo from '@/app/components/datasets/create/step-two'
import AccountSetting from '@/app/components/header/account-setting'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import DatasetDetailContext from '@/context/dataset-detail'
@@ -33,8 +35,13 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
const { t } = useTranslation()
const router = useRouter()
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
const [accountSettingTab, setAccountSettingTab] = React.useState<AccountSettingTab>(ACCOUNT_SETTING_TAB.PROVIDER)
const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
const handleOpenAccountSetting = React.useCallback(() => {
setAccountSettingTab(ACCOUNT_SETTING_TAB.PROVIDER)
showSetAPIKey()
}, [showSetAPIKey])
const invalidDocumentList = useInvalidDocumentList(datasetId)
const invalidDocumentDetail = useInvalidDocumentDetail()
@@ -135,7 +142,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
{dataset && documentDetail && (
<StepTwo
isAPIKeySet={!!embeddingsDefaultModel}
onSetting={showSetAPIKey}
onSetting={handleOpenAccountSetting}
datasetId={datasetId}
dataSourceType={documentDetail.data_source_type as DataSourceType}
notionPages={currentPage ? [currentPage as unknown as NotionPage] : []}
@@ -155,8 +162,9 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
</div>
{isShowSetAPIKey && (
<AccountSetting
activeTab="provider"
onCancel={async () => {
activeTab={accountSettingTab}
onTabChangeAction={setAccountSettingTab}
onCancelAction={async () => {
hideSetAPIkey()
}}
/>

View File

@@ -1,12 +1,16 @@
import type { AccountSettingTab } from './constants'
import type { AppContextValue } from '@/context/app-context'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { fireEvent, render, screen } from '@testing-library/react'
import { useState } from 'react'
import { useAppContext } from '@/context/app-context'
import { baseProviderContextValue, useProviderContext } from '@/context/provider-context'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { ACCOUNT_SETTING_TAB } from './constants'
import AccountSetting from './index'
const mockResetModelProviderListExpanded = vi.fn()
vi.mock('@/context/provider-context', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/context/provider-context')>()
return {
@@ -47,10 +51,15 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', ()
useDefaultModel: vi.fn(() => ({ data: null, isLoading: false })),
useUpdateDefaultModel: vi.fn(() => ({ trigger: vi.fn() })),
useUpdateModelList: vi.fn(() => vi.fn()),
useInvalidateDefaultModel: vi.fn(() => vi.fn()),
useModelList: vi.fn(() => ({ data: [], isLoading: false })),
useSystemDefaultModelAndModelList: vi.fn(() => [null, vi.fn()]),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/atoms', () => ({
useResetModelProviderListExpanded: () => mockResetModelProviderListExpanded,
}))
vi.mock('@/service/use-datasource', () => ({
useGetDataSourceListAuth: vi.fn(() => ({ data: { result: [] } })),
}))
@@ -105,6 +114,38 @@ const baseAppContextValue: AppContextValue = {
describe('AccountSetting', () => {
const mockOnCancel = vi.fn()
const mockOnTabChange = vi.fn()
const renderAccountSetting = (props?: {
initialTab?: AccountSettingTab
onCancel?: () => void
onTabChange?: (tab: AccountSettingTab) => void
}) => {
const {
initialTab = ACCOUNT_SETTING_TAB.MEMBERS,
onCancel = mockOnCancel,
onTabChange = mockOnTabChange,
} = props ?? {}
const StatefulAccountSetting = () => {
const [activeTab, setActiveTab] = useState<AccountSettingTab>(initialTab)
return (
<AccountSetting
onCancelAction={onCancel}
activeTab={activeTab}
onTabChangeAction={(tab) => {
setActiveTab(tab)
onTabChange(tab)
}}
/>
)
}
return render(
<QueryClientProvider client={new QueryClient()}>
<StatefulAccountSetting />
</QueryClientProvider>,
)
}
beforeEach(() => {
vi.clearAllMocks()
@@ -120,11 +161,7 @@ describe('AccountSetting', () => {
describe('Rendering', () => {
it('should render the sidebar with correct menu items', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
expect(screen.getByText('common.userProfile.settings')).toBeInTheDocument()
@@ -137,13 +174,9 @@ describe('AccountSetting', () => {
expect(screen.getAllByText('common.settings.language').length).toBeGreaterThan(0)
})
it('should respect the activeTab prop', () => {
it('should respect the initial tab', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} activeTab={ACCOUNT_SETTING_TAB.DATA_SOURCE} />
</QueryClientProvider>,
)
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.DATA_SOURCE })
// Assert
// Check that the active item title is Data Source
@@ -157,11 +190,7 @@ describe('AccountSetting', () => {
vi.mocked(useBreakpoints).mockReturnValue(MediaType.mobile)
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
// On mobile, the labels should not be rendered as per the implementation
@@ -176,11 +205,7 @@ describe('AccountSetting', () => {
})
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
expect(screen.queryByText('common.settings.provider')).not.toBeInTheDocument()
@@ -197,11 +222,7 @@ describe('AccountSetting', () => {
})
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
expect(screen.queryByText('common.settings.billing')).not.toBeInTheDocument()
@@ -212,11 +233,7 @@ describe('AccountSetting', () => {
describe('Tab Navigation', () => {
it('should change active tab when clicking on menu item', () => {
// Arrange
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} onTabChange={mockOnTabChange} />
</QueryClientProvider>,
)
renderAccountSetting({ onTabChange: mockOnTabChange })
// Act
fireEvent.click(screen.getByText('common.settings.provider'))
@@ -229,11 +246,7 @@ describe('AccountSetting', () => {
it('should navigate through various tabs and show correct details', () => {
// Act & Assert
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Billing
fireEvent.click(screen.getByText('common.settings.billing'))
@@ -267,13 +280,11 @@ describe('AccountSetting', () => {
describe('Interactions', () => {
it('should call onCancel when clicking close button', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
const buttons = screen.getAllByRole('button')
fireEvent.click(buttons[0])
renderAccountSetting()
const closeIcon = document.querySelector('.i-ri-close-line')
const closeButton = closeIcon?.closest('button')
expect(closeButton).not.toBeNull()
fireEvent.click(closeButton!)
// Assert
expect(mockOnCancel).toHaveBeenCalled()
@@ -281,11 +292,7 @@ describe('AccountSetting', () => {
it('should call onCancel when pressing Escape key', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
fireEvent.keyDown(document, { key: 'Escape' })
// Assert
@@ -294,12 +301,7 @@ describe('AccountSetting', () => {
it('should update search value in provider tab', () => {
// Arrange
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
fireEvent.click(screen.getByText('common.settings.provider'))
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.PROVIDER })
// Act
const input = screen.getByRole('textbox')
@@ -312,11 +314,7 @@ describe('AccountSetting', () => {
it('should handle scroll event in panel', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto')
// Assert

View File

@@ -1,6 +1,6 @@
'use client'
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
import { useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import SearchInput from '@/app/components/base/search-input'
import BillingPage from '@/app/components/billing/billing-page'
@@ -20,15 +20,16 @@ import DataSourcePage from './data-source-page-new'
import LanguagePage from './language-page'
import MembersPage from './members-page'
import ModelProviderPage from './model-provider-page'
import { useResetModelProviderListExpanded } from './model-provider-page/atoms'
const iconClassName = `
w-5 h-5 mr-2
`
type IAccountSettingProps = {
onCancel: () => void
activeTab?: AccountSettingTab
onTabChange?: (tab: AccountSettingTab) => void
onCancelAction: () => void
activeTab: AccountSettingTab
onTabChangeAction: (tab: AccountSettingTab) => void
}
type GroupItem = {
@@ -40,14 +41,12 @@ type GroupItem = {
}
export default function AccountSetting({
onCancel,
activeTab = ACCOUNT_SETTING_TAB.MEMBERS,
onTabChange,
onCancelAction,
activeTab,
onTabChangeAction,
}: IAccountSettingProps) {
const [activeMenu, setActiveMenu] = useState<AccountSettingTab>(activeTab)
useEffect(() => {
setActiveMenu(activeTab)
}, [activeTab])
const resetModelProviderListExpanded = useResetModelProviderListExpanded()
const activeMenu = activeTab
const { t } = useTranslation()
const { enableBilling, enableReplaceWebAppLogo } = useProviderContext()
const { isCurrentWorkspaceDatasetOperator } = useAppContext()
@@ -148,10 +147,22 @@ export default function AccountSetting({
const [searchValue, setSearchValue] = useState<string>('')
const handleTabChange = useCallback((tab: AccountSettingTab) => {
if (tab === ACCOUNT_SETTING_TAB.PROVIDER)
resetModelProviderListExpanded()
onTabChangeAction(tab)
}, [onTabChangeAction, resetModelProviderListExpanded])
const handleClose = useCallback(() => {
resetModelProviderListExpanded()
onCancelAction()
}, [onCancelAction, resetModelProviderListExpanded])
return (
<MenuDialog
show
onClose={onCancel}
onClose={handleClose}
>
<div className="mx-auto flex h-[100vh] max-w-[1048px]">
<div className="flex w-[44px] flex-col border-r border-divider-burn pl-4 pr-6 sm:w-[224px]">
@@ -166,21 +177,22 @@ export default function AccountSetting({
<div>
{
menuItem.items.map(item => (
<div
<button
type="button"
key={item.key}
className={cn(
'mb-0.5 flex h-[37px] cursor-pointer items-center rounded-lg p-1 pl-3 text-sm',
'mb-0.5 flex h-[37px] w-full items-center rounded-lg p-1 pl-3 text-left text-sm',
activeMenu === item.key ? 'bg-state-base-active text-components-menu-item-text-active system-sm-semibold' : 'text-components-menu-item-text system-sm-medium',
)}
aria-label={item.name}
title={item.name}
onClick={() => {
setActiveMenu(item.key)
onTabChange?.(item.key)
handleTabChange(item.key)
}}
>
{activeMenu === item.key ? item.activeIcon : item.icon}
{!isMobile && <div className="truncate">{item.name}</div>}
</div>
</button>
))
}
</div>
@@ -195,7 +207,8 @@ export default function AccountSetting({
variant="tertiary"
size="large"
className="px-2"
onClick={onCancel}
aria-label={t('operation.close', { ns: 'common' })}
onClick={handleClose}
>
<span className="i-ri-close-line h-5 w-5" />
</Button>

View File

@@ -40,8 +40,7 @@ describe('MenuDialog', () => {
)
// Assert
const panel = screen.getByRole('dialog').querySelector('.custom-class')
expect(panel).toBeInTheDocument()
expect(screen.getByRole('dialog')).toHaveClass('custom-class')
})
})

View File

@@ -1,7 +1,6 @@
import type { ReactNode } from 'react'
import { Dialog, DialogPanel, Transition, TransitionChild } from '@headlessui/react'
import { noop } from 'es-toolkit/function'
import { Fragment, useCallback, useEffect } from 'react'
import { useCallback } from 'react'
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
import { cn } from '@/utils/classnames'
type DialogProps = {
@@ -19,42 +18,25 @@ const MenuDialog = ({
}: DialogProps) => {
const close = useCallback(() => onClose?.(), [onClose])
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
event.preventDefault()
close()
}
}
document.addEventListener('keydown', handleKeyDown)
return () => {
document.removeEventListener('keydown', handleKeyDown)
}
}, [close])
return (
<Transition appear show={show} as={Fragment}>
<Dialog as="div" className="relative z-[60]" onClose={noop}>
<div className="fixed inset-0">
<div className="flex min-h-full flex-col items-center justify-center">
<TransitionChild>
<DialogPanel className={cn(
'relative h-full w-full grow overflow-hidden bg-background-sidenav-bg p-0 text-left align-middle backdrop-blur-md transition-all',
'duration-300 ease-in data-[closed]:scale-95 data-[closed]:opacity-0',
'data-[enter]:scale-100 data-[enter]:opacity-100',
'data-[enter]:scale-95 data-[leave]:opacity-0',
className,
)}
>
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
{children}
</DialogPanel>
</TransitionChild>
</div>
</div>
</Dialog>
</Transition>
<Dialog
open={show}
onOpenChange={(open) => {
if (!open)
close()
}}
>
<DialogContent
overlayClassName="bg-transparent"
className={cn(
'left-0 top-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-hidden rounded-none border-none bg-background-sidenav-bg p-0 shadow-none backdrop-blur-md',
className,
)}
>
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
{children}
</DialogContent>
</Dialog>
)
}

View File

@@ -0,0 +1,64 @@
import { atom, useAtomValue, useSetAtom } from 'jotai'
import { selectAtom } from 'jotai/utils'
import { useCallback, useMemo } from 'react'
const modelProviderListExpandedAtom = atom<Record<string, boolean>>({})
const setModelProviderListExpandedAtom = atom(
null,
(get, set, params: { providerName: string, expanded: boolean }) => {
const { providerName, expanded } = params
const current = get(modelProviderListExpandedAtom)
if (expanded) {
if (current[providerName])
return
set(modelProviderListExpandedAtom, {
...current,
[providerName]: true,
})
return
}
if (!current[providerName])
return
const next = { ...current }
delete next[providerName]
set(modelProviderListExpandedAtom, next)
},
)
const resetModelProviderListExpandedAtom = atom(
null,
(_get, set) => {
set(modelProviderListExpandedAtom, {})
},
)
export function useModelProviderListExpanded(providerName: string) {
const selectedAtom = useMemo(
() => selectAtom(modelProviderListExpandedAtom, state => state[providerName] ?? false),
[providerName],
)
return useAtomValue(selectedAtom)
}
export function useSetModelProviderListExpanded(providerName: string) {
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
return useCallback((expanded: boolean) => {
setExpanded({ providerName, expanded })
}, [providerName, setExpanded])
}
export function useExpandModelProviderList() {
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
return useCallback((providerName: string) => {
setExpanded({ providerName, expanded: true })
}, [setExpanded])
}
export function useResetModelProviderListExpanded() {
return useSetAtom(resetModelProviderListExpandedAtom)
}

View File

@@ -9,6 +9,7 @@ import type {
} from './declarations'
import { act, renderHook, waitFor } from '@testing-library/react'
import { useLocale } from '@/context/i18n'
import { consoleQuery } from '@/service/client'
import { fetchDefaultModal, fetchModelList, fetchModelProviderCredentials } from '@/service/common'
import {
ConfigurationMethodEnum,
@@ -23,6 +24,7 @@ import {
useAnthropicBuyQuota,
useCurrentProviderAndModel,
useDefaultModel,
useInvalidateDefaultModel,
useLanguage,
useMarketplaceAllPlugins,
useModelList,
@@ -36,7 +38,6 @@ import {
useUpdateModelList,
useUpdateModelProviders,
} from './hooks'
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
// Mock dependencies
vi.mock('@tanstack/react-query', () => ({
@@ -78,14 +79,6 @@ vi.mock('@/context/modal-context', () => ({
}),
}))
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: vi.fn(() => ({
eventEmitter: {
emit: vi.fn(),
},
})),
}))
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
useMarketplacePlugins: vi.fn(() => ({
plugins: [],
@@ -99,12 +92,16 @@ vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
})),
}))
vi.mock('./atoms', () => ({
useExpandModelProviderList: vi.fn(() => vi.fn()),
}))
const { useQuery, useQueryClient } = await import('@tanstack/react-query')
const { getPayUrl } = await import('@/service/common')
const { useProviderContext } = await import('@/context/provider-context')
const { useModalContextSelector } = await import('@/context/modal-context')
const { useEventEmitterContextContext } = await import('@/context/event-emitter')
const { useMarketplacePlugins, useMarketplacePluginsByCollectionId } = await import('@/app/components/plugins/marketplace/hooks')
const { useExpandModelProviderList } = await import('./atoms')
describe('hooks', () => {
beforeEach(() => {
@@ -864,6 +861,38 @@ describe('hooks', () => {
})
})
describe('useInvalidateDefaultModel', () => {
it('should invalidate default model queries', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
const { result } = renderHook(() => useInvalidateDefaultModel())
act(() => {
result.current(ModelTypeEnum.textGeneration)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: ['default-model', ModelTypeEnum.textGeneration],
})
})
it('should handle multiple model types', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
const { result } = renderHook(() => useInvalidateDefaultModel())
act(() => {
result.current(ModelTypeEnum.textGeneration)
result.current(ModelTypeEnum.textEmbedding)
result.current(ModelTypeEnum.rerank)
})
expect(invalidateQueries).toHaveBeenCalledTimes(3)
})
})
describe('useAnthropicBuyQuota', () => {
beforeEach(() => {
Object.defineProperty(window, 'location', {
@@ -1167,39 +1196,52 @@ describe('hooks', () => {
it('should refresh providers and model lists', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
const provider = createMockProvider()
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
act(() => {
result.current.handleRefreshModel(provider)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-providers'] })
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textEmbedding] })
})
it('should emit event when refreshModelList is true and custom config is active', () => {
it('should expand target provider list when refreshModelList is true and custom config is active', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
const expandModelProviderList = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
const provider = createMockProvider()
const customFields: CustomConfigurationModelFixedFields = {
__model_name: 'gpt-4',
__model_type: ModelTypeEnum.textGeneration,
}
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
@@ -1207,23 +1249,30 @@ describe('hooks', () => {
result.current.handleRefreshModel(provider, customFields, true)
})
expect(emit).toHaveBeenCalledWith({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: 'openai',
expect(expandModelProviderList).toHaveBeenCalledWith('openai')
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
})
it('should not emit event when custom config is not active', () => {
it('should not expand provider list when custom config is not active', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
const expandModelProviderList = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
const provider = { ...createMockProvider(), custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure } }
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
@@ -1231,16 +1280,43 @@ describe('hooks', () => {
result.current.handleRefreshModel(provider, undefined, true)
})
expect(emit).not.toHaveBeenCalled()
expect(expandModelProviderList).not.toHaveBeenCalled()
expect(invalidateQueries).not.toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
})
it('should refetch active model provider list when custom refresh callback is absent', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
const provider = createMockProvider()
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
act(() => {
result.current.handleRefreshModel(provider, undefined, true)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
})
it('should handle provider with single model type', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit: vi.fn() },
})
const provider = {
...createMockProvider(),

View File

@@ -21,10 +21,10 @@ import {
useMarketplacePluginsByCollectionId,
} from '@/app/components/plugins/marketplace/hooks'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useLocale } from '@/context/i18n'
import { useModalContextSelector } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
import { consoleQuery } from '@/service/client'
import {
fetchDefaultModal,
fetchModelList,
@@ -32,12 +32,12 @@ import {
getPayUrl,
} from '@/service/common'
import { commonQueryKeys } from '@/service/use-common'
import { useExpandModelProviderList } from './atoms'
import {
ConfigurationMethodEnum,
CustomConfigurationStatusEnum,
ModelStatusEnum,
} from './declarations'
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
type UseDefaultModelAndModelList = (
defaultModel: DefaultModelResponse | undefined,
@@ -222,6 +222,14 @@ export const useUpdateModelList = () => {
return updateModelList
}
export const useInvalidateDefaultModel = () => {
const queryClient = useQueryClient()
return useCallback((type: ModelTypeEnum) => {
queryClient.invalidateQueries({ queryKey: commonQueryKeys.defaultModel(type) })
}, [queryClient])
}
export const useAnthropicBuyQuota = () => {
const [loading, setLoading] = useState(false)
@@ -314,7 +322,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
}
export const useRefreshModel = () => {
const { eventEmitter } = useEventEmitterContextContext()
const expandModelProviderList = useExpandModelProviderList()
const queryClient = useQueryClient()
const updateModelProviders = useUpdateModelProviders()
const updateModelList = useUpdateModelList()
const handleRefreshModel = useCallback((
@@ -322,6 +331,19 @@ export const useRefreshModel = () => {
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
refreshModelList?: boolean,
) => {
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
queryClient.invalidateQueries({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
updateModelProviders()
provider.supported_model_types.forEach((type) => {
@@ -329,15 +351,17 @@ export const useRefreshModel = () => {
})
if (refreshModelList && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
eventEmitter?.emit({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: provider.provider,
} as any)
expandModelProviderList(provider.provider)
queryClient.invalidateQueries({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
if (CustomConfigurationModelFixedFields?.__model_type)
updateModelList(CustomConfigurationModelFixedFields.__model_type)
}
}, [eventEmitter, updateModelList, updateModelProviders])
}, [expandModelProviderList, queryClient, updateModelList, updateModelProviders])
return {
handleRefreshModel,

View File

@@ -7,16 +7,7 @@ import {
} from './declarations'
import ModelProviderPage from './index'
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
mutateCurrentWorkspace: vi.fn(),
isValidatingCurrentWorkspace: false,
}),
}))
const mockGlobalState = {
systemFeatures: { enable_marketplace: true },
}
let mockEnableMarketplace = true
const mockQuotaConfig = {
quota_type: CurrentSystemQuotaTypeEnum.free,
@@ -28,7 +19,11 @@ const mockQuotaConfig = {
}
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: (selector: (s: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector(mockGlobalState),
useSystemFeaturesQuery: () => ({
data: {
enable_marketplace: mockEnableMarketplace,
},
}),
}))
const mockProviders = [
@@ -60,13 +55,16 @@ vi.mock('@/context/provider-context', () => ({
}),
}))
const mockDefaultModelState = {
data: null,
isLoading: false,
const mockDefaultModels: Record<string, { data: unknown, isLoading: boolean }> = {
'llm': { data: null, isLoading: false },
'text-embedding': { data: null, isLoading: false },
'rerank': { data: null, isLoading: false },
'speech2text': { data: null, isLoading: false },
'tts': { data: null, isLoading: false },
}
vi.mock('./hooks', () => ({
useDefaultModel: () => mockDefaultModelState,
useDefaultModel: (type: string) => mockDefaultModels[type] ?? { data: null, isLoading: false },
}))
vi.mock('./install-from-marketplace', () => ({
@@ -85,13 +83,18 @@ vi.mock('./system-model-selector', () => ({
default: () => <div data-testid="system-model-selector" />,
}))
vi.mock('@/service/use-plugins', () => ({
useCheckInstalled: () => ({ data: undefined }),
}))
describe('ModelProviderPage', () => {
beforeEach(() => {
vi.useFakeTimers()
vi.clearAllMocks()
mockGlobalState.systemFeatures.enable_marketplace = true
mockDefaultModelState.data = null
mockDefaultModelState.isLoading = false
mockEnableMarketplace = true
Object.keys(mockDefaultModels).forEach((key) => {
mockDefaultModels[key] = { data: null, isLoading: false }
})
mockProviders.splice(0, mockProviders.length, {
provider: 'openai',
label: { en_US: 'OpenAI' },
@@ -149,13 +152,76 @@ describe('ModelProviderPage', () => {
})
it('should hide marketplace section when marketplace feature is disabled', () => {
mockGlobalState.systemFeatures.enable_marketplace = false
mockEnableMarketplace = false
render(<ModelProviderPage searchText="" />)
expect(screen.queryByTestId('install-from-marketplace')).not.toBeInTheDocument()
})
describe('system model config status', () => {
it('should not show top warning when no configured providers exist (empty state card handles it)', () => {
mockProviders.splice(0, mockProviders.length, {
provider: 'anthropic',
label: { en_US: 'Anthropic' },
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
system_configuration: {
enabled: false,
current_quota_type: CurrentSystemQuotaTypeEnum.free,
quota_configurations: [mockQuotaConfig],
},
})
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
expect(screen.getByText('common.modelProvider.emptyProviderTitle')).toBeInTheDocument()
})
it('should show none-configured warning when providers exist but no default models set', () => {
render(<ModelProviderPage searchText="" />)
expect(screen.getByText('common.modelProvider.noneConfigured')).toBeInTheDocument()
})
it('should show partially-configured warning when some default models are set', () => {
mockDefaultModels.llm = {
data: { model: 'gpt-4', model_type: 'llm', provider: { provider: 'openai', icon_small: { en_US: '' } } },
isLoading: false,
}
render(<ModelProviderPage searchText="" />)
expect(screen.getByText('common.modelProvider.notConfigured')).toBeInTheDocument()
})
it('should not show warning when all default models are configured', () => {
const makeModel = (model: string, type: string) => ({
data: { model, model_type: type, provider: { provider: 'openai', icon_small: { en_US: '' } } },
isLoading: false,
})
mockDefaultModels.llm = makeModel('gpt-4', 'llm')
mockDefaultModels['text-embedding'] = makeModel('text-embedding-3', 'text-embedding')
mockDefaultModels.rerank = makeModel('rerank-v3', 'rerank')
mockDefaultModels.speech2text = makeModel('whisper-1', 'speech2text')
mockDefaultModels.tts = makeModel('tts-1', 'tts')
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
it('should not show warning while loading', () => {
Object.keys(mockDefaultModels).forEach((key) => {
mockDefaultModels[key] = { data: null, isLoading: true }
})
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
})
it('should prioritize fixed providers in visible order', () => {
mockProviders.splice(0, mockProviders.length, {
provider: 'zeta-provider',

View File

@@ -1,17 +1,14 @@
import type {
ModelProvider,
} from './declarations'
import {
RiAlertFill,
RiBrainLine,
} from '@remixicon/react'
import type { PluginDetail } from '@/app/components/plugins/types'
import { useDebounce } from 'ahooks'
import { useEffect, useMemo } from 'react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { IS_CLOUD_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useSystemFeaturesQuery } from '@/context/global-public-context'
import { useProviderContext } from '@/context/provider-context'
import { useCheckInstalled } from '@/service/use-plugins'
import { cn } from '@/utils/classnames'
import {
CustomConfigurationStatusEnum,
@@ -24,6 +21,9 @@ import InstallFromMarketplace from './install-from-marketplace'
import ProviderAddedCard from './provider-added-card'
import QuotaPanel from './provider-added-card/quota-panel'
import SystemModelSelector from './system-model-selector'
import { providerToPluginId } from './utils'
type SystemModelConfigStatus = 'no-provider' | 'none-configured' | 'partially-configured' | 'fully-configured'
type Props = {
searchText: string
@@ -34,20 +34,35 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an
const ModelProviderPage = ({ searchText }: Props) => {
const debouncedSearchText = useDebounce(searchText, { wait: 500 })
const { t } = useTranslation()
const { mutateCurrentWorkspace, isValidatingCurrentWorkspace } = useAppContext()
const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration)
const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding)
const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank)
const { data: speech2textDefaultModel, isLoading: isSpeech2textDefaultModelLoading } = useDefaultModel(ModelTypeEnum.speech2text)
const { data: ttsDefaultModel, isLoading: isTTSDefaultModelLoading } = useDefaultModel(ModelTypeEnum.tts)
const { modelProviders: providers } = useProviderContext()
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
const { data: systemFeatures } = useSystemFeaturesQuery()
const allPluginIds = useMemo(() => {
return [...new Set(providers.map(p => providerToPluginId(p.provider)).filter(Boolean))]
}, [providers])
const { data: installedPlugins } = useCheckInstalled({
pluginIds: allPluginIds,
enabled: allPluginIds.length > 0,
})
const pluginDetailMap = useMemo(() => {
const map = new Map<string, PluginDetail>()
if (installedPlugins?.plugins) {
for (const plugin of installedPlugins.plugins)
map.set(plugin.plugin_id, plugin)
}
return map
}, [installedPlugins])
const enableMarketplace = systemFeatures?.enable_marketplace ?? false
const isDefaultModelLoading = isTextGenerationDefaultModelLoading
|| isEmbeddingsDefaultModelLoading
|| isRerankDefaultModelLoading
|| isSpeech2textDefaultModelLoading
|| isTTSDefaultModelLoading
const defaultModelNotConfigured = !isDefaultModelLoading && !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel
const [configuredProviders, notConfiguredProviders] = useMemo(() => {
const configuredProviders: ModelProvider[] = []
const notConfiguredProviders: ModelProvider[] = []
@@ -79,6 +94,26 @@ const ModelProviderPage = ({ searchText }: Props) => {
return [configuredProviders, notConfiguredProviders]
}, [providers])
const systemModelConfigStatus: SystemModelConfigStatus = useMemo(() => {
const defaultModels = [textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel]
const configuredCount = defaultModels.filter(Boolean).length
if (configuredCount === 0 && configuredProviders.length === 0)
return 'no-provider'
if (configuredCount === 0)
return 'none-configured'
if (configuredCount < defaultModels.length)
return 'partially-configured'
return 'fully-configured'
}, [configuredProviders, textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel])
const warningTextKey
= systemModelConfigStatus === 'none-configured'
? 'modelProvider.noneConfigured'
: systemModelConfigStatus === 'partially-configured'
? 'modelProvider.notConfigured'
: null
const showWarning = !isDefaultModelLoading && !!warningTextKey
const [filteredConfiguredProviders, filteredNotConfiguredProviders] = useMemo(() => {
const filteredConfiguredProviders = configuredProviders.filter(
provider => provider.provider.toLowerCase().includes(debouncedSearchText.toLowerCase())
@@ -92,28 +127,24 @@ const ModelProviderPage = ({ searchText }: Props) => {
return [filteredConfiguredProviders, filteredNotConfiguredProviders]
}, [configuredProviders, debouncedSearchText, notConfiguredProviders])
useEffect(() => {
mutateCurrentWorkspace()
}, [mutateCurrentWorkspace])
return (
<div className="relative -mt-2 pt-1">
<div className={cn('mb-2 flex items-center')}>
<div className="system-md-semibold grow text-text-primary">{t('modelProvider.models', { ns: 'common' })}</div>
<div className="grow text-text-primary system-md-semibold">{t('modelProvider.models', { ns: 'common' })}</div>
<div className={cn(
'relative flex shrink-0 items-center justify-end gap-2 rounded-lg border border-transparent p-px',
defaultModelNotConfigured && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
showWarning && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
)}
>
{defaultModelNotConfigured && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
{defaultModelNotConfigured && (
<div className="system-xs-medium flex items-center gap-1 text-text-primary">
<RiAlertFill className="h-4 w-4 text-text-warning-secondary" />
<span className="max-w-[460px] truncate" title={t('modelProvider.notConfigured', { ns: 'common' })}>{t('modelProvider.notConfigured', { ns: 'common' })}</span>
{showWarning && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
{showWarning && (
<div className="flex items-center gap-1 text-text-primary system-xs-medium">
<span className="i-ri-alert-fill h-4 w-4 text-text-warning-secondary" />
<span className="max-w-[460px] truncate" title={t(warningTextKey, { ns: 'common' })}>{t(warningTextKey, { ns: 'common' })}</span>
</div>
)}
<SystemModelSelector
notConfigured={defaultModelNotConfigured}
notConfigured={showWarning}
textGenerationDefaultModel={textGenerationDefaultModel}
embeddingsDefaultModel={embeddingsDefaultModel}
rerankDefaultModel={rerankDefaultModel}
@@ -123,14 +154,14 @@ const ModelProviderPage = ({ searchText }: Props) => {
/>
</div>
</div>
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} isLoading={isValidatingCurrentWorkspace} />}
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} />}
{!filteredConfiguredProviders?.length && (
<div className="mb-2 rounded-[10px] bg-workflow-process-bg p-4">
<div className="flex h-10 w-10 items-center justify-center rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg backdrop-blur">
<RiBrainLine className="h-5 w-5 text-text-primary" />
<span className="i-ri-brain-line h-5 w-5 text-text-primary" />
</div>
<div className="system-sm-medium mt-2 text-text-secondary">{t('modelProvider.emptyProviderTitle', { ns: 'common' })}</div>
<div className="system-xs-regular mt-1 text-text-tertiary">{t('modelProvider.emptyProviderTip', { ns: 'common' })}</div>
<div className="mt-2 text-text-secondary system-sm-medium">{t('modelProvider.emptyProviderTitle', { ns: 'common' })}</div>
<div className="mt-1 text-text-tertiary system-xs-regular">{t('modelProvider.emptyProviderTip', { ns: 'common' })}</div>
</div>
)}
{!!filteredConfiguredProviders?.length && (
@@ -139,26 +170,28 @@ const ModelProviderPage = ({ searchText }: Props) => {
<ProviderAddedCard
key={provider.provider}
provider={provider}
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
/>
))}
</div>
)}
{!!filteredNotConfiguredProviders?.length && (
<>
<div className="system-md-semibold mb-2 flex items-center pt-2 text-text-primary">{t('modelProvider.toBeConfigured', { ns: 'common' })}</div>
<div className="mb-2 flex items-center pt-2 text-text-primary system-md-semibold">{t('modelProvider.toBeConfigured', { ns: 'common' })}</div>
<div className="relative">
{filteredNotConfiguredProviders?.map(provider => (
<ProviderAddedCard
notConfigured
key={provider.provider}
provider={provider}
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
/>
))}
</div>
</>
)}
{
enable_marketplace && (
enableMarketplace && (
<InstallFromMarketplace
providers={providers}
searchText={searchText}

View File

@@ -2,12 +2,6 @@ import type { Credential } from '../../declarations'
import { fireEvent, render, screen } from '@testing-library/react'
import CredentialItem from './credential-item'
vi.mock('@remixicon/react', () => ({
RiCheckLine: () => <div data-testid="check-icon" />,
RiDeleteBinLine: () => <div data-testid="delete-icon" />,
RiEqualizer2Line: () => <div data-testid="edit-icon" />,
}))
vi.mock('@/app/components/header/indicator', () => ({
default: () => <div data-testid="indicator" />,
}))
@@ -61,8 +55,12 @@ describe('CredentialItem', () => {
render(<CredentialItem credential={credential} onEdit={onEdit} onDelete={onDelete} />)
fireEvent.click(screen.getByTestId('edit-icon').closest('button') as HTMLButtonElement)
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
const buttons = screen.getAllByRole('button')
const editButton = buttons.find(b => b.querySelector('.i-ri-equalizer-2-line'))!
const deleteButton = buttons.find(b => b.querySelector('.i-ri-delete-bin-line'))!
fireEvent.click(editButton)
fireEvent.click(deleteButton)
expect(onEdit).toHaveBeenCalledWith(credential)
expect(onDelete).toHaveBeenCalledWith(credential)
@@ -81,7 +79,10 @@ describe('CredentialItem', () => {
/>,
)
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
const deleteButton = screen.getAllByRole('button')
.find(b => b.querySelector('.i-ri-delete-bin-line'))!
fireEvent.click(deleteButton)
expect(onDelete).not.toHaveBeenCalled()
})

View File

@@ -1,9 +1,4 @@
import type { Credential } from '../../declarations'
import {
RiCheckLine,
RiDeleteBinLine,
RiEqualizer2Line,
} from '@remixicon/react'
import {
memo,
useMemo,
@@ -11,7 +6,7 @@ import {
import { useTranslation } from 'react-i18next'
import ActionButton from '@/app/components/base/action-button'
import Badge from '@/app/components/base/badge'
import Tooltip from '@/app/components/base/tooltip'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import Indicator from '@/app/components/header/indicator'
import { cn } from '@/utils/classnames'
@@ -56,7 +51,7 @@ const CredentialItem = ({
key={credential.credential_id}
className={cn(
'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover',
(disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50',
(disabled || credential.not_allowed_to_use) ? 'cursor-not-allowed opacity-50' : onItemClick && 'cursor-pointer',
)}
onClick={() => {
if (disabled || credential.not_allowed_to_use)
@@ -70,7 +65,7 @@ const CredentialItem = ({
<div className="h-4 w-4">
{
selectedCredentialId === credential.credential_id && (
<RiCheckLine className="h-4 w-4 text-text-accent" />
<span className="i-ri-check-line h-4 w-4 text-text-accent" />
)
}
</div>
@@ -78,7 +73,7 @@ const CredentialItem = ({
}
<Indicator className="ml-2 mr-1.5 shrink-0" />
<div
className="system-md-regular truncate text-text-secondary"
className="truncate text-text-secondary system-md-regular"
title={credential.credential_name}
>
{credential.credential_name}
@@ -96,38 +91,50 @@ const CredentialItem = ({
<div className="ml-2 hidden shrink-0 items-center group-hover:flex">
{
!disableEdit && !credential.not_allowed_to_use && (
<Tooltip popupContent={t('operation.edit', { ns: 'common' })}>
<ActionButton
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
onEdit?.(credential)
}}
>
<RiEqualizer2Line className="h-4 w-4 text-text-tertiary" />
</ActionButton>
<Tooltip>
<TooltipTrigger
render={(
<ActionButton
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
onEdit?.(credential)
}}
>
<span className="i-ri-equalizer-2-line h-4 w-4 text-text-tertiary" />
</ActionButton>
)}
/>
<TooltipContent>{t('operation.edit', { ns: 'common' })}</TooltipContent>
</Tooltip>
)
}
{
!disableDelete && (
<Tooltip popupContent={disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}>
<ActionButton
className="hover:bg-transparent"
onClick={(e) => {
if (disabled || disableDeleteWhenSelected)
return
e.stopPropagation()
onDelete?.(credential)
}}
>
<RiDeleteBinLine className={cn(
'h-4 w-4 text-text-tertiary',
!disableDeleteWhenSelected && 'hover:text-text-destructive',
disableDeleteWhenSelected && 'opacity-50',
<Tooltip>
<TooltipTrigger
render={(
<ActionButton
className="hover:bg-transparent"
onClick={(e) => {
if (disabled || disableDeleteWhenSelected)
return
e.stopPropagation()
onDelete?.(credential)
}}
>
<span className={cn(
'i-ri-delete-bin-line h-4 w-4 text-text-tertiary',
!disableDeleteWhenSelected && 'hover:text-text-destructive',
disableDeleteWhenSelected && 'opacity-50',
)}
/>
</ActionButton>
)}
/>
</ActionButton>
/>
<TooltipContent>
{disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}
</TooltipContent>
</Tooltip>
)
}
@@ -139,8 +146,9 @@ const CredentialItem = ({
if (credential.not_allowed_to_use) {
return (
<Tooltip popupContent={t('auth.customCredentialUnavailable', { ns: 'plugin' })}>
{Item}
<Tooltip>
<TooltipTrigger render={Item} />
<TooltipContent>{t('auth.customCredentialUnavailable', { ns: 'plugin' })}</TooltipContent>
</Tooltip>
)
}

View File

@@ -10,7 +10,7 @@ const ModelBadge: FC<ModelBadgeProps> = ({
children,
}) => {
return (
<div className={cn('system-2xs-medium-uppercase flex h-[18px] cursor-default items-center rounded-[5px] border border-divider-deep px-1 text-text-tertiary', className)}>
<div className={cn('inline-flex h-[18px] shrink-0 items-center justify-center whitespace-nowrap rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-[5px] text-text-tertiary system-2xs-medium-uppercase', className)}>
{children}
</div>
)

View File

@@ -1,5 +1,5 @@
import type { Credential, CredentialFormSchema, ModelProvider } from '../declarations'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react'
import {
ConfigurationMethodEnum,
CurrentSystemQuotaTypeEnum,
@@ -243,9 +243,10 @@ describe('ModelModal', () => {
const credential: Credential = { credential_id: 'cred-1' }
const { onCancel } = renderModal({ credential })
expect(screen.getByText('common.modelProvider.confirmDelete')).toBeInTheDocument()
const alertDialog = screen.getByRole('alertdialog', { hidden: true })
expect(alertDialog).toHaveTextContent('common.modelProvider.confirmDelete')
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
fireEvent.click(within(alertDialog).getByRole('button', { hidden: true, name: 'common.operation.confirm' }))
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalledTimes(1)
expect(onCancel).toHaveBeenCalledTimes(1)

View File

@@ -9,11 +9,9 @@ import type {
FormRefObject,
FormSchema,
} from '@/app/components/base/form/types'
import { RiCloseLine } from '@remixicon/react'
import {
memo,
useCallback,
useEffect,
useMemo,
useRef,
useState,
@@ -21,15 +19,23 @@ import {
import { useTranslation } from 'react-i18next'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general'
import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
import Loading from '@/app/components/base/loading'
import {
PortalToFollowElem,
PortalToFollowElemContent,
} from '@/app/components/base/portal-to-follow-elem'
AlertDialog,
AlertDialogActions,
AlertDialogCancelButton,
AlertDialogConfirmButton,
AlertDialogContent,
AlertDialogTitle,
} from '@/app/components/base/ui/alert-dialog'
import {
Dialog,
DialogCloseButton,
DialogContent,
} from '@/app/components/base/ui/dialog'
import {
useAuth,
useCredentialData,
@@ -197,7 +203,7 @@ const ModelModal: FC<ModelModalProps> = ({
}
return (
<div className="title-2xl-semi-bold text-text-primary">
<div className="text-text-primary title-2xl-semi-bold">
{label}
</div>
)
@@ -206,7 +212,7 @@ const ModelModal: FC<ModelModalProps> = ({
const modalDesc = useMemo(() => {
if (providerFormSchemaPredefined) {
return (
<div className="system-xs-regular mt-1 text-text-tertiary">
<div className="mt-1 text-text-tertiary system-xs-regular">
{t('modelProvider.auth.apiKeyModal.desc', { ns: 'common' })}
</div>
)
@@ -223,7 +229,7 @@ const ModelModal: FC<ModelModalProps> = ({
className="mr-2 h-4 w-4 shrink-0"
provider={provider}
/>
<div className="system-md-regular mr-1 text-text-secondary">{renderI18nObject(provider.label)}</div>
<div className="mr-1 text-text-secondary system-md-regular">{renderI18nObject(provider.label)}</div>
</div>
)
}
@@ -235,7 +241,7 @@ const ModelModal: FC<ModelModalProps> = ({
provider={provider}
modelName={model.model}
/>
<div className="system-md-regular mr-1 text-text-secondary">{model.model}</div>
<div className="mr-1 text-text-secondary system-md-regular">{model.model}</div>
<Badge>{model.model_type}</Badge>
</div>
)
@@ -275,174 +281,171 @@ const ModelModal: FC<ModelModalProps> = ({
}, [])
const notAllowCustomCredential = provider.allow_custom_token === false
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
event.stopPropagation()
onCancel()
}
}
document.addEventListener('keydown', handleKeyDown, true)
return () => {
document.removeEventListener('keydown', handleKeyDown, true)
}
const handleOpenChange = useCallback((open: boolean) => {
if (!open)
onCancel()
}, [onCancel])
const handleConfirmOpenChange = useCallback((open: boolean) => {
if (!open)
closeConfirmDelete()
}, [closeConfirmDelete])
return (
<PortalToFollowElem open>
<PortalToFollowElemContent className="z-[60] h-full w-full">
<div className="fixed inset-0 flex items-center justify-center bg-black/[.25]">
<div className="relative w-[640px] rounded-2xl bg-components-panel-bg shadow-xl">
<div
className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center"
onClick={onCancel}
>
<RiCloseLine className="h-4 w-4 text-text-tertiary" />
</div>
<div className="p-6 pb-3">
{modalTitle}
{modalDesc}
{modalModel}
</div>
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
{
mode === ModelModalModeEnum.configCustomModel && (
<AuthForm
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
}
}) as FormSchema[]}
defaultValues={modelNameAndTypeFormValues}
inputClassName="justify-start"
ref={formRef1}
onChange={handleModelNameAndTypeChange}
/>
)
}
{
mode === ModelModalModeEnum.addCustomModelToModelList && (
<CredentialSelector
credentials={available_credentials || []}
onSelect={setSelectedCredential}
selectedCredential={selectedCredential}
disabled={isLoading}
notAllowAddNewCredential={notAllowCustomCredential}
/>
)
}
{
showCredentialLabel && (
<div className="system-xs-medium-uppercase mb-3 mt-6 flex items-center text-text-tertiary">
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
</div>
)
}
{
isLoading && (
<div className="mt-3 flex items-center justify-center">
<Loading />
</div>
)
}
{
!isLoading
&& showCredentialForm
&& (
<AuthForm
formSchemas={formSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
showRadioUI: formSchema.type === FormTypeEnum.radio,
}
}) as FormSchema[]}
defaultValues={formValues}
inputClassName="justify-start"
ref={formRef2}
/>
)
}
</div>
<div className="flex justify-between p-6 pt-5">
{
(provider.help && (provider.help.title || provider.help.url))
? (
<a
href={provider.help?.url[language] || provider.help?.url.en_US}
target="_blank"
rel="noopener noreferrer"
className="system-xs-regular mt-2 inline-block align-middle text-text-accent"
onClick={e => !provider.help.url && e.preventDefault()}
>
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
</a>
)
: <div />
}
<div className="ml-2 flex items-center justify-end space-x-2">
{
isEditMode && (
<Button
variant="warning"
onClick={() => openConfirmDelete(credential, model)}
>
{t('operation.remove', { ns: 'common' })}
</Button>
)
}
<Button
onClick={onCancel}
>
{t('operation.cancel', { ns: 'common' })}
</Button>
<Button
variant="primary"
onClick={handleSave}
disabled={isLoading || doingAction}
>
{saveButtonText}
</Button>
</div>
</div>
{
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
<div className="border-t-[0.5px] border-t-divider-regular">
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
{t('modelProvider.encrypted.front', { ns: 'common' })}
<a
className="mx-1 text-text-accent"
target="_blank"
rel="noopener noreferrer"
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
>
PKCS1_OAEP
</a>
{t('modelProvider.encrypted.back', { ns: 'common' })}
</div>
</div>
)
}
</div>
<Dialog open onOpenChange={handleOpenChange}>
<DialogContent
backdropProps={{ forceRender: true }}
className="w-[640px] max-w-[640px] overflow-hidden p-0"
>
<DialogCloseButton className="right-5 top-5 h-8 w-8" />
<div className="p-6 pb-3">
{modalTitle}
{modalDesc}
{modalModel}
</div>
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
{
deleteCredentialId && (
<Confirm
isShow
title={t('modelProvider.confirmDelete', { ns: 'common' })}
isDisabled={doingAction}
onCancel={closeConfirmDelete}
onConfirm={handleDeleteCredential}
mode === ModelModalModeEnum.configCustomModel && (
<AuthForm
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
}
}) as FormSchema[]}
defaultValues={modelNameAndTypeFormValues}
inputClassName="justify-start"
ref={formRef1}
onChange={handleModelNameAndTypeChange}
/>
)
}
{
mode === ModelModalModeEnum.addCustomModelToModelList && (
<CredentialSelector
credentials={available_credentials || []}
onSelect={setSelectedCredential}
selectedCredential={selectedCredential}
disabled={isLoading}
notAllowAddNewCredential={notAllowCustomCredential}
/>
)
}
{
showCredentialLabel && (
<div className="mb-3 mt-6 flex items-center text-text-tertiary system-xs-medium-uppercase">
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
</div>
)
}
{
isLoading && (
<div className="mt-3 flex items-center justify-center">
<Loading />
</div>
)
}
{
!isLoading
&& showCredentialForm
&& (
<AuthForm
formSchemas={formSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
showRadioUI: formSchema.type === FormTypeEnum.radio,
}
}) as FormSchema[]}
defaultValues={formValues}
inputClassName="justify-start"
ref={formRef2}
/>
)
}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<div className="flex justify-between p-6 pt-5">
{
(provider.help && (provider.help.title || provider.help.url))
? (
<a
href={provider.help?.url[language] || provider.help?.url.en_US}
target="_blank"
rel="noopener noreferrer"
className="mt-2 inline-block align-middle text-text-accent system-xs-regular"
onClick={e => !provider.help.url && e.preventDefault()}
>
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
</a>
)
: <div />
}
<div className="ml-2 flex items-center justify-end space-x-2">
{
isEditMode && (
<Button
variant="warning"
onClick={() => openConfirmDelete(credential, model)}
>
{t('operation.remove', { ns: 'common' })}
</Button>
)
}
<Button
onClick={onCancel}
>
{t('operation.cancel', { ns: 'common' })}
</Button>
<Button
variant="primary"
onClick={handleSave}
disabled={isLoading || doingAction}
>
{saveButtonText}
</Button>
</div>
</div>
{
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
<div className="border-t-[0.5px] border-t-divider-regular">
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
{t('modelProvider.encrypted.front', { ns: 'common' })}
<a
className="mx-1 text-text-accent"
target="_blank"
rel="noopener noreferrer"
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
>
PKCS1_OAEP
</a>
{t('modelProvider.encrypted.back', { ns: 'common' })}
</div>
</div>
)
}
</DialogContent>
<AlertDialog open={!!deleteCredentialId} onOpenChange={handleConfirmOpenChange}>
<AlertDialogContent backdropProps={{ forceRender: true }}>
<div className="flex flex-col gap-2 p-6 pb-4">
<AlertDialogTitle className="text-text-primary title-2xl-semi-bold">
{t('modelProvider.confirmDelete', { ns: 'common' })}
</AlertDialogTitle>
</div>
<AlertDialogActions>
<AlertDialogCancelButton>{t('operation.cancel', { ns: 'common' })}</AlertDialogCancelButton>
<AlertDialogConfirmButton
disabled={doingAction}
onClick={handleDeleteCredential}
>
{t('operation.confirm', { ns: 'common' })}
</AlertDialogConfirmButton>
</AlertDialogActions>
</AlertDialogContent>
</AlertDialog>
</Dialog>
)
}

View File

@@ -1,7 +1,5 @@
import type { FC } from 'react'
import { RiEqualizer2Line } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
import { cn } from '@/utils/classnames'
type ModelTriggerProps = {
@@ -16,24 +14,26 @@ const ModelTrigger: FC<ModelTriggerProps> = ({
return (
<div
className={cn(
'flex cursor-pointer items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1 hover:bg-components-input-bg-hover',
'group flex h-8 cursor-pointer items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1 hover:bg-components-input-bg-hover',
open && 'bg-components-input-bg-hover',
className,
)}
>
<div className="flex grow items-center">
<div className="mr-1.5 flex h-4 w-4 items-center justify-center rounded-[5px] border border-dashed border-divider-regular">
<CubeOutline className="h-3 w-3 text-text-quaternary" />
<div className="flex h-6 w-6 items-center justify-center">
<div className="flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle">
<span className="i-ri-brain-2-line h-3.5 w-3.5 text-text-quaternary" />
</div>
</div>
<div className="flex grow items-center gap-1 truncate px-1 py-[3px]">
<div
className="truncate text-[13px] text-text-tertiary"
className="grow truncate text-[13px] text-text-quaternary"
title="Configure model"
>
{t('detailPanel.configureModel', { ns: 'plugin' })}
</div>
</div>
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
<RiEqualizer2Line className="h-3.5 w-3.5 text-text-tertiary" />
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
<span className="i-ri-arrow-down-s-line h-3.5 w-3.5 text-text-tertiary" />
</div>
</div>
</div>
)

View File

@@ -1,17 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import AddModelButton from './add-model-button'
describe('AddModelButton', () => {
it('should render button with text', () => {
render(<AddModelButton onClick={vi.fn()} />)
expect(screen.getByText('common.modelProvider.addModel')).toBeInTheDocument()
})
it('should call onClick when clicked', () => {
const handleClick = vi.fn()
render(<AddModelButton onClick={handleClick} />)
const button = screen.getByText('common.modelProvider.addModel')
fireEvent.click(button)
expect(handleClick).toHaveBeenCalledTimes(1)
})
})

View File

@@ -1,27 +0,0 @@
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import { PlusCircle } from '@/app/components/base/icons/src/vender/solid/general'
import { cn } from '@/utils/classnames'
type AddModelButtonProps = {
className?: string
onClick: () => void
}
const AddModelButton: FC<AddModelButtonProps> = ({
className,
onClick,
}) => {
const { t } = useTranslation()
return (
<span
className={cn('system-xs-medium flex h-6 shrink-0 cursor-pointer items-center rounded-md px-1.5 text-text-tertiary hover:bg-components-button-ghost-bg-hover hover:text-components-button-ghost-text', className)}
onClick={onClick}
>
<PlusCircle className="mr-1 h-3 w-3" />
{t('modelProvider.addModel', { ns: 'common' })}
</span>
)
}
export default AddModelButton

View File

@@ -1,51 +1,54 @@
import type { ModelProvider } from '../declarations'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { changeModelProviderPriority } from '@/service/common'
import { ConfigurationMethodEnum } from '../declarations'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import {
ConfigurationMethodEnum,
CurrentSystemQuotaTypeEnum,
CustomConfigurationStatusEnum,
PreferredProviderTypeEnum,
} from '../declarations'
import CredentialPanel from './credential-panel'
const mockEventEmitter = { emit: vi.fn() }
const mockNotify = vi.fn()
const mockUpdateModelList = vi.fn()
const mockUpdateModelProviders = vi.fn()
const mockCredentialStatus = {
hasCredential: true,
authorized: true,
authRemoved: false,
current_credential_name: 'test-credential',
notAllowedToUse: false,
}
const {
mockToastNotify,
mockUpdateModelList,
mockUpdateModelProviders,
mockTrialCredits,
mockChangePriorityFn,
} = vi.hoisted(() => ({
mockToastNotify: vi.fn(),
mockUpdateModelList: vi.fn(),
mockUpdateModelProviders: vi.fn(),
mockTrialCredits: { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false, nextCreditResetDate: undefined },
mockChangePriorityFn: vi.fn().mockResolvedValue({ result: 'success' }),
}))
vi.mock('@/config', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/config')>()
return {
...actual,
IS_CLOUD_EDITION: true,
}
return { ...actual, IS_CLOUD_EDITION: true }
})
vi.mock('@/app/components/base/toast', () => ({
useToastContext: () => ({
notify: mockNotify,
}),
default: { notify: mockToastNotify },
}))
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: () => ({
eventEmitter: mockEventEmitter,
}),
}))
vi.mock('@/service/common', () => ({
changeModelProviderPriority: vi.fn(),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth', () => ({
ConfigProvider: () => <div data-testid="config-provider" />,
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth/hooks', () => ({
useCredentialStatus: () => mockCredentialStatus,
vi.mock('@/service/client', () => ({
consoleQuery: {
modelProviders: {
models: {
queryKey: ({ input }: { input: { params: { provider: string } } }) => ['console', 'modelProviders', 'models', input.params.provider],
},
changePreferredProviderType: {
mutationOptions: (opts: Record<string, unknown>) => ({
mutationFn: (...args: unknown[]) => {
mockChangePriorityFn(...args)
return Promise.resolve({ result: 'success' })
},
...opts,
}),
},
},
},
}))
vi.mock('../hooks', () => ({
@@ -53,93 +56,375 @@ vi.mock('../hooks', () => ({
useUpdateModelProviders: () => mockUpdateModelProviders,
}))
vi.mock('./priority-selector', () => ({
default: ({ value, onSelect }: { value: string, onSelect: (key: string) => void }) => (
<button data-testid="priority-selector" onClick={() => onSelect('custom')}>
Priority Selector
{' '}
{value}
</button>
vi.mock('./use-trial-credits', () => ({
useTrialCredits: () => mockTrialCredits,
}))
vi.mock('./model-auth-dropdown', () => ({
default: ({ state, onChangePriority }: { state: { variant: string, hasCredentials: boolean }, onChangePriority: (key: string) => void }) => (
<div data-testid="model-auth-dropdown" data-variant={state.variant}>
<button data-testid="change-priority-btn" onClick={() => onChangePriority('custom')}>
Change Priority
</button>
</div>
),
}))
vi.mock('./priority-use-tip', () => ({
default: () => <div data-testid="priority-use-tip">Priority Tip</div>,
vi.mock('@/app/components/header/indicator', () => ({
default: ({ color }: { color: string }) => <div data-testid="indicator" data-color={color} />,
}))
vi.mock('@/app/components/header/indicator', () => ({
default: ({ color }: { color: string }) => <div data-testid="indicator">{color}</div>,
vi.mock('@/app/components/base/icons/src/vender/line/alertsAndFeedback/Warning', () => ({
default: (props: Record<string, unknown>) => <div data-testid="warning-icon" className={props.className as string} />,
}))
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false, gcTime: 0 },
mutations: { retry: false },
},
})
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
provider: 'test-provider',
provider_credential_schema: { credential_form_schemas: [] },
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: 'cred-1',
current_credential_name: 'test-credential',
available_credentials: [{ credential_id: 'cred-1', credential_name: 'test-credential' }],
},
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
preferred_provider_type: PreferredProviderTypeEnum.system,
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
supported_model_types: ['llm'],
...overrides,
} as unknown as ModelProvider)
const renderWithQueryClient = (provider: ModelProvider) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
<CredentialPanel provider={provider} />
</QueryClientProvider>,
)
}
describe('CredentialPanel', () => {
const mockProvider: ModelProvider = {
provider: 'test-provider',
provider_credential_schema: true,
custom_configuration: { status: 'active' },
system_configuration: { enabled: true },
preferred_provider_type: 'system',
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
supported_model_types: ['gpt-4'],
} as unknown as ModelProvider
beforeEach(() => {
vi.clearAllMocks()
Object.assign(mockCredentialStatus, {
hasCredential: true,
authorized: true,
authRemoved: false,
current_credential_name: 'test-credential',
notAllowedToUse: false,
Object.assign(mockTrialCredits, { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false })
})
describe('Text label variants', () => {
it('should show "AI credits in use" for credits-active variant', () => {
renderWithQueryClient(createProvider())
expect(screen.getByText(/aiCreditsInUse/)).toBeInTheDocument()
})
it('should show "Credits exhausted" for credits-exhausted variant (no credentials)', () => {
mockTrialCredits.isExhausted = true
mockTrialCredits.credits = 0
renderWithQueryClient(createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(screen.getByText(/quotaExhausted/)).toBeInTheDocument()
})
it('should show "No available usage" for no-usage variant (exhausted + credential unauthorized)', () => {
mockTrialCredits.isExhausted = true
renderWithQueryClient(createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: undefined,
available_credentials: [{ credential_id: 'cred-1' }],
},
}))
expect(screen.getByText(/noAvailableUsage/)).toBeInTheDocument()
})
it('should show "AI credits in use" with warning for credits-fallback (custom priority, no credentials, credits available)', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(screen.getByText(/aiCreditsInUse/)).toBeInTheDocument()
})
it('should show "AI credits in use" with warning for credits-fallback (custom priority, credential unauthorized, credits available)', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: undefined,
available_credentials: [{ credential_id: 'cred-1' }],
},
}))
expect(screen.getByText(/aiCreditsInUse/)).toBeInTheDocument()
})
it('should show warning icon for credits-fallback variant', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(screen.getByTestId('warning-icon')).toBeInTheDocument()
})
})
it('should show credential name and configuration actions', () => {
render(<CredentialPanel provider={mockProvider} />)
describe('Status label variants', () => {
it('should show green indicator and credential name for api-fallback (exhausted + authorized key)', () => {
mockTrialCredits.isExhausted = true
renderWithQueryClient(createProvider())
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'green')
expect(screen.getByText('test-credential')).toBeInTheDocument()
})
expect(screen.getByText('test-credential')).toBeInTheDocument()
expect(screen.getByTestId('config-provider')).toBeInTheDocument()
expect(screen.getByTestId('priority-selector')).toBeInTheDocument()
})
it('should show warning icon for api-fallback variant', () => {
mockTrialCredits.isExhausted = true
renderWithQueryClient(createProvider())
expect(screen.getByTestId('warning-icon')).toBeInTheDocument()
})
it('should show unauthorized status label when credential is missing', () => {
mockCredentialStatus.hasCredential = false
render(<CredentialPanel provider={mockProvider} />)
it('should show green indicator for api-active (custom priority + authorized)', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
}))
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'green')
expect(screen.getByText('test-credential')).toBeInTheDocument()
})
expect(screen.getByText(/modelProvider\.auth\.unAuthorized/)).toBeInTheDocument()
})
it('should NOT show warning icon for api-active variant', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
}))
expect(screen.queryByTestId('warning-icon')).not.toBeInTheDocument()
})
it('should show removed credential label and priority tip for custom preference', () => {
mockCredentialStatus.authorized = false
mockCredentialStatus.authRemoved = true
render(<CredentialPanel provider={{ ...mockProvider, preferred_provider_type: 'custom' } as ModelProvider} />)
expect(screen.getByText(/modelProvider\.auth\.authRemoved/)).toBeInTheDocument()
expect(screen.getByTestId('priority-use-tip')).toBeInTheDocument()
})
it('should change priority and refresh related data after success', async () => {
const mockChangePriority = changeModelProviderPriority as ReturnType<typeof vi.fn>
mockChangePriority.mockResolvedValue({ result: 'success' })
render(<CredentialPanel provider={mockProvider} />)
fireEvent.click(screen.getByTestId('priority-selector'))
await waitFor(() => {
expect(mockChangePriority).toHaveBeenCalled()
expect(mockNotify).toHaveBeenCalled()
expect(mockUpdateModelProviders).toHaveBeenCalled()
expect(mockUpdateModelList).toHaveBeenCalledWith('gpt-4')
expect(mockEventEmitter.emit).toHaveBeenCalled()
it('should show red indicator and "Unavailable" for api-unavailable (exhausted + named unauthorized key)', () => {
mockTrialCredits.isExhausted = true
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: 'Bad Key',
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
},
}))
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'red')
expect(screen.getByText(/unavailable/i)).toBeInTheDocument()
expect(screen.getByText('Bad Key')).toBeInTheDocument()
})
})
it('should render standalone priority selector without provider schema', () => {
const providerNoSchema = {
...mockProvider,
provider_credential_schema: null,
} as unknown as ModelProvider
render(<CredentialPanel provider={providerNoSchema} />)
expect(screen.getByTestId('priority-selector')).toBeInTheDocument()
expect(screen.queryByTestId('config-provider')).not.toBeInTheDocument()
describe('Destructive styling', () => {
it('should apply destructive container for credits-exhausted', () => {
mockTrialCredits.isExhausted = true
const { container } = renderWithQueryClient(createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(container.querySelector('[class*="border-state-destructive"]')).toBeTruthy()
})
it('should apply destructive container for no-usage variant', () => {
mockTrialCredits.isExhausted = true
const { container } = renderWithQueryClient(createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: undefined,
available_credentials: [{ credential_id: 'cred-1' }],
},
}))
expect(container.querySelector('[class*="border-state-destructive"]')).toBeTruthy()
})
it('should apply destructive container for api-unavailable variant', () => {
mockTrialCredits.isExhausted = true
const { container } = renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: 'Bad Key',
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
},
}))
expect(container.querySelector('[class*="border-state-destructive"]')).toBeTruthy()
})
it('should apply default container for credits-active', () => {
const { container } = renderWithQueryClient(createProvider())
expect(container.querySelector('[class*="bg-white"]')).toBeTruthy()
})
it('should apply default container for api-active', () => {
const { container } = renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
}))
expect(container.querySelector('[class*="bg-white"]')).toBeTruthy()
})
it('should apply default container for api-fallback', () => {
mockTrialCredits.isExhausted = true
const { container } = renderWithQueryClient(createProvider())
expect(container.querySelector('[class*="bg-white"]')).toBeTruthy()
})
})
describe('Text color', () => {
it('should use destructive text color for credits-exhausted label', () => {
mockTrialCredits.isExhausted = true
const { container } = renderWithQueryClient(createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(container.querySelector('.text-text-destructive')).toBeTruthy()
})
it('should use secondary text color for credits-active label', () => {
const { container } = renderWithQueryClient(createProvider())
expect(container.querySelector('.text-text-secondary')).toBeTruthy()
})
})
describe('Priority change', () => {
it('should call mutation with correct params on priority change', async () => {
renderWithQueryClient(createProvider())
await act(async () => {
fireEvent.click(screen.getByTestId('change-priority-btn'))
})
await waitFor(() => {
expect(mockChangePriorityFn.mock.calls[0]?.[0]).toEqual({
params: { provider: 'test-provider' },
body: { preferred_provider_type: 'custom' },
})
})
})
it('should show success toast and refresh data after successful mutation', async () => {
renderWithQueryClient(createProvider())
await act(async () => {
fireEvent.click(screen.getByTestId('change-priority-btn'))
})
await waitFor(() => {
expect(mockToastNotify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'success' }),
)
expect(mockUpdateModelProviders).toHaveBeenCalled()
expect(mockUpdateModelList).toHaveBeenCalledWith('llm')
})
})
})
describe('ModelAuthDropdown integration', () => {
it('should pass credits-active variant to dropdown when credits available', () => {
renderWithQueryClient(createProvider())
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-active')
})
it('should pass api-fallback variant to dropdown when exhausted with valid key', () => {
mockTrialCredits.isExhausted = true
renderWithQueryClient(createProvider())
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-fallback')
})
it('should pass credits-exhausted variant when exhausted with no credentials', () => {
mockTrialCredits.isExhausted = true
renderWithQueryClient(createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-exhausted')
})
it('should pass api-active variant for custom priority with authorized key', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
}))
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-active')
})
it('should pass credits-fallback variant for custom priority with no credentials and credits available', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-fallback')
})
it('should pass credits-fallback variant for custom priority with named unauthorized key and credits available', () => {
renderWithQueryClient(createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: 'Bad Key',
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
},
}))
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-fallback')
})
it('should pass no-usage variant when exhausted + credential but unauthorized', () => {
mockTrialCredits.isExhausted = true
renderWithQueryClient(createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: undefined,
available_credentials: [{ credential_id: 'cred-1' }],
},
}))
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'no-usage')
})
})
describe('apiKeyOnly priority (system disabled)', () => {
it('should derive api-required-add when system config disabled and no credentials', () => {
renderWithQueryClient(createProvider({
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
preferred_provider_type: PreferredProviderTypeEnum.system,
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
}))
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-required-add')
expect(screen.getByText(/apiKeyRequired/)).toBeInTheDocument()
})
it('should derive api-active when system config disabled but has authorized key', () => {
renderWithQueryClient(createProvider({
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
}))
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-active')
})
})
})

View File

@@ -1,149 +1,157 @@
import type {
ModelProvider,
PreferredProviderTypeEnum,
} from '../declarations'
import { useMemo } from 'react'
import type { CardVariant } from './use-credential-panel-state'
import { useMutation, useQueryClient } from '@tanstack/react-query'
import { useTranslation } from 'react-i18next'
import { useToastContext } from '@/app/components/base/toast'
import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth'
import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks'
import Warning from '@/app/components/base/icons/src/vender/line/alertsAndFeedback/Warning'
import Toast from '@/app/components/base/toast'
import Indicator from '@/app/components/header/indicator'
import { IS_CLOUD_EDITION } from '@/config'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { changeModelProviderPriority } from '@/service/common'
import { cn } from '@/utils/classnames'
import { consoleQuery } from '@/service/client'
import {
ConfigurationMethodEnum,
CustomConfigurationStatusEnum,
PreferredProviderTypeEnum,
} from '../declarations'
import {
useUpdateModelList,
useUpdateModelProviders,
} from '../hooks'
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './index'
import PrioritySelector from './priority-selector'
import PriorityUseTip from './priority-use-tip'
import ModelAuthDropdown from './model-auth-dropdown'
import SystemQuotaCard from './system-quota-card'
import { isDestructiveVariant, useCredentialPanelState } from './use-credential-panel-state'
type CredentialPanelProps = {
provider: ModelProvider
}
const TEXT_LABEL_VARIANTS = new Set<CardVariant>([
'credits-active',
'credits-fallback',
'credits-exhausted',
'no-usage',
'api-required-add',
'api-required-configure',
])
const CredentialPanel = ({
provider,
}: CredentialPanelProps) => {
const { t } = useTranslation()
const { notify } = useToastContext()
const { eventEmitter } = useEventEmitterContextContext()
const queryClient = useQueryClient()
const updateModelList = useUpdateModelList()
const updateModelProviders = useUpdateModelProviders()
const customConfig = provider.custom_configuration
const systemConfig = provider.system_configuration
const priorityUseType = provider.preferred_provider_type
const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active
const configurateMethods = provider.configurate_methods
const {
hasCredential,
authorized,
authRemoved,
current_credential_name,
notAllowedToUse,
} = useCredentialStatus(provider)
const showPrioritySelector = systemConfig.enabled && isCustomConfigured && IS_CLOUD_EDITION
const handleChangePriority = async (key: PreferredProviderTypeEnum) => {
const res = await changeModelProviderPriority({
url: `/workspaces/current/model-providers/${provider.provider}/preferred-provider-type`,
body: {
preferred_provider_type: key,
const state = useCredentialPanelState(provider)
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { mutate: changePriority, isPending: isChangingPriority } = useMutation(
consoleQuery.modelProviders.changePreferredProviderType.mutationOptions({
onSuccess: () => {
Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
queryClient.invalidateQueries({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
updateModelProviders()
provider.configurate_methods.forEach((method) => {
if (method === ConfigurationMethodEnum.predefinedModel)
provider.supported_model_types.forEach(modelType => updateModelList(modelType))
})
},
onError: () => {
Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) })
},
}),
)
const handleChangePriority = (key: PreferredProviderTypeEnum) => {
changePriority({
params: { provider: provider.provider },
body: { preferred_provider_type: key },
})
if (res.result === 'success') {
notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
updateModelProviders()
configurateMethods.forEach((method) => {
if (method === ConfigurationMethodEnum.predefinedModel)
provider.supported_model_types.forEach(modelType => updateModelList(modelType))
})
eventEmitter?.emit({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: provider.provider,
} as any)
}
}
const credentialLabel = useMemo(() => {
if (!hasCredential)
return t('modelProvider.auth.unAuthorized', { ns: 'common' })
if (authorized)
return current_credential_name
if (authRemoved)
return t('modelProvider.auth.authRemoved', { ns: 'common' })
return ''
}, [authorized, authRemoved, current_credential_name, hasCredential])
const { variant, credentialName } = state
const isDestructive = isDestructiveVariant(variant)
const isTextLabel = TEXT_LABEL_VARIANTS.has(variant)
const needsGap = !isTextLabel || variant === 'credits-fallback'
const color = useMemo(() => {
if (authRemoved || !hasCredential)
return 'red'
if (notAllowedToUse)
return 'gray'
return 'green'
}, [authRemoved, notAllowedToUse, hasCredential])
return (
<SystemQuotaCard variant={isDestructive ? 'destructive' : 'default'}>
<SystemQuotaCard.Label className={needsGap ? 'gap-1' : undefined}>
{isTextLabel
? <TextLabel variant={variant} />
: <StatusLabel variant={variant} credentialName={credentialName} />}
</SystemQuotaCard.Label>
<SystemQuotaCard.Actions>
<ModelAuthDropdown
provider={provider}
state={state}
isChangingPriority={isChangingPriority}
onChangePriority={handleChangePriority}
/>
</SystemQuotaCard.Actions>
</SystemQuotaCard>
)
}
const TEXT_LABEL_KEYS = {
'credits-active': 'modelProvider.card.aiCreditsInUse',
'credits-fallback': 'modelProvider.card.aiCreditsInUse',
'credits-exhausted': 'modelProvider.card.quotaExhausted',
'no-usage': 'modelProvider.card.noAvailableUsage',
'api-required-add': 'modelProvider.card.apiKeyRequired',
'api-required-configure': 'modelProvider.card.apiKeyRequired',
} as const satisfies Partial<Record<CardVariant, string>>
function TextLabel({ variant }: { variant: CardVariant }) {
const { t } = useTranslation()
const isDestructive = isDestructiveVariant(variant)
const labelKey = TEXT_LABEL_KEYS[variant as keyof typeof TEXT_LABEL_KEYS]
return (
<>
{
provider.provider_credential_schema && (
<div className={cn(
'relative ml-1 w-[120px] shrink-0 rounded-lg border-[0.5px] border-components-panel-border bg-white/[0.18] p-1',
authRemoved && 'border-state-destructive-border bg-state-destructive-hover',
)}
>
<div className="system-xs-medium mb-1 flex h-5 items-center justify-between pl-2 pr-[7px] pt-1 text-text-tertiary">
<div
className={cn(
'grow truncate',
authRemoved && 'text-text-destructive',
)}
title={credentialLabel}
>
{credentialLabel}
</div>
<Indicator className="shrink-0" color={color} />
</div>
<div className="flex items-center gap-0.5">
<ConfigProvider
provider={provider}
/>
{
showPrioritySelector && (
<PrioritySelector
value={priorityUseType}
onSelect={handleChangePriority}
/>
)
}
</div>
{
priorityUseType === PreferredProviderTypeEnum.custom && systemConfig.enabled && (
<PriorityUseTip />
)
}
</div>
)
}
{
showPrioritySelector && !provider.provider_credential_schema && (
<div className="ml-1">
<PrioritySelector
value={priorityUseType}
onSelect={handleChangePriority}
/>
</div>
)
}
<span className={isDestructive ? 'text-text-destructive' : 'text-text-secondary'}>
{t(labelKey, { ns: 'common' })}
</span>
{variant === 'credits-fallback' && (
<Warning className="h-3 w-3 shrink-0 text-text-warning" />
)}
</>
)
}
function StatusLabel({ variant, credentialName }: {
variant: CardVariant
credentialName: string | undefined
}) {
const { t } = useTranslation()
const dotColor = variant === 'api-unavailable' ? 'red' : 'green'
const showWarning = variant === 'api-fallback'
return (
<>
<Indicator className="shrink-0" color={dotColor} />
<span
className="truncate text-text-secondary"
title={credentialName}
>
{credentialName}
</span>
{showWarning && (
<Warning className="h-3 w-3 shrink-0 text-text-warning" />
)}
{variant === 'api-unavailable' && (
<span className="shrink-0 text-text-destructive system-2xs-medium">
{t('modelProvider.card.unavailable', { ns: 'common' })}
</span>
)}
</>
)
}

View File

@@ -1,17 +1,28 @@
import type { ModelItem, ModelProvider } from '../declarations'
import type { ReactNode } from 'react'
import type { ModelProvider } from '../declarations'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { fetchModelProviderModelList } from '@/service/common'
import { createStore, Provider as JotaiProvider } from 'jotai'
import { useExpandModelProviderList } from '../atoms'
import { ConfigurationMethodEnum } from '../declarations'
import ProviderAddedCard from './index'
let mockIsCurrentWorkspaceManager = true
const mockEventEmitter = {
useSubscription: vi.fn(),
emit: vi.fn(),
}
const mockFetchModelProviderModels = vi.fn()
const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { provider: string } }, enabled?: boolean }) => ({
queryKey: ['console', 'modelProviders', 'models', input.params.provider],
queryFn: () => mockFetchModelProviderModels(input.params.provider),
...options,
}))
vi.mock('@/service/common', () => ({
fetchModelProviderModelList: vi.fn(),
vi.mock('@/service/client', () => ({
consoleQuery: {
modelProviders: {
models: {
queryOptions: (options: { input: { params: { provider: string } }, enabled?: boolean }) => mockQueryOptions(options),
},
},
},
}))
vi.mock('@/context/app-context', () => ({
@@ -20,12 +31,6 @@ vi.mock('@/context/app-context', () => ({
}),
}))
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: () => ({
eventEmitter: mockEventEmitter,
}),
}))
// Mock internal components to simplify testing of the index file
vi.mock('./credential-panel', () => ({
default: () => <div data-testid="credential-panel" />,
@@ -53,6 +58,38 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth'
ManageCustomModelCredentials: () => <div data-testid="manage-custom-model" />,
}))
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false, gcTime: 0 },
},
})
const renderWithQueryClient = (node: ReactNode) => {
const queryClient = createTestQueryClient()
const store = createStore()
return render(
<JotaiProvider store={store}>
<QueryClientProvider client={queryClient}>
{node}
</QueryClientProvider>
</JotaiProvider>,
)
}
const ExternalExpandControls = () => {
const expandModelProviderList = useExpandModelProviderList()
return (
<>
<button type="button" data-testid="expand-other-provider" onClick={() => expandModelProviderList('langgenius/anthropic/anthropic')}>
expand other
</button>
<button type="button" data-testid="expand-current-provider" onClick={() => expandModelProviderList('langgenius/openai/openai')}>
expand current
</button>
</>
)
}
describe('ProviderAddedCard', () => {
const mockProvider = {
provider: 'langgenius/openai/openai',
@@ -67,19 +104,21 @@ describe('ProviderAddedCard', () => {
})
it('should render provider added card component', () => {
render(<ProviderAddedCard provider={mockProvider} />)
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
expect(screen.getByTestId('provider-added-card')).toBeInTheDocument()
expect(screen.getByTestId('provider-icon')).toBeInTheDocument()
})
it('should open, refresh and collapse model list', async () => {
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [{ model: 'gpt-4' }] } as unknown as { data: ModelItem[] })
render(<ProviderAddedCard provider={mockProvider} />)
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
const showModelsBtn = screen.getByTestId('show-models-button')
fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledWith(`/workspaces/current/model-providers/${mockProvider.provider}/models`)
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
})
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
// Test line 71-72: Opening when already fetched
@@ -90,13 +129,13 @@ describe('ProviderAddedCard', () => {
// Explicitly re-find and click to re-open
fireEvent.click(screen.getByTestId('show-models-button'))
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) // Should not fetch again
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2) // Re-open fetches again with default stale/gc behavior
// Refresh list from ModelList
const refreshBtn = screen.getByRole('button', { name: 'refresh list' })
fireEvent.click(refreshBtn)
await waitFor(() => {
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(2)
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(3)
})
})
@@ -105,18 +144,20 @@ describe('ProviderAddedCard', () => {
const promise = new Promise((resolve) => {
resolveOuter = resolve
})
vi.mocked(fetchModelProviderModelList).mockReturnValue(promise as unknown as ReturnType<typeof fetchModelProviderModelList>)
mockFetchModelProviderModels.mockReturnValue(promise)
render(<ProviderAddedCard provider={mockProvider} />)
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
const showModelsBtn = screen.getByTestId('show-models-button')
// First call sets loading to true
fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
// Second call should return early because loading is true
fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
await act(async () => {
resolveOuter({ data: [] })
@@ -125,56 +166,49 @@ describe('ProviderAddedCard', () => {
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
})
it('should only react to external expansion for the matching provider', async () => {
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
renderWithQueryClient(
<>
<ProviderAddedCard provider={mockProvider} />
<ExternalExpandControls />
</>,
)
fireEvent.click(screen.getByTestId('expand-other-provider'))
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(0)
})
fireEvent.click(screen.getByTestId('expand-current-provider'))
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
})
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
it('should render configure tip when provider is not in quota list and not configured', () => {
const providerWithoutQuota = {
...mockProvider,
provider: 'custom/provider',
} as unknown as ModelProvider
render(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
renderWithQueryClient(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
})
it('should refresh model list on event subscription', async () => {
let capturedHandler: (v: { type: string, payload: string } | null) => void = () => { }
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
capturedHandler = handler as (v: { type: string, payload: string } | null) => void
})
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [] } as unknown as { data: ModelItem[] })
render(<ProviderAddedCard provider={mockProvider} />)
expect(capturedHandler).toBeDefined()
act(() => {
capturedHandler({
type: 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST',
payload: mockProvider.provider,
})
})
await waitFor(() => {
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
})
// Should ignore non-matching events
act(() => {
capturedHandler({ type: 'OTHER', payload: '' })
capturedHandler(null)
})
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
})
it('should render custom model actions for workspace managers', () => {
const customConfigProvider = {
...mockProvider,
configurate_methods: [ConfigurationMethodEnum.customizableModel],
} as unknown as ModelProvider
const { rerender } = render(<ProviderAddedCard provider={customConfigProvider} />)
const { unmount } = renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
unmount()
mockIsCurrentWorkspaceManager = false
rerender(<ProviderAddedCard provider={customConfigProvider} />)
renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
})
})

View File

@@ -1,11 +1,12 @@
import type { FC } from 'react'
import type {
ModelItem,
ModelProvider,
} from '../declarations'
import type { ModelProviderQuotaGetPaid } from '../utils'
import type { PluginDetail } from '@/app/components/plugins/types'
import { useState } from 'react'
import { useQuery } from '@tanstack/react-query'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import {
AddCustomModel,
@@ -13,9 +14,10 @@ import {
} from '@/app/components/header/account-setting/model-provider-page/model-auth'
import { IS_CE_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { fetchModelProviderModelList } from '@/service/common'
import { useProviderContext } from '@/context/provider-context'
import { consoleQuery } from '@/service/client'
import { cn } from '@/utils/classnames'
import { useModelProviderListExpanded, useSetModelProviderListExpanded } from '../atoms'
import { ConfigurationMethodEnum } from '../declarations'
import ModelBadge from '../model-badge'
import ProviderIcon from '../provider-icon'
@@ -25,121 +27,123 @@ import {
} from '../utils'
import CredentialPanel from './credential-panel'
import ModelList from './model-list'
import ProviderCardActions from './provider-card-actions'
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
type ProviderAddedCardProps = {
notConfigured?: boolean
provider: ModelProvider
pluginDetail?: PluginDetail
}
const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
notConfigured,
provider,
pluginDetail,
}) => {
const { t } = useTranslation()
const { eventEmitter } = useEventEmitterContextContext()
const [fetched, setFetched] = useState(false)
const [loading, setLoading] = useState(false)
const [collapsed, setCollapsed] = useState(true)
const [modelList, setModelList] = useState<ModelItem[]>([])
const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote)
const { refreshModelProviders } = useProviderContext()
const currentProviderName = provider.provider
const expanded = useModelProviderListExpanded(currentProviderName)
const setExpanded = useSetModelProviderListExpanded(currentProviderName)
const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
const systemConfig = provider.system_configuration
const hasModelList = fetched && !!modelList.length
const {
data: modelList = [],
isFetching: loading,
isSuccess: hasFetchedModelList,
refetch: refetchModelList,
} = useQuery(consoleQuery.modelProviders.models.queryOptions({
input: { params: { provider: currentProviderName } },
enabled: expanded,
refetchOnWindowFocus: false,
select: response => response.data,
}))
const hasModelList = hasFetchedModelList && !!modelList.length
const showCollapsedSection = !expanded || !hasFetchedModelList
const { isCurrentWorkspaceManager } = useAppContext()
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager
const getModelList = async (providerName: string) => {
const refreshModelList = useCallback((targetProviderName: string) => {
if (targetProviderName !== currentProviderName)
return
if (!expanded)
setExpanded(true)
refetchModelList().catch(() => {})
}, [currentProviderName, expanded, refetchModelList, setExpanded])
const handleOpenModelList = useCallback(() => {
if (loading)
return
try {
setLoading(true)
const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
setModelList(modelsData.data)
setCollapsed(false)
setFetched(true)
}
finally {
setLoading(false)
}
}
const handleOpenModelList = () => {
if (fetched) {
setCollapsed(false)
if (!expanded) {
setExpanded(true)
return
}
getModelList(provider.provider)
}
eventEmitter?.useSubscription((v: any) => {
if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider)
getModelList(v.payload)
})
refetchModelList().catch(() => {})
}, [expanded, loading, refetchModelList, setExpanded])
return (
<div
data-testid="provider-added-card"
className={cn(
'mb-2 rounded-xl border-[0.5px] border-divider-regular bg-third-party-model-bg-default shadow-xs',
provider.provider === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
provider.provider === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
currentProviderName === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
currentProviderName === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
)}
>
<div className="flex rounded-t-xl py-2 pl-3 pr-2">
<div className="grow px-1 pb-0.5 pt-1">
<ProviderIcon
className="mb-2"
provider={provider}
/>
<div className="mb-2 flex items-center gap-1">
<ProviderIcon provider={provider} />
{pluginDetail && (
<ProviderCardActions
detail={pluginDetail}
onUpdate={refreshModelProviders}
/>
)}
</div>
<div className="flex gap-0.5">
{
provider.supported_model_types.map(modelType => (
<ModelBadge key={modelType}>
{modelTypeFormat(modelType)}
</ModelBadge>
))
}
{provider.supported_model_types.map(modelType => (
<ModelBadge key={modelType}>
{modelTypeFormat(modelType)}
</ModelBadge>
))}
</div>
</div>
{
showCredential && (
<CredentialPanel
provider={provider}
/>
)
}
{showCredential && (
<CredentialPanel
provider={provider}
/>
)}
</div>
{
collapsed && (
showCollapsedSection && (
<div className="group flex items-center justify-between border-t border-t-divider-subtle py-1.5 pl-2 pr-[11px] text-text-tertiary system-xs-medium">
{(showModelProvider || !notConfigured) && (
<>
<div className="flex h-6 items-center pl-1 pr-1.5 leading-6 group-hover:hidden">
{
hasModelList
? t('modelProvider.modelsNum', { ns: 'common', num: modelList.length })
: t('modelProvider.showModels', { ns: 'common' })
}
{!loading && <div className="i-ri-arrow-right-s-line h-4 w-4" />}
</div>
<div
data-testid="show-models-button"
className="hidden h-6 cursor-pointer items-center rounded-lg pl-1 pr-1.5 hover:bg-components-button-ghost-bg-hover group-hover:flex"
onClick={handleOpenModelList}
>
{
hasModelList
? t('modelProvider.showModelsNum', { ns: 'common', num: modelList.length })
: t('modelProvider.showModels', { ns: 'common' })
}
{!loading && <div className="i-ri-arrow-right-s-line h-4 w-4" />}
{
loading && (
<div className="i-ri-loader-2-line ml-0.5 h-3 w-3 animate-spin" />
)
}
</div>
</>
<button
type="button"
data-testid="show-models-button"
className="flex h-6 items-center rounded-lg pl-1 pr-1.5 hover:bg-components-button-ghost-bg-hover"
aria-label={t('modelProvider.showModels', { ns: 'common' })}
onClick={handleOpenModelList}
>
{
hasModelList
? t('modelProvider.modelsNum', { ns: 'common', num: modelList.length })
: t('modelProvider.showModels', { ns: 'common' })
}
{!loading && <div className="i-ri-arrow-right-s-line h-4 w-4" />}
{
loading && (
<div className="i-ri-loader-2-line ml-0.5 h-3 w-3 animate-spin" />
)
}
</button>
)}
{!showModelProvider && notConfigured && (
<div className="flex h-6 items-center pl-1 pr-1.5">
@@ -148,7 +152,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
</div>
)}
{
configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && (
showCustomModelActions && (
<div className="flex grow justify-end">
<ManageCustomModelCredentials
provider={provider}
@@ -166,12 +170,12 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
)
}
{
!collapsed && (
!showCollapsedSection && (
<ModelList
provider={provider}
models={modelList}
onCollapse={() => setCollapsed(true)}
onChange={(provider: string) => getModelList(provider)}
onCollapse={() => setExpanded(false)}
onChange={refreshModelList}
/>
)
}

View File

@@ -0,0 +1,142 @@
import type { Credential, ModelProvider } from '../../declarations'
import { fireEvent, render, screen } from '@testing-library/react'
import { CustomConfigurationStatusEnum, PreferredProviderTypeEnum } from '../../declarations'
import ApiKeySection from './api-key-section'
const createCredential = (overrides: Partial<Credential> = {}): Credential => ({
credential_id: 'cred-1',
credential_name: 'Test API Key',
...overrides,
})
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
provider: 'test-provider',
allow_custom_token: true,
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
available_credentials: [],
},
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
preferred_provider_type: PreferredProviderTypeEnum.system,
...overrides,
} as unknown as ModelProvider)
describe('ApiKeySection', () => {
const handlers = {
onItemClick: vi.fn(),
onEdit: vi.fn(),
onDelete: vi.fn(),
onAdd: vi.fn(),
}
beforeEach(() => {
vi.clearAllMocks()
})
// Empty state
describe('Empty state (no credentials)', () => {
it('should show empty state message', () => {
render(
<ApiKeySection
provider={createProvider()}
credentials={[]}
selectedCredentialId={undefined}
{...handlers}
/>,
)
expect(screen.getByText(/noApiKeysTitle/)).toBeInTheDocument()
expect(screen.getByText(/noApiKeysDescription/)).toBeInTheDocument()
})
it('should show Add API Key button', () => {
render(
<ApiKeySection
provider={createProvider()}
credentials={[]}
selectedCredentialId={undefined}
{...handlers}
/>,
)
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
})
it('should call onAdd when Add API Key is clicked', () => {
render(
<ApiKeySection
provider={createProvider()}
credentials={[]}
selectedCredentialId={undefined}
{...handlers}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /addApiKey/ }))
expect(handlers.onAdd).toHaveBeenCalledTimes(1)
})
it('should hide Add API Key button when allow_custom_token is false', () => {
render(
<ApiKeySection
provider={createProvider({ allow_custom_token: false })}
credentials={[]}
selectedCredentialId={undefined}
{...handlers}
/>,
)
expect(screen.queryByRole('button', { name: /addApiKey/ })).not.toBeInTheDocument()
})
})
// With credentials
describe('With credentials', () => {
const credentials = [
createCredential({ credential_id: 'cred-1', credential_name: 'Key Alpha' }),
createCredential({ credential_id: 'cred-2', credential_name: 'Key Beta' }),
]
it('should render credential list with header', () => {
render(
<ApiKeySection
provider={createProvider()}
credentials={credentials}
selectedCredentialId="cred-1"
{...handlers}
/>,
)
expect(screen.getByText(/apiKeys/)).toBeInTheDocument()
expect(screen.getByText('Key Alpha')).toBeInTheDocument()
expect(screen.getByText('Key Beta')).toBeInTheDocument()
})
it('should show Add API Key button in footer', () => {
render(
<ApiKeySection
provider={createProvider()}
credentials={credentials}
selectedCredentialId="cred-1"
{...handlers}
/>,
)
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
})
it('should hide Add API Key when allow_custom_token is false', () => {
render(
<ApiKeySection
provider={createProvider({ allow_custom_token: false })}
credentials={credentials}
selectedCredentialId="cred-1"
{...handlers}
/>,
)
expect(screen.queryByRole('button', { name: /addApiKey/ })).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,91 @@
import type { Credential, CustomModel, ModelProvider } from '../../declarations'
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import CredentialItem from '../../model-auth/authorized/credential-item'
type ApiKeySectionProps = {
provider: ModelProvider
credentials: Credential[]
selectedCredentialId: string | undefined
isActivating?: boolean
onItemClick: (credential: Credential, model?: CustomModel) => void
onEdit: (credential?: Credential) => void
onDelete: (credential?: Credential) => void
onAdd: () => void
}
function ApiKeySection({
provider,
credentials,
selectedCredentialId,
isActivating,
onItemClick,
onEdit,
onDelete,
onAdd,
}: ApiKeySectionProps) {
const { t } = useTranslation()
const notAllowCustomCredential = provider.allow_custom_token === false
if (!credentials.length) {
return (
<div className="flex flex-col gap-2 p-2">
<div className="rounded-[10px] bg-gradient-to-r from-state-base-hover to-transparent p-4">
<div className="flex flex-col gap-1">
<div className="text-text-secondary system-sm-medium">
{t('modelProvider.card.noApiKeysTitle', { ns: 'common' })}
</div>
<div className="text-text-tertiary system-xs-regular">
{t('modelProvider.card.noApiKeysDescription', { ns: 'common' })}
</div>
</div>
</div>
{!notAllowCustomCredential && (
<Button
onClick={onAdd}
className="w-full"
>
{t('modelProvider.auth.addApiKey', { ns: 'common' })}
</Button>
)}
</div>
)
}
return (
<div className="border-t border-t-divider-subtle">
<div className="px-1">
<div className="pb-1 pl-7 pr-2 pt-3 text-text-tertiary system-xs-medium-uppercase">
{t('modelProvider.auth.apiKeys', { ns: 'common' })}
</div>
<div className="max-h-[200px] overflow-y-auto">
{credentials.map(credential => (
<CredentialItem
key={credential.credential_id}
credential={credential}
disabled={isActivating}
showSelectedIcon
selectedCredentialId={selectedCredentialId}
onItemClick={onItemClick}
onEdit={onEdit}
onDelete={onDelete}
/>
))}
</div>
</div>
{!notAllowCustomCredential && (
<div className="p-2">
<Button
onClick={onAdd}
className="w-full"
>
{t('modelProvider.auth.addApiKey', { ns: 'common' })}
</Button>
</div>
)}
</div>
)
}
export default memo(ApiKeySection)

View File

@@ -0,0 +1,63 @@
import { render, screen } from '@testing-library/react'
import CreditsExhaustedAlert from './credits-exhausted-alert'
const mockTrialCredits = { credits: 0, totalCredits: 10_000, isExhausted: true, isLoading: false, nextCreditResetDate: undefined }
vi.mock('../use-trial-credits', () => ({
useTrialCredits: () => mockTrialCredits,
}))
describe('CreditsExhaustedAlert', () => {
beforeEach(() => {
vi.clearAllMocks()
Object.assign(mockTrialCredits, { credits: 0 })
})
// Without API key fallback
describe('Without API key fallback', () => {
it('should show exhausted message', () => {
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
expect(screen.getByText(/creditsExhaustedMessage/)).toBeInTheDocument()
})
it('should show description with upgrade link', () => {
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
expect(screen.getByText(/creditsExhaustedDescription/)).toBeInTheDocument()
})
})
// With API key fallback
describe('With API key fallback', () => {
it('should show fallback message', () => {
render(<CreditsExhaustedAlert hasApiKeyFallback />)
expect(screen.getByText(/creditsExhaustedFallback(?!Description)/)).toBeInTheDocument()
})
it('should show fallback description', () => {
render(<CreditsExhaustedAlert hasApiKeyFallback />)
expect(screen.getByText(/creditsExhaustedFallbackDescription/)).toBeInTheDocument()
})
})
// Usage display
describe('Usage display', () => {
it('should show usage label', () => {
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
expect(screen.getByText(/usageLabel/)).toBeInTheDocument()
})
it('should show usage amounts', () => {
mockTrialCredits.credits = 200
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
expect(screen.getByText(/9,800/)).toBeInTheDocument()
expect(screen.getByText(/10,000/)).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,65 @@
import { Trans, useTranslation } from 'react-i18next'
import { CreditsCoin } from '@/app/components/base/icons/src/vender/line/financeAndECommerce'
import { useModalContextSelector } from '@/context/modal-context'
import { formatNumber } from '@/utils/format'
import { useTrialCredits } from '../use-trial-credits'
type CreditsExhaustedAlertProps = {
hasApiKeyFallback: boolean
}
export default function CreditsExhaustedAlert({ hasApiKeyFallback }: CreditsExhaustedAlertProps) {
const { t } = useTranslation()
const setShowPricingModal = useModalContextSelector(s => s.setShowPricingModal)
const { credits, totalCredits } = useTrialCredits()
const titleKey = hasApiKeyFallback
? 'modelProvider.card.creditsExhaustedFallback'
: 'modelProvider.card.creditsExhaustedMessage'
const descriptionKey = hasApiKeyFallback
? 'modelProvider.card.creditsExhaustedFallbackDescription'
: 'modelProvider.card.creditsExhaustedDescription'
const usedCredits = totalCredits - credits
const usagePercent = totalCredits > 0 ? Math.min((usedCredits / totalCredits) * 100, 100) : 100
return (
<div className="mx-2 mb-1 mt-0.5 rounded-lg bg-background-section-burn p-3">
<div className="flex flex-col gap-1">
<div className="text-text-primary system-sm-medium">
{t(titleKey, { ns: 'common' })}
</div>
<div className="text-text-tertiary system-xs-regular">
<Trans
i18nKey={descriptionKey}
ns="common"
components={{
upgradeLink: <span className="cursor-pointer text-text-accent system-xs-medium" onClick={setShowPricingModal} />,
}}
/>
</div>
</div>
<div className="mt-3 flex flex-col gap-1">
<div className="flex items-center justify-between">
<span className="text-text-tertiary system-xs-medium">
{t('modelProvider.card.usageLabel', { ns: 'common' })}
</span>
<div className="flex items-center gap-0.5 text-text-tertiary system-xs-regular">
<CreditsCoin className="h-3 w-3" />
<span>
{formatNumber(usedCredits)}
/
{formatNumber(totalCredits)}
</span>
</div>
</div>
<div className="h-1 overflow-hidden rounded-[6px] bg-components-progress-error-bg">
<div
className="h-full rounded-l-[6px] bg-components-progress-error-progress"
style={{ width: `${usagePercent}%` }}
/>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,28 @@
import { useTranslation } from 'react-i18next'
type CreditsFallbackAlertProps = {
hasCredentials: boolean
}
export default function CreditsFallbackAlert({ hasCredentials }: CreditsFallbackAlertProps) {
const { t } = useTranslation()
const titleKey = hasCredentials
? 'modelProvider.card.apiKeyUnavailableFallback'
: 'modelProvider.card.noApiKeysFallback'
return (
<div className="mx-2 mb-1 mt-0.5 rounded-lg bg-background-section-burn p-3">
<div className="flex flex-col gap-1">
<div className="text-text-primary system-sm-medium">
{t(titleKey, { ns: 'common' })}
</div>
{hasCredentials && (
<div className="text-text-tertiary system-xs-regular">
{t('modelProvider.card.apiKeyUnavailableFallbackDescription', { ns: 'common' })}
</div>
)}
</div>
</div>
)
}

View File

@@ -0,0 +1,435 @@
import type { ModelProvider } from '../../declarations'
import type { CredentialPanelState } from '../use-credential-panel-state'
import { fireEvent, render, screen } from '@testing-library/react'
import { CustomConfigurationStatusEnum, PreferredProviderTypeEnum } from '../../declarations'
import DropdownContent from './dropdown-content'
const mockHandleOpenModal = vi.fn()
const mockActivate = vi.fn()
const mockOpenConfirmDelete = vi.fn()
const mockCloseConfirmDelete = vi.fn()
const mockHandleConfirmDelete = vi.fn()
let mockDeleteCredentialId: string | null = null
vi.mock('../use-trial-credits', () => ({
useTrialCredits: () => ({ credits: 0, totalCredits: 10_000, isExhausted: true, isLoading: false }),
}))
vi.mock('./use-activate-credential', () => ({
useActivateCredential: () => ({
selectedCredentialId: 'cred-1',
isActivating: false,
activate: mockActivate,
}),
}))
vi.mock('../../model-auth/hooks', () => ({
useAuth: () => ({
openConfirmDelete: mockOpenConfirmDelete,
closeConfirmDelete: mockCloseConfirmDelete,
doingAction: false,
handleConfirmDelete: mockHandleConfirmDelete,
deleteCredentialId: mockDeleteCredentialId,
handleOpenModal: mockHandleOpenModal,
}),
}))
vi.mock('../../model-auth/authorized/credential-item', () => ({
default: ({ credential, onItemClick, onEdit, onDelete }: {
credential: { credential_id: string, credential_name: string }
onItemClick?: (c: unknown) => void
onEdit?: (c: unknown) => void
onDelete?: (c: unknown) => void
}) => (
<div data-testid={`credential-${credential.credential_id}`}>
<span>{credential.credential_name}</span>
<button data-testid={`click-${credential.credential_id}`} onClick={() => onItemClick?.(credential)}>select</button>
<button data-testid={`edit-${credential.credential_id}`} onClick={() => onEdit?.(credential)}>edit</button>
<button data-testid={`delete-${credential.credential_id}`} onClick={() => onDelete?.(credential)}>delete</button>
</div>
),
}))
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
provider: 'test',
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: 'cred-1',
current_credential_name: 'My Key',
available_credentials: [
{ credential_id: 'cred-1', credential_name: 'My Key' },
{ credential_id: 'cred-2', credential_name: 'Other Key' },
],
},
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
preferred_provider_type: PreferredProviderTypeEnum.system,
configurate_methods: ['predefined-model'],
supported_model_types: ['llm'],
...overrides,
} as unknown as ModelProvider)
const createState = (overrides: Partial<CredentialPanelState> = {}): CredentialPanelState => ({
variant: 'api-active',
priority: 'apiKey',
supportsCredits: true,
showPrioritySwitcher: true,
hasCredentials: true,
isCreditsExhausted: false,
credentialName: 'My Key',
credits: 100,
...overrides,
})
describe('DropdownContent', () => {
const onChangePriority = vi.fn()
const onClose = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
mockDeleteCredentialId = null
})
describe('UsagePrioritySection visibility', () => {
it('should show when showPrioritySwitcher is true', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({ showPrioritySwitcher: true })}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getByText(/usagePriority/)).toBeInTheDocument()
})
it('should hide when showPrioritySwitcher is false', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({ showPrioritySwitcher: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.queryByText(/usagePriority/)).not.toBeInTheDocument()
})
})
describe('CreditsExhaustedAlert', () => {
it('should show when credits exhausted and supports credits', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({ isCreditsExhausted: true, supportsCredits: true })}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getAllByText(/creditsExhausted/).length).toBeGreaterThan(0)
})
it('should hide when credits not exhausted', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({ isCreditsExhausted: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.queryByText(/creditsExhausted/)).not.toBeInTheDocument()
})
it('should hide when credits exhausted but supportsCredits is false', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({ isCreditsExhausted: true, supportsCredits: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.queryByText(/creditsExhausted/)).not.toBeInTheDocument()
})
it('should show fallback message when api-fallback variant with exhausted credits', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({
variant: 'api-fallback',
isCreditsExhausted: true,
supportsCredits: true,
priority: 'credits',
})}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getAllByText(/creditsExhaustedFallback/).length).toBeGreaterThan(0)
})
it('should show non-fallback message when credits-exhausted variant', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({
variant: 'credits-exhausted',
isCreditsExhausted: true,
supportsCredits: true,
hasCredentials: false,
priority: 'credits',
})}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getByText(/creditsExhaustedMessage/)).toBeInTheDocument()
})
})
describe('CreditsFallbackAlert', () => {
it('should show when priority is apiKey, supports credits, not exhausted, and variant is not api-active', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({
variant: 'api-required-add',
priority: 'apiKey',
supportsCredits: true,
isCreditsExhausted: false,
hasCredentials: false,
})}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getByText(/noApiKeysFallback/)).toBeInTheDocument()
})
it('should show unavailable message when priority is apiKey with credentials but not api-active', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({
variant: 'api-unavailable',
priority: 'apiKey',
supportsCredits: true,
isCreditsExhausted: false,
hasCredentials: true,
})}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getAllByText(/apiKeyUnavailableFallback/).length).toBeGreaterThan(0)
})
it('should NOT show when variant is api-active', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({
variant: 'api-active',
priority: 'apiKey',
supportsCredits: true,
isCreditsExhausted: false,
})}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.queryByText(/noApiKeysFallback/)).not.toBeInTheDocument()
expect(screen.queryByText(/apiKeyUnavailableFallback/)).not.toBeInTheDocument()
})
it('should NOT show when priority is credits', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState({
variant: 'credits-active',
priority: 'credits',
supportsCredits: true,
isCreditsExhausted: false,
})}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.queryByText(/noApiKeysFallback/)).not.toBeInTheDocument()
expect(screen.queryByText(/apiKeyUnavailableFallback/)).not.toBeInTheDocument()
})
})
describe('API key section', () => {
it('should render all credential items', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState()}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getByText('My Key')).toBeInTheDocument()
expect(screen.getByText('Other Key')).toBeInTheDocument()
})
it('should show empty state when no credentials', () => {
render(
<DropdownContent
provider={createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
})}
state={createState({ hasCredentials: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getByText(/noApiKeysTitle/)).toBeInTheDocument()
expect(screen.getByText(/noApiKeysDescription/)).toBeInTheDocument()
})
it('should call activate without closing on credential item click', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState()}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
fireEvent.click(screen.getByTestId('click-cred-2'))
expect(mockActivate).toHaveBeenCalledWith(
expect.objectContaining({ credential_id: 'cred-2' }),
)
expect(onClose).not.toHaveBeenCalled()
})
it('should call handleOpenModal and close on edit credential', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState()}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
fireEvent.click(screen.getByTestId('edit-cred-2'))
expect(mockHandleOpenModal).toHaveBeenCalledWith(
expect.objectContaining({ credential_id: 'cred-2' }),
)
expect(onClose).toHaveBeenCalled()
})
it('should call openConfirmDelete on delete credential', () => {
render(
<DropdownContent
provider={createProvider()}
state={createState()}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
fireEvent.click(screen.getByTestId('delete-cred-2'))
expect(mockOpenConfirmDelete).toHaveBeenCalledWith(
expect.objectContaining({ credential_id: 'cred-2' }),
)
})
})
describe('Add API Key', () => {
it('should call handleOpenModal with no args and close on add', () => {
render(
<DropdownContent
provider={createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
})}
state={createState({ hasCredentials: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /addApiKey/ }))
expect(mockHandleOpenModal).toHaveBeenCalledWith()
expect(onClose).toHaveBeenCalled()
})
})
describe('AlertDialog for delete confirmation', () => {
it('should show confirm dialog when deleteCredentialId is set', () => {
mockDeleteCredentialId = 'cred-1'
render(
<DropdownContent
provider={createProvider()}
state={createState()}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.getByText(/confirmDelete/)).toBeInTheDocument()
})
it('should not show confirm dialog when deleteCredentialId is null', () => {
mockDeleteCredentialId = null
render(
<DropdownContent
provider={createProvider()}
state={createState()}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(screen.queryByText(/confirmDelete/)).not.toBeInTheDocument()
})
})
describe('Layout', () => {
it('should have 320px width container', () => {
const { container } = render(
<DropdownContent
provider={createProvider()}
state={createState()}
isChangingPriority={false}
onChangePriority={onChangePriority}
onClose={onClose}
/>,
)
expect(container.querySelector('.w-\\[320px\\]')).toBeTruthy()
})
})
})

View File

@@ -0,0 +1,131 @@
import type { Credential, ModelProvider, PreferredProviderTypeEnum } from '../../declarations'
import type { CredentialPanelState } from '../use-credential-panel-state'
import { memo, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import {
AlertDialog,
AlertDialogActions,
AlertDialogCancelButton,
AlertDialogConfirmButton,
AlertDialogContent,
AlertDialogDescription,
AlertDialogTitle,
} from '@/app/components/base/ui/alert-dialog'
import { ConfigurationMethodEnum } from '../../declarations'
import { useAuth } from '../../model-auth/hooks'
import ApiKeySection from './api-key-section'
import CreditsExhaustedAlert from './credits-exhausted-alert'
import CreditsFallbackAlert from './credits-fallback-alert'
import UsagePrioritySection from './usage-priority-section'
import { useActivateCredential } from './use-activate-credential'
const EMPTY_CREDENTIALS: Credential[] = []
type DropdownContentProps = {
provider: ModelProvider
state: CredentialPanelState
isChangingPriority: boolean
onChangePriority: (key: PreferredProviderTypeEnum) => void
onClose: () => void
}
function DropdownContent({
provider,
state,
isChangingPriority,
onChangePriority,
onClose,
}: DropdownContentProps) {
const { t } = useTranslation()
const { available_credentials } = provider.custom_configuration
const {
openConfirmDelete,
closeConfirmDelete,
doingAction,
handleConfirmDelete,
deleteCredentialId,
handleOpenModal,
} = useAuth(provider, ConfigurationMethodEnum.predefinedModel)
const { selectedCredentialId, isActivating, activate } = useActivateCredential(provider)
const handleEdit = useCallback((credential?: Credential) => {
handleOpenModal(credential)
onClose()
}, [handleOpenModal, onClose])
const handleDelete = useCallback((credential?: Credential) => {
if (credential)
openConfirmDelete(credential)
}, [openConfirmDelete])
const handleAdd = useCallback(() => {
handleOpenModal()
onClose()
}, [handleOpenModal, onClose])
const showCreditsExhaustedAlert = state.isCreditsExhausted && state.supportsCredits
const hasApiKeyFallback = state.variant === 'api-fallback'
|| (state.variant === 'api-active' && state.priority === 'apiKey')
const showCreditsFallbackAlert = state.priority === 'apiKey'
&& state.supportsCredits
&& !state.isCreditsExhausted
&& state.variant !== 'api-active'
return (
<>
<div className="w-[320px]">
{state.showPrioritySwitcher && (
<UsagePrioritySection
value={state.priority}
disabled={isChangingPriority}
onSelect={onChangePriority}
/>
)}
{showCreditsFallbackAlert && (
<CreditsFallbackAlert hasCredentials={state.hasCredentials} />
)}
{showCreditsExhaustedAlert && (
<CreditsExhaustedAlert hasApiKeyFallback={hasApiKeyFallback} />
)}
<ApiKeySection
provider={provider}
credentials={available_credentials ?? EMPTY_CREDENTIALS}
selectedCredentialId={selectedCredentialId}
isActivating={isActivating}
onItemClick={activate}
onEdit={handleEdit}
onDelete={handleDelete}
onAdd={handleAdd}
/>
</div>
<AlertDialog
open={!!deleteCredentialId}
onOpenChange={(open) => {
if (!open)
closeConfirmDelete()
}}
>
<AlertDialogContent>
<div className="p-6 pb-0">
<AlertDialogTitle className="text-text-primary system-xl-semibold">
{t('modelProvider.confirmDelete', { ns: 'common' })}
</AlertDialogTitle>
<AlertDialogDescription className="mt-1 text-text-secondary system-sm-regular" />
</div>
<AlertDialogActions>
<AlertDialogCancelButton disabled={doingAction}>
{t('operation.cancel', { ns: 'common' })}
</AlertDialogCancelButton>
<AlertDialogConfirmButton disabled={doingAction} onClick={handleConfirmDelete}>
{t('operation.delete', { ns: 'common' })}
</AlertDialogConfirmButton>
</AlertDialogActions>
</AlertDialogContent>
</AlertDialog>
</>
)
}
export default memo(DropdownContent)

View File

@@ -0,0 +1,211 @@
import type { ModelProvider } from '../../declarations'
import type { CredentialPanelState } from '../use-credential-panel-state'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { CustomConfigurationStatusEnum, PreferredProviderTypeEnum } from '../../declarations'
import ModelAuthDropdown from './index'
vi.mock('../../model-auth/hooks', () => ({
useAuth: () => ({
openConfirmDelete: vi.fn(),
closeConfirmDelete: vi.fn(),
doingAction: false,
handleConfirmDelete: vi.fn(),
deleteCredentialId: null,
handleOpenModal: vi.fn(),
}),
}))
vi.mock('./use-activate-credential', () => ({
useActivateCredential: () => ({
selectedCredentialId: undefined,
isActivating: false,
activate: vi.fn(),
}),
}))
vi.mock('../use-trial-credits', () => ({
useTrialCredits: () => ({ credits: 0, totalCredits: 10_000, isExhausted: true, isLoading: false }),
}))
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
provider: 'test',
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
available_credentials: [],
},
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
preferred_provider_type: PreferredProviderTypeEnum.system,
...overrides,
} as unknown as ModelProvider)
const createState = (overrides: Partial<CredentialPanelState> = {}): CredentialPanelState => ({
variant: 'credits-active',
priority: 'credits',
supportsCredits: true,
showPrioritySwitcher: false,
hasCredentials: false,
isCreditsExhausted: false,
credentialName: undefined,
credits: 100,
...overrides,
})
describe('ModelAuthDropdown', () => {
const onChangePriority = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
describe('Button text', () => {
it('should show "Add API Key" when no credentials for credits-active', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ hasCredentials: false, variant: 'credits-active' })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
})
it('should show "Configure" when has credentials for api-active', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ hasCredentials: true, variant: 'api-active' })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
})
it('should show "Add API Key" for api-required-add variant', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ variant: 'api-required-add', hasCredentials: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
})
it('should show "Configure" for api-required-configure variant', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ variant: 'api-required-configure', hasCredentials: true })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
})
it('should show "Configure" for credits-active when has credentials', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ hasCredentials: true, variant: 'credits-active' })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
})
it('should show "Add API Key" for credits-exhausted (no credentials)', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ variant: 'credits-exhausted', hasCredentials: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
})
it('should show "Configure" for api-unavailable (has credentials)', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ variant: 'api-unavailable', hasCredentials: true })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
})
it('should show "Configure" for api-fallback (has credentials)', () => {
render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ variant: 'api-fallback', hasCredentials: true })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
})
})
describe('Button variant styling', () => {
it('should use secondary-accent for api-required-add', () => {
const { container } = render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ variant: 'api-required-add', hasCredentials: false })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
const button = container.querySelector('button')
expect(button?.getAttribute('data-variant') ?? button?.className).toMatch(/accent/)
})
it('should use secondary-accent for api-required-configure', () => {
const { container } = render(
<ModelAuthDropdown
provider={createProvider()}
state={createState({ variant: 'api-required-configure', hasCredentials: true })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
const button = container.querySelector('button')
expect(button?.getAttribute('data-variant') ?? button?.className).toMatch(/accent/)
})
})
describe('Popover behavior', () => {
it('should open popover on button click and show dropdown content', async () => {
render(
<ModelAuthDropdown
provider={createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
available_credentials: [{ credential_id: 'c1', credential_name: 'Key 1' }],
current_credential_id: 'c1',
current_credential_name: 'Key 1',
},
})}
state={createState({ hasCredentials: true, variant: 'api-active' })}
isChangingPriority={false}
onChangePriority={onChangePriority}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /config/i }))
await waitFor(() => {
expect(screen.getByText('Key 1')).toBeInTheDocument()
})
})
})
})

View File

@@ -0,0 +1,85 @@
import type { ModelProvider, PreferredProviderTypeEnum } from '../../declarations'
import type { CardVariant, CredentialPanelState } from '../use-credential-panel-state'
import { memo, useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import {
Popover,
PopoverContent,
PopoverTrigger,
} from '@/app/components/base/ui/popover'
import DropdownContent from './dropdown-content'
type ModelAuthDropdownProps = {
provider: ModelProvider
state: CredentialPanelState
isChangingPriority: boolean
onChangePriority: (key: PreferredProviderTypeEnum) => void
}
const ACCENT_VARIANTS = new Set<CardVariant>([
'api-required-add',
'api-required-configure',
])
function getButtonConfig(variant: CardVariant, hasCredentials: boolean, t: (key: string, opts?: Record<string, string>) => string) {
if (ACCENT_VARIANTS.has(variant)) {
return {
text: variant === 'api-required-add'
? t('modelProvider.auth.addApiKey', { ns: 'common' })
: t('operation.config', { ns: 'common' }),
variant: 'secondary-accent' as const,
}
}
const text = hasCredentials
? t('operation.config', { ns: 'common' })
: t('modelProvider.auth.addApiKey', { ns: 'common' })
return { text, variant: 'secondary' as const }
}
function ModelAuthDropdown({
provider,
state,
isChangingPriority,
onChangePriority,
}: ModelAuthDropdownProps) {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const handleClose = useCallback(() => setOpen(false), [])
const buttonConfig = getButtonConfig(state.variant, state.hasCredentials, t)
return (
<Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger
render={(
<Button
className="flex grow"
size="small"
variant={buttonConfig.variant}
title={buttonConfig.text}
>
<span className="i-ri-equalizer-2-line mr-1 h-3.5 w-3.5 shrink-0" />
<span className="w-0 grow truncate text-left">
{buttonConfig.text}
</span>
</Button>
)}
/>
<PopoverContent placement="bottom-end">
<DropdownContent
provider={provider}
state={state}
isChangingPriority={isChangingPriority}
onChangePriority={onChangePriority}
onClose={handleClose}
/>
</PopoverContent>
</Popover>
)
}
export default memo(ModelAuthDropdown)

View File

@@ -0,0 +1,66 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { PreferredProviderTypeEnum } from '../../declarations'
import UsagePrioritySection from './usage-priority-section'
describe('UsagePrioritySection', () => {
const onSelect = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
// Rendering
describe('Rendering', () => {
it('should render title and both option buttons', () => {
render(<UsagePrioritySection value="credits" onSelect={onSelect} />)
expect(screen.getByText(/usagePriority/)).toBeInTheDocument()
expect(screen.getAllByRole('button')).toHaveLength(2)
})
})
// Selection state
describe('Selection state', () => {
it('should highlight AI credits option when value is credits', () => {
render(<UsagePrioritySection value="credits" onSelect={onSelect} />)
const buttons = screen.getAllByRole('button')
expect(buttons[0].className).toContain('border-components-option-card-option-selected-border')
expect(buttons[1].className).not.toContain('border-components-option-card-option-selected-border')
})
it('should highlight API key option when value is apiKey', () => {
render(<UsagePrioritySection value="apiKey" onSelect={onSelect} />)
const buttons = screen.getAllByRole('button')
expect(buttons[0].className).not.toContain('border-components-option-card-option-selected-border')
expect(buttons[1].className).toContain('border-components-option-card-option-selected-border')
})
it('should highlight API key option when value is apiKeyOnly', () => {
render(<UsagePrioritySection value="apiKeyOnly" onSelect={onSelect} />)
const buttons = screen.getAllByRole('button')
expect(buttons[1].className).toContain('border-components-option-card-option-selected-border')
})
})
// User interactions
describe('User interactions', () => {
it('should call onSelect with system when clicking AI credits option', () => {
render(<UsagePrioritySection value="apiKey" onSelect={onSelect} />)
fireEvent.click(screen.getAllByRole('button')[0])
expect(onSelect).toHaveBeenCalledWith(PreferredProviderTypeEnum.system)
})
it('should call onSelect with custom when clicking API key option', () => {
render(<UsagePrioritySection value="credits" onSelect={onSelect} />)
fireEvent.click(screen.getAllByRole('button')[1])
expect(onSelect).toHaveBeenCalledWith(PreferredProviderTypeEnum.custom)
})
})
})

View File

@@ -0,0 +1,56 @@
import type { UsagePriority } from '../use-credential-panel-state'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'
import { PreferredProviderTypeEnum } from '../../declarations'
type UsagePrioritySectionProps = {
value: UsagePriority
disabled?: boolean
onSelect: (key: PreferredProviderTypeEnum) => void
}
const options = [
{ key: PreferredProviderTypeEnum.system, labelKey: 'modelProvider.card.aiCreditsOption' },
{ key: PreferredProviderTypeEnum.custom, labelKey: 'modelProvider.card.apiKeyOption' },
] as const
export default function UsagePrioritySection({ value, disabled, onSelect }: UsagePrioritySectionProps) {
const { t } = useTranslation()
const selectedKey = value === 'credits'
? PreferredProviderTypeEnum.system
: PreferredProviderTypeEnum.custom
return (
<div className="p-1">
<div className="flex items-center gap-1 rounded-lg p-1">
<div className="shrink-0 px-0.5 py-1">
<span className="i-ri-arrow-up-double-line block h-4 w-4 text-text-tertiary" />
</div>
<div className="flex min-w-0 flex-1 items-center gap-0.5 py-0.5">
<span className="truncate text-text-secondary system-sm-medium">
{t('modelProvider.card.usagePriority', { ns: 'common' })}
</span>
<span className="i-ri-question-line h-3.5 w-3.5 shrink-0 text-text-quaternary" />
</div>
<div className="flex shrink-0 items-center gap-1">
{options.map(option => (
<button
key={option.key}
type="button"
className={cn(
'shrink-0 whitespace-nowrap rounded-md px-2 py-1 text-center transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-components-button-primary-border disabled:opacity-50',
selectedKey === option.key
? 'border-[1.5px] border-components-option-card-option-selected-border bg-components-panel-bg text-text-primary shadow-xs system-xs-medium'
: 'border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary system-xs-regular hover:bg-components-option-card-option-bg-hover',
)}
disabled={disabled}
onClick={() => onSelect(option.key)}
>
{t(option.labelKey, { ns: 'common' })}
</button>
))}
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,51 @@
import type { Credential, ModelProvider } from '../../declarations'
import { useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import { useActiveProviderCredential } from '@/service/use-models'
import {
useUpdateModelList,
useUpdateModelProviders,
} from '../../hooks'
export function useActivateCredential(provider: ModelProvider) {
const { t } = useTranslation()
const updateModelProviders = useUpdateModelProviders()
const updateModelList = useUpdateModelList()
const { mutate, isPending } = useActiveProviderCredential(provider.provider)
const [optimisticId, setOptimisticId] = useState<string>()
const currentId = provider.custom_configuration.current_credential_id
const selectedCredentialId = optimisticId ?? currentId
const selectedIdRef = useRef(selectedCredentialId)
selectedIdRef.current = selectedCredentialId
const supportedModelTypes = provider.supported_model_types
const activate = useCallback((credential: Credential) => {
if (credential.credential_id === selectedIdRef.current)
return
setOptimisticId(credential.credential_id)
mutate(
{ credential_id: credential.credential_id },
{
onSuccess: () => {
Toast.notify({ type: 'success', message: t('api.actionSuccess', { ns: 'common' }) })
updateModelProviders()
supportedModelTypes.forEach(type => updateModelList(type))
},
onError: () => {
setOptimisticId(undefined)
Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) })
},
},
)
}, [mutate, t, updateModelProviders, updateModelList, supportedModelTypes])
return {
selectedCredentialId,
isActivating: isPending,
activate,
}
}

View File

@@ -1,9 +1,17 @@
import type { ModelItem, ModelProvider } from '../declarations'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { disableModel, enableModel } from '@/service/common'
import { ModelStatusEnum } from '../declarations'
import ModelListItem from './model-list-item'
function createWrapper() {
const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } })
return ({ children }: { children: React.ReactNode }) => (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
)
}
let mockModelLoadBalancingEnabled = false
vi.mock('@/context/app-context', () => ({
@@ -69,6 +77,7 @@ describe('ModelListItem', () => {
provider={mockProvider}
isConfigurable={false}
/>,
{ wrapper: createWrapper() },
)
expect(screen.getByTestId('model-icon')).toBeInTheDocument()
expect(screen.getByTestId('model-name')).toBeInTheDocument()
@@ -83,6 +92,7 @@ describe('ModelListItem', () => {
isConfigurable={false}
onChange={onChange}
/>,
{ wrapper: createWrapper() },
)
fireEvent.click(screen.getByRole('switch'))
@@ -102,6 +112,7 @@ describe('ModelListItem', () => {
isConfigurable={false}
onChange={onChange}
/>,
{ wrapper: createWrapper() },
)
fireEvent.click(screen.getByRole('switch'))
@@ -122,6 +133,7 @@ describe('ModelListItem', () => {
isConfigurable={false}
onModifyLoadBalancing={onModifyLoadBalancing}
/>,
{ wrapper: createWrapper() },
)
fireEvent.click(screen.getByRole('button', { name: 'modify load balancing' }))

View File

@@ -1,4 +1,5 @@
import type { ModelItem, ModelProvider } from '../declarations'
import { useQueryClient } from '@tanstack/react-query'
import { useDebounceFn } from 'ahooks'
import { memo, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
@@ -9,6 +10,7 @@ import Tooltip from '@/app/components/base/tooltip'
import { Plan } from '@/app/components/billing/type'
import { useAppContext } from '@/context/app-context'
import { useProviderContext, useProviderContextSelector } from '@/context/provider-context'
import { consoleQuery } from '@/service/client'
import { disableModel, enableModel } from '@/service/common'
import { cn } from '@/utils/classnames'
import { ModelStatusEnum } from '../declarations'
@@ -30,16 +32,30 @@ const ModelListItem = ({ model, provider, isConfigurable, onChange, onModifyLoad
const { plan } = useProviderContext()
const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled)
const { isCurrentWorkspaceManager } = useAppContext()
const queryClient = useQueryClient()
const updateModelList = useUpdateModelList()
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const toggleModelEnablingStatus = useCallback(async (enabled: boolean) => {
if (enabled)
await enableModel(`/workspaces/current/model-providers/${provider.provider}/models/enable`, { model: model.model, model_type: model.model_type })
else
await disableModel(`/workspaces/current/model-providers/${provider.provider}/models/disable`, { model: model.model, model_type: model.model_type })
queryClient.invalidateQueries({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
updateModelList(model.model_type)
onChange?.(provider.provider)
}, [model.model, model.model_type, onChange, provider.provider, updateModelList])
}, [model.model, model.model_type, modelProviderModelListQueryKey, onChange, provider.provider, queryClient, updateModelList])
const { run: debouncedToggleModelEnablingStatus } = useDebounceFn(toggleModelEnablingStatus, { wait: 500 })
@@ -58,7 +74,7 @@ const ModelListItem = ({ model, provider, isConfigurable, onChange, onModifyLoad
modelName={model.model}
/>
<ModelName
className="system-md-regular grow text-text-secondary"
className="grow text-text-secondary system-md-regular"
modelItem={model}
showModelType
showMode

View File

@@ -0,0 +1,137 @@
import type { FC } from 'react'
import type { PluginDetail } from '@/app/components/plugins/types'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import { HeaderModals } from '@/app/components/plugins/plugin-detail-panel/detail-header/components'
import { useDetailHeaderState, usePluginOperations } from '@/app/components/plugins/plugin-detail-panel/detail-header/hooks'
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
import { PluginSource } from '@/app/components/plugins/types'
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
import { useLocale } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { cn } from '@/utils/classnames'
import { getMarketplaceUrl } from '@/utils/var'
type Props = {
detail: PluginDetail
onUpdate?: () => void
}
const ProviderCardActions: FC<Props> = ({ detail, onUpdate }) => {
const { t } = useTranslation()
const { theme } = useTheme()
const locale = useLocale()
const { source, version, latest_version, latest_unique_identifier, meta } = detail
const author = detail.declaration?.author ?? ''
const name = detail.declaration?.name ?? detail.name
const {
modalStates,
versionPicker,
hasNewVersion,
isAutoUpgradeEnabled,
isFromMarketplace,
isFromGitHub,
} = useDetailHeaderState(detail)
const {
handleUpdate,
handleUpdatedFromMarketplace,
handleDelete,
} = usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace,
onUpdate,
})
const handleVersionSelect = (state: { version: string, unique_identifier: string, isDowngrade?: boolean }) => {
versionPicker.setTargetVersion(state)
handleUpdate(state.isDowngrade)
}
const handleTriggerLatestUpdate = () => {
if (isFromMarketplace) {
versionPicker.setTargetVersion({
version: latest_version,
unique_identifier: latest_unique_identifier,
})
}
handleUpdate()
}
const detailUrl = useMemo(() => {
if (source === PluginSource.github)
return meta?.repo ? `https://github.com/${meta.repo}` : ''
if (source === PluginSource.marketplace)
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: locale, theme })
return ''
}, [source, meta?.repo, author, name, locale, theme])
return (
<>
{!!version && (
<PluginVersionPicker
disabled={!isFromMarketplace}
isShow={versionPicker.isShow}
onShowChange={versionPicker.setIsShow}
pluginID={detail.plugin_id}
currentVersion={version}
onSelect={handleVersionSelect}
sideOffset={4}
alignOffset={0}
trigger={(
<span
className={cn(
'relative inline-flex min-w-5 items-center justify-center gap-[3px] rounded-md border border-divider-deep bg-state-base-hover px-[5px] py-[2px] text-text-tertiary system-xs-medium-uppercase',
isFromMarketplace && 'cursor-pointer hover:bg-state-base-hover-alt',
)}
>
<span>{version}</span>
{isFromMarketplace && <span aria-hidden className="i-ri-arrow-left-right-line h-3 w-3" />}
{hasNewVersion && (
<span className="absolute -right-0.5 -top-0.5 h-1.5 w-1.5 rounded-full bg-state-destructive-solid" />
)}
</span>
)}
/>
)}
{(hasNewVersion || isFromGitHub) && (
<Button
variant="secondary-accent"
size="small"
className="!h-5"
onClick={handleTriggerLatestUpdate}
>
{t('detailPanel.operation.update', { ns: 'plugin' })}
</Button>
)}
<OperationDropdown
source={source}
onInfo={modalStates.showPluginInfo}
onCheckVersion={() => handleUpdate()}
onRemove={modalStates.showDeleteConfirm}
detailUrl={detailUrl}
placement="bottom-start"
popupClassName="w-[192px]"
/>
<HeaderModals
detail={detail}
modalStates={modalStates}
targetVersion={versionPicker.targetVersion}
isDowngrade={versionPicker.isDowngrade}
isAutoUpgradeEnabled={isAutoUpgradeEnabled}
onUpdatedFromMarketplace={handleUpdatedFromMarketplace}
onDelete={handleDelete}
/>
</>
)
}
export default ProviderCardActions

View File

@@ -0,0 +1,8 @@
.gridBg {
background-size: 4px 4px;
background-image:
linear-gradient(to right, var(--color-divider-subtle) 0.5px, transparent 0.5px),
linear-gradient(to bottom, var(--color-divider-subtle) 0.5px, transparent 0.5px);
-webkit-mask-image: radial-gradient(ellipse at center, rgba(0, 0, 0, 0.6), transparent 70%);
mask-image: radial-gradient(ellipse at center, rgba(0, 0, 0, 0.6), transparent 70%);
}

View File

@@ -2,11 +2,16 @@ import type { ModelProvider } from '../declarations'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import QuotaPanel from './quota-panel'
let mockWorkspace = {
let mockWorkspaceData: {
trial_credits: number
trial_credits_used: number
next_credit_reset_date: string
} | undefined = {
trial_credits: 100,
trial_credits_used: 30,
next_credit_reset_date: '2024-12-31',
}
let mockWorkspaceIsPending = false
let mockTrialModels: string[] = ['langgenius/openai/openai']
let mockPlugins = [{
plugin_id: 'langgenius/openai',
@@ -25,15 +30,16 @@ vi.mock('@/app/components/base/icons/src/public/llm', () => {
}
})
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
currentWorkspace: mockWorkspace,
vi.mock('@/service/use-common', () => ({
useCurrentWorkspace: () => ({
data: mockWorkspaceData,
isPending: mockWorkspaceIsPending,
}),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: (selector: (state: { systemFeatures: { trial_models: string[] } }) => unknown) => selector({
systemFeatures: {
useSystemFeaturesQuery: () => ({
data: {
trial_models: mockTrialModels,
},
}),
@@ -71,22 +77,21 @@ describe('QuotaPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
mockWorkspace = {
mockWorkspaceData = {
trial_credits: 100,
trial_credits_used: 30,
next_credit_reset_date: '2024-12-31',
}
mockWorkspaceIsPending = false
mockTrialModels = ['langgenius/openai/openai']
mockPlugins = [{ plugin_id: 'langgenius/openai', latest_package_identifier: 'openai@1.0.0' }]
})
it('should render loading state', () => {
render(
<QuotaPanel
providers={mockProviders}
isLoading
/>,
)
mockWorkspaceData = undefined
mockWorkspaceIsPending = true
render(<QuotaPanel providers={mockProviders} />)
expect(screen.getByRole('status')).toBeInTheDocument()
})
@@ -102,8 +107,17 @@ describe('QuotaPanel', () => {
expect(screen.getByText(/modelProvider\.resetDate/)).toBeInTheDocument()
})
it('should keep quota content during background refetch when cached workspace exists', () => {
mockWorkspaceIsPending = true
render(<QuotaPanel providers={mockProviders} />)
expect(screen.queryByRole('status')).not.toBeInTheDocument()
expect(screen.getByText('70')).toBeInTheDocument()
})
it('should floor credits at zero when usage is higher than quota', () => {
mockWorkspace = {
mockWorkspaceData = {
trial_credits: 10,
trial_credits_used: 999,
next_credit_reset_date: '',
@@ -111,7 +125,7 @@ describe('QuotaPanel', () => {
render(<QuotaPanel providers={mockProviders} />)
expect(screen.getByText('0')).toBeInTheDocument()
expect(screen.getByText(/modelProvider\.card\.quotaExhausted/)).toBeInTheDocument()
expect(screen.queryByText(/modelProvider\.resetDate/)).not.toBeInTheDocument()
})

View File

@@ -7,10 +7,9 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { AnthropicShortLight, Deepseek, Gemini, Grok, OpenaiSmall, Tongyi } from '@/app/components/base/icons/src/public/llm'
import Loading from '@/app/components/base/loading'
import Tooltip from '@/app/components/base/tooltip'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import InstallFromMarketplace from '@/app/components/plugins/install-plugin/install-from-marketplace'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useSystemFeaturesQuery } from '@/context/global-public-context'
import useTimestamp from '@/hooks/use-timestamp'
import { ModelProviderQuotaGetPaid } from '@/types/model-provider'
import { cn } from '@/utils/classnames'
@@ -18,8 +17,9 @@ import { formatNumber } from '@/utils/format'
import { PreferredProviderTypeEnum } from '../declarations'
import { useMarketplaceAllPlugins } from '../hooks'
import { MODEL_PROVIDER_QUOTA_GET_PAID, modelNameMap } from '../utils'
import styles from './quota-panel.module.css'
import { useTrialCredits } from './use-trial-credits'
// Icon map for each provider - single source of truth for provider icons
const providerIconMap: Record<ModelProviderQuotaGetPaid, ComponentType<{ className?: string }>> = {
[ModelProviderQuotaGetPaid.OPENAI]: OpenaiSmall,
[ModelProviderQuotaGetPaid.ANTHROPIC]: AnthropicShortLight,
@@ -29,14 +29,11 @@ const providerIconMap: Record<ModelProviderQuotaGetPaid, ComponentType<{ classNa
[ModelProviderQuotaGetPaid.TONGYI]: Tongyi,
}
// Derive allProviders from the shared constant
const allProviders = MODEL_PROVIDER_QUOTA_GET_PAID.map(key => ({
key,
Icon: providerIconMap[key],
}))
// Map provider key to plugin ID
// provider key format: langgenius/provider/model, plugin ID format: langgenius/provider
const providerKeyToPluginId: Record<ModelProviderQuotaGetPaid, string> = {
[ModelProviderQuotaGetPaid.OPENAI]: 'langgenius/openai',
[ModelProviderQuotaGetPaid.ANTHROPIC]: 'langgenius/anthropic',
@@ -48,16 +45,14 @@ const providerKeyToPluginId: Record<ModelProviderQuotaGetPaid, string> = {
type QuotaPanelProps = {
providers: ModelProvider[]
isLoading?: boolean
}
const QuotaPanel: FC<QuotaPanelProps> = ({
providers,
isLoading = false,
}) => {
const { t } = useTranslation()
const { currentWorkspace } = useAppContext()
const { trial_models } = useGlobalPublicStore(s => s.systemFeatures)
const credits = Math.max((currentWorkspace.trial_credits - currentWorkspace.trial_credits_used) || 0, 0)
const { credits, isExhausted, isLoading, nextCreditResetDate } = useTrialCredits()
const { data: systemFeatures } = useSystemFeaturesQuery()
const trialModels = systemFeatures?.trial_models ?? []
const providerMap = useMemo(() => new Map(
providers.map(p => [p.provider, p.preferred_provider_type]),
), [providers])
@@ -98,6 +93,11 @@ const QuotaPanel: FC<QuotaPanelProps> = ({
}
}, [providers, isShowInstallModal, hideInstallFromMarketplace])
const tipText = t('modelProvider.card.tip', {
ns: 'common',
modelNames: trialModels.map(key => modelNameMap[key as keyof typeof modelNameMap]).filter(Boolean).join(', '),
})
if (isLoading) {
return (
<div className="my-2 flex min-h-[72px] items-center justify-center rounded-xl border-[0.5px] border-components-panel-border bg-third-party-model-bg-default shadow-xs">
@@ -107,59 +107,88 @@ const QuotaPanel: FC<QuotaPanelProps> = ({
}
return (
<div className={cn('my-2 min-w-[72px] shrink-0 rounded-xl border-[0.5px] pb-2.5 pl-4 pr-2.5 pt-3 shadow-xs', credits <= 0 ? 'border-state-destructive-border hover:bg-state-destructive-hover' : 'border-components-panel-border bg-third-party-model-bg-default')}>
<div className="system-xs-medium-uppercase mb-2 flex h-4 items-center text-text-tertiary">
{t('modelProvider.quota', { ns: 'common' })}
<Tooltip popupContent={t('modelProvider.card.tip', { ns: 'common', modelNames: trial_models.map(key => modelNameMap[key as keyof typeof modelNameMap]).filter(Boolean).join(', ') })} />
</div>
<div className="flex items-center justify-between">
<div className="flex items-center gap-1 text-xs text-text-tertiary">
<span className="system-md-semibold-uppercase mr-0.5 text-text-secondary">{formatNumber(credits)}</span>
<span>{t('modelProvider.credits', { ns: 'common' })}</span>
{currentWorkspace.next_credit_reset_date
? (
<>
<span>·</span>
<span>
{t('modelProvider.resetDate', {
ns: 'common',
date: formatTime(currentWorkspace.next_credit_reset_date, t('dateFormat', { ns: 'appLog' })),
interpolation: { escapeValue: false },
})}
</span>
</>
)
: null}
<div className={cn(
'relative my-2 min-w-[72px] shrink-0 overflow-hidden rounded-xl border-[0.5px] pb-2.5 pl-4 pr-2.5 pt-3 shadow-xs',
isExhausted
? 'border-state-destructive-border hover:bg-state-destructive-hover'
: 'border-components-panel-border bg-third-party-model-bg-default',
)}
>
<div className={cn('pointer-events-none absolute inset-0', styles.gridBg)} />
<div className="relative">
<div className="mb-2 flex h-4 items-center text-text-tertiary system-xs-medium-uppercase">
{t('modelProvider.quota', { ns: 'common' })}
<Tooltip>
<TooltipTrigger
aria-label={tipText}
delay={0}
render={(
<span className="ml-0.5 flex h-4 w-4 shrink-0 items-center justify-center">
<span aria-hidden className="i-ri-question-line h-3.5 w-3.5 text-text-quaternary hover:text-text-tertiary" />
</span>
)}
/>
<TooltipContent>
{tipText}
</TooltipContent>
</Tooltip>
</div>
<div className="flex items-center gap-1">
{allProviders.filter(({ key }) => trial_models.includes(key)).map(({ key, Icon }) => {
const providerType = providerMap.get(key)
const isConfigured = (installedProvidersMap.get(key)?.length ?? 0) > 0 // means the provider is configured API key
const getTooltipKey = () => {
// if provider type is not set, it means the provider is not installed
if (!providerType)
return 'modelProvider.card.modelNotSupported'
if (isConfigured && providerType === PreferredProviderTypeEnum.custom)
return 'modelProvider.card.modelAPI'
return 'modelProvider.card.modelSupported'
}
return (
<Tooltip
key={key}
popupContent={t(getTooltipKey(), { modelName: modelNameMap[key], ns: 'common' })}
>
<div
className={cn('relative h-6 w-6', !providerType && 'cursor-pointer hover:opacity-80')}
onClick={() => handleIconClick(key)}
>
<Icon className="h-6 w-6 rounded-lg" />
{!providerType && (
<div className="absolute inset-0 rounded-lg border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge opacity-30" />
)}
</div>
</Tooltip>
)
})}
<div className="flex items-center justify-between">
<div className="flex items-center gap-1 text-xs text-text-tertiary">
{credits > 0
? <span className="mr-0.5 text-text-secondary system-xl-semibold">{formatNumber(credits)}</span>
: <span className="mr-0.5 text-text-destructive system-xl-semibold">{t('modelProvider.card.quotaExhausted', { ns: 'common' })}</span>}
{nextCreditResetDate
? (
<>
<span>·</span>
<span>
{t('modelProvider.resetDate', {
ns: 'common',
date: formatTime(nextCreditResetDate!, t('dateFormat', { ns: 'appLog' })),
interpolation: { escapeValue: false },
})}
</span>
</>
)
: null}
</div>
<div className="flex items-center gap-1">
{allProviders.filter(({ key }) => trialModels.includes(key)).map(({ key, Icon }) => {
const providerType = providerMap.get(key)
const isConfigured = (installedProvidersMap.get(key)?.length ?? 0) > 0
const getTooltipKey = () => {
if (!providerType)
return 'modelProvider.card.modelNotSupported'
if (isConfigured && providerType === PreferredProviderTypeEnum.custom)
return 'modelProvider.card.modelAPI'
return 'modelProvider.card.modelSupported'
}
const tooltipText = t(getTooltipKey(), { modelName: modelNameMap[key], ns: 'common' })
return (
<Tooltip key={key}>
<TooltipTrigger
aria-label={tooltipText}
delay={0}
render={(
<div
className={cn('relative h-6 w-6', !providerType && 'cursor-pointer hover:opacity-80')}
onClick={() => handleIconClick(key)}
>
<Icon className="h-6 w-6 rounded-lg" />
{!providerType && (
<div className="absolute inset-0 rounded-lg border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge opacity-30" />
)}
</div>
)}
/>
<TooltipContent>
{tooltipText}
</TooltipContent>
</Tooltip>
)
})}
</div>
</div>
</div>
{isShowInstallModal && selectedPlugin && (

View File

@@ -0,0 +1,89 @@
import { render, screen } from '@testing-library/react'
import SystemQuotaCard from './system-quota-card'
describe('SystemQuotaCard', () => {
// Renders container with children
describe('Rendering', () => {
it('should render children', () => {
render(
<SystemQuotaCard>
<span>content</span>
</SystemQuotaCard>,
)
expect(screen.getByText('content')).toBeInTheDocument()
})
it('should apply default variant styles', () => {
const { container } = render(
<SystemQuotaCard>
<span>test</span>
</SystemQuotaCard>,
)
const card = container.firstElementChild!
expect(card.className).toContain('bg-white')
})
it('should apply destructive variant styles', () => {
const { container } = render(
<SystemQuotaCard variant="destructive">
<span>test</span>
</SystemQuotaCard>,
)
const card = container.firstElementChild!
expect(card.className).toContain('border-state-destructive-border')
})
})
// Label sub-component
describe('Label', () => {
it('should apply default variant text color when no className provided', () => {
render(
<SystemQuotaCard>
<SystemQuotaCard.Label>Default label</SystemQuotaCard.Label>
</SystemQuotaCard>,
)
expect(screen.getByText('Default label').className).toContain('text-text-secondary')
})
it('should apply destructive variant text color when no className provided', () => {
render(
<SystemQuotaCard variant="destructive">
<SystemQuotaCard.Label>Error label</SystemQuotaCard.Label>
</SystemQuotaCard>,
)
expect(screen.getByText('Error label').className).toContain('text-text-destructive')
})
it('should override variant color with custom className', () => {
render(
<SystemQuotaCard variant="destructive">
<SystemQuotaCard.Label className="gap-1">Custom label</SystemQuotaCard.Label>
</SystemQuotaCard>,
)
const label = screen.getByText('Custom label')
expect(label.className).toContain('gap-1')
expect(label.className).not.toContain('text-text-destructive')
})
})
// Actions sub-component
describe('Actions', () => {
it('should render action children', () => {
render(
<SystemQuotaCard>
<SystemQuotaCard.Actions>
<button>Click me</button>
</SystemQuotaCard.Actions>
</SystemQuotaCard>,
)
expect(screen.getByRole('button', { name: /click me/i })).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,67 @@
import type { ReactNode } from 'react'
import { createContext, useContext } from 'react'
import { cn } from '@/utils/classnames'
import styles from './quota-panel.module.css'
type Variant = 'default' | 'destructive'
const VariantContext = createContext<Variant>('default')
const containerVariants: Record<Variant, string> = {
default: 'border-components-panel-border bg-white/[0.18]',
destructive: 'border-state-destructive-border bg-state-destructive-hover',
}
const labelVariants: Record<Variant, string> = {
default: 'text-text-secondary',
destructive: 'text-text-destructive',
}
type SystemQuotaCardProps = {
variant?: Variant
children: ReactNode
}
const SystemQuotaCard = ({
variant = 'default',
children,
}: SystemQuotaCardProps) => {
return (
<VariantContext.Provider value={variant}>
<div className={cn(
'relative isolate ml-1 flex w-[128px] shrink-0 flex-col justify-between rounded-lg border-[0.5px] p-1 shadow-xs',
containerVariants[variant],
)}
>
<div className={cn('pointer-events-none absolute inset-0 rounded-[7px]', styles.gridBg)} />
{children}
</div>
</VariantContext.Provider>
)
}
const Label = ({ children, className }: { children: ReactNode, className?: string }) => {
const variant = useContext(VariantContext)
return (
<div className={cn(
'relative z-[1] flex items-center gap-1 truncate px-1.5 pt-1 system-xs-medium',
className ?? labelVariants[variant],
)}
>
{children}
</div>
)
}
const Actions = ({ children }: { children: ReactNode }) => {
return (
<div className="relative z-[1] flex items-center gap-0.5">
{children}
</div>
)
}
SystemQuotaCard.Label = Label
SystemQuotaCard.Actions = Actions
export default SystemQuotaCard

View File

@@ -0,0 +1,235 @@
import type { ModelProvider } from '../declarations'
import { renderHook } from '@testing-library/react'
import {
ConfigurationMethodEnum,
CurrentSystemQuotaTypeEnum,
CustomConfigurationStatusEnum,
PreferredProviderTypeEnum,
} from '../declarations'
import { isDestructiveVariant, useCredentialPanelState } from './use-credential-panel-state'
const mockTrialCredits = { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false, nextCreditResetDate: undefined }
vi.mock('./use-trial-credits', () => ({
useTrialCredits: () => mockTrialCredits,
}))
vi.mock('@/config', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/config')>()
return { ...actual, IS_CLOUD_EDITION: true }
})
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
provider: 'test-provider',
provider_credential_schema: { credential_form_schemas: [] },
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: 'cred-1',
current_credential_name: 'My Key',
available_credentials: [{ credential_id: 'cred-1', credential_name: 'My Key' }],
},
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
preferred_provider_type: PreferredProviderTypeEnum.system,
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
supported_model_types: ['llm'],
...overrides,
} as unknown as ModelProvider)
describe('useCredentialPanelState', () => {
beforeEach(() => {
vi.clearAllMocks()
Object.assign(mockTrialCredits, { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false })
})
// Credits priority variants
describe('Credits priority variants', () => {
it('should return credits-active when credits available', () => {
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
expect(result.current.variant).toBe('credits-active')
expect(result.current.priority).toBe('credits')
expect(result.current.supportsCredits).toBe(true)
})
it('should return api-fallback when credits exhausted but API key authorized', () => {
mockTrialCredits.isExhausted = true
mockTrialCredits.credits = 0
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
expect(result.current.variant).toBe('api-fallback')
})
it('should return no-usage when credits exhausted and API key unauthorized', () => {
mockTrialCredits.isExhausted = true
const provider = createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: undefined,
available_credentials: [{ credential_id: 'cred-1', credential_name: 'My Key' }],
},
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.variant).toBe('no-usage')
})
it('should return credits-exhausted when credits exhausted and no credentials', () => {
mockTrialCredits.isExhausted = true
const provider = createProvider({
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.variant).toBe('credits-exhausted')
})
})
// API key priority variants
describe('API key priority variants', () => {
it('should return api-active when API key authorized', () => {
const provider = createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.variant).toBe('api-active')
expect(result.current.priority).toBe('apiKey')
})
it('should return credits-fallback when API key unauthorized and credits available', () => {
const provider = createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: undefined,
available_credentials: [{ credential_id: 'cred-1', credential_name: 'My Key' }],
},
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.variant).toBe('credits-fallback')
})
it('should return credits-fallback when no credentials and credits available', () => {
const provider = createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.variant).toBe('credits-fallback')
})
it('should return no-usage when no credentials and credits exhausted', () => {
mockTrialCredits.isExhausted = true
mockTrialCredits.credits = 0
const provider = createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.noConfigure,
available_credentials: [],
},
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.variant).toBe('no-usage')
})
it('should return api-unavailable when credential with name unauthorized and credits exhausted', () => {
mockTrialCredits.isExhausted = true
mockTrialCredits.credits = 0
const provider = createProvider({
preferred_provider_type: PreferredProviderTypeEnum.custom,
custom_configuration: {
status: CustomConfigurationStatusEnum.active,
current_credential_id: undefined,
current_credential_name: 'Bad Key',
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
},
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.variant).toBe('api-unavailable')
})
})
// apiKeyOnly priority
describe('apiKeyOnly priority (non-cloud / system disabled)', () => {
it('should return apiKeyOnly when system config disabled', () => {
const provider = createProvider({
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.priority).toBe('apiKeyOnly')
expect(result.current.supportsCredits).toBe(false)
})
})
// Derived metadata
describe('Derived metadata', () => {
it('should show priority switcher when credits supported and custom config active', () => {
const provider = createProvider()
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.showPrioritySwitcher).toBe(true)
})
it('should hide priority switcher when system config disabled', () => {
const provider = createProvider({
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
})
const { result } = renderHook(() => useCredentialPanelState(provider))
expect(result.current.showPrioritySwitcher).toBe(false)
})
it('should expose credential name from provider', () => {
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
expect(result.current.credentialName).toBe('My Key')
})
it('should expose credits amount', () => {
mockTrialCredits.credits = 500
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
expect(result.current.credits).toBe(500)
})
})
})
describe('isDestructiveVariant', () => {
it.each([
['credits-exhausted', true],
['no-usage', true],
['api-unavailable', true],
['credits-active', false],
['api-fallback', false],
['api-active', false],
['api-required-add', false],
['api-required-configure', false],
] as const)('should return %s for variant %s', (variant, expected) => {
expect(isDestructiveVariant(variant)).toBe(expected)
})
})

View File

@@ -0,0 +1,106 @@
import type { ModelProvider } from '../declarations'
import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks'
import { IS_CLOUD_EDITION } from '@/config'
import {
PreferredProviderTypeEnum,
} from '../declarations'
import { useTrialCredits } from './use-trial-credits'
export type UsagePriority = 'credits' | 'apiKey' | 'apiKeyOnly'
export type CardVariant
= | 'credits-active'
| 'credits-fallback'
| 'credits-exhausted'
| 'no-usage'
| 'api-fallback'
| 'api-active'
| 'api-required-add'
| 'api-required-configure'
| 'api-unavailable'
export type CredentialPanelState = {
variant: CardVariant
priority: UsagePriority
supportsCredits: boolean
showPrioritySwitcher: boolean
hasCredentials: boolean
isCreditsExhausted: boolean
credentialName: string | undefined
credits: number
}
const DESTRUCTIVE_VARIANTS = new Set<CardVariant>([
'credits-exhausted',
'no-usage',
'api-unavailable',
])
export const isDestructiveVariant = (variant: CardVariant) =>
DESTRUCTIVE_VARIANTS.has(variant)
function deriveVariant(
priority: UsagePriority,
isExhausted: boolean,
hasCredential: boolean,
authorized: boolean | undefined,
credentialName: string | undefined,
): CardVariant {
if (priority === 'credits') {
if (!isExhausted)
return 'credits-active'
if (hasCredential && authorized)
return 'api-fallback'
if (hasCredential && !authorized)
return 'no-usage'
return 'credits-exhausted'
}
if (hasCredential && authorized)
return 'api-active'
if (priority === 'apiKey' && !isExhausted)
return 'credits-fallback'
if (priority === 'apiKey' && !hasCredential)
return 'no-usage'
if (hasCredential && !authorized)
return credentialName ? 'api-unavailable' : 'api-required-configure'
return 'api-required-add'
}
export function useCredentialPanelState(provider: ModelProvider): CredentialPanelState {
const { isExhausted, credits } = useTrialCredits()
const {
hasCredential,
authorized,
current_credential_name,
} = useCredentialStatus(provider)
const systemConfig = provider.system_configuration
const preferredType = provider.preferred_provider_type
const supportsCredits = systemConfig.enabled && IS_CLOUD_EDITION
const priority: UsagePriority = !supportsCredits
? 'apiKeyOnly'
: preferredType === PreferredProviderTypeEnum.system
? 'credits'
: 'apiKey'
const showPrioritySwitcher = supportsCredits
const variant = deriveVariant(priority, isExhausted, hasCredential, !!authorized, current_credential_name)
return {
variant,
priority,
supportsCredits,
showPrioritySwitcher,
hasCredentials: hasCredential,
isCreditsExhausted: isExhausted,
credentialName: current_credential_name,
credits,
}
}

View File

@@ -0,0 +1,15 @@
import { useCurrentWorkspace } from '@/service/use-common'
export const useTrialCredits = () => {
const { data: currentWorkspace, isPending } = useCurrentWorkspace()
const totalCredits = currentWorkspace?.trial_credits ?? 0
const credits = Math.max(totalCredits - (currentWorkspace?.trial_credits_used ?? 0), 0)
return {
credits,
totalCredits,
isExhausted: credits <= 0,
isLoading: isPending && !currentWorkspace,
nextCreditResetDate: currentWorkspace?.next_credit_reset_date,
}
}

View File

@@ -21,7 +21,7 @@ const ProviderIcon: FC<ProviderIconProps> = ({
if (provider.provider === 'langgenius/anthropic/anthropic') {
return (
<div className="mb-2 py-[7px]">
<div className={cn('py-[7px]', className)}>
{theme === Theme.dark && <AnthropicLight className="h-2.5 w-[90px]" />}
{theme === Theme.light && <AnthropicDark className="h-2.5 w-[90px]" />}
</div>
@@ -30,7 +30,7 @@ const ProviderIcon: FC<ProviderIconProps> = ({
if (provider.provider === 'langgenius/openai/openai') {
return (
<div className="mb-2">
<div className={className}>
<Openai className="h-6 w-auto text-text-inverted-dimmed" />
</div>
)

View File

@@ -26,6 +26,7 @@ vi.mock('react-i18next', async () => {
const mockNotify = vi.hoisted(() => vi.fn())
const mockUpdateModelList = vi.hoisted(() => vi.fn())
const mockInvalidateDefaultModel = vi.hoisted(() => vi.fn())
const mockUpdateDefaultModel = vi.hoisted(() => vi.fn(() => Promise.resolve({ result: 'success' })))
let mockIsCurrentWorkspaceManager = true
@@ -57,6 +58,7 @@ vi.mock('../hooks', () => ({
vi.fn(),
],
useUpdateModelList: () => mockUpdateModelList,
useInvalidateDefaultModel: () => mockInvalidateDefaultModel,
}))
vi.mock('@/service/common', () => ({
@@ -99,7 +101,7 @@ describe('SystemModel', () => {
expect(screen.getByRole('button', { name: /system model settings/i })).toBeInTheDocument()
})
it('should open modal when button is clicked', async () => {
it('should open dialog when button is clicked', async () => {
render(<SystemModel {...defaultProps} />)
const button = screen.getByRole('button', { name: /system model settings/i })
fireEvent.click(button)
@@ -113,7 +115,7 @@ describe('SystemModel', () => {
expect(screen.getByRole('button', { name: /system model settings/i })).toBeDisabled()
})
it('should close modal when cancel is clicked', async () => {
it('should close dialog when cancel is clicked', async () => {
render(<SystemModel {...defaultProps} />)
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
await waitFor(() => {
@@ -144,6 +146,7 @@ describe('SystemModel', () => {
type: 'success',
message: 'Modified successfully',
})
expect(mockInvalidateDefaultModel).toHaveBeenCalledTimes(5)
expect(mockUpdateModelList).toHaveBeenCalledTimes(5)
})
})

View File

@@ -3,22 +3,27 @@ import type {
DefaultModel,
DefaultModelResponse,
} from '../declarations'
import { RiEqualizer2Line, RiLoader2Line } from '@remixicon/react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import { useToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import {
Dialog,
DialogCloseButton,
DialogContent,
DialogTitle,
} from '@/app/components/base/ui/dialog'
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from '@/app/components/base/ui/tooltip'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { updateDefaultModel } from '@/service/common'
import { ModelTypeEnum } from '../declarations'
import {
useInvalidateDefaultModel,
useModelList,
useSystemDefaultModelAndModelList,
useUpdateModelList,
@@ -34,6 +39,21 @@ type SystemModelSelectorProps = {
notConfigured: boolean
isLoading?: boolean
}
type SystemModelLabelKey
= | 'modelProvider.systemReasoningModel.key'
| 'modelProvider.embeddingModel.key'
| 'modelProvider.rerankModel.key'
| 'modelProvider.speechToTextModel.key'
| 'modelProvider.ttsModel.key'
type SystemModelTipKey
= | 'modelProvider.systemReasoningModel.tip'
| 'modelProvider.embeddingModel.tip'
| 'modelProvider.rerankModel.tip'
| 'modelProvider.speechToTextModel.tip'
| 'modelProvider.ttsModel.tip'
const SystemModel: FC<SystemModelSelectorProps> = ({
textGenerationDefaultModel,
embeddingsDefaultModel,
@@ -48,6 +68,7 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
const { isCurrentWorkspaceManager } = useAppContext()
const { textGenerationModelList } = useProviderContext()
const updateModelList = useUpdateModelList()
const invalidateDefaultModel = useInvalidateDefaultModel()
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
const { data: speech2textModelList } = useModelList(ModelTypeEnum.speech2text)
@@ -106,154 +127,124 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
setOpen(false)
changedModelTypes.forEach((modelType) => {
if (modelType === ModelTypeEnum.textGeneration)
updateModelList(modelType)
else if (modelType === ModelTypeEnum.textEmbedding)
updateModelList(modelType)
else if (modelType === ModelTypeEnum.rerank)
updateModelList(modelType)
else if (modelType === ModelTypeEnum.speech2text)
updateModelList(modelType)
else if (modelType === ModelTypeEnum.tts)
updateModelList(modelType)
})
const allModelTypes = [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text, ModelTypeEnum.tts]
allModelTypes.forEach(type => invalidateDefaultModel(type))
changedModelTypes.forEach(type => updateModelList(type))
}
}
const renderModelLabel = (labelKey: SystemModelLabelKey, tipKey: SystemModelTipKey) => {
const tipText = t(tipKey, { ns: 'common' })
return (
<div className="flex min-h-6 items-center text-[13px] font-medium text-text-secondary">
{t(labelKey, { ns: 'common' })}
<Tooltip>
<TooltipTrigger
aria-label={tipText}
delay={0}
render={(
<span className="ml-0.5 flex h-4 w-4 shrink-0 items-center justify-center">
<span aria-hidden className="i-ri-question-line h-3.5 w-3.5 text-text-quaternary hover:text-text-tertiary" />
</span>
)}
/>
<TooltipContent>
<div className="w-[261px] text-text-tertiary">
{tipText}
</div>
</TooltipContent>
</Tooltip>
</div>
)
}
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement="bottom-end"
offset={{
mainAxis: 4,
crossAxis: 8,
}}
>
<PortalToFollowElemTrigger asChild onClick={() => setOpen(v => !v)}>
<Button
className="relative"
variant={notConfigured ? 'primary' : 'secondary'}
size="small"
disabled={isLoading}
<>
<Button
className="relative"
variant={notConfigured ? 'primary' : 'secondary'}
size="small"
disabled={isLoading}
onClick={() => setOpen(true)}
>
{isLoading
? <span className="i-ri-loader-2-line mr-1 h-3.5 w-3.5 animate-spin" />
: <span className="i-ri-equalizer-2-line mr-1 h-3.5 w-3.5" />}
{t('modelProvider.systemModelSettings', { ns: 'common' })}
</Button>
<Dialog open={open} onOpenChange={setOpen}>
<DialogContent
backdropProps={{ forceRender: true }}
className="w-[480px] max-w-[480px] overflow-hidden p-0"
>
{isLoading
? <RiLoader2Line className="mr-1 h-3.5 w-3.5 animate-spin" />
: <RiEqualizer2Line className="mr-1 h-3.5 w-3.5" />}
{t('modelProvider.systemModelSettings', { ns: 'common' })}
</Button>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className="z-[60]">
<div className="w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg pt-4 shadow-xl">
<div className="px-6 py-1">
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
{t('modelProvider.systemReasoningModel.key', { ns: 'common' })}
<Tooltip
popupContent={(
<div className="w-[261px] text-text-tertiary">
{t('modelProvider.systemReasoningModel.tip', { ns: 'common' })}
</div>
)}
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
/>
<DialogCloseButton className="right-5 top-5" />
<div className="px-6 pb-3 pr-14 pt-6">
<DialogTitle className="text-text-primary title-2xl-semi-bold">
{t('modelProvider.systemModelSettings', { ns: 'common' })}
</DialogTitle>
</div>
<div className="flex flex-col gap-4 px-6 py-3">
<div className="flex flex-col gap-1">
{renderModelLabel('modelProvider.systemReasoningModel.key', 'modelProvider.systemReasoningModel.tip')}
<div>
<ModelSelector
defaultModel={currentTextGenerationDefaultModel}
modelList={textGenerationModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textGeneration, model)}
/>
</div>
</div>
<div>
<ModelSelector
defaultModel={currentTextGenerationDefaultModel}
modelList={textGenerationModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textGeneration, model)}
/>
<div className="flex flex-col gap-1">
{renderModelLabel('modelProvider.embeddingModel.key', 'modelProvider.embeddingModel.tip')}
<div>
<ModelSelector
defaultModel={currentEmbeddingsDefaultModel}
modelList={embeddingModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textEmbedding, model)}
/>
</div>
</div>
<div className="flex flex-col gap-1">
{renderModelLabel('modelProvider.rerankModel.key', 'modelProvider.rerankModel.tip')}
<div>
<ModelSelector
defaultModel={currentRerankDefaultModel}
modelList={rerankModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.rerank, model)}
/>
</div>
</div>
<div className="flex flex-col gap-1">
{renderModelLabel('modelProvider.speechToTextModel.key', 'modelProvider.speechToTextModel.tip')}
<div>
<ModelSelector
defaultModel={currentSpeech2textDefaultModel}
modelList={speech2textModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.speech2text, model)}
/>
</div>
</div>
<div className="flex flex-col gap-1">
{renderModelLabel('modelProvider.ttsModel.key', 'modelProvider.ttsModel.tip')}
<div>
<ModelSelector
defaultModel={currentTTSDefaultModel}
modelList={ttsModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.tts, model)}
/>
</div>
</div>
</div>
<div className="px-6 py-1">
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
{t('modelProvider.embeddingModel.key', { ns: 'common' })}
<Tooltip
popupContent={(
<div className="w-[261px] text-text-tertiary">
{t('modelProvider.embeddingModel.tip', { ns: 'common' })}
</div>
)}
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
/>
</div>
<div>
<ModelSelector
defaultModel={currentEmbeddingsDefaultModel}
modelList={embeddingModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textEmbedding, model)}
/>
</div>
</div>
<div className="px-6 py-1">
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
{t('modelProvider.rerankModel.key', { ns: 'common' })}
<Tooltip
popupContent={(
<div className="w-[261px] text-text-tertiary">
{t('modelProvider.rerankModel.tip', { ns: 'common' })}
</div>
)}
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
/>
</div>
<div>
<ModelSelector
defaultModel={currentRerankDefaultModel}
modelList={rerankModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.rerank, model)}
/>
</div>
</div>
<div className="px-6 py-1">
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
{t('modelProvider.speechToTextModel.key', { ns: 'common' })}
<Tooltip
popupContent={(
<div className="w-[261px] text-text-tertiary">
{t('modelProvider.speechToTextModel.tip', { ns: 'common' })}
</div>
)}
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
/>
</div>
<div>
<ModelSelector
defaultModel={currentSpeech2textDefaultModel}
modelList={speech2textModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.speech2text, model)}
/>
</div>
</div>
<div className="px-6 py-1">
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
{t('modelProvider.ttsModel.key', { ns: 'common' })}
<Tooltip
popupContent={(
<div className="w-[261px] text-text-tertiary">
{t('modelProvider.ttsModel.tip', { ns: 'common' })}
</div>
)}
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
/>
</div>
<div>
<ModelSelector
defaultModel={currentTTSDefaultModel}
modelList={ttsModelList}
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.tts, model)}
/>
</div>
</div>
<div className="flex items-center justify-end px-6 py-4">
<div className="flex items-center justify-end gap-2 px-6 pb-6 pt-5">
<Button
className="min-w-[72px]"
onClick={() => setOpen(false)}
>
{t('operation.cancel', { ns: 'common' })}
</Button>
<Button
className="ml-2"
className="min-w-[72px]"
variant="primary"
onClick={handleSave}
disabled={!isCurrentWorkspaceManager}
@@ -261,9 +252,9 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
{t('operation.save', { ns: 'common' })}
</Button>
</div>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
</DialogContent>
</Dialog>
</>
)
}

View File

@@ -21,6 +21,11 @@ import {
export { ModelProviderQuotaGetPaid } from '@/types/model-provider'
export const providerToPluginId = (providerKey: string): string => {
const lastSlash = providerKey.lastIndexOf('/')
return lastSlash > 0 ? providerKey.slice(0, lastSlash) : ''
}
export const MODEL_PROVIDER_QUOTA_GET_PAID = [ModelProviderQuotaGetPaid.OPENAI, ModelProviderQuotaGetPaid.ANTHROPIC, ModelProviderQuotaGetPaid.GEMINI, ModelProviderQuotaGetPaid.X, ModelProviderQuotaGetPaid.DEEPSEEK, ModelProviderQuotaGetPaid.TONGYI]
export const modelNameMap = {

View File

@@ -79,6 +79,10 @@ vi.mock('@/service/plugins', () => ({
uninstallPlugin: mockUninstallPlugin,
}))
vi.mock('@/service/use-plugins', () => ({
useInvalidateCheckInstalled: () => vi.fn(),
}))
vi.mock('@/service/use-tools', () => ({
useAllToolProviders: () => ({ data: [] }),
useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders,
@@ -218,23 +222,6 @@ vi.mock('../../plugin-auth', () => ({
PluginAuth: () => <div data-testid="plugin-auth" />,
}))
// Mock Confirm component
vi.mock('@/app/components/base/confirm', () => ({
default: ({ isShow, onCancel, onConfirm, isLoading }: {
isShow: boolean
onCancel: () => void
onConfirm: () => void
isLoading: boolean
}) => isShow
? (
<div data-testid="delete-confirm">
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
</div>
)
: null,
}))
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
id: 'test-id',
created_at: '2024-01-01',
@@ -801,7 +788,7 @@ describe('DetailHeader', () => {
fireEvent.click(screen.getByTestId('remove-btn'))
await waitFor(() => {
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
})
@@ -810,13 +797,13 @@ describe('DetailHeader', () => {
fireEvent.click(screen.getByTestId('remove-btn'))
await waitFor(() => {
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('confirm-cancel'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
await waitFor(() => {
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument()
})
})
@@ -825,10 +812,10 @@ describe('DetailHeader', () => {
fireEvent.click(screen.getByTestId('remove-btn'))
await waitFor(() => {
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('confirm-ok'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
await waitFor(() => {
expect(mockUninstallPlugin).toHaveBeenCalledWith('test-id')
@@ -840,10 +827,10 @@ describe('DetailHeader', () => {
fireEvent.click(screen.getByTestId('remove-btn'))
await waitFor(() => {
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('confirm-ok'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
await waitFor(() => {
expect(mockOnUpdate).toHaveBeenCalledWith(true)
@@ -861,10 +848,10 @@ describe('DetailHeader', () => {
fireEvent.click(screen.getByTestId('remove-btn'))
await waitFor(() => {
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('confirm-ok'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
await waitFor(() => {
expect(mockRefreshModelProviders).toHaveBeenCalled()
@@ -876,10 +863,10 @@ describe('DetailHeader', () => {
fireEvent.click(screen.getByTestId('remove-btn'))
await waitFor(() => {
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('confirm-ok'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
await waitFor(() => {
expect(mockInvalidateAllToolProviders).toHaveBeenCalled()
@@ -891,10 +878,10 @@ describe('DetailHeader', () => {
fireEvent.click(screen.getByTestId('remove-btn'))
await waitFor(() => {
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('confirm-ok'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
await waitFor(() => {
expect(amplitude.trackEvent).toHaveBeenCalledWith('plugin_uninstalled', expect.any(Object))

View File

@@ -1,4 +1,6 @@
import type { ReactElement, ReactNode } from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import { cloneElement } from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginSource } from '../../types'
import OperationDropdown from '../operation-dropdown'
@@ -12,24 +14,22 @@ vi.mock('@/utils/classnames', () => ({
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
}))
vi.mock('@/app/components/base/action-button', () => ({
default: ({ children, className, onClick }: { children: React.ReactNode, className?: string, onClick?: () => void }) => (
<button data-testid="action-button" className={className} onClick={onClick}>
{children}
</button>
vi.mock('@/app/components/base/ui/dropdown-menu', () => ({
DropdownMenu: ({ children, open }: { children: ReactNode, open: boolean }) => (
<div data-testid="dropdown-menu" data-open={open}>{children}</div>
),
}))
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<div data-testid="portal-elem" data-open={open}>{children}</div>
DropdownMenuTrigger: ({ children, className }: { children: ReactNode, className?: string }) => (
<button data-testid="dropdown-trigger" className={className}>{children}</button>
),
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
<div data-testid="portal-trigger" onClick={onClick}>{children}</div>
),
PortalToFollowElemContent: ({ children, className }: { children: React.ReactNode, className?: string }) => (
<div data-testid="portal-content" className={className}>{children}</div>
DropdownMenuContent: ({ children }: { children: ReactNode }) => (
<div data-testid="dropdown-content">{children}</div>
),
DropdownMenuItem: ({ children, onClick, render, destructive }: { children: ReactNode, onClick?: () => void, render?: ReactElement, destructive?: boolean }) => {
if (render)
return cloneElement(render, { onClick, 'data-destructive': destructive } as Record<string, unknown>, children)
return <div data-testid="dropdown-item" data-destructive={destructive} onClick={onClick}>{children}</div>
},
DropdownMenuSeparator: () => <hr data-testid="dropdown-separator" />,
}))
describe('OperationDropdown', () => {
@@ -52,14 +52,13 @@ describe('OperationDropdown', () => {
it('should render trigger button', () => {
render(<OperationDropdown {...defaultProps} />)
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
expect(screen.getByTestId('action-button')).toBeInTheDocument()
expect(screen.getByTestId('dropdown-trigger')).toBeInTheDocument()
})
it('should render dropdown content', () => {
render(<OperationDropdown {...defaultProps} />)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
expect(screen.getByTestId('dropdown-content')).toBeInTheDocument()
})
it('should render info option for github source', () => {
@@ -118,14 +117,10 @@ describe('OperationDropdown', () => {
})
describe('User Interactions', () => {
it('should toggle dropdown when trigger is clicked', () => {
it('should render dropdown menu root', () => {
render(<OperationDropdown {...defaultProps} />)
const trigger = screen.getByTestId('portal-trigger')
fireEvent.click(trigger)
// The portal-elem should reflect the open state
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
expect(screen.getByTestId('dropdown-menu')).toBeInTheDocument()
})
it('should call onInfo when info option is clicked', () => {
@@ -174,7 +169,7 @@ describe('OperationDropdown', () => {
const { unmount } = render(
<OperationDropdown {...defaultProps} source={source} />,
)
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
expect(screen.getByTestId('dropdown-menu')).toBeInTheDocument()
expect(screen.getByText('plugin.detailPanel.operation.remove')).toBeInTheDocument()
unmount()
})
@@ -199,9 +194,7 @@ describe('OperationDropdown', () => {
describe('Memoization', () => {
it('should be wrapped with React.memo', () => {
// Verify the component is exported as a memo component
expect(OperationDropdown).toBeDefined()
// React.memo wraps the component, so it should have $$typeof
expect((OperationDropdown as { $$typeof?: symbol }).$$typeof).toBeDefined()
})
})

View File

@@ -9,24 +9,6 @@ vi.mock('@/context/i18n', () => ({
useGetLanguage: () => 'en_US',
}))
vi.mock('@/app/components/base/confirm', () => ({
default: ({ isShow, title, onCancel, onConfirm, isLoading }: {
isShow: boolean
title: string
onCancel: () => void
onConfirm: () => void
isLoading: boolean
}) => isShow
? (
<div data-testid="delete-confirm">
<div data-testid="delete-title">{title}</div>
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
</div>
)
: null,
}))
vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({
default: ({ repository, release, packageName, onHide }: {
repository: string
@@ -230,7 +212,7 @@ describe('HeaderModals', () => {
/>,
)
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument()
})
it('should render delete confirm when isShowDeleteConfirm is true', () => {
@@ -247,7 +229,7 @@ describe('HeaderModals', () => {
/>,
)
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
})
it('should show correct delete title', () => {
@@ -264,7 +246,7 @@ describe('HeaderModals', () => {
/>,
)
expect(screen.getByTestId('delete-title')).toHaveTextContent('plugin.action.delete')
expect(screen.getByRole('alertdialog')).toHaveTextContent('plugin.action.delete')
})
it('should call hideDeleteConfirm when cancel is clicked', () => {
@@ -281,7 +263,7 @@ describe('HeaderModals', () => {
/>,
)
fireEvent.click(screen.getByTestId('confirm-cancel'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
})
@@ -300,7 +282,7 @@ describe('HeaderModals', () => {
/>,
)
fireEvent.click(screen.getByTestId('confirm-ok'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
expect(mockOnDelete).toHaveBeenCalled()
})
@@ -319,7 +301,7 @@ describe('HeaderModals', () => {
/>,
)
expect(screen.getByTestId('confirm-ok')).toBeDisabled()
expect(screen.getByRole('button', { name: /common\.operation\.confirm/ })).toBeDisabled()
})
})
@@ -485,7 +467,7 @@ describe('HeaderModals', () => {
)
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
})
})

View File

@@ -4,7 +4,15 @@ import type { FC } from 'react'
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from '../hooks'
import { useTranslation } from 'react-i18next'
import Confirm from '@/app/components/base/confirm'
import {
AlertDialog,
AlertDialogActions,
AlertDialogCancelButton,
AlertDialogConfirmButton,
AlertDialogContent,
AlertDialogDescription,
AlertDialogTitle,
} from '@/app/components/base/ui/alert-dialog'
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
import { useGetLanguage } from '@/context/i18n'
@@ -50,7 +58,6 @@ const HeaderModals: FC<HeaderModalsProps> = ({
return (
<>
{/* Plugin Info Modal */}
{isShowPluginInfo && (
<PluginInfo
repository={isFromGitHub ? meta?.repo : ''}
@@ -60,27 +67,35 @@ const HeaderModals: FC<HeaderModalsProps> = ({
/>
)}
{/* Delete Confirm Modal */}
{isShowDeleteConfirm && (
<Confirm
isShow
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
content={(
<div>
<AlertDialog
open={isShowDeleteConfirm}
onOpenChange={(open) => {
if (!open)
hideDeleteConfirm()
}}
>
<AlertDialogContent backdropProps={{ forceRender: true }}>
<div className="flex flex-col gap-2 px-6 pb-4 pt-6">
<AlertDialogTitle className="text-text-primary title-2xl-semi-bold">
{t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
</AlertDialogTitle>
<AlertDialogDescription className="w-full whitespace-pre-wrap break-words text-text-tertiary system-md-regular">
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
<span className="system-md-semibold">{label[locale]}</span>
<span className="text-text-secondary system-md-semibold">{label[locale]}</span>
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
<br />
</div>
)}
onCancel={hideDeleteConfirm}
onConfirm={onDelete}
isLoading={deleting}
isDisabled={deleting}
/>
)}
</AlertDialogDescription>
</div>
<AlertDialogActions>
<AlertDialogCancelButton disabled={deleting}>
{t('operation.cancel', { ns: 'common' })}
</AlertDialogCancelButton>
<AlertDialogConfirmButton loading={deleting} disabled={deleting} onClick={onDelete}>
{t('operation.confirm', { ns: 'common' })}
</AlertDialogConfirmButton>
</AlertDialogActions>
</AlertDialogContent>
</AlertDialog>
{/* Update from Marketplace Modal */}
{isShowUpdateModal && (
<UpdateFromMarketplace
pluginId={detail.plugin_id}

View File

@@ -15,6 +15,7 @@ type VersionPickerMock = {
const {
mockSetShowUpdatePluginModal,
mockRefreshModelProviders,
mockInvalidateCheckInstalled,
mockInvalidateAllToolProviders,
mockUninstallPlugin,
mockFetchReleases,
@@ -23,6 +24,7 @@ const {
return {
mockSetShowUpdatePluginModal: vi.fn(),
mockRefreshModelProviders: vi.fn(),
mockInvalidateCheckInstalled: vi.fn(),
mockInvalidateAllToolProviders: vi.fn(),
mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })),
mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])),
@@ -46,6 +48,10 @@ vi.mock('@/service/plugins', () => ({
uninstallPlugin: mockUninstallPlugin,
}))
vi.mock('@/service/use-plugins', () => ({
useInvalidateCheckInstalled: () => mockInvalidateCheckInstalled,
}))
vi.mock('@/service/use-tools', () => ({
useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders,
}))
@@ -178,6 +184,7 @@ describe('usePluginOperations', () => {
result.current.handleUpdatedFromMarketplace()
})
expect(mockInvalidateCheckInstalled).toHaveBeenCalled()
expect(mockOnUpdate).toHaveBeenCalled()
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
})
@@ -251,6 +258,32 @@ describe('usePluginOperations', () => {
expect(mockSetShowUpdatePluginModal).toHaveBeenCalled()
})
it('should invalidate checkInstalled when GitHub update save callback fires', async () => {
const detail = createPluginDetail({
source: PluginSource.github,
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
})
const { result } = renderHook(() =>
usePluginOperations({
detail,
modalStates,
versionPicker,
isFromMarketplace: false,
onUpdate: mockOnUpdate,
}),
)
await act(async () => {
await result.current.handleUpdate()
})
const firstCall = mockSetShowUpdatePluginModal.mock.calls.at(0)?.[0]
firstCall?.onSaveCallback()
expect(mockInvalidateCheckInstalled).toHaveBeenCalled()
expect(mockOnUpdate).toHaveBeenCalled()
})
it('should not show modal when no releases found', async () => {
mockFetchReleases.mockResolvedValueOnce([])
const detail = createPluginDetail({
@@ -388,6 +421,7 @@ describe('usePluginOperations', () => {
await result.current.handleDelete()
})
expect(mockInvalidateCheckInstalled).toHaveBeenCalled()
expect(mockOnUpdate).toHaveBeenCalledWith(true)
})

View File

@@ -3,11 +3,13 @@
import type { PluginDetail } from '../../../types'
import type { ModalStates, VersionTarget } from './use-detail-header-state'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Toast from '@/app/components/base/toast'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
import { uninstallPlugin } from '@/service/plugins'
import { useInvalidateCheckInstalled } from '@/service/use-plugins'
import { useInvalidateAllToolProviders } from '@/service/use-tools'
import { useGitHubReleases } from '../../../install-plugin/hooks'
import { PluginCategoryEnum, PluginSource } from '../../../types'
@@ -36,13 +38,19 @@ export const usePluginOperations = ({
isFromMarketplace,
onUpdate,
}: UsePluginOperationsParams): UsePluginOperationsReturn => {
const { t } = useTranslation()
const { checkForUpdates, fetchReleases } = useGitHubReleases()
const { setShowUpdatePluginModal } = useModalContext()
const { refreshModelProviders } = useProviderContext()
const invalidateCheckInstalled = useInvalidateCheckInstalled()
const invalidateAllToolProviders = useInvalidateAllToolProviders()
const { id, meta, plugin_id } = detail
const { author, category, name } = detail.declaration || detail
const handlePluginUpdated = useCallback((isDelete?: boolean) => {
invalidateCheckInstalled()
onUpdate?.(isDelete)
}, [invalidateCheckInstalled, onUpdate])
const handleUpdate = useCallback(async (isDowngrade?: boolean) => {
if (isFromMarketplace) {
@@ -71,7 +79,7 @@ export const usePluginOperations = ({
if (needUpdate) {
setShowUpdatePluginModal({
onSaveCallback: () => {
onUpdate?.()
handlePluginUpdated()
},
payload: {
type: PluginSource.github,
@@ -97,15 +105,15 @@ export const usePluginOperations = ({
checkForUpdates,
setShowUpdatePluginModal,
detail,
onUpdate,
handlePluginUpdated,
modalStates,
versionPicker,
])
const handleUpdatedFromMarketplace = useCallback(() => {
onUpdate?.()
handlePluginUpdated()
modalStates.hideUpdateModal()
}, [onUpdate, modalStates])
}, [handlePluginUpdated, modalStates])
const handleDelete = useCallback(async () => {
modalStates.showDeleting()
@@ -114,7 +122,11 @@ export const usePluginOperations = ({
if (res.success) {
modalStates.hideDeleteConfirm()
onUpdate?.(true)
Toast.notify({
type: 'success',
message: t('action.deleteSuccess', { ns: 'plugin' }),
})
handlePluginUpdated(true)
if (PluginCategoryEnum.model.includes(category))
refreshModelProviders()
@@ -130,7 +142,7 @@ export const usePluginOperations = ({
plugin_id,
name,
modalStates,
onUpdate,
handlePluginUpdated,
refreshModelProviders,
invalidateAllToolProviders,
])

View File

@@ -1,16 +1,15 @@
'use client'
import type { FC } from 'react'
import { RiArrowRightUpLine, RiMoreFill } from '@remixicon/react'
import type { Placement } from '@/app/components/base/ui/placement'
import * as React from 'react'
import { useCallback, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ActionButton from '@/app/components/base/action-button'
// import Button from '@/app/components/base/button'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from '@/app/components/base/ui/dropdown-menu'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { cn } from '@/utils/classnames'
import { PluginSource } from '../types'
@@ -21,6 +20,10 @@ type Props = {
onCheckVersion: () => void
onRemove: () => void
detailUrl: string
placement?: Placement
sideOffset?: number
alignOffset?: number
popupClassName?: string
}
const OperationDropdown: FC<Props> = ({
@@ -29,83 +32,52 @@ const OperationDropdown: FC<Props> = ({
onInfo,
onCheckVersion,
onRemove,
placement = 'bottom-end',
sideOffset = 4,
alignOffset = 0,
popupClassName,
}) => {
const { t } = useTranslation()
const [open, doSetOpen] = useState(false)
const openRef = useRef(open)
const setOpen = useCallback((v: boolean) => {
doSetOpen(v)
openRef.current = v
}, [doSetOpen])
const handleTrigger = useCallback(() => {
setOpen(!openRef.current)
}, [setOpen])
const [open, setOpen] = React.useState(false)
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement="bottom-end"
offset={{
mainAxis: -12,
crossAxis: 36,
}}
>
<PortalToFollowElemTrigger onClick={handleTrigger}>
<div>
<ActionButton className={cn(open && 'bg-state-base-hover')}>
<RiMoreFill className="h-4 w-4" />
</ActionButton>
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className="z-50">
<div className="w-[160px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg">
{source === PluginSource.github && (
<div
onClick={() => {
onInfo()
handleTrigger()
}}
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover"
>
{t('detailPanel.operation.info', { ns: 'plugin' })}
</div>
)}
{source === PluginSource.github && (
<div
onClick={() => {
onCheckVersion()
handleTrigger()
}}
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover"
>
{t('detailPanel.operation.checkUpdate', { ns: 'plugin' })}
</div>
)}
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
<a href={detailUrl} target="_blank" className="system-md-regular flex cursor-pointer items-center rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover">
<span className="grow">{t('detailPanel.operation.viewDetail', { ns: 'plugin' })}</span>
<RiArrowRightUpLine className="h-3.5 w-3.5 shrink-0 text-text-tertiary" />
</a>
)}
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
<div className="my-1 h-px bg-divider-subtle"></div>
)}
<div
onClick={() => {
onRemove()
handleTrigger()
}}
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-destructive-hover hover:text-text-destructive"
>
{t('detailPanel.operation.remove', { ns: 'plugin' })}
</div>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<DropdownMenu open={open} onOpenChange={setOpen}>
<DropdownMenuTrigger
className={cn('action-btn action-btn-m', open && 'bg-state-base-hover')}
>
<span className="i-ri-more-fill h-4 w-4" />
</DropdownMenuTrigger>
<DropdownMenuContent
placement={placement}
sideOffset={sideOffset}
alignOffset={alignOffset}
popupClassName={cn('w-[160px]', popupClassName)}
>
{source === PluginSource.github && (
<DropdownMenuItem onClick={onInfo}>
{t('detailPanel.operation.info', { ns: 'plugin' })}
</DropdownMenuItem>
)}
{source === PluginSource.github && (
<DropdownMenuItem onClick={onCheckVersion}>
{t('detailPanel.operation.checkUpdate', { ns: 'plugin' })}
</DropdownMenuItem>
)}
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
<DropdownMenuItem render={<a href={detailUrl} target="_blank" rel="noopener noreferrer" />}>
<span className="grow">{t('detailPanel.operation.viewDetail', { ns: 'plugin' })}</span>
<span className="i-ri-arrow-right-up-line h-3.5 w-3.5 shrink-0 text-text-tertiary" />
</DropdownMenuItem>
)}
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
<DropdownMenuSeparator />
)}
<DropdownMenuItem destructive onClick={onRemove}>
{t('detailPanel.operation.remove', { ns: 'plugin' })}
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
)
}
export default React.memo(OperationDropdown)

View File

@@ -104,36 +104,6 @@ vi.mock('../../install-plugin/install-from-github', () => ({
),
}))
// Mock Portal components for PluginVersionPicker
let mockPortalOpen = false
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, onOpenChange: _onOpenChange }: {
children: React.ReactNode
open: boolean
onOpenChange: (open: boolean) => void
}) => {
mockPortalOpen = open
return <div data-testid="portal-elem" data-open={open}>{children}</div>
},
PortalToFollowElemTrigger: ({ children, onClick, className }: {
children: React.ReactNode
onClick: () => void
className?: string
}) => (
<div data-testid="portal-trigger" onClick={onClick} className={className}>
{children}
</div>
),
PortalToFollowElemContent: ({ children, className }: {
children: React.ReactNode
className?: string
}) => {
if (!mockPortalOpen)
return null
return <div data-testid="portal-content" className={className}>{children}</div>
},
}))
// Mock semver
vi.mock('semver', () => ({
lt: (v1: string, v2: string) => {
@@ -247,7 +217,6 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
describe('update-plugin', () => {
beforeEach(() => {
vi.clearAllMocks()
mockPortalOpen = false
mockCheck.mockResolvedValue({ status: TaskStatus.success })
})
@@ -946,7 +915,7 @@ describe('update-plugin', () => {
onShowChange: vi.fn(),
pluginID: 'test-plugin-id',
currentVersion: '1.0.0',
trigger: <button>Select Version</button>,
trigger: <span>Select Version</span>,
onSelect: vi.fn(),
}
@@ -964,7 +933,7 @@ describe('update-plugin', () => {
render(<PluginVersionPicker {...defaultProps} isShow={false} />)
// Assert
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
expect(screen.queryByText('plugin.detailPanel.switchVersion')).not.toBeInTheDocument()
})
it('should render version list when isShow is true', () => {
@@ -972,7 +941,6 @@ describe('update-plugin', () => {
render(<PluginVersionPicker {...defaultProps} isShow={true} />)
// Assert
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
expect(screen.getByText('plugin.detailPanel.switchVersion')).toBeInTheDocument()
})
@@ -1002,7 +970,7 @@ describe('update-plugin', () => {
// Act
render(<PluginVersionPicker {...defaultProps} onShowChange={onShowChange} />)
fireEvent.click(screen.getByTestId('portal-trigger'))
fireEvent.click(screen.getByText('Select Version'))
// Assert
expect(onShowChange).toHaveBeenCalledWith(true)
@@ -1014,7 +982,7 @@ describe('update-plugin', () => {
// Act
render(<PluginVersionPicker {...defaultProps} disabled={true} onShowChange={onShowChange} />)
fireEvent.click(screen.getByTestId('portal-trigger'))
fireEvent.click(screen.getByText('Select Version'))
// Assert
expect(onShowChange).not.toHaveBeenCalled()
@@ -1116,7 +1084,7 @@ describe('update-plugin', () => {
)
// Assert
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
expect(screen.getByText('plugin.detailPanel.switchVersion')).toBeInTheDocument()
})
it('should support custom offset', () => {
@@ -1125,12 +1093,13 @@ describe('update-plugin', () => {
<PluginVersionPicker
{...defaultProps}
isShow={true}
offset={{ mainAxis: 10, crossAxis: 20 }}
sideOffset={10}
alignOffset={20}
/>,
)
// Assert
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
expect(screen.getByText('plugin.detailPanel.switchVersion')).toBeInTheDocument()
})
})
@@ -1190,7 +1159,7 @@ describe('update-plugin', () => {
onShowChange: vi.fn(),
pluginID: 'test',
currentVersion: '1.0.0',
trigger: <button>Select</button>,
trigger: <span>Select</span>,
onSelect: vi.fn(),
}}
/>,

View File

@@ -18,8 +18,8 @@ const DowngradeWarningModal = ({
return (
<>
<div className="flex flex-col items-start gap-2 self-stretch">
<div className="title-2xl-semi-bold text-text-primary">{t(`${i18nPrefix}.title`, { ns: 'plugin' })}</div>
<div className="system-md-regular text-text-secondary">
<div className="text-text-primary title-2xl-semi-bold">{t(`${i18nPrefix}.title`, { ns: 'plugin' })}</div>
<div className="text-text-secondary system-md-regular">
{t(`${i18nPrefix}.description`, { ns: 'plugin' })}
</div>
</div>

View File

@@ -6,7 +6,12 @@ import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Badge, { BadgeState } from '@/app/components/base/badge/index'
import Button from '@/app/components/base/button'
import Modal from '@/app/components/base/modal'
import {
Dialog,
DialogCloseButton,
DialogContent,
DialogTitle,
} from '@/app/components/base/ui/dialog'
import Card from '@/app/components/plugins/card'
import checkTaskStatus from '@/app/components/plugins/install-plugin/base/check-task-status'
import { pluginManifestToCardPluginProps } from '@/app/components/plugins/install-plugin/utils'
@@ -125,63 +130,65 @@ const UpdatePluginModal: FC<Props> = ({
const doShowDowngradeWarningModal = isShowDowngradeWarningModal && uploadStep === UploadStep.notStarted
return (
<Modal
isShow={true}
onClose={onCancel}
className={cn('min-w-[560px]', doShowDowngradeWarningModal && 'min-w-[640px]')}
closable
title={!doShowDowngradeWarningModal && t(`${i18nPrefix}.${uploadStep === UploadStep.installed ? 'successfulTitle' : 'title'}`, { ns: 'plugin' })}
>
{doShowDowngradeWarningModal && (
<DowngradeWarningModal
onCancel={onCancel}
onJustDowngrade={handleConfirm}
onExcludeAndDowngrade={handleExcludeAndDownload}
/>
)}
{!doShowDowngradeWarningModal && (
<>
<div className="system-md-regular mb-2 mt-3 text-text-secondary">
{t(`${i18nPrefix}.description`, { ns: 'plugin' })}
</div>
<div className="flex flex-wrap content-start items-start gap-1 self-stretch rounded-2xl bg-background-section-burn p-2">
<Card
installed={uploadStep === UploadStep.installed}
payload={pluginManifestToCardPluginProps({
...originalPackageInfo.payload,
icon: icon!,
})}
className="w-full"
titleLeft={(
<>
<Badge className="mx-1" size="s" state={BadgeState.Warning}>
{`${originalPackageInfo.payload.version} -> ${targetPackageInfo.version}`}
</Badge>
</>
<Dialog open onOpenChange={() => onCancel()}>
<DialogContent
backdropProps={{ forceRender: true }}
className={cn('min-w-[560px]', doShowDowngradeWarningModal && 'min-w-[640px]')}
>
<DialogCloseButton />
{doShowDowngradeWarningModal && (
<DowngradeWarningModal
onCancel={onCancel}
onJustDowngrade={handleConfirm}
onExcludeAndDowngrade={handleExcludeAndDownload}
/>
)}
{!doShowDowngradeWarningModal && (
<>
<DialogTitle className="text-text-primary title-2xl-semi-bold">
{t(`${i18nPrefix}.${uploadStep === UploadStep.installed ? 'successfulTitle' : 'title'}`, { ns: 'plugin' })}
</DialogTitle>
<div className="mb-2 mt-3 text-text-secondary system-md-regular">
{t(`${i18nPrefix}.description`, { ns: 'plugin' })}
</div>
<div className="flex flex-wrap content-start items-start gap-1 self-stretch rounded-2xl bg-background-section-burn p-2">
<Card
installed={uploadStep === UploadStep.installed}
payload={pluginManifestToCardPluginProps({
...originalPackageInfo.payload,
icon: icon!,
})}
className="w-full"
titleLeft={(
<>
<Badge className="mx-1" size="s" state={BadgeState.Warning}>
{`${originalPackageInfo.payload.version} -> ${targetPackageInfo.version}`}
</Badge>
</>
)}
/>
</div>
<div className="flex items-center justify-end gap-2 self-stretch pt-5">
{uploadStep === UploadStep.notStarted && (
<Button
onClick={handleCancel}
>
{t('operation.cancel', { ns: 'common' })}
</Button>
)}
/>
</div>
<div className="flex items-center justify-end gap-2 self-stretch pt-5">
{uploadStep === UploadStep.notStarted && (
<Button
onClick={handleCancel}
variant="primary"
loading={uploadStep === UploadStep.upgrading}
onClick={handleConfirm}
disabled={uploadStep === UploadStep.upgrading}
>
{t('operation.cancel', { ns: 'common' })}
{configBtnText}
</Button>
)}
<Button
variant="primary"
loading={uploadStep === UploadStep.upgrading}
onClick={handleConfirm}
disabled={uploadStep === UploadStep.upgrading}
>
{configBtnText}
</Button>
</div>
</>
)}
</Modal>
</div>
</>
)}
</DialogContent>
</Dialog>
)
}
export default React.memo(UpdatePluginModal)

View File

@@ -1,19 +1,16 @@
'use client'
import type {
OffsetOptions,
Placement,
} from '@floating-ui/react'
import type { FC } from 'react'
import type { Placement } from '@/app/components/base/ui/placement'
import * as React from 'react'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { lt } from 'semver'
import Badge from '@/app/components/base/badge'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
Popover,
PopoverContent,
PopoverTrigger,
} from '@/app/components/base/ui/popover'
import useTimestamp from '@/hooks/use-timestamp'
import { useVersionListOfPlugin } from '@/service/use-plugins'
import { cn } from '@/utils/classnames'
@@ -26,7 +23,8 @@ type Props = {
currentVersion: string
trigger: React.ReactNode
placement?: Placement
offset?: OffsetOptions
sideOffset?: number
alignOffset?: number
onSelect: ({
version,
unique_identifier,
@@ -46,22 +44,14 @@ const PluginVersionPicker: FC<Props> = ({
currentVersion,
trigger,
placement = 'bottom-start',
offset = {
mainAxis: 4,
crossAxis: -16,
},
sideOffset = 4,
alignOffset = -16,
onSelect,
}) => {
const { t } = useTranslation()
const format = t('dateTimeFormat', { ns: 'appLog' }).split(' ')[0]
const { formatDate } = useTimestamp()
const handleTriggerClick = () => {
if (disabled)
return
onShowChange(true)
}
const { data: res } = useVersionListOfPlugin(pluginID)
const handleSelect = useCallback(({ version, unique_identifier, isDowngrade }: {
@@ -76,49 +66,52 @@ const PluginVersionPicker: FC<Props> = ({
}, [currentVersion, onSelect, onShowChange])
return (
<PortalToFollowElem
placement={placement}
offset={offset}
<Popover
open={isShow}
onOpenChange={onShowChange}
onOpenChange={(open) => {
if (!disabled)
onShowChange(open)
}}
>
<PortalToFollowElemTrigger
<PopoverTrigger
className={cn('inline-flex cursor-pointer items-center', disabled && 'cursor-default')}
onClick={handleTriggerClick}
>
{trigger}
</PortalToFollowElemTrigger>
</PopoverTrigger>
<PortalToFollowElemContent className="z-[1000]">
<div className="relative w-[209px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg backdrop-blur-sm">
<div className="system-xs-medium-uppercase px-3 pb-0.5 pt-1 text-text-tertiary">
{t('detailPanel.switchVersion', { ns: 'plugin' })}
</div>
<div className="relative">
{res?.data.versions.map(version => (
<div
key={version.unique_identifier}
className={cn(
'flex h-7 cursor-pointer items-center gap-1 rounded-lg px-3 py-1 hover:bg-state-base-hover',
currentVersion === version.version && 'cursor-default opacity-30 hover:bg-transparent',
)}
onClick={() => handleSelect({
version: version.version,
unique_identifier: version.unique_identifier,
isDowngrade: lt(version.version, currentVersion),
})}
>
<div className="flex grow items-center">
<div className="system-sm-medium text-text-secondary">{version.version}</div>
{currentVersion === version.version && <Badge className="ml-1" text="CURRENT" />}
</div>
<div className="system-xs-regular shrink-0 text-text-tertiary">{formatDate(version.created_at, format)}</div>
</div>
))}
</div>
<PopoverContent
placement={placement}
sideOffset={sideOffset}
alignOffset={alignOffset}
popupClassName="relative w-[209px] bg-components-panel-bg-blur p-1 backdrop-blur-sm"
>
<div className="px-3 pb-0.5 pt-1 text-text-tertiary system-xs-medium-uppercase">
{t('detailPanel.switchVersion', { ns: 'plugin' })}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<div className="relative max-h-[224px] overflow-y-auto">
{res?.data.versions.map(version => (
<div
key={version.unique_identifier}
className={cn(
'flex h-7 cursor-pointer items-center gap-1 rounded-lg px-3 py-1 hover:bg-state-base-hover',
currentVersion === version.version && 'cursor-default opacity-30 hover:bg-transparent',
)}
onClick={() => handleSelect({
version: version.version,
unique_identifier: version.unique_identifier,
isDowngrade: lt(version.version, currentVersion),
})}
>
<div className="flex grow items-center">
<div className="text-text-secondary system-sm-medium">{version.version}</div>
{currentVersion === version.version && <Badge className="ml-1" text="CURRENT" />}
</div>
<div className="shrink-0 text-text-tertiary system-xs-regular">{formatDate(version.created_at, format)}</div>
</div>
))}
</div>
</PopoverContent>
</Popover>
)
}

View File

@@ -3,13 +3,18 @@
import type { FC, ReactNode } from 'react'
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
import { useQueryClient } from '@tanstack/react-query'
import { noop } from 'es-toolkit/function'
import { useCallback, useEffect, useMemo } from 'react'
import { createContext, useContext, useContextSelector } from 'use-context-selector'
import { setUserId, setUserProperties } from '@/app/components/base/amplitude'
import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils'
import MaintenanceNotice from '@/app/components/header/maintenance-notice'
import { ZENDESK_FIELD_IDS } from '@/config'
import {
AppContext,
initialLangGeniusVersionInfo,
initialWorkspaceInfo,
userProfilePlaceholder,
useSelector,
} from '@/context/app-context'
import { env } from '@/env'
import {
useCurrentWorkspace,
@@ -18,72 +23,6 @@ import {
} from '@/service/use-common'
import { useGlobalPublicStore } from './global-public-context'
export type AppContextValue = {
userProfile: UserProfileResponse
mutateUserProfile: VoidFunction
currentWorkspace: ICurrentWorkspace
isCurrentWorkspaceManager: boolean
isCurrentWorkspaceOwner: boolean
isCurrentWorkspaceEditor: boolean
isCurrentWorkspaceDatasetOperator: boolean
mutateCurrentWorkspace: VoidFunction
langGeniusVersionInfo: LangGeniusVersionResponse
useSelector: typeof useSelector
isLoadingCurrentWorkspace: boolean
isValidatingCurrentWorkspace: boolean
}
const userProfilePlaceholder = {
id: '',
name: '',
email: '',
avatar: '',
avatar_url: '',
is_password_set: false,
}
const initialLangGeniusVersionInfo = {
current_env: '',
current_version: '',
latest_version: '',
release_date: '',
release_notes: '',
version: '',
can_auto_update: false,
}
const initialWorkspaceInfo: ICurrentWorkspace = {
id: '',
name: '',
plan: '',
status: '',
created_at: 0,
role: 'normal',
providers: [],
trial_credits: 200,
trial_credits_used: 0,
next_credit_reset_date: 0,
}
const AppContext = createContext<AppContextValue>({
userProfile: userProfilePlaceholder,
currentWorkspace: initialWorkspaceInfo,
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: false,
isCurrentWorkspaceEditor: false,
isCurrentWorkspaceDatasetOperator: false,
mutateUserProfile: noop,
mutateCurrentWorkspace: noop,
langGeniusVersionInfo: initialLangGeniusVersionInfo,
useSelector,
isLoadingCurrentWorkspace: false,
isValidatingCurrentWorkspace: false,
})
export function useSelector<T>(selector: (value: AppContextValue) => T): T {
return useContextSelector(AppContext, selector)
}
export type AppContextProviderProps = {
children: ReactNode
}
@@ -170,7 +109,7 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
// Report user and workspace info to Amplitude when loaded
if (userProfile?.id) {
setUserId(userProfile.email)
const properties: Record<string, any> = {
const properties: Record<string, string | number | boolean> = {
email: userProfile.email,
name: userProfile.name,
has_password: userProfile.is_password_set,
@@ -213,7 +152,3 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
</AppContext.Provider>
)
}
export const useAppContext = () => useContext(AppContext)
export default AppContext

View File

@@ -0,0 +1,73 @@
'use client'
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
import { noop } from 'es-toolkit/function'
import { createContext, useContext, useContextSelector } from 'use-context-selector'
export type AppContextValue = {
userProfile: UserProfileResponse
mutateUserProfile: VoidFunction
currentWorkspace: ICurrentWorkspace
isCurrentWorkspaceManager: boolean
isCurrentWorkspaceOwner: boolean
isCurrentWorkspaceEditor: boolean
isCurrentWorkspaceDatasetOperator: boolean
mutateCurrentWorkspace: VoidFunction
langGeniusVersionInfo: LangGeniusVersionResponse
useSelector: typeof useSelector
isLoadingCurrentWorkspace: boolean
isValidatingCurrentWorkspace: boolean
}
export const userProfilePlaceholder = {
id: '',
name: '',
email: '',
avatar: '',
avatar_url: '',
is_password_set: false,
}
export const initialLangGeniusVersionInfo = {
current_env: '',
current_version: '',
latest_version: '',
release_date: '',
release_notes: '',
version: '',
can_auto_update: false,
}
export const initialWorkspaceInfo: ICurrentWorkspace = {
id: '',
name: '',
plan: '',
status: '',
created_at: 0,
role: 'normal',
providers: [],
trial_credits: 200,
trial_credits_used: 0,
next_credit_reset_date: 0,
}
export const AppContext = createContext<AppContextValue>({
userProfile: userProfilePlaceholder,
currentWorkspace: initialWorkspaceInfo,
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: false,
isCurrentWorkspaceEditor: false,
isCurrentWorkspaceDatasetOperator: false,
mutateUserProfile: noop,
mutateCurrentWorkspace: noop,
langGeniusVersionInfo: initialLangGeniusVersionInfo,
useSelector,
isLoadingCurrentWorkspace: false,
isValidatingCurrentWorkspace: false,
})
export function useSelector<T>(selector: (value: AppContextValue) => T): T {
return useContextSelector(AppContext, selector)
}
export const useAppContext = () => useContext(AppContext)

View File

@@ -343,8 +343,8 @@ export const ModalContextProvider = ({
accountSettingTab && (
<AccountSetting
activeTab={accountSettingTab}
onCancel={handleCancelAccountSettingModal}
onTabChange={handleAccountSettingTabChange}
onCancelAction={handleCancelAccountSettingModal}
onTabChangeAction={handleAccountSettingTabChange}
/>
)
}

View File

@@ -0,0 +1,33 @@
import type { ModelItem, PreferredProviderTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { CommonResponse } from '@/models/common'
import { type } from '@orpc/contract'
import { base } from '../base'
export const modelProvidersModelsContract = base
.route({
path: '/workspaces/current/model-providers/{provider}/models',
method: 'GET',
})
.input(type<{
params: {
provider: string
}
}>())
.output(type<{
data: ModelItem[]
}>())
export const changePreferredProviderTypeContract = base
.route({
path: '/workspaces/current/model-providers/{provider}/preferred-provider-type',
method: 'POST',
})
.input(type<{
params: {
provider: string
}
body: {
preferred_provider_type: PreferredProviderTypeEnum
}
}>())
.output(type<CommonResponse>())

View File

@@ -12,6 +12,7 @@ import {
exploreInstalledAppsContract,
exploreInstalledAppUninstallContract,
} from './console/explore'
import { changePreferredProviderTypeContract, modelProvidersModelsContract } from './console/model-providers'
import { systemFeaturesContract } from './console/system'
import {
triggerOAuthConfigContract,
@@ -63,6 +64,10 @@ export const consoleRouterContract = {
parameters: trialAppParametersContract,
workflows: trialAppWorkflowsContract,
},
modelProviders: {
models: modelProvidersModelsContract,
changePreferredProviderType: changePreferredProviderTypeContract,
},
billing: {
invoices: invoicesContract,
bindPartnerStack: bindPartnerStackContract,

View File

@@ -10,6 +10,9 @@ This document tracks the migration away from legacy overlay APIs.
- `@/app/components/base/modal`
- `@/app/components/base/confirm`
- `@/app/components/base/select` (including `custom` / `pure`)
- `@/app/components/base/popover`
- `@/app/components/base/dropdown`
- `@/app/components/base/dialog`
- Replacement primitives:
- `@/app/components/base/ui/tooltip`
- `@/app/components/base/ui/dropdown-menu`

Some files were not shown because too many files have changed in this diff Show More