Compare commits

..

104 Commits

Author SHA1 Message Date
yyh
b3c98e417d Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 23:47:27 +08:00
yyh
dfe389c017 Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 23:42:04 +08:00
yyh
b364b06e51 refactor(model-selector): migrate overlays to Popover/Tooltip and unify trigger component
- Migrate PortalToFollowElem to base-ui Popover in model-selector,
  model-parameter-modal, and plugin-detail-panel model-selector
- Migrate legacy Tooltip to compound Tooltip in popup-item and trigger
- Unify EmptyTrigger, ModelTrigger, DeprecatedModelTrigger into a
  single declarative ModelSelectorTrigger that derives state from props
- Remove showDeprecatedWarnIcon boolean prop anti-pattern; deprecated
  state always renders warn icon as part of component's visual contract
- Remove deprecatedClassName prop; component manages disabled styling
- Replace manual triggerRef width measurement with CSS var(--anchor-width)
- Remove tooltip scroll listener (base-ui auto-tracks anchor position)
- Restore conditional placement for workflow mode in plugin-detail-panel
- Prune stale ESLint suppressions for removed deprecated imports
2026-03-09 23:34:42 +08:00
CodingOnStar
ce0197b107 fix(provider): handle undefined provider in credential status and panel state 2026-03-09 18:20:02 +08:00
yyh
164cefc65c Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 17:41:13 +08:00
yyh
f6d80b9fa7 fix(workflow): derive plugin install state in render
Remove useEffect-based sync of _pluginInstallLocked/_dimmed in workflow nodes to avoid render-update loops.\n\nMove plugin-missing checks to pure utilities and use them in checklist.\nOptimize node installation hooks by enabling only relevant queries and narrowing memo dependencies.
2026-03-09 17:18:09 +08:00
yyh
e845fa7e6a fix(plugin-install): support bundle marketplace dependency shape 2026-03-09 17:07:27 +08:00
yyh
bab7bd5ecc Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 17:03:54 +08:00
yyh
cfb02bceaf feat(workflow): open install bundle from checklist and strict marketplace parsing 2026-03-09 17:03:43 +08:00
yyh
694ca840e1 feat(web): add warning dot indicator on LLM panel field labels synced with checklist
Store checklist items in zustand WorkflowStore so both the checklist UI
and node panels share a single source of truth. The LLM panel reads from
the store to show a Figma-aligned warning dot (absolute-positioned, no
layout shift) on the MODEL field label when the node has checklist warnings.
2026-03-09 16:38:31 +08:00
yyh
2d979e2cec fix(web): silence toast for model parameter rules fetch on missing provider
Add silent option to useModelParameterRules API call so uninstalled
provider errors are swallowed instead of surfacing a raw backend toast.
2026-03-09 16:17:09 +08:00
yyh
5cee7cf8ce feat(web): add LLM model plugin check to workflow checklist
Detect uninstalled model plugins for LLM nodes in the checklist and
publish-gate. Migrate ChecklistItem.errorMessage to errorMessages[]
so a single node can surface multiple validation issues at once.

- Extract shared extractPluginId utility for checklist and prompt editor
- Build installed-plugin Set (O(1) lookup) from ProviderContext
- Remove short-circuit between checkValid and variable validation
- Sync the same check into handleCheckBeforePublish
- Adapt node-group, use-last-run, and test assertions
2026-03-09 16:16:16 +08:00
yyh
0c17823c8b fix 2026-03-09 15:38:46 +08:00
yyh
49c6696d08 fix: use css icons 2026-03-09 15:27:08 +08:00
yyh
292c98a8f3 refactor(web): redesign workflow checklist panel with grouped tree view and Popover primitive
Migrate checklist from flat card list using deprecated PortalToFollowElem to
grouped tree view using base-ui Popover. Split into checklist/ directory with
separate components: plugin group with batch install, per-node groups with
sub-items and "Go to fix" hover action, and tree-line SVG indicators.
2026-03-09 15:23:34 +08:00
CodingOnStar
0e0a6ad043 test(web): enhance unit tests for credential and popup components
- Updated tests for CredentialItem to improve delete button interaction and check icon rendering.
- Enhanced PopupItem tests by mocking credential panel state for various scenarios, ensuring accurate rendering based on credit status.
- Adjusted Popup tests to include trial credits mock for better coverage of credit management logic.
- Refactored model list item tests to include wrapper for consistent rendering context.
2026-03-09 14:20:12 +08:00
yyh
456c95adb1 refactor(web): trigger error tooltip on entire variable badge hover 2026-03-09 14:03:52 +08:00
yyh
1abbaf9fd5 feat(web): differentiate invalid variable tooltips by model plugin status
Replace the generic "Invalid variable" message in prompt editor variable
labels with two context-aware messages: one for missing nodes and another
for uninstalled model plugins. Add useLlmModelPluginInstalled hook that
checks LLM node model providers against installed providers via
useProviderContextSelector. Migrate Tooltip usage to base-ui primitives
and replace RiErrorWarningFill with Warning icon in warning color.
2026-03-09 14:02:26 +08:00
CodingOnStar
1a26e1669b refactor(web): streamline PopupItem component for credit management
- Removed unused context and variables related to workspace and custom configuration.
- Simplified credit usage logic by leveraging state management for better clarity and performance.
- Enhanced readability by restructuring the code for determining credit status and API key activity.
2026-03-09 13:10:29 +08:00
CodingOnStar
02444af2e3 feat(web): enhance Popup and CreditsFallbackAlert components for better credit management
- Integrated trial credits check in the Popup component to conditionally display the CreditsExhaustedAlert.
- Updated the CreditsFallbackAlert to show a message only when API keys are unavailable.
- Removed the fallback description from translation files as it is no longer used.
2026-03-09 12:57:41 +08:00
CodingOnStar
56038e3684 feat(web): update credits fallback alert to include new description for no API keys
- Modified the CreditsFallbackAlert component to display a different message based on the presence of API keys.
- Added a new translation key for the fallback description in both English and Chinese JSON files.
2026-03-09 12:34:41 +08:00
CodingOnStar
eb9341e7ec feat(web): integrate CreditsCoin icon in PopupItem for enhanced UI
- Replaced the existing credits coin span with the CreditsCoin component for improved visual consistency.
- Updated imports to include the new CreditsCoin icon component.
2026-03-09 12:28:13 +08:00
CodingOnStar
e40b31b9c4 refactor(web): enhance model selector functionality and improve UI consistency
- Removed unnecessary ESLint suppressions for better code quality.
- Updated the ModelParameterModal and ModelSelector components to ensure consistent class ordering.
- Added onHide prop to ModelSelector for better control over dropdown visibility.
- Introduced useChangeProviderPriority hook to manage provider priority changes more effectively.
- Integrated CreditsExhaustedAlert in the Popup component to handle API key status more gracefully.
2026-03-09 12:24:54 +08:00
yyh
b89ee4807f Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/index.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/model-modal/index.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/model-selector/popup-item.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/model-selector/popup.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx
2026-03-09 12:12:27 +08:00
yyh
9907cf9e06 Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-08 22:27:42 +08:00
yyh
208a31719f Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-08 01:10:51 +08:00
yyh
3d1ef1f7f5 Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-06 21:45:37 +08:00
CodingOnStar
24b14e2c1a Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-06 19:00:17 +08:00
CodingOnStar
53f122f717 Merge branch 'feat/model-provider-refactor' into feat/model-plugins-implementing 2026-03-06 17:33:38 +08:00
CodingOnStar
fced2f9e65 refactor: enhance plugin management UI with error handling, improved rendering, and new components 2026-03-06 16:27:26 +08:00
yyh
0c08c4016d Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-06 14:57:48 +08:00
CodingOnStar
ff4e4a8d64 refactor: enhance model trigger component with internationalization support and improved tooltip handling 2026-03-06 14:50:23 +08:00
yyh
948efa129f Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-06 14:47:56 +08:00
CodingOnStar
e371bfd676 refactor: enhance model provider management with new icons, improved UI elements, and marketplace integration 2026-03-06 14:18:29 +08:00
yyh
6d612c0909 test: improve Jotai atom test quality and add model-provider atoms tests
Replace dynamic imports with static imports in marketplace atom tests.
Convert type-only and not-toThrow assertions into proper state-change
verifications. Add comprehensive test suite for model-provider-page
atoms covering all four hooks, cross-hook interaction, selectAtom
granularity, and Provider isolation.
2026-03-05 22:49:09 +08:00
yyh
56e0dc0ae6 trigger ci
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
2026-03-05 21:22:03 +08:00
yyh
975eca00c3 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 20:25:53 +08:00
yyh
f049bafcc3 refactor: simplify Jotai atoms by removing redundant write-only atoms
Replace 2 write-only derived atoms with primitive atom's built-in
updater functions. The selectAtom on the read side already prevents
unnecessary re-renders, making the manual guard logic redundant.
2026-03-05 20:25:29 +08:00
CodingOnStar
dd9c526447 refactor: update model-selector popup-item to support collapsible items and improve icon color handling 2026-03-05 16:45:37 +08:00
yyh
922dc71e36 fix 2026-03-05 16:17:38 +08:00
yyh
f03ec7f671 Merge branch 'main' into feat/model-provider-refactor 2026-03-05 16:14:36 +08:00
yyh
29f275442d Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx
#	web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx
2026-03-05 16:13:40 +08:00
yyh
c9532ffd43 add stories 2026-03-05 15:55:21 +08:00
yyh
840dc33b8b Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 15:12:32 +08:00
yyh
cae58a0649 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 15:08:13 +08:00
yyh
1752edc047 refactor(web): optimize model provider re-render and remove useEffect state sync
- Replace useEffect state sync with derived state pattern in useSystemDefaultModelAndModelList
- Use useCallback instead of useMemo for function memoization in useProviderCredentialsAndLoadBalancing
- Add memo() to ProviderAddedCard and CredentialPanel to prevent unnecessary re-renders
- Switch to useProviderContextSelector for precise context subscription in ProviderAddedCard
- Stabilize activate callback ref in useActivateCredential via supportedModelTypes ref
- Add usage priority tooltip with i18n support
2026-03-05 15:07:53 +08:00
yyh
7471c32612 Revert "temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing"
This reverts commit ab87ac333a.
2026-03-05 14:33:48 +08:00
yyh
2d333bbbe5 refactor(web): extract credential activation into hook and migrate credential-item overlays
Extract credential switching logic from dropdown-content into a dedicated
useActivateCredential hook with optimistic updates and proper data flow
separation. Credential items now stay visible in the popover after clicking
(no auto-close), show cursor-pointer, and disable during activation.

Migrate credential-item from legacy Tooltip and remixicon imports to
base-ui Tooltip and CSS icon classes, pruning stale ESLint suppressions.
2026-03-05 14:22:39 +08:00
yyh
4af6788ce0 fix(web): wrap Header test in Dialog context for base-ui compatibility 2026-03-05 14:20:35 +08:00
yyh
24b072def9 fix: lint 2026-03-05 14:08:20 +08:00
yyh
909c8c3350 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 13:58:51 +08:00
yyh
80e9c8bee0 refactor(web): make account setting fully controlled with action props 2026-03-05 13:39:36 +08:00
yyh
15b7b304d2 refactor(web): migrate model-modal overlays to base-ui Dialog and AlertDialog
Replace legacy PortalToFollowElem and Confirm with Dialog/AlertDialog
primitives. Remove manual ESC handler and backdrop div — now handled
natively by base-ui. Add backdropProps={{ forceRender: true }} for
correct nested overlay rendering.
2026-03-05 13:33:53 +08:00
yyh
61e2672b59 refactor(web): make provider reset event-driven and scope model invalidation
- remove provider-page lifecycle reset effect and handle reset in explicit tab/close actions
- switch account setting tab state to controlled/uncontrolled pattern without sync effect
- use provider-scoped model list queryKey with exact invalidation in credential and model toggle mutations
- update related tests and mocks for new behavior
2026-03-05 13:28:30 +08:00
yyh
5f4ed4c6f6 refactor(web): replace model provider emitter refresh with jotai state
- add atom-based provider expansion state with reset/prune helpers
- remove event-emitter dependency from model provider refresh flow
- invalidate exact provider model-list query key on refresh
- reset expansion state on model provider page mount/unmount
- update and extend tests for external expansion and query invalidation
- update eslint suppressions to match current code
2026-03-05 13:20:58 +08:00
yyh
4a1032c628 fix(web): remove redundant hover text swap on show models button
Merge the two hover-toggling divs into a single always-visible element
and remove the unused showModelsNum i18n key from all locales.
2026-03-05 13:16:04 +08:00
yyh
423c97a47e code style 2026-03-05 13:09:33 +08:00
yyh
a7e3fb2e33 fix(web): use triangle Warning icon instead of circle error icon
Replace i-ri-error-warning-fill (circle exclamation) with the
Warning component (triangle) for api-fallback and credits-fallback
variants to match Figma design.
2026-03-05 13:07:20 +08:00
yyh
ce34937a1c feat(web): add credits-fallback variant for API Key priority with available credits
When API Key is selected but unavailable/unconfigured and credits are
available, the card now shows "AI credits in use" with a warning icon
instead of "API key required". When both credits are exhausted and no
API key exists, it shows "No available usage" (destructive).

New deriveVariant logic for priority=apiKey:
- !exhausted + !authorized → credits-fallback (was api-required-*)
- exhausted + no credential → no-usage (was api-required-add)
- exhausted + named unauthorized → api-unavailable (unchanged)
2026-03-05 13:02:40 +08:00
yyh
ad9ac6978e fix(web): align alert card width with API key section in dropdown
Change mx-1 (4px) to mx-2 (8px) on CreditsFallbackAlert and
CreditsExhaustedAlert to match ApiKeySection's p-2 (8px) padding,
consistent with Figma design where both sections are 8px from the
dropdown edge.
2026-03-05 12:56:55 +08:00
yyh
57c1ba3543 fix(web): hide divider above empty API keys state in dropdown
Move the border from UsagePrioritySection (always visible) to
ApiKeySection's list variant (only when credentials exist). This
removes the unwanted divider line above the "No API Keys" empty
state card when on the AI Credits tab with no keys configured.
2026-03-05 12:25:11 +08:00
yyh
d7a5af2b9a Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/index.tsx
2026-03-05 10:46:24 +08:00
yyh
d45edffaa3 fix(web): wire upgrade link to pricing modal and add credits-coin icon
Replace broken HTML string interpolation with Trans component and
useModalContextSelector so "upgrade your plan" opens the pricing modal.
Add custom credits-coin SVG icon to replace the generic ri-coin-line.
2026-03-05 10:39:31 +08:00
yyh
530515b6ef fix(web): prevent model list from expanding on priority switch
Remove UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST event emission from
changePriority onSuccess. This event was designed for custom model
add/edit/delete scenarios where the card should expand, but firing
it on priority switch caused ProviderAddedCard to unexpectedly
expand via refreshModelList → setCollapsed(false).
2026-03-05 10:35:03 +08:00
yyh
f13f0d1f9a fix(web): align dropdown alerts with Figma design and fix hardcoded credits total
- Expose totalCredits from useTrialCredits hook instead of hardcoding 10,000
- Align CreditsExhaustedAlert with Figma: dynamic progress bar, correct
  design tokens (components-progress-error-bg/progress), sm-medium/xs-regular
  typography
- Align CreditsFallbackAlert typography to sm-medium/xs-regular
- Fix ApiKeySection empty state: horizontal gradient, sm-medium title,
  Figma-aligned padding (pl-7 for API KEYS label)
- Hoist empty credentials array constant to stabilize memo (rerender-memo-with-default-value)
- Remove redundant useCallback wrapper in ApiKeySection
- Replace nested ternary with Record lookup in TextLabel
- Remove dead || 0 guard in useTrialCredits
- Update all test mocks with totalCredits field
2026-03-05 10:09:51 +08:00
yyh
b597d52c11 refactor(web): remove dialog description from system model selector
Remove the DialogDescription and its i18n key (modelProvider.systemModelSettingsLink)
from the system model settings dialog across all 23 locales.
2026-03-05 10:05:01 +08:00
yyh
34c42fe666 Revert "temp: remove cloud condition"
This reverts commit 29e344ac8b.
2026-03-05 09:44:19 +08:00
yyh
dc109c99f0 test(web): expand credential panel and dropdown test coverage for all 8 card variants
Add comprehensive behavioral tests covering all discriminated union variants,
destructive/default styling, warning icons, CreditsFallbackAlert conditions,
credential CRUD interactions, AlertDialog delete confirmation, and Popover behavior.
2026-03-05 09:41:48 +08:00
yyh
223b9d89c1 refactor(web): migrate priority change to oRPC contract with useMutation
- Add changePreferredProviderType contract in model-providers.ts
- Register in consoleRouterContract
- Replace raw async changeModelProviderPriority with useMutation
- Use Toast.notify (static API) instead of useToastContext hook
- Pass isPending as isChangingPriority to disable buttons during switch
- Add disabled prop to UsagePrioritySection
- Fix pre-existing test assertions for api-unavailable variant
- Update all specs with isChangingPriority prop and oRPC mock pattern
2026-03-05 09:30:38 +08:00
yyh
dd119eb44f fix(web): align UsagePrioritySection with Figma design and fix i18n key ordering
- Single-row layout for icon, label, and option cards
- Icon: arrow-up-double-line matching design spec
- Buttons: flexible width with whitespace-nowrap instead of fixed w-[72px]
- Add min-w-0 + truncate for text overflow, focus-visible ring for a11y
- Sort modelProvider.card.* i18n keys alphabetically
2026-03-05 09:15:16 +08:00
yyh
970493fa85 test(web): update tests for credential panel refactoring and new ModelAuthDropdown components
Rewrite credential-panel.spec.tsx to match the new discriminated union
state model and variant-driven rendering. Add new test files for
useCredentialPanelState hook, SystemQuotaCard Label enhancement,
and all ModelAuthDropdown sub-components.
2026-03-05 08:41:17 +08:00
yyh
ab87ac333a temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing 2026-03-05 08:34:10 +08:00
yyh
b8b70da9ad refactor(web): rewrite CredentialPanel with declarative variant-driven state and new ModelAuthDropdown
- Extract useCredentialPanelState hook with discriminated union CardVariant type replacing scattered boolean conditions
- Create ModelAuthDropdown compound component (Popover-based) with UsagePrioritySection, CreditsExhaustedAlert, and ApiKeySection
- Enhance SystemQuotaCard.Label to accept className override for flexible styling
- Add i18n keys for new card states and dropdown content (en-US, zh-Hans)
2026-03-05 08:33:04 +08:00
yyh
77d81aebe8 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-04 23:35:20 +08:00
yyh
deb4cd3ece fix: i18n 2026-03-04 23:35:13 +08:00
yyh
648d9ef1f9 refactor(web): extract SystemQuotaCard compound component and shared useTrialCredits hook
Extract trial credits calculation into a shared useTrialCredits hook to prevent
logic drift between QuotaPanel and CredentialPanel. Add SystemQuotaCard compound
component with explicit default/destructive variants for the system quota UI
state in provider cards, replacing inline conditional styling with composable
Label and Actions slots. Remove unnecessary useMemo for simple derived values.
2026-03-04 23:30:25 +08:00
yyh
5ed4797078 fix 2026-03-04 22:53:29 +08:00
yyh
62631658e9 fix(web): update tests for AlertDialog migration and component API changes
- Replace deprecated Confirm mock with real AlertDialog role-based queries
- Add useInvalidateCheckInstalled mock for QueryClient dependency
- Wrap model-list-item renders in QueryClientProvider
- Migrate PluginVersionPicker from PortalToFollowElem to Popover
- Migrate UpdatePluginModal from Modal to Dialog
- Update version picker offset props (sideOffset/alignOffset)
2026-03-04 22:52:21 +08:00
yyh
22a4100dd7 fix(web): invalidate plugin checkInstalled cache after version updates 2026-03-04 22:33:17 +08:00
yyh
0f7ed6f67e refactor(web): align provider badges with figma and remove dead add-model-button 2026-03-04 22:29:51 +08:00
yyh
4d9fcbec57 refactor(web): migrate remove-plugin dialog to base UI AlertDialog and improve UX
- Replace deprecated Confirm component with AlertDialog primitives
- Add forceRender backdrop for proper overlay rendering
- Add success Toast notification after plugin removal
- Update "View Detail" text to "View on Marketplace" (en/zh-Hans)
- Add i18n keys for delete success message
- Prune stale eslint suppression for header-modals
2026-03-04 22:14:19 +08:00
yyh
4d7a9bc798 fix(web): align model provider cache invalidation with oRPC keys 2026-03-04 22:06:27 +08:00
yyh
d6d04ed657 fix 2026-03-04 22:03:06 +08:00
yyh
f594a71dae fix: icon 2026-03-04 22:02:36 +08:00
yyh
04e0ab7eda refactor(web): migrate provider-added-card model list to oRPC query-driven state 2026-03-04 21:55:34 +08:00
yyh
784bda9c86 refactor(web): migrate operation-dropdown to base UI and align provider card styles with Figma
- Migrate OperationDropdown from legacy portal-to-follow-elem to base UI DropdownMenu primitives
- Add placement, sideOffset, alignOffset, popupClassName props for flexible positioning
- Fix version badge font size: system-2xs-medium-uppercase (10px) → system-xs-medium-uppercase (12px)
- Set provider card dropdown to bottom-start placement with 192px width per Figma spec
- Fix PluginVersionPicker toggle: clicking badge now opens and closes the picker
- Add max-h-[224px] overflow scroll to version list
- Replace Remix icon imports with Tailwind CSS icon classes
- Prune stale eslint suppressions for migrated files
2026-03-04 21:55:23 +08:00
yyh
1af1fb6913 feat(web): add version badge and actions menu to provider cards
Integrate plugin version management into model provider cards by
reusing existing plugin detail panel hooks and components. Batch
query installed plugins at list level to avoid N+1 requests.
2026-03-04 21:29:52 +08:00
yyh
1f0c36e9f7 fix: style 2026-03-04 21:07:42 +08:00
yyh
455ae65025 fix: style 2026-03-04 20:58:14 +08:00
yyh
d44682e957 refactor(web): align quota panel with Figma design and migrate to base UI tooltip
- Rename title from "Quota" to "AI Credits" and update tooltip copy
  (Message Credits → AI Credits, free → Trial)
- Show "Credits exhausted" in destructive text when credits reach zero
  instead of displaying the number "0"
- Migrate from deprecated Tooltip to base UI Tooltip compound component
- Add 4px grid background with radial fade mask via CSS module
- Simplify provider icon tooltip text for uninstalled state
- Update i18n keys for both en-US and zh-Hans
2026-03-04 20:52:30 +08:00
yyh
8c4afc0c18 fix(model-selector): align empty trigger with default trigger style 2026-03-04 20:14:49 +08:00
yyh
539cbcae6a fix(account-settings): render nested system model backdrop via base ui 2026-03-04 19:57:53 +08:00
yyh
8d257fea7c chore(web): commit dialog overlay follow-up changes 2026-03-04 19:37:10 +08:00
yyh
c3364ac350 refactor(web): align account settings dialogs with base UI 2026-03-04 19:31:14 +08:00
yyh
f991644989 refactor(pricing): migrate to base ui dialog and extract category types 2026-03-04 19:26:54 +08:00
yyh
29e344ac8b temp: remove cloud condition 2026-03-04 18:50:38 +08:00
yyh
1ad9305732 fix(web): avoid quota panel flicker on account-setting tab switch
- remove mount-time workspace invalidate in model provider page

- read quota with useCurrentWorkspace and keep loading only for initial empty fetch

- reuse existing useSystemFeaturesQuery for marketplace and trial models

- update model provider and quota panel tests for new query/loading behavior
2026-03-04 18:43:01 +08:00
yyh
17f38f171d lint 2026-03-04 18:21:59 +08:00
yyh
802088c8eb test(web): fix trivial assertion and add useInvalidateDefaultModel tests
Replace the no-provider test assertion from checking a nonexistent i18n
key to verifying actual warning keys are absent. Add unit tests for
useInvalidateDefaultModel following the useUpdateModelList pattern.
2026-03-04 17:51:20 +08:00
yyh
cad6d94491 refactor(web): replace remixicon imports with Tailwind CSS icons in system-model-selector 2026-03-04 17:45:41 +08:00
yyh
621d0fb2c9 fix 2026-03-04 17:42:34 +08:00
yyh
a92fb3244b fix(web): skip top warning for no-provider state and remove unused i18n key
The empty state card below already prompts users to install a provider,
so the top warning bar is redundant for the no-provider case. Remove
the unused noProviderInstalled i18n key and replace the lookup map with
a ternary to preserve i18n literal types without assertions.
2026-03-04 17:39:49 +08:00
yyh
97508f8d7b fix(web): invalidate default model cache after saving system model settings
After saving system models, only the model list cache was invalidated
but not the default model cache, causing stale config status in the UI.
Add useInvalidateDefaultModel hook and call it for all 5 model types
after a successful save.
2026-03-04 17:26:24 +08:00
yyh
70e677a6ac feat(web): refine system model settings to 4 distinct config states
Replace the single `defaultModelNotConfigured` boolean with a derived
`systemModelConfigStatus` that distinguishes between no-provider,
none-configured, partially-configured, and fully-configured states,
each showing a context-appropriate warning message. Also updates the
button label from "System Model Settings" to "Default Model Settings"
and migrates remixicon imports to Tailwind CSS icon classes.
2026-03-04 16:58:46 +08:00
234 changed files with 9489 additions and 8525 deletions

View File

@@ -1,92 +0,0 @@
from __future__ import annotations
from controllers.console.app import annotation as annotation_module
def test_annotation_reply_payload_valid():
"""Test AnnotationReplyPayload with valid data."""
payload = annotation_module.AnnotationReplyPayload(
score_threshold=0.5,
embedding_provider_name="openai",
embedding_model_name="text-embedding-3-small",
)
assert payload.score_threshold == 0.5
assert payload.embedding_provider_name == "openai"
assert payload.embedding_model_name == "text-embedding-3-small"
def test_annotation_setting_update_payload_valid():
"""Test AnnotationSettingUpdatePayload with valid data."""
payload = annotation_module.AnnotationSettingUpdatePayload(
score_threshold=0.75,
)
assert payload.score_threshold == 0.75
def test_annotation_list_query_defaults():
"""Test AnnotationListQuery with default parameters."""
query = annotation_module.AnnotationListQuery()
assert query.page == 1
assert query.limit == 20
assert query.keyword == ""
def test_annotation_list_query_custom_page():
"""Test AnnotationListQuery with custom page."""
query = annotation_module.AnnotationListQuery(page=3, limit=50)
assert query.page == 3
assert query.limit == 50
def test_annotation_list_query_with_keyword():
"""Test AnnotationListQuery with keyword."""
query = annotation_module.AnnotationListQuery(keyword="test")
assert query.keyword == "test"
def test_create_annotation_payload_with_message_id():
"""Test CreateAnnotationPayload with message ID."""
payload = annotation_module.CreateAnnotationPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
question="What is AI?",
)
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
assert payload.question == "What is AI?"
def test_create_annotation_payload_with_text():
"""Test CreateAnnotationPayload with text content."""
payload = annotation_module.CreateAnnotationPayload(
question="What is ML?",
answer="Machine learning is...",
)
assert payload.question == "What is ML?"
assert payload.answer == "Machine learning is..."
def test_update_annotation_payload():
"""Test UpdateAnnotationPayload."""
payload = annotation_module.UpdateAnnotationPayload(
question="Updated question",
answer="Updated answer",
)
assert payload.question == "Updated question"
assert payload.answer == "Updated answer"
def test_annotation_reply_status_query_enable():
"""Test AnnotationReplyStatusQuery with enable action."""
query = annotation_module.AnnotationReplyStatusQuery(action="enable")
assert query.action == "enable"
def test_annotation_reply_status_query_disable():
"""Test AnnotationReplyStatusQuery with disable action."""
query = annotation_module.AnnotationReplyStatusQuery(action="disable")
assert query.action == "disable"
def test_annotation_file_payload_valid():
"""Test AnnotationFilePayload with valid message ID."""
payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000")
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"

View File

@@ -13,9 +13,6 @@ from pandas.errors import ParserError
from werkzeug.datastructures import FileStorage
from configs import dify_config
from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit
from services.annotation_service import AppAnnotationService
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
class TestAnnotationImportRateLimiting:
@@ -36,6 +33,8 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
"""Test that per-minute rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-minute limit
mock_redis.zcard.side_effect = [
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
@@ -55,6 +54,7 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
"""Test that per-hour rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-hour limit
mock_redis.zcard.side_effect = [
@@ -74,6 +74,7 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
"""Test that requests within limits are allowed."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate being under both limits
mock_redis.zcard.return_value = 2
@@ -109,6 +110,7 @@ class TestAnnotationImportConcurrencyControl:
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
"""Test that concurrent task limit is enforced."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate max concurrent tasks already running
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
@@ -125,6 +127,7 @@ class TestAnnotationImportConcurrencyControl:
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
"""Test that requests within concurrency limits are allowed."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate being under concurrent task limit
mock_redis.zcard.return_value = 1
@@ -139,6 +142,7 @@ class TestAnnotationImportConcurrencyControl:
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
"""Test that old/stale job entries are removed."""
from controllers.console.wraps import annotation_import_concurrency_limit
mock_redis.zcard.return_value = 0
@@ -199,6 +203,7 @@ class TestAnnotationImportServiceValidation:
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too many records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with too many records
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
@@ -224,6 +229,7 @@ class TestAnnotationImportServiceValidation:
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too few valid records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with only header (no data rows)
csv_content = "question,answer\n"
@@ -243,6 +249,7 @@ class TestAnnotationImportServiceValidation:
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
"""Test that invalid CSV format is handled gracefully."""
from services.annotation_service import AppAnnotationService
# Any content is fine once we force ParserError
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
@@ -263,6 +270,7 @@ class TestAnnotationImportServiceValidation:
def test_valid_import_succeeds(self, mock_app, mock_db_session):
"""Test that valid import request succeeds."""
from services.annotation_service import AppAnnotationService
# Create valid CSV
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
@@ -292,10 +300,18 @@ class TestAnnotationImportServiceValidation:
class TestAnnotationImportTaskOptimization:
"""Test optimizations in batch import task."""
def test_task_is_registered_with_queue(self):
"""Test that task is registered with the correct queue."""
assert hasattr(batch_import_annotations_task, "apply_async")
assert hasattr(batch_import_annotations_task, "delay")
def test_task_has_timeout_configured(self):
"""Test that task has proper timeout configuration."""
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
# Verify task configuration
assert hasattr(batch_import_annotations_task, "time_limit")
assert hasattr(batch_import_annotations_task, "soft_time_limit")
# Check timeout values are reasonable
# Hard limit should be 6 minutes (360s)
# Soft limit should be 5 minutes (300s)
# Note: actual values depend on Celery configuration
class TestConfigurationValues:

View File

@@ -1,585 +0,0 @@
"""
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
Target: increase coverage for files with <75% coverage.
"""
from __future__ import annotations
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.app import (
annotation as annotation_module,
)
from controllers.console.app import (
completion as completion_module,
)
from controllers.console.app import (
message as message_module,
)
from controllers.console.app import (
ops_trace as ops_trace_module,
)
from controllers.console.app import (
site as site_module,
)
from controllers.console.app import (
statistic as statistic_module,
)
from controllers.console.app import (
workflow_app_log as workflow_app_log_module,
)
from controllers.console.app import (
workflow_draft_variable as workflow_draft_variable_module,
)
from controllers.console.app import (
workflow_statistic as workflow_statistic_module,
)
from controllers.console.app import (
workflow_trigger as workflow_trigger_module,
)
from controllers.console.app import (
wraps as wraps_module,
)
from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload
from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload
from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery
from controllers.console.app.site import AppSiteUpdatePayload
from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload
from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
from controllers.console.app.workflow_trigger import Parser, ParserEnable
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _query, _args):
return self._rows
# ========== Completion Tests ==========
class TestCompletionEndpoints:
"""Tests for completion API endpoints."""
def test_completion_create_payload(self):
"""Test completion creation payload."""
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
assert payload.inputs == {"prompt": "test"}
def test_chat_message_payload_uuid_validation(self):
payload = ChatMessagePayload(
inputs={},
model_config={},
query="hi",
conversation_id=str(uuid.uuid4()),
parent_message_id=str(uuid.uuid4()),
)
assert payload.query == "hi"
def test_completion_api_success(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: {"text": "ok"},
)
monkeypatch.setattr(
completion_module.helper,
"compact_generate_response",
lambda response: {"result": response},
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
resp = method(app_model=MagicMock(id="app-1"))
assert resp == {"result": {"text": "ok"}}
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(
completion_module.services.errors.conversation.ConversationNotExistsError()
),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(NotFound):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_quota_exceeded(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(app_model=MagicMock(id="app-1"))
# ========== OpsTrace Tests ==========
class TestOpsTraceEndpoints:
"""Tests for ops_trace endpoint."""
def test_ops_trace_query_basic(self):
"""Test ops_trace query."""
query = TraceProviderQuery(tracing_provider="langfuse")
assert query.tracing_provider == "langfuse"
def test_ops_trace_config_payload(self):
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
assert payload.tracing_config["api_key"] == "k"
def test_trace_app_config_get_empty(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.get)
monkeypatch.setattr(
ops_trace_module.OpsService,
"get_tracing_app_config",
lambda **_kwargs: None,
)
with app.test_request_context("/?tracing_provider=langfuse"):
result = method(app_id="app-1")
assert result == {"has_not_configured": True}
def test_trace_app_config_post_invalid(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.post)
monkeypatch.setattr(
ops_trace_module.OpsService,
"create_tracing_app_config",
lambda **_kwargs: {"error": True},
)
with app.test_request_context(
"/",
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
):
with pytest.raises(BadRequest):
method(app_id="app-1")
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.delete)
monkeypatch.setattr(
ops_trace_module.OpsService,
"delete_tracing_app_config",
lambda **_kwargs: False,
)
with app.test_request_context("/?tracing_provider=langfuse"):
with pytest.raises(BadRequest):
method(app_id="app-1")
# ========== Site Tests ==========
class TestSiteEndpoints:
"""Tests for site endpoint."""
def test_site_response_structure(self):
"""Test site response structure."""
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
assert payload.title == "My Site"
def test_site_default_language_validation(self):
payload = AppSiteUpdatePayload(default_language="en-US")
assert payload.default_language == "en-US"
def test_app_site_update_post(self, app, monkeypatch):
api = site_module.AppSite()
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
)
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/", json={"title": "My Site"}):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
def test_app_site_access_token_reset(self, app, monkeypatch):
api = site_module.AppSiteAccessTokenReset()
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
)
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
# ========== Workflow Tests ==========
class TestWorkflowEndpoints:
"""Tests for workflow endpoints."""
def test_workflow_copy_payload(self):
"""Test workflow copy payload."""
payload = SyncDraftWorkflowPayload(graph={}, features={})
assert payload.graph == {}
def test_workflow_mode_query(self):
"""Test workflow mode query."""
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
assert payload.query == "hi"
# ========== Workflow App Log Tests ==========
class TestWorkflowAppLogEndpoints:
"""Tests for workflow app log endpoints."""
def test_workflow_app_log_query(self):
"""Test workflow app log query."""
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
assert query.keyword == "test"
def test_workflow_app_log_query_detail_bool(self):
query = WorkflowAppLogQuery(detail="true")
assert query.detail is True
def test_workflow_app_log_api_get(self, app, monkeypatch):
api = workflow_app_log_module.WorkflowAppLogApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
class DummySession:
def __enter__(self):
return "session"
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession())
def fake_get_paginate(self, **_kwargs):
return {"items": [], "total": 0}
monkeypatch.setattr(
workflow_app_log_module.WorkflowAppService,
"get_paginate_workflow_app_logs",
fake_get_paginate,
)
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
# ========== Workflow Draft Variable Tests ==========
class TestWorkflowDraftVariableEndpoints:
"""Tests for workflow draft variable endpoints."""
def test_workflow_variable_creation(self):
"""Test workflow variable creation."""
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
assert payload.name == "var1"
def test_workflow_variable_collection_get(self, app, monkeypatch):
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
class DummySession:
def __enter__(self):
return "session"
def __exit__(self, exc_type, exc, tb):
return False
class DummyDraftService:
def __init__(self, session):
self.session = session
def list_variables_without_values(self, **_kwargs):
return {"items": [], "total": 0}
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession())
class DummyWorkflowService:
def is_workflow_exist(self, *args, **kwargs):
return True
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService)
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
# ========== Workflow Statistic Tests ==========
class TestWorkflowStatisticEndpoints:
"""Tests for workflow statistic endpoints."""
def test_workflow_statistic_time_range(self):
"""Test workflow statistic time range query."""
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
assert query.start == "2024-01-01"
def test_workflow_statistic_blank_to_none(self):
query = WorkflowStatisticQuery(start="", end="")
assert query.start is None
assert query.end is None
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: SimpleNamespace(
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
# ========== Workflow Trigger Tests ==========
class TestWorkflowTriggerEndpoints:
"""Tests for workflow trigger endpoints."""
def test_webhook_trigger_payload(self):
"""Test webhook trigger payload."""
payload = Parser(node_id="node-1")
assert payload.node_id == "node-1"
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
assert enable_payload.enable_trigger is True
def test_webhook_trigger_api_get(self, app, monkeypatch):
api = workflow_trigger_module.WebhookTriggerApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
trigger = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = trigger
class DummySession:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession())
with app.test_request_context("/?node_id=node-1"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is trigger
# ========== Wraps Tests ==========
class TestWrapsEndpoints:
"""Tests for wraps utility functions."""
def test_get_app_model_context(self):
"""Test get_app_model wrapper context."""
# These are decorator functions, so we test their availability
assert hasattr(wraps_module, "get_app_model")
# ========== MCP Server Tests ==========
class TestMCPServerEndpoints:
"""Tests for MCP server endpoints."""
def test_mcp_server_connection(self):
"""Test MCP server connection."""
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
assert payload.parameters["url"] == "http://localhost:3000"
def test_mcp_server_update_payload(self):
payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active")
assert payload.status == "active"
# ========== Error Handling Tests ==========
class TestErrorHandling:
"""Tests for error handling in various endpoints."""
def test_annotation_list_query_validation(self):
"""Test annotation list query validation."""
with pytest.raises(ValueError):
annotation_module.AnnotationListQuery(page=0)
# ========== Integration-like Tests ==========
class TestPayloadIntegration:
"""Integration tests for payload handling."""
def test_multiple_payload_types(self):
"""Test handling of multiple payload types."""
payloads = [
annotation_module.AnnotationReplyPayload(
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
),
message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"),
statistic_module.StatisticTimeRangeQuery(start="2024-01-01"),
]
assert len(payloads) == 3
assert all(p is not None for p in payloads)

View File

@@ -1,157 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
self.app_id = app_id
def model_dump(self, mode: str = "json"):
return {"status": self.status, "app_id": self.app_id}
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=True)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
update_access.assert_called_once_with("app-123", "private")
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"check_dependencies",
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
)
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
response, status = method(app_model=SimpleNamespace(id="app-1"))
assert status == 200
assert response["leaked_dependencies"] == []

View File

@@ -1,292 +0,0 @@
from __future__ import annotations
import io
from types import SimpleNamespace
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
ProviderNotSupportTextToSpeechLanageServiceError,
UnsupportedAudioTypeServiceError,
)
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _file_data():
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
response = handler(app_model=app_model)
assert response == {"text": "ok"}
@pytest.mark.parametrize(
("exc", "expected"),
[
(AppModelConfigBrokenError(), AppUnavailableError),
(NoAudioUploadedServiceError(), NoAudioUploadedError),
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
(QuotaExceededError(), ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(expected):
handler(app_model=app_model)
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(InternalServerError):
handler(app_model=app_model)
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = ChatMessageTextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
"/console/api/apps/app/text-to-audio",
method="POST",
json={"text": "hello", "voice": "v"},
):
response = handler(app_model=app_model)
assert response == {"audio": "ok"}
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
api = ChatMessageTextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
"/console/api/apps/app/text-to-audio",
method="POST",
json={"text": "hello"},
):
with pytest.raises(ProviderQuotaExceededError):
handler(app_model=app_model)
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
api = TextModesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
response = handler(app_model=app_model)
assert response == ["voice-1"]
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService,
"transcript_tts_voices",
lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()),
)
api = TextModesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
with pytest.raises(AppUnavailableError):
handler(app_model=app_model)
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(
AudioService,
"transcript_asr",
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices",
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(
AudioService,
"transcript_tts_voices",
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
)
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
assert isinstance(response, list)

View File

@@ -1,156 +0,0 @@
from __future__ import annotations
import io
from types import SimpleNamespace
import pytest
from controllers.console.app import audio as audio_module
from controllers.console.app.error import AudioTooLargeError
from services.errors.audio import AudioTooLargeServiceError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(
audio_module.AudioService,
"transcript_asr",
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices",
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(
audio_module.AudioService,
"transcript_tts_voices",
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
)
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
assert isinstance(response, list)

View File

@@ -1,130 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.app import conversation as conversation_module
from models.model import AppMode
from services.errors.conversation import ConversationNotExistsError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _make_account():
return SimpleNamespace(timezone="UTC", id="u1")
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response is paginate_result
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(
conversation_module,
"parse_time_range",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")),
)
with app.test_request_context(
"/console/api/apps/app-1/completion-conversations",
method="GET",
query_string={"start": "bad"},
):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.ChatConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
assert response is paginate_result
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
conversation = SimpleNamespace(id="c1", app_id="app-1")
query = MagicMock()
query.where.return_value = query
query.first.return_value = conversation
session = MagicMock()
session.query.return_value = query
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
assert result is conversation
session.execute.assert_called_once()
session.commit.assert_called_once()
session.refresh.assert_called_once_with(conversation)
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
session = MagicMock()
session.query.return_value = query
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
with pytest.raises(NotFound):
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationDetailApi()
method = _unwrap(api.delete)
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(
conversation_module.ConversationService,
"delete",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
with pytest.raises(NotFound):
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")

View File

@@ -1,260 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from controllers.console.app import generator as generator_module
from controllers.console.app.error import ProviderNotInitializeError
from core.errors.error import ProviderTokenNotInitError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _model_config_payload():
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
class _Service:
def get_draft_workflow(self, app_model):
return workflow
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
with app.test_request_context(
"/console/api/rule-generate",
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
response = method()
assert response == {"rules": []}
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleCodeGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
def _raise(*_args, **_kwargs):
raise ProviderTokenNotInitError("missing token")
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise)
with app.test_request_context(
"/console/api/rule-code-generate",
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
with pytest.raises(ProviderNotInitializeError):
method()
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "app app-1 not found"
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
_install_workflow_service(monkeypatch, workflow=None)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "workflow app-1 not found"
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
workflow = SimpleNamespace(graph_dict={"nodes": []})
_install_workflow_service(monkeypatch, workflow=workflow)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "node node-1 not found"
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
workflow = SimpleNamespace(
graph_dict={
"nodes": [
{"id": "node-1", "data": {"type": "code"}},
]
}
)
_install_workflow_service(monkeypatch, workflow=workflow)
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response = method()
assert response == {"code": "x"}
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(
generator_module.LLMGenerator,
"instruction_modify_legacy",
lambda **_kwargs: {"instruction": "ok"},
)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "",
"current": "old",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response = method()
assert response == {"instruction": "ok"}
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "",
"current": "",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "incompatible parameters"
def test_instruction_template_prompt(app) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
method="POST",
json={"type": "prompt"},
):
response = method()
assert "data" in response
def test_instruction_template_invalid_type(app) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
method="POST",
json={"type": "unknown"},
):
with pytest.raises(ValueError):
method()

View File

@@ -1,122 +0,0 @@
from __future__ import annotations
import pytest
from controllers.console.app import message as message_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test valid ChatMessagesQuery with all fields."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
first_id="550e8400-e29b-41d4-a716-446655440001",
limit=50,
)
assert query.limit == 50
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery with defaults."""
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
assert query.first_id is None
assert query.limit == 20
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery converts empty first_id to None."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
first_id="",
)
assert query.first_id is None
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with like rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
rating="like",
content="Good answer",
)
assert payload.rating == "like"
assert payload.content == "Good answer"
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with dislike rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
rating="dislike",
)
assert payload.rating == "dislike"
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload without rating."""
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
assert payload.rating is None
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with default format."""
query = message_module.FeedbackExportQuery()
assert query.format == "csv"
assert query.from_source is None
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with JSON format."""
query = message_module.FeedbackExportQuery(format="json")
assert query.format == "json"
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as true string."""
query = message_module.FeedbackExportQuery(has_comment="true")
assert query.has_comment is True
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as false string."""
query = message_module.FeedbackExportQuery(has_comment="false")
assert query.has_comment is False
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 1."""
query = message_module.FeedbackExportQuery(has_comment="1")
assert query.has_comment is True
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 0."""
query = message_module.FeedbackExportQuery(has_comment="0")
assert query.has_comment is False
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with rating filter."""
query = message_module.FeedbackExportQuery(rating="like")
assert query.rating == "like"
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test AnnotationCountResponse creation."""
response = message_module.AnnotationCountResponse(count=10)
assert response.count == 10
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test SuggestedQuestionsResponse creation."""
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
assert len(response.data) == 2
assert response.data[0] == "What is AI?"

View File

@@ -1,151 +0,0 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import model_config as model_config_module
from models.model import AppMode, AppModelConfig
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
mode=AppMode.CHAT.value,
is_agent=False,
app_model_config_id=None,
updated_by=None,
updated_at=None,
)
monkeypatch.setattr(
model_config_module.AppModelConfigService,
"validate_configuration",
lambda **_kwargs: {"pre_prompt": "hi"},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
session = MagicMock()
monkeypatch.setattr(model_config_module.db, "session", session)
def _from_model_config_dict(self, model_config):
self.pre_prompt = model_config["pre_prompt"]
self.id = "config-1"
return self
monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict)
send_mock = MagicMock()
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
session.add.assert_called_once()
session.flush.assert_called_once()
session.commit.assert_called_once()
send_mock.assert_called_once()
assert app_model.app_model_config_id == "config-1"
assert response["result"] == "success"
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
mode=AppMode.AGENT_CHAT.value,
is_agent=True,
app_model_config_id="config-0",
updated_by=None,
updated_at=None,
)
original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1")
original_config.agent_mode = json.dumps(
{
"enabled": True,
"strategy": "function-calling",
"tools": [
{
"provider_id": "provider",
"provider_type": "builtin",
"tool_name": "tool",
"tool_parameters": {"secret": "masked"},
}
],
"prompt": None,
}
)
session = MagicMock()
query = MagicMock()
query.where.return_value = query
query.first.return_value = original_config
session.query.return_value = query
monkeypatch.setattr(model_config_module.db, "session", session)
monkeypatch.setattr(
model_config_module.AppModelConfigService,
"validate_configuration",
lambda **_kwargs: {
"pre_prompt": "hi",
"agent_mode": {
"enabled": True,
"strategy": "function-calling",
"tools": [
{
"provider_id": "provider",
"provider_type": "builtin",
"tool_name": "tool",
"tool_parameters": {"secret": "masked"},
}
],
"prompt": None,
},
},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
class _ParamManager:
def __init__(self, **_kwargs):
self.delete_called = False
def decrypt_tool_parameters(self, _value):
return {"secret": "decrypted"}
def mask_tool_parameters(self, _value):
return {"secret": "masked"}
def encrypt_tool_parameters(self, _value):
return {"secret": "encrypted"}
def delete_tool_parameters_cache(self):
self.delete_called = True
monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager)
send_mock = MagicMock()
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
stored_config = session.add.call_args[0][0]
stored_agent_mode = json.loads(stored_config.agent_mode)
assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted"
assert response["result"] == "success"

View File

@@ -1,215 +0,0 @@
from __future__ import annotations
from decimal import Decimal
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import BadRequest
from controllers.console.app import statistic as statistic_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _query, _args):
return self._rows
def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
engine = SimpleNamespace(begin=lambda: _ConnContext(rows))
monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine))
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 1
assert data["data"][0]["date"] == "2024-01-03"
assert data["data"][0]["token_count"] == 10
assert data["data"][0]["total_price"] == 0.25
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTerminalsStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
# This just verifies the decorator is applied correctly
# Actual endpoint testing would require complex JOIN mocking
api = statistic_module.AverageSessionInteractionStatistic()
method = _unwrap(api.get)
assert callable(method)
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
def mock_parse(*args, **kwargs):
raise ValueError("Invalid time range")
_install_db(monkeypatch, [])
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", message_count=10),
SimpleNamespace(date="2024-01-02", message_count=15),
SimpleNamespace(date="2024-01-03", message_count=12),
]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 3
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
_install_common(monkeypatch)
_install_db(monkeypatch, [])
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": []}
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_db(monkeypatch, rows)
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: ("s", "e"),
)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"),
]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 2

View File

@@ -1,163 +0,0 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import HTTPException, NotFound
from controllers.console.app import workflow as workflow_module
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.file.models import File
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
config = object()
file_list = [
File(
tenant_id="t1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)
]
build_mock = Mock(return_value=file_list)
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config)
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
assert result == file_list
build_mock.assert_called_once()
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
with pytest.raises(HTTPException) as exc:
handler(api, app_model=SimpleNamespace(id="app"))
assert exc.value.code == 415
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
data="[]",
content_type="application/json",
):
response, status = handler(api, app_model=SimpleNamespace(id="app"))
assert status == 400
assert response["message"] == "Invalid JSON data"
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(
unique_hash="h",
updated_at=None,
created_at=datetime(2024, 1, 1),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
monkeypatch.setattr(
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
)
monkeypatch.setattr(
workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv"
)
service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow)
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
json={"graph": {}, "features": {}, "hash": "h"},
):
response = handler(api, app_model=SimpleNamespace(id="app"))
assert response["result"] == "success"
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
def _raise(*_args, **_kwargs):
raise workflow_module.WorkflowHashNotEqualError()
service = SimpleNamespace(sync_draft_workflow=_raise)
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
json={"graph": {}, "features": {}, "hash": "h"},
):
with pytest.raises(DraftWorkflowNotSync):
handler(api, app_model=SimpleNamespace(id="app"))
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.get)
with pytest.raises(DraftWorkflowNotExist):
handler(api, app_model=SimpleNamespace(id="app"))
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module.AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
workflow_module.services.errors.conversation.ConversationNotExistsError()
),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/advanced-chat/workflows/draft/run",
method="POST",
json={"inputs": {}},
):
with pytest.raises(NotFound):
handler(api, app_model=SimpleNamespace(id="app"))

View File

@@ -1,47 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from controllers.console.app import wraps as wraps_module
from controllers.console.app.error import AppNotFoundError
from models.model import AppMode
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
@wraps_module.get_app_model
def handler(app_model):
return app_model.id
assert handler(app_id="app-1") == "app-1"
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
def handler(app_model):
return app_model.id
with pytest.raises(AppNotFoundError):
handler(app_id="app-1")
def test_get_app_model_requires_app_id() -> None:
@wraps_module.get_app_model
def handler(app_model):
return app_model.id
with pytest.raises(ValueError):
handler()

View File

@@ -1,483 +1,13 @@
"""Final working unit tests for admin endpoints - tests business logic directly."""
import uuid
from unittest.mock import Mock, PropertyMock, patch
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.console.admin import (
DeleteExploreBannerApi,
InsertExploreAppApi,
InsertExploreAppListApi,
InsertExploreAppPayload,
InsertExploreBannerApi,
InsertExploreBannerPayload,
)
from models.model import App, InstalledApp, RecommendedApp
@pytest.fixture(autouse=True)
def bypass_only_edition_cloud(mocker):
"""
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
"""
mocker.patch(
"controllers.console.wraps.dify_config.EDITION",
new="CLOUD",
)
@pytest.fixture
def mock_admin_auth(mocker):
"""
Provide valid admin authentication for controller tests.
"""
mocker.patch(
"controllers.console.admin.dify_config.ADMIN_API_KEY",
"test-admin-key",
)
mocker.patch(
"controllers.console.admin.extract_access_token",
return_value="test-admin-key",
)
@pytest.fixture
def mock_console_payload(mocker):
payload = {
"app_id": str(uuid.uuid4()),
"language": "en-US",
"category": "Productivity",
"position": 1,
}
mocker.patch(
"flask_restx.namespace.Namespace.payload",
new_callable=PropertyMock,
return_value=payload,
)
return payload
@pytest.fixture
def mock_banner_payload(mocker):
mocker.patch(
"flask_restx.namespace.Namespace.payload",
new_callable=PropertyMock,
return_value={
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
},
)
@pytest.fixture
def mock_session_factory(mocker):
mock_session = Mock()
mock_session.execute = Mock()
mock_session.add = Mock()
mock_session.commit = Mock()
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
class TestDeleteExploreBannerApi:
def setup_method(self):
self.api = DeleteExploreBannerApi()
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
)
with pytest.raises(NotFound, match="is not found"):
self.api.delete(uuid.uuid4())
def test_delete_banner_success(self, mocker, mock_admin_auth):
mock_banner = Mock()
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: mock_banner),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(uuid.uuid4())
assert status == 204
assert response["result"] == "success"
class TestInsertExploreBannerApi:
def setup_method(self):
self.api = InsertExploreBannerApi()
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 201
assert response["result"] == "success"
def test_banner_payload_valid_language(self):
payload = {
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
"language": "en-US",
}
model = InsertExploreBannerPayload.model_validate(payload)
assert model.language == "en-US"
def test_banner_payload_invalid_language(self):
payload = {
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
"language": "invalid-lang",
}
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
InsertExploreBannerPayload.model_validate(payload)
class TestInsertExploreAppApiDelete:
def setup_method(self):
self.api = InsertExploreAppApi()
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: s,
__exit__=Mock(return_value=False),
execute=lambda *_: Mock(scalar_one_or_none=lambda: None),
),
)
response, status = self.api.delete(uuid.uuid4())
assert status == 204
assert response["result"] == "success"
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
"""Test deleting an app from explore that has a trial app."""
app_id = uuid.uuid4()
mock_recommended = Mock(spec=RecommendedApp)
mock_recommended.app_id = "app-123"
mock_app = Mock(spec=App)
mock_app.is_public = True
mock_trial = Mock()
# Mock session context manager and its execute
mock_session = Mock()
mock_session.execute = Mock()
mock_session.delete = Mock()
# Set up side effects for execute calls
mock_session.execute.side_effect = [
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalars=Mock(return_value=Mock(all=lambda: []))),
Mock(scalar_one_or_none=lambda: mock_trial),
]
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(app_id)
assert status == 204
assert response["result"] == "success"
assert mock_app.is_public is False
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
"""Test deleting an app that has installed apps in other tenants."""
app_id = uuid.uuid4()
mock_recommended = Mock(spec=RecommendedApp)
mock_recommended.app_id = "app-123"
mock_app = Mock(spec=App)
mock_app.is_public = True
mock_installed_app = Mock(spec=InstalledApp)
# Mock session
mock_session = Mock()
mock_session.execute = Mock()
mock_session.delete = Mock()
mock_session.execute.side_effect = [
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))),
Mock(scalar_one_or_none=lambda: None),
]
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(app_id)
assert status == 204
assert mock_session.delete.called
class TestInsertExploreAppListApi:
def setup_method(self):
self.api = InsertExploreAppListApi()
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
)
with pytest.raises(NotFound, match="is not found"):
self.api.post()
def test_create_recommended_app(
self,
mocker,
mock_admin_auth,
mock_console_payload,
):
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.tenant_id = "tenant"
mock_app.is_public = False
# db.session.execute → fetch App
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: mock_app),
)
# session_factory.create_session → recommended_app lookup
mock_session = Mock()
mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None))
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 201
assert response["result"] == "success"
assert mock_app.is_public is True
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
],
)
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
def test_site_data_overrides_payload(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
site = Mock()
site.description = "Site Desc"
site.copyright = "Site Copyright"
site.privacy_policy = "Site Privacy"
site.custom_disclaimer = "Site Disclaimer"
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = site
mock_app.tenant_id = "tenant"
mock_app.is_public = False
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: None),
Mock(scalar_one_or_none=lambda: None),
],
)
commit_spy = mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
commit_spy.assert_called_once()
def test_create_trial_app_when_can_trial_enabled(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
mock_console_payload["can_trial"] = True
mock_console_payload["trial_limit"] = 5
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.tenant_id = "tenant"
mock_app.is_public = False
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: None),
Mock(scalar_one_or_none=lambda: None),
],
)
add_spy = mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
self.api.post()
assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list)
def test_update_recommended_app_with_trial(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
"""Test updating a recommended app when trial is enabled."""
mock_console_payload["can_trial"] = True
mock_console_payload["trial_limit"] = 10
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_app.tenant_id = "tenant-123"
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: None),
],
)
add_spy = mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
def test_update_recommended_app_without_trial(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
"""Test updating a recommended app without trial enabled."""
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
],
)
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
from controllers.console.admin import InsertExploreAppPayload
from models.model import App, RecommendedApp
class TestInsertExploreAppPayload:

View File

@@ -1,138 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.apikey import (
BaseApiKeyListResource,
BaseApiKeyResource,
_get_resource,
)
@pytest.fixture
def tenant_context_admin():
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
user = MagicMock()
user.is_admin_or_owner = True
mock.return_value = (user, "tenant-123")
yield mock
@pytest.fixture
def tenant_context_non_admin():
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
user = MagicMock()
user.is_admin_or_owner = False
mock.return_value = (user, "tenant-123")
yield mock
@pytest.fixture
def db_mock():
with patch("controllers.console.apikey.db") as mock_db:
mock_db.session = MagicMock()
yield mock_db
@pytest.fixture(autouse=True)
def bypass_permissions():
with patch(
"controllers.console.apikey.edit_permission_required",
lambda f: f,
):
yield
class DummyApiKeyListResource(BaseApiKeyListResource):
resource_type = "app"
resource_model = MagicMock()
resource_id_field = "app_id"
token_prefix = "app-"
class DummyApiKeyResource(BaseApiKeyResource):
resource_type = "app"
resource_model = MagicMock()
resource_id_field = "app_id"
class TestGetResource:
def test_get_resource_success(self):
fake_resource = MagicMock()
with (
patch("controllers.console.apikey.select") as mock_select,
patch("controllers.console.apikey.Session") as mock_session,
patch("controllers.console.apikey.db") as mock_db,
):
mock_db.engine = MagicMock()
mock_select.return_value.filter_by.return_value = MagicMock()
session = mock_session.return_value.__enter__.return_value
session.execute.return_value.scalar_one_or_none.return_value = fake_resource
result = _get_resource("rid", "tid", MagicMock)
assert result == fake_resource
def test_get_resource_not_found(self):
with (
patch("controllers.console.apikey.select") as mock_select,
patch("controllers.console.apikey.Session") as mock_session,
patch("controllers.console.apikey.db") as mock_db,
patch("controllers.console.apikey.flask_restx.abort") as abort,
):
mock_db.engine = MagicMock()
mock_select.return_value.filter_by.return_value = MagicMock()
session = mock_session.return_value.__enter__.return_value
session.execute.return_value.scalar_one_or_none.return_value = None
_get_resource("rid", "tid", MagicMock)
abort.assert_called_once()
class TestBaseApiKeyListResource:
def test_get_apikeys_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyListResource()
with patch("controllers.console.apikey._get_resource"):
db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()]
result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id")
assert "items" in result
class TestBaseApiKeyResource:
def test_delete_forbidden(self, tenant_context_non_admin, db_mock):
resource = DummyApiKeyResource()
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Forbidden):
DummyApiKeyResource.delete(resource, "rid", "kid")
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = None
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Exception) as exc_info:
DummyApiKeyResource.delete(resource, "rid", "kid")
# flask_restx.abort raises HTTPException with message in data attribute
assert exc_info.value.data["message"] == "API key not found"
def test_delete_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
with (
patch("controllers.console.apikey._get_resource"),
patch("controllers.console.apikey.ApiTokenCache.delete"),
):
result, status = DummyApiKeyResource.delete(resource, "rid", "kid")
assert status == 204
assert result == {"result": "success"}
db_mock.session.commit.assert_called_once()

View File

@@ -0,0 +1,46 @@
import builtins
from unittest.mock import patch
import pytest
from flask import Flask
from flask.views import MethodView
from extensions import ext_fastopenapi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.secret_key = "test-secret-key"
return app
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.delenv("INIT_PASSWORD", raising=False)
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
client = app.test_client()
response = client.get("/console/api/init")
assert response.status_code == 200
assert response.get_json() == {"status": "finished"}
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
with (
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
):
client = app.test_client()
response = client.post("/console/api/init", json={"password": "test-init-password"})
assert response.status_code == 201
assert response.get_json() == {"result": "success"}

View File

@@ -0,0 +1,286 @@
"""Tests for remote file upload API endpoints using Flask-RESTX."""
import contextlib
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
import httpx
import pytest
from flask import Flask, g
@pytest.fixture
def app() -> Flask:
"""Create Flask app for testing."""
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SECRET_KEY"] = "test-secret-key"
return app
@pytest.fixture
def client(app):
"""Create test client with console blueprint registered."""
from controllers.console import bp
app.register_blueprint(bp)
return app.test_client()
@pytest.fixture
def mock_account():
"""Create a mock account for testing."""
from models import Account
account = Mock(spec=Account)
account.id = "test-account-id"
account.current_tenant_id = "test-tenant-id"
return account
@pytest.fixture
def auth_ctx(app, mock_account):
"""Context manager to set auth/tenant context in flask.g for a request."""
@contextlib.contextmanager
def _ctx():
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
yield
return _ctx
class TestGetRemoteFileInfo:
"""Test GET /console/api/remote-files/<path:url> endpoint."""
def test_get_remote_file_info_success(self, app, client, mock_account):
"""Test successful retrieval of remote file info."""
response = httpx.Response(
200,
request=httpx.Request("HEAD", "http://example.com/file.txt"),
headers={"Content-Type": "text/plain", "Content-Length": "1024"},
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response),
patch("libs.login.check_csrf_token", return_value=None),
):
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
data = resp.get_json()
assert data["file_type"] == "text/plain"
assert data["file_length"] == 1024
def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account):
"""Test fallback to GET when HEAD returns non-200 status."""
head_response = httpx.Response(
404,
request=httpx.Request("HEAD", "http://example.com/file.pdf"),
)
get_response = httpx.Response(
200,
request=httpx.Request("GET", "http://example.com/file.pdf"),
headers={"Content-Type": "application/pdf", "Content-Length": "2048"},
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response),
patch("libs.login.check_csrf_token", return_value=None),
):
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
data = resp.get_json()
assert data["file_type"] == "application/pdf"
assert data["file_length"] == 2048
class TestRemoteFileUpload:
"""Test POST /console/api/remote-files/upload endpoint."""
@pytest.mark.parametrize(
("head_status", "use_get"),
[
(200, False), # HEAD succeeds
(405, True), # HEAD fails -> fallback GET
],
)
def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get):
url = "http://example.com/file.pdf"
head_resp = httpx.Response(
head_status,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
)
get_resp = httpx.Response(
200,
request=httpx.Request("GET", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
content=b"file content",
)
file_info = SimpleNamespace(
extension="pdf",
size=1024,
filename="file.pdf",
mimetype="application/pdf",
)
uploaded_file = SimpleNamespace(
id="uploaded-file-id",
name="file.pdf",
size=1024,
extension="pdf",
mime_type="application/pdf",
created_by="test-account-id",
created_at=datetime(2024, 1, 1, 12, 0, 0),
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head,
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get,
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=True,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("controllers.console.remote_files.FileService") as mock_file_service,
patch(
"controllers.console.remote_files.file_helpers.get_signed_file_url",
return_value="http://example.com/signed-url",
),
patch("libs.login.check_csrf_token", return_value=None),
):
mock_file_service.return_value.upload_file.return_value = uploaded_file
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
assert resp.status_code == 201
p_head.assert_called_once()
# GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds
p_get.assert_called_once()
mock_file_service.return_value.upload_file.assert_called_once()
data = resp.get_json()
assert data["id"] == "uploaded-file-id"
assert data["name"] == "file.pdf"
assert data["size"] == 1024
assert data["extension"] == "pdf"
assert data["url"] == "http://example.com/signed-url"
assert data["mime_type"] == "application/pdf"
assert data["created_by"] == "test-account-id"
@pytest.mark.parametrize(
("size_ok", "raises", "expected_status", "expected_msg"),
[
# When size check fails in controller, API returns 413 with message "File size exceeded..."
(False, None, 413, "file size exceeded"),
# When service raises unsupported type, controller maps to 415 with message "File type not allowed."
(True, "unsupported", 415, "file type not allowed"),
],
)
def test_upload_remote_file_errors(
self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg
):
url = "http://example.com/x.pdf"
head_resp = httpx.Response(
200,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "9"},
)
file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf")
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp),
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=size_ok,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("libs.login.check_csrf_token", return_value=None),
):
if raises == "unsupported":
from services.errors.file import UnsupportedFileTypeError
with patch("controllers.console.remote_files.FileService") as mock_file_service:
mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad")
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
else:
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
assert resp.status_code == expected_status
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert expected_msg in msg.lower()
def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx):
"""Test upload when fetching of remote file fails."""
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch(
"controllers.console.remote_files.ssrf_proxy.head",
side_effect=httpx.RequestError("Connection failed"),
),
patch("libs.login.check_csrf_token", return_value=None),
):
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": "http://unreachable.com/file.pdf"},
)
assert resp.status_code == 400
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert "failed to fetch" in msg.lower()

View File

@@ -1,81 +0,0 @@
from werkzeug.exceptions import Unauthorized
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestFeatureApi:
def test_get_tenant_features_success(self, mocker):
from controllers.console.feature import FeatureApi
mocker.patch(
"controllers.console.feature.current_account_with_tenant",
return_value=("account_id", "tenant_123"),
)
mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = {
"features": {"feature_a": True}
}
api = FeatureApi()
raw_get = unwrap(FeatureApi.get)
result = raw_get(api)
assert result == {"features": {"feature_a": True}}
class TestSystemFeatureApi:
def test_get_system_features_authenticated(self, mocker):
"""
current_user.is_authenticated == True
"""
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
fake_user.is_authenticated = True
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": True}}
def test_get_system_features_unauthenticated(self, mocker):
"""
current_user.is_authenticated raises Unauthorized
"""
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": False}}

View File

@@ -1,300 +0,0 @@
import io
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.files import (
FileApi,
FilePreviewApi,
FileSupportTypeApi,
)
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask(__name__)
app.testing = True
return app
@pytest.fixture(autouse=True)
def mock_decorators():
"""
Make decorators no-ops so logic is directly testable
"""
with (
patch("controllers.console.files.setup_required", new=lambda f: f),
patch("controllers.console.files.login_required", new=lambda f: f),
patch("controllers.console.files.account_initialization_required", new=lambda f: f),
patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f),
):
yield
@pytest.fixture
def mock_current_user():
user = MagicMock()
user.is_dataset_editor = True
return user
@pytest.fixture
def mock_account_context(mock_current_user):
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
yield
@pytest.fixture
def mock_db():
with patch("controllers.console.files.db") as db_mock:
db_mock.engine = MagicMock()
yield db_mock
@pytest.fixture
def mock_file_service(mock_db):
with patch("controllers.console.files.FileService") as fs:
instance = fs.return_value
yield instance
class TestFileApiGet:
def test_get_upload_config(self, app):
api = FileApi()
get_method = unwrap(api.get)
with app.test_request_context():
data, status = get_method(api)
assert status == 200
assert "file_size_limit" in data
assert "batch_count_limit" in data
class TestFileApiPost:
def test_no_file_uploaded(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST", data={}):
with pytest.raises(NoFileUploadedError):
post_method(api)
def test_too_many_files(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST"):
from unittest.mock import MagicMock, patch
with patch("controllers.console.files.request") as mock_request:
mock_request.files = MagicMock()
mock_request.files.__len__.return_value = 2
mock_request.files.__contains__.return_value = True
mock_request.form = MagicMock()
mock_request.form.get.return_value = None
with pytest.raises(TooManyFilesError):
post_method(api)
def test_filename_missing(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
data = {
"file": (io.BytesIO(b"abc"), ""),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(FilenameNotExistsError):
post_method(api)
def test_dataset_upload_without_permission(self, app, mock_current_user):
mock_current_user.is_dataset_editor = False
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
api = FileApi()
post_method = unwrap(api.post)
data = {
"file": (io.BytesIO(b"abc"), "test.txt"),
"source": "datasets",
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(Forbidden):
post_method(api)
def test_successful_upload(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
mock_file = MagicMock()
mock_file.id = "file-id-123"
mock_file.filename = "test.txt"
mock_file.name = "test.txt"
mock_file.size = 1024
mock_file.extension = "txt"
mock_file.mime_type = "text/plain"
mock_file.created_by = "user-123"
mock_file.created_at = 1234567890
mock_file.preview_url = "http://example.com/preview/file-id-123"
mock_file.source_url = "http://example.com/source/file-id-123"
mock_file.original_url = None
mock_file.user_id = "user-123"
mock_file.tenant_id = "tenant-123"
mock_file.conversation_id = None
mock_file.file_key = "file-key-123"
mock_file_service.upload_file.return_value = mock_file
data = {
"file": (io.BytesIO(b"hello"), "test.txt"),
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
assert status == 201
assert response["id"] == "file-id-123"
assert response["name"] == "test.txt"
def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service):
"""Test that invalid source parameter gets normalized to None"""
api = FileApi()
post_method = unwrap(api.post)
# Create a properly structured mock file object
mock_file = MagicMock()
mock_file.id = "file-id-456"
mock_file.filename = "test.txt"
mock_file.name = "test.txt"
mock_file.size = 512
mock_file.extension = "txt"
mock_file.mime_type = "text/plain"
mock_file.created_by = "user-456"
mock_file.created_at = 1234567890
mock_file.preview_url = None
mock_file.source_url = None
mock_file.original_url = None
mock_file.user_id = "user-456"
mock_file.tenant_id = "tenant-456"
mock_file.conversation_id = None
mock_file.file_key = "file-key-456"
mock_file_service.upload_file.return_value = mock_file
data = {
"file": (io.BytesIO(b"content"), "test.txt"),
"source": "invalid_source", # Should be normalized to None
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
assert status == 201
assert response["id"] == "file-id-456"
# Verify that FileService was called with source=None
mock_file_service.upload_file.assert_called_once()
call_kwargs = mock_file_service.upload_file.call_args[1]
assert call_kwargs["source"] is None
def test_file_too_large_error(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
error = ServiceFileTooLargeError("File is too large")
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x" * 1000000), "big.txt"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(FileTooLargeError):
post_method(api)
def test_unsupported_file_type(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
error = ServiceUnsupportedFileTypeError()
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x"), "bad.exe"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(UnsupportedFileTypeError):
post_method(api)
def test_blocked_extension(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError
error = ServiceBlockedFileExtensionError("File extension is blocked")
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x"), "blocked.txt"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(BlockedFileExtensionError):
post_method(api)
class TestFilePreviewApi:
def test_get_preview(self, app, mock_file_service):
api = FilePreviewApi()
get_method = unwrap(api.get)
mock_file_service.get_file_preview.return_value = "preview text"
with app.test_request_context():
result = get_method(api, "1234")
assert result == {"content": "preview text"}
class TestFileSupportTypeApi:
def test_get_supported_types(self, app):
api = FileSupportTypeApi()
get_method = unwrap(api.get)
with app.test_request_context():
result = get_method(api)
assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@@ -1,293 +0,0 @@
from __future__ import annotations
import json
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Response
from controllers.console.human_input_form import (
ConsoleHumanInputFormApi,
ConsoleWorkflowEventsApi,
DifyAPIRepositoryFactory,
WorkflowResponseConverter,
_jsonify_form_definition,
)
from controllers.web.error import NotFoundError
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_jsonify_form_definition() -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration)
response = _jsonify_form_definition(form)
assert isinstance(response, Response)
payload = json.loads(response.get_data(as_text=True))
assert payload["expiration_time"] == int(expiration.timestamp())
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1")
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
with pytest.raises(NotFoundError):
ConsoleHumanInputFormApi._ensure_console_access(form)
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_definition_by_token_for_console(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
response = handler(api, form_token="token")
payload = json.loads(response.get_data(as_text=True))
assert payload["fields"] == ["a"]
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_definition_by_token_for_console(self, _token):
return None
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
def submit_form_by_token(self, **kwargs):
submit_mock(**kwargs)
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
response = handler(api, form_token="token")
assert response.get_json() == {}
submit_mock.assert_called_once()
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return None
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
tenant_id="t1",
)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-2",
tenant_id="t1",
)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-1",
tenant_id="t1",
app_id="app-1",
finished_at=datetime(2024, 1, 1, tzinfo=UTC),
)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
response_obj = SimpleNamespace(
event=SimpleNamespace(value="finished"),
model_dump=lambda mode="json": {"status": "done"},
)
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form._retrieve_app_for_workflow_run",
lambda *_args, **_kwargs: app_model,
)
monkeypatch.setattr(
WorkflowResponseConverter,
"workflow_run_result_to_finish_response",
lambda **_kwargs: response_obj,
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
response = handler(api, workflow_run_id="run-1")
assert response.mimetype == "text/event-stream"
assert "data" in response.get_data(as_text=True)

View File

@@ -1,108 +0,0 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from controllers.console import init_validate
from controllers.console.error import AlreadySetupError, InitValidateFailedError
class _SessionStub:
def __init__(self, has_setup: bool):
self._has_setup = has_setup
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, *_args, **_kwargs):
return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None)
def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True)
result = init_validate.get_init_status()
assert result.status == "finished"
def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False)
result = init_validate.get_init_status()
assert result.status == "not_started"
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
with pytest.raises(AlreadySetupError):
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
with pytest.raises(InitValidateFailedError):
init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong"))
assert init_validate.session.get("is_init_validated") is False
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected"))
assert result.result == "success"
assert init_validate.session.get("is_init_validated") is True
def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD")
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session["is_init_validated"] = True
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session.pop("is_init_validated", None)
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session.pop("is_init_validated", None)
assert init_validate.get_init_validate_status() is False

View File

@@ -1,281 +0,0 @@
from __future__ import annotations
import urllib.parse
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import httpx
import pytest
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
from controllers.console import remote_files as remote_files_module
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class _FakeResponse:
def __init__(
self,
*,
status_code: int = 200,
headers: dict[str, str] | None = None,
method: str = "GET",
content: bytes = b"",
text: str = "",
error: Exception | None = None,
) -> None:
self.status_code = status_code
self.headers = headers or {}
self.request = SimpleNamespace(method=method)
self.content = content
self.text = text
self._error = error
def raise_for_status(self) -> None:
if self._error:
raise self._error
def _mock_upload_dependencies(
monkeypatch: pytest.MonkeyPatch,
*,
file_size_within_limit: bool = True,
):
file_info = SimpleNamespace(
filename="report.txt",
extension=".txt",
mimetype="text/plain",
size=3,
)
monkeypatch.setattr(
remote_files_module.helpers,
"guess_file_info_from_response",
MagicMock(return_value=file_info),
)
file_service_cls = MagicMock()
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
remote_files_module.file_helpers,
"get_signed_file_url",
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
)
return file_service_cls
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
head_resp = _FakeResponse(
status_code=200,
headers={"Content-Type": "text/plain", "Content-Length": "128"},
method="HEAD",
)
head_mock = MagicMock(return_value=head_resp)
get_mock = MagicMock()
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "text/plain", "file_length": 128}
head_mock.assert_called_once_with(decoded_url)
get_mock.assert_not_called()
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "application/octet-stream", "file_length": 0}
get_mock.assert_called_once_with(decoded_url, timeout=3)
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/report.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
get_mock = MagicMock(return_value=get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-1",
name="report.txt",
size=16,
extension=".txt",
mime_type="text/plain",
created_by="u1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
assert status == 201
assert payload["id"] == "file-1"
assert payload["url"] == "https://signed.example/file-1"
get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True)
file_service_cls.return_value.upload_file.assert_called_once_with(
filename="report.txt",
content=b"fallback-content",
mimetype="text/plain",
user=SimpleNamespace(id="u1"),
source_url=url,
)
def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/photo.jpg"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
)
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
get_mock = MagicMock(return_value=extra_get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-2",
name="photo.jpg",
size=18,
extension=".jpg",
mime_type="image/jpeg",
created_by="u1",
created_at=datetime(2024, 1, 2, tzinfo=UTC),
)
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
assert status == 201
assert payload["id"] == "file-2"
get_mock.assert_called_once_with(url)
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/fail.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"get",
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
handler(api)
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/fail.txt"
request = httpx.Request("HEAD", url)
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
handler(api)
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError):
handler(api)
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError, match="size exceeded"):
handler(api)
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/file.exe"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(UnsupportedFileTypeError):
handler(api)

View File

@@ -1,49 +0,0 @@
from unittest.mock import patch
import controllers.console.spec as spec_module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestSpecSchemaDefinitionsApi:
def test_get_success(self):
api = spec_module.SpecSchemaDefinitionsApi()
method = unwrap(api.get)
schema_definitions = [{"type": "string"}]
with patch.object(
spec_module,
"SchemaManager",
) as schema_manager_cls:
schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions
resp, status = method(api)
assert status == 200
assert resp == schema_definitions
def test_get_exception_returns_empty_list(self):
api = spec_module.SpecSchemaDefinitionsApi()
method = unwrap(api.get)
with (
patch.object(
spec_module,
"SchemaManager",
side_effect=Exception("boom"),
),
patch.object(
spec_module.logger,
"exception",
) as log_exception,
):
resp, status = method(api)
assert status == 200
assert resp == []
log_exception.assert_called_once()

View File

@@ -1,162 +0,0 @@
from unittest.mock import MagicMock, patch
import controllers.console.version as version_module
class TestHasNewVersion:
def test_has_new_version_true(self):
result = version_module._has_new_version(
latest_version="1.2.0",
current_version="1.1.0",
)
assert result is True
def test_has_new_version_false(self):
result = version_module._has_new_version(
latest_version="1.0.0",
current_version="1.1.0",
)
assert result is False
def test_has_new_version_invalid_version(self):
with patch.object(version_module.logger, "warning") as log_warning:
result = version_module._has_new_version(
latest_version="invalid",
current_version="1.0.0",
)
assert result is False
log_warning.assert_called_once()
class TestCheckVersionUpdate:
def test_no_check_update_url(self):
query = version_module.VersionQuery(current_version="1.0.0")
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"",
),
patch.object(
version_module.dify_config.project,
"version",
"1.0.0",
),
patch.object(
version_module.dify_config,
"CAN_REPLACE_LOGO",
True,
),
patch.object(
version_module.dify_config,
"MODEL_LB_ENABLED",
False,
),
):
result = version_module.check_version_update(query)
assert result.version == "1.0.0"
assert result.can_auto_update is False
assert result.features.can_replace_logo is True
assert result.features.model_load_balancing_enabled is False
def test_http_error_fallback(self):
query = version_module.VersionQuery(current_version="1.0.0")
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
side_effect=Exception("boom"),
),
patch.object(
version_module.logger,
"warning",
) as log_warning,
):
result = version_module.check_version_update(query)
assert result.version == "1.0.0"
log_warning.assert_called_once()
def test_new_version_available(self):
query = version_module.VersionQuery(current_version="1.0.0")
response = MagicMock()
response.json.return_value = {
"version": "1.2.0",
"releaseDate": "2024-01-01",
"releaseNotes": "New features",
"canAutoUpdate": True,
}
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
return_value=response,
),
patch.object(
version_module.dify_config.project,
"version",
"1.0.0",
),
patch.object(
version_module.dify_config,
"CAN_REPLACE_LOGO",
False,
),
patch.object(
version_module.dify_config,
"MODEL_LB_ENABLED",
True,
),
):
result = version_module.check_version_update(query)
assert result.version == "1.2.0"
assert result.release_date == "2024-01-01"
assert result.release_notes == "New features"
assert result.can_auto_update is True
def test_no_new_version(self):
query = version_module.VersionQuery(current_version="1.2.0")
response = MagicMock()
response.json.return_value = {
"version": "1.1.0",
}
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
return_value=response,
),
patch.object(
version_module.dify_config.project,
"version",
"1.2.0",
),
):
result = version_module.check_version_update(query)
assert result.version == "1.2.0"
assert result.can_auto_update is False

View File

@@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => {
})
})
// ─── 6. Close Handling ───────────────────────────────────────────────────
describe('Close handling', () => {
it('should call onCancel when pressing ESC key', () => {
render(<Pricing onCancel={onCancel} />)
// ahooks useKeyPress listens on document for keydown events
document.dispatchEvent(new KeyboardEvent('keydown', {
key: 'Escape',
code: 'Escape',
keyCode: 27,
bubbles: true,
}))
expect(onCancel).toHaveBeenCalledTimes(1)
})
})
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
describe('Pricing page URL', () => {
it('should render pricing link with correct URL', () => {
render(<Pricing onCancel={onCancel} />)

View File

@@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
isShow={isShowDeleteConfirm}
onClose={() => setIsShowDeleteConfirm(false)}
>
<div className="title-2xl-semi-bold mb-3 text-text-primary">{t('avatar.deleteTitle', { ns: 'common' })}</div>
<div className="mb-3 text-text-primary title-2xl-semi-bold">{t('avatar.deleteTitle', { ns: 'common' })}</div>
<p className="mb-8 text-text-secondary">{t('avatar.deleteDescription', { ns: 'common' })}</p>
<div className="flex w-full items-center justify-center gap-2">

View File

@@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
</div>
{step === STEP.start && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.title', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.title', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-medium text-text-warning">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
<div className="body-md-regular text-text-secondary">
<div className="text-text-warning body-md-medium">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
<div className="text-text-secondary body-md-regular">
<Trans
i18nKey="account.changeEmail.content1"
ns="common"
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
values={{ email }}
/>
</div>
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
)}
{step === STEP.verifyOrigin && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-regular text-text-secondary">
<div className="text-text-secondary body-md-regular">
<Trans
i18nKey="account.changeEmail.content2"
ns="common"
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
values={{ email }}
/>
</div>
</div>
<div className="pt-3">
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<Input
className="!w-full"
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
@@ -278,25 +278,25 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{t('operation.cancel', { ns: 'common' })}
</Button>
</div>
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
{time > 0 && (
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
)}
{!time && (
<span onClick={sendCodeToOriginEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
<span onClick={sendCodeToOriginEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
)}
</div>
</>
)}
{step === STEP.newEmail && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-regular text-text-secondary">{t('account.changeEmail.content3', { ns: 'common' })}</div>
<div className="text-text-secondary body-md-regular">{t('account.changeEmail.content3', { ns: 'common' })}</div>
</div>
<div className="pt-3">
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
<Input
className="!w-full"
placeholder={t('account.changeEmail.emailPlaceholder', { ns: 'common' })}
@@ -305,10 +305,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
destructive={newEmailExited || unAvailableEmail}
/>
{newEmailExited && (
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
)}
{unAvailableEmail && (
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
)}
</div>
<div className="mt-3 space-y-2">
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
)}
{step === STEP.verifyNew && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-regular text-text-secondary">
<div className="text-text-secondary body-md-regular">
<Trans
i18nKey="account.changeEmail.content4"
ns="common"
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
values={{ email: mail }}
/>
</div>
</div>
<div className="pt-3">
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<Input
className="!w-full"
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
@@ -368,13 +368,13 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{t('operation.cancel', { ns: 'common' })}
</Button>
</div>
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
{time > 0 && (
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
)}
{!time && (
<span onClick={sendCodeToNewEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
<span onClick={sendCodeToNewEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
)}
</div>
</>

View File

@@ -138,7 +138,7 @@ export default function AccountPage() {
imageUrl={icon_url}
/>
</div>
<div className="system-sm-medium mt-[3px] text-text-secondary">{item.name}</div>
<div className="mt-[3px] text-text-secondary system-sm-medium">{item.name}</div>
</div>
)
}
@@ -146,12 +146,12 @@ export default function AccountPage() {
return (
<>
<div className="pb-3 pt-2">
<h4 className="title-2xl-semi-bold text-text-primary">{t('account.myAccount', { ns: 'common' })}</h4>
<h4 className="text-text-primary title-2xl-semi-bold">{t('account.myAccount', { ns: 'common' })}</h4>
</div>
<div className="mb-8 flex items-center rounded-xl bg-gradient-to-r from-background-gradient-bg-fill-chat-bg-2 to-background-gradient-bg-fill-chat-bg-1 p-6">
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size={64} />
<div className="ml-4">
<p className="system-xl-semibold text-text-primary">
<p className="text-text-primary system-xl-semibold">
{userProfile.name}
{isEducationAccount && (
<PremiumBadge size="s" color="blue" className="ml-1 !px-2">
@@ -160,16 +160,16 @@ export default function AccountPage() {
</PremiumBadge>
)}
</p>
<p className="system-xs-regular text-text-tertiary">{userProfile.email}</p>
<p className="text-text-tertiary system-xs-regular">{userProfile.email}</p>
</div>
</div>
<div className="mb-8">
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
<div className="mt-2 flex w-full items-center justify-between gap-2">
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
<span className="pl-1">{userProfile.name}</span>
</div>
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={handleEditName}>
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={handleEditName}>
{t('operation.edit', { ns: 'common' })}
</div>
</div>
@@ -177,11 +177,11 @@ export default function AccountPage() {
<div className="mb-8">
<div className={titleClassName}>{t('account.email', { ns: 'common' })}</div>
<div className="mt-2 flex w-full items-center justify-between gap-2">
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
<span className="pl-1">{userProfile.email}</span>
</div>
{systemFeatures.enable_change_email && (
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={() => setShowUpdateEmail(true)}>
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={() => setShowUpdateEmail(true)}>
{t('operation.change', { ns: 'common' })}
</div>
)}
@@ -191,8 +191,8 @@ export default function AccountPage() {
systemFeatures.enable_email_password_login && (
<div className="mb-8 flex justify-between gap-2">
<div>
<div className="system-sm-semibold mb-1 text-text-secondary">{t('account.password', { ns: 'common' })}</div>
<div className="body-xs-regular mb-2 text-text-tertiary">{t('account.passwordTip', { ns: 'common' })}</div>
<div className="mb-1 text-text-secondary system-sm-semibold">{t('account.password', { ns: 'common' })}</div>
<div className="mb-2 text-text-tertiary body-xs-regular">{t('account.passwordTip', { ns: 'common' })}</div>
</div>
<Button onClick={() => setEditPasswordModalVisible(true)}>{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</Button>
</div>
@@ -219,7 +219,7 @@ export default function AccountPage() {
onClose={() => setEditNameModalVisible(false)}
className="!w-[420px] !p-6"
>
<div className="title-2xl-semi-bold mb-6 text-text-primary">{t('account.editName', { ns: 'common' })}</div>
<div className="mb-6 text-text-primary title-2xl-semi-bold">{t('account.editName', { ns: 'common' })}</div>
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
<Input
className="mt-2"
@@ -249,7 +249,7 @@ export default function AccountPage() {
}}
className="!w-[420px] !p-6"
>
<div className="title-2xl-semi-bold mb-6 text-text-primary">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
<div className="mb-6 text-text-primary title-2xl-semi-bold">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
{userProfile.is_password_set && (
<>
<div className={titleClassName}>{t('account.currentPassword', { ns: 'common' })}</div>
@@ -272,7 +272,7 @@ export default function AccountPage() {
</div>
</>
)}
<div className="system-sm-semibold mt-8 text-text-secondary">
<div className="mt-8 text-text-secondary system-sm-semibold">
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
</div>
<div className="relative mt-2">
@@ -291,7 +291,7 @@ export default function AccountPage() {
</Button>
</div>
</div>
<div className="system-sm-semibold mt-8 text-text-secondary">{t('account.confirmPassword', { ns: 'common' })}</div>
<div className="mt-8 text-text-secondary system-sm-semibold">{t('account.confirmPassword', { ns: 'common' })}</div>
<div className="relative mt-2">
<Input
type={showConfirmPassword ? 'text' : 'password'}

View File

@@ -94,7 +94,7 @@ const CSVUploader: FC<Props> = ({
/>
<div ref={dropRef}>
{!file && (
<div className={cn('system-sm-regular flex h-20 items-center rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg', dragging && 'border border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
<div className={cn('flex h-20 items-center rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg system-sm-regular', dragging && 'border border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
<div className="flex w-full items-center justify-center space-x-2">
<CSVIcon className="shrink-0" />
<div className="text-text-tertiary">

View File

@@ -178,7 +178,7 @@ const Prompt: FC<ISimplePromptInput> = ({
{!noTitle && (
<div className="flex h-11 items-center justify-between pl-3 pr-2.5">
<div className="flex items-center space-x-1">
<div className="h2 system-sm-semibold-uppercase text-text-secondary">{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}</div>
<div className="h2 text-text-secondary system-sm-semibold-uppercase">{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}</div>
{!readonly && (
<Tooltip
popupContent={(

View File

@@ -96,7 +96,7 @@ const Editor: FC<Props> = ({
)}
</div>
</div>
<div className={cn(editorHeight, ' min-h-[102px] overflow-y-auto px-4 text-sm text-gray-700')}>
<div className={cn(editorHeight, 'min-h-[102px] overflow-y-auto px-4 text-sm text-gray-700')}>
<PromptEditor
className={editorHeight}
value={value}

View File

@@ -3,8 +3,10 @@ import type { FormValue } from '@/app/components/header/account-setting/model-pr
import type { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
import type { GenRes } from '@/service/debug'
import type { AppModeEnum, CompletionParams, Model, ModelModeType } from '@/types/app'
import { useSessionStorageState } from 'ahooks'
import useBoolean from 'ahooks/lib/useBoolean'
import {
useBoolean,
useSessionStorageState,
} from 'ahooks'
import * as React from 'react'
import { useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -224,7 +226,7 @@ export const GetCodeGeneratorResModal: FC<IGetCodeGeneratorResProps> = (
</div>
<div>
<div className="text-[0px]">
<div className="system-sm-semibold-uppercase mb-1.5 text-text-secondary">{t('codegen.instruction', { ns: 'appDebug' })}</div>
<div className="mb-1.5 text-text-secondary system-sm-semibold-uppercase">{t('codegen.instruction', { ns: 'appDebug' })}</div>
<InstructionEditor
editorKey={editorKey}
value={instruction}
@@ -248,7 +250,7 @@ export const GetCodeGeneratorResModal: FC<IGetCodeGeneratorResProps> = (
disabled={isLoading}
>
<Generator className="h-4 w-4" />
<span className="text-xs font-semibold ">{t('codegen.generate', { ns: 'appDebug' })}</span>
<span className="text-xs font-semibold">{t('codegen.generate', { ns: 'appDebug' })}</span>
</Button>
</div>
</div>

View File

@@ -210,7 +210,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
<div className="overflow-y-auto border-b border-divider-regular p-6 pb-[68px] pt-5">
<div className={cn(rowClass, 'items-center')}>
<div className={labelClass}>
<div className="system-sm-semibold text-text-secondary">{t('form.name', { ns: 'datasetSettings' })}</div>
<div className="text-text-secondary system-sm-semibold">{t('form.name', { ns: 'datasetSettings' })}</div>
</div>
<Input
value={localeCurrentDataset.name}
@@ -221,7 +221,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
</div>
<div className={cn(rowClass)}>
<div className={labelClass}>
<div className="system-sm-semibold text-text-secondary">{t('form.desc', { ns: 'datasetSettings' })}</div>
<div className="text-text-secondary system-sm-semibold">{t('form.desc', { ns: 'datasetSettings' })}</div>
</div>
<div className="w-full">
<Textarea
@@ -234,7 +234,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
</div>
<div className={rowClass}>
<div className={labelClass}>
<div className="system-sm-semibold text-text-secondary">{t('form.permissions', { ns: 'datasetSettings' })}</div>
<div className="text-text-secondary system-sm-semibold">{t('form.permissions', { ns: 'datasetSettings' })}</div>
</div>
<div className="w-full">
<PermissionSelector
@@ -250,7 +250,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
{!!(currentDataset && currentDataset.indexing_technique) && (
<div className={cn(rowClass)}>
<div className={labelClass}>
<div className="system-sm-semibold text-text-secondary">{t('form.indexMethod', { ns: 'datasetSettings' })}</div>
<div className="text-text-secondary system-sm-semibold">{t('form.indexMethod', { ns: 'datasetSettings' })}</div>
</div>
<div className="grow">
<IndexMethod
@@ -267,7 +267,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
{indexMethod === IndexingType.QUALIFIED && (
<div className={cn(rowClass)}>
<div className={labelClass}>
<div className="system-sm-semibold text-text-secondary">{t('form.embeddingModel', { ns: 'datasetSettings' })}</div>
<div className="text-text-secondary system-sm-semibold">{t('form.embeddingModel', { ns: 'datasetSettings' })}</div>
</div>
<div className="w-full">
<div className="h-8 w-full rounded-lg bg-components-input-bg-normal opacity-60">

View File

@@ -394,7 +394,7 @@ const Debug: FC<IDebug> = ({
<>
<div className="shrink-0">
<div className="flex items-center justify-between px-4 pb-2 pt-3">
<div className="system-xl-semibold text-text-primary">{t('inputs.title', { ns: 'appDebug' })}</div>
<div className="text-text-primary system-xl-semibold">{t('inputs.title', { ns: 'appDebug' })}</div>
<div className="flex items-center">
{
debugWithMultipleModel
@@ -539,7 +539,7 @@ const Debug: FC<IDebug> = ({
{!completionRes && !isResponding && (
<div className="flex grow flex-col items-center justify-center gap-2">
<RiSparklingFill className="h-12 w-12 text-text-empty-state-icon" />
<div className="system-sm-regular text-text-quaternary">{t('noResult', { ns: 'appDebug' })}</div>
<div className="text-text-quaternary system-sm-regular">{t('noResult', { ns: 'appDebug' })}</div>
</div>
)}
</>

View File

@@ -966,10 +966,10 @@ const Configuration: FC = () => {
<div className="bg-default-subtle absolute left-0 top-0 h-14 w-full">
<div className="flex h-14 items-center justify-between px-6">
<div className="flex items-center">
<div className="system-xl-semibold text-text-primary">{t('orchestrate', { ns: 'appDebug' })}</div>
<div className="text-text-primary system-xl-semibold">{t('orchestrate', { ns: 'appDebug' })}</div>
<div className="flex h-[14px] items-center space-x-1 text-xs">
{isAdvancedMode && (
<div className="system-xs-medium-uppercase ml-1 flex h-5 items-center rounded-md border border-components-button-secondary-border px-1.5 uppercase text-text-tertiary">{t('promptMode.advanced', { ns: 'appDebug' })}</div>
<div className="ml-1 flex h-5 items-center rounded-md border border-components-button-secondary-border px-1.5 uppercase text-text-tertiary system-xs-medium-uppercase">{t('promptMode.advanced', { ns: 'appDebug' })}</div>
)}
</div>
</div>
@@ -1030,8 +1030,8 @@ const Configuration: FC = () => {
<Config />
</div>
{!isMobile && (
<div className="relative flex h-full w-1/2 grow flex-col overflow-y-auto " style={{ borderColor: 'rgba(0, 0, 0, 0.02)' }}>
<div className="flex grow flex-col rounded-tl-2xl border-l-[0.5px] border-t-[0.5px] border-components-panel-border bg-chatbot-bg ">
<div className="relative flex h-full w-1/2 grow flex-col overflow-y-auto" style={{ borderColor: 'rgba(0, 0, 0, 0.02)' }}>
<div className="flex grow flex-col rounded-tl-2xl border-l-[0.5px] border-t-[0.5px] border-components-panel-border bg-chatbot-bg">
<Debug
isAPIKeySet={isAPIKeySet}
onSetting={() => setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PROVIDER })}

View File

@@ -217,7 +217,7 @@ const ExternalDataToolModal: FC<ExternalDataToolModalProps> = ({
<AppIcon
size="large"
onClick={() => { setShowEmojiPicker(true) }}
className="!h-9 !w-9 cursor-pointer rounded-lg border-[0.5px] border-components-panel-border "
className="!h-9 !w-9 cursor-pointer rounded-lg border-[0.5px] border-components-panel-border"
icon={localeData.icon}
background={localeData.icon_background}
/>

View File

@@ -117,10 +117,10 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
<div className="px-10">
<div className="h-6 w-full 2xl:h-[139px]" />
<div className="pb-6 pt-1">
<span className="title-2xl-semi-bold text-text-primary">{t('newApp.startFromBlank', { ns: 'app' })}</span>
<span className="text-text-primary title-2xl-semi-bold">{t('newApp.startFromBlank', { ns: 'app' })}</span>
</div>
<div className="mb-2 leading-6">
<span className="system-sm-semibold text-text-secondary">{t('newApp.chooseAppType', { ns: 'app' })}</span>
<span className="text-text-secondary system-sm-semibold">{t('newApp.chooseAppType', { ns: 'app' })}</span>
</div>
<div className="flex w-[660px] flex-col gap-4">
<div>
@@ -160,7 +160,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
className="flex cursor-pointer items-center border-0 bg-transparent p-0"
onClick={() => setIsAppTypeExpanded(!isAppTypeExpanded)}
>
<span className="system-2xs-medium-uppercase text-text-tertiary">{t('newApp.forBeginners', { ns: 'app' })}</span>
<span className="text-text-tertiary system-2xs-medium-uppercase">{t('newApp.forBeginners', { ns: 'app' })}</span>
<RiArrowRightSLine className={`ml-1 h-4 w-4 text-text-tertiary transition-transform ${isAppTypeExpanded ? 'rotate-90' : ''}`} />
</button>
</div>
@@ -212,7 +212,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
<div className="flex items-center space-x-3">
<div className="flex-1">
<div className="mb-1 flex h-6 items-center">
<label className="system-sm-semibold text-text-secondary">{t('newApp.captionName', { ns: 'app' })}</label>
<label className="text-text-secondary system-sm-semibold">{t('newApp.captionName', { ns: 'app' })}</label>
</div>
<Input
value={name}
@@ -243,8 +243,8 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
</div>
<div>
<div className="mb-1 flex h-6 items-center">
<label className="system-sm-semibold text-text-secondary">{t('newApp.captionDescription', { ns: 'app' })}</label>
<span className="system-xs-regular ml-1 text-text-tertiary">
<label className="text-text-secondary system-sm-semibold">{t('newApp.captionDescription', { ns: 'app' })}</label>
<span className="ml-1 text-text-tertiary system-xs-regular">
(
{t('newApp.optional', { ns: 'app' })}
)
@@ -260,7 +260,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
</div>
{isAppsFull && <AppsFull className="mt-4" loc="app-create" />}
<div className="flex items-center justify-between pb-10 pt-5">
<div className="system-xs-regular flex cursor-pointer items-center gap-1 text-text-tertiary" onClick={onCreateFromTemplate}>
<div className="flex cursor-pointer items-center gap-1 text-text-tertiary system-xs-regular" onClick={onCreateFromTemplate}>
<span>{t('newApp.noIdeaTip', { ns: 'app' })}</span>
<div className="p-[1px]">
<RiArrowRightLine className="h-3.5 w-3.5" />
@@ -334,8 +334,8 @@ function AppTypeCard({ icon, title, description, active, onClick }: AppTypeCardP
onClick={onClick}
>
{icon}
<div className="system-sm-semibold mb-0.5 mt-2 text-text-secondary">{title}</div>
<div className="system-xs-regular line-clamp-2 text-text-tertiary" title={description}>{description}</div>
<div className="mb-0.5 mt-2 text-text-secondary system-sm-semibold">{title}</div>
<div className="line-clamp-2 text-text-tertiary system-xs-regular" title={description}>{description}</div>
</div>
)
}
@@ -367,8 +367,8 @@ function AppPreview({ mode }: { mode: AppModeEnum }) {
const previewInfo = modeToPreviewInfoMap[mode]
return (
<div className="px-8 py-4">
<h4 className="system-sm-semibold-uppercase text-text-secondary">{previewInfo.title}</h4>
<div className="system-xs-regular mt-1 min-h-8 max-w-96 text-text-tertiary">
<h4 className="text-text-secondary system-sm-semibold-uppercase">{previewInfo.title}</h4>
<div className="mt-1 min-h-8 max-w-96 text-text-tertiary system-xs-regular">
<span>{previewInfo.description}</span>
</div>
</div>

View File

@@ -232,7 +232,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
isShow={show}
onClose={noop}
>
<div className="title-2xl-semi-bold flex items-center justify-between pb-3 pl-6 pr-5 pt-6 text-text-primary">
<div className="flex items-center justify-between pb-3 pl-6 pr-5 pt-6 text-text-primary title-2xl-semi-bold">
{t('importFromDSL', { ns: 'app' })}
<div
className="flex h-8 w-8 cursor-pointer items-center"
@@ -241,7 +241,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
<RiCloseLine className="h-5 w-5 text-text-tertiary" />
</div>
</div>
<div className="system-md-semibold flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 text-text-tertiary">
<div className="flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 text-text-tertiary system-md-semibold">
{
tabs.map(tab => (
<div
@@ -275,7 +275,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
{
currentTab === CreateFromDSLModalTab.FROM_URL && (
<div>
<div className="system-md-semibold mb-1 text-text-secondary">DSL URL</div>
<div className="mb-1 text-text-secondary system-md-semibold">DSL URL</div>
<Input
placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''}
value={dslUrlValue}
@@ -309,8 +309,8 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
className="w-[480px]"
>
<div className="flex flex-col items-start gap-2 self-stretch pb-4">
<div className="title-2xl-semi-bold text-text-primary">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
<div className="system-md-regular flex grow flex-col text-text-secondary">
<div className="text-text-primary title-2xl-semi-bold">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
<div className="flex grow flex-col text-text-secondary system-md-regular">
<div>{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}</div>
<div>{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}</div>
<br />

View File

@@ -121,7 +121,7 @@ const Uploader: FC<Props> = ({
</div>
)}
{file && (
<div className={cn('group flex items-center rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs', ' hover:bg-components-panel-on-panel-item-bg-hover')}>
<div className={cn('group flex items-center rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs', 'hover:bg-components-panel-on-panel-item-bg-hover')}>
<div className="flex items-center justify-center p-3">
<YamlIcon className="h-6 w-6 shrink-0" />
</div>

View File

@@ -96,7 +96,7 @@ const statusTdRender = (statusCount: StatusCount) => {
if (statusCount.paused > 0) {
return (
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
<Indicator color="yellow" />
<span className="text-util-colors-warning-warning-600">Pending</span>
</div>
@@ -104,7 +104,7 @@ const statusTdRender = (statusCount: StatusCount) => {
}
else if (statusCount.partial_success + statusCount.failed === 0) {
return (
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
<Indicator color="green" />
<span className="text-util-colors-green-green-600">Success</span>
</div>
@@ -112,7 +112,7 @@ const statusTdRender = (statusCount: StatusCount) => {
}
else if (statusCount.failed === 0) {
return (
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
<Indicator color="green" />
<span className="text-util-colors-green-green-600">Partial Success</span>
</div>
@@ -120,7 +120,7 @@ const statusTdRender = (statusCount: StatusCount) => {
}
else {
return (
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
<Indicator color="red" />
<span className="text-util-colors-red-red-600">
{statusCount.failed}
@@ -562,9 +562,9 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
{/* Panel Header */}
<div className="flex shrink-0 items-center gap-2 rounded-t-xl bg-components-panel-bg pb-2 pl-4 pr-3 pt-3">
<div className="shrink-0">
<div className="system-xs-semibold-uppercase mb-0.5 text-text-primary">{isChatMode ? t('detail.conversationId', { ns: 'appLog' }) : t('detail.time', { ns: 'appLog' })}</div>
<div className="mb-0.5 text-text-primary system-xs-semibold-uppercase">{isChatMode ? t('detail.conversationId', { ns: 'appLog' }) : t('detail.time', { ns: 'appLog' })}</div>
{isChatMode && (
<div className="system-2xs-regular-uppercase flex items-center text-text-secondary">
<div className="flex items-center text-text-secondary system-2xs-regular-uppercase">
<Tooltip
popupContent={detail.id}
>
@@ -574,7 +574,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
</div>
)}
{!isChatMode && (
<div className="system-2xs-regular-uppercase text-text-secondary">{formatTime(detail.created_at, t('dateTimeFormat', { ns: 'appLog' }) as string)}</div>
<div className="text-text-secondary system-2xs-regular-uppercase">{formatTime(detail.created_at, t('dateTimeFormat', { ns: 'appLog' }) as string)}</div>
)}
</div>
<div className="flex grow flex-wrap items-center justify-end gap-y-1">
@@ -600,7 +600,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
? (
<div className="px-6 py-4">
<div className="flex h-[18px] items-center space-x-3">
<div className="system-xs-semibold-uppercase text-text-tertiary">{t('table.header.output', { ns: 'appLog' })}</div>
<div className="text-text-tertiary system-xs-semibold-uppercase">{t('table.header.output', { ns: 'appLog' })}</div>
<div
className="h-px grow"
style={{
@@ -692,7 +692,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
</div>
{hasMore && (
<div className="py-3 text-center">
<div className="system-xs-regular text-text-tertiary">
<div className="text-text-tertiary system-xs-regular">
{t('detail.loading', { ns: 'appLog' })}
...
</div>
@@ -950,7 +950,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
)}
popupClassName={(isHighlight && !isChatMode) ? '' : '!hidden'}
>
<div className={cn(isEmptyStyle ? 'text-text-quaternary' : 'text-text-secondary', !isHighlight ? '' : 'bg-orange-100', 'system-sm-regular overflow-hidden text-ellipsis whitespace-nowrap')}>
<div className={cn(isEmptyStyle ? 'text-text-quaternary' : 'text-text-secondary', !isHighlight ? '' : 'bg-orange-100', 'overflow-hidden text-ellipsis whitespace-nowrap system-sm-regular')}>
{value || '-'}
</div>
</Tooltip>
@@ -963,7 +963,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
return (
<div className="relative mt-2 grow overflow-x-auto">
<table className={cn('w-full min-w-[440px] border-collapse border-0')}>
<thead className="system-xs-medium-uppercase text-text-tertiary">
<thead className="text-text-tertiary system-xs-medium-uppercase">
<tr>
<td className="w-5 whitespace-nowrap rounded-l-lg bg-background-section-burn pl-2 pr-1"></td>
<td className="whitespace-nowrap bg-background-section-burn py-1.5 pl-3">{isChatMode ? t('table.header.summary', { ns: 'appLog' }) : t('table.header.input', { ns: 'appLog' })}</td>
@@ -976,7 +976,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
<td className="whitespace-nowrap rounded-r-lg bg-background-section-burn py-1.5 pl-3">{t('table.header.time', { ns: 'appLog' })}</td>
</tr>
</thead>
<tbody className="system-sm-regular text-text-secondary">
<tbody className="text-text-secondary system-sm-regular">
{logs.data.map((log: any) => {
const endUser = log.from_end_user_session_id || log.from_account_name
const leftValue = get(log, isChatMode ? 'name' : 'message.inputs.query') || (!isChatMode ? (get(log, 'message.query') || get(log, 'message.inputs.default_input')) : '') || ''

View File

@@ -231,12 +231,12 @@ const SettingsModal: FC<ISettingsModalProps> = ({
{/* header */}
<div className="pb-3 pl-6 pr-5 pt-5">
<div className="flex items-center gap-1">
<div className="title-2xl-semi-bold grow text-text-primary">{t(`${prefixSettings}.title`, { ns: 'appOverview' })}</div>
<div className="grow text-text-primary title-2xl-semi-bold">{t(`${prefixSettings}.title`, { ns: 'appOverview' })}</div>
<ActionButton className="shrink-0" onClick={onHide}>
<RiCloseLine className="h-4 w-4" />
</ActionButton>
</div>
<div className="system-xs-regular mt-0.5 text-text-tertiary">
<div className="mt-0.5 text-text-tertiary system-xs-regular">
<span>{t(`${prefixSettings}.modalTip`, { ns: 'appOverview' })}</span>
</div>
</div>
@@ -245,7 +245,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
{/* name & icon */}
<div className="flex gap-4">
<div className="grow">
<div className={cn('system-sm-semibold mb-1 py-1 text-text-secondary')}>{t(`${prefixSettings}.webName`, { ns: 'appOverview' })}</div>
<div className={cn('mb-1 py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.webName`, { ns: 'appOverview' })}</div>
<Input
className="w-full"
value={inputInfo.title}
@@ -265,32 +265,32 @@ const SettingsModal: FC<ISettingsModalProps> = ({
</div>
{/* description */}
<div className="relative">
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}</div>
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}</div>
<Textarea
className="mt-1"
value={inputInfo.desc}
onChange={e => onDesChange(e.target.value)}
placeholder={t(`${prefixSettings}.webDescPlaceholder`, { ns: 'appOverview' }) as string}
/>
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>{t(`${prefixSettings}.webDescTip`, { ns: 'appOverview' })}</p>
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>{t(`${prefixSettings}.webDescTip`, { ns: 'appOverview' })}</p>
</div>
<Divider className="my-0 h-px" />
{/* answer icon */}
{isChat && (
<div className="w-full">
<div className="flex items-center justify-between">
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t('answerIcon.title', { ns: 'app' })}</div>
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t('answerIcon.title', { ns: 'app' })}</div>
<Switch
value={inputInfo.use_icon_as_answer_icon}
onChange={v => setInputInfo({ ...inputInfo, use_icon_as_answer_icon: v })}
/>
</div>
<p className="body-xs-regular pb-0.5 text-text-tertiary">{t('answerIcon.description', { ns: 'app' })}</p>
<p className="pb-0.5 text-text-tertiary body-xs-regular">{t('answerIcon.description', { ns: 'app' })}</p>
</div>
)}
{/* language */}
<div className="flex items-center">
<div className={cn('system-sm-semibold grow py-1 text-text-secondary')}>{t(`${prefixSettings}.language`, { ns: 'appOverview' })}</div>
<div className={cn('grow py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.language`, { ns: 'appOverview' })}</div>
<SimpleSelect
wrapperClassName="w-[200px]"
items={languages.filter(item => item.supported)}
@@ -303,8 +303,8 @@ const SettingsModal: FC<ISettingsModalProps> = ({
{isChat && (
<div className="flex items-center">
<div className="grow">
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.chatColorTheme`, { ns: 'appOverview' })}</div>
<div className="body-xs-regular pb-0.5 text-text-tertiary">{t(`${prefixSettings}.chatColorThemeDesc`, { ns: 'appOverview' })}</div>
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.chatColorTheme`, { ns: 'appOverview' })}</div>
<div className="pb-0.5 text-text-tertiary body-xs-regular">{t(`${prefixSettings}.chatColorThemeDesc`, { ns: 'appOverview' })}</div>
</div>
<div className="shrink-0">
<Input
@@ -314,7 +314,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
placeholder="E.g #A020F0"
/>
<div className="flex items-center justify-between">
<p className={cn('body-xs-regular text-text-tertiary')}>{t(`${prefixSettings}.chatColorThemeInverted`, { ns: 'appOverview' })}</p>
<p className={cn('text-text-tertiary body-xs-regular')}>{t(`${prefixSettings}.chatColorThemeInverted`, { ns: 'appOverview' })}</p>
<Switch value={inputInfo.chatColorThemeInverted} onChange={v => setInputInfo({ ...inputInfo, chatColorThemeInverted: v })}></Switch>
</div>
</div>
@@ -323,22 +323,22 @@ const SettingsModal: FC<ISettingsModalProps> = ({
{/* workflow detail */}
<div className="w-full">
<div className="flex items-center justify-between">
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.workflow.subTitle`, { ns: 'appOverview' })}</div>
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.workflow.subTitle`, { ns: 'appOverview' })}</div>
<Switch
disabled={!(appInfo.mode === AppModeEnum.WORKFLOW || appInfo.mode === AppModeEnum.ADVANCED_CHAT)}
value={inputInfo.show_workflow_steps}
onChange={v => setInputInfo({ ...inputInfo, show_workflow_steps: v })}
/>
</div>
<p className="body-xs-regular pb-0.5 text-text-tertiary">{t(`${prefixSettings}.workflow.showDesc`, { ns: 'appOverview' })}</p>
<p className="pb-0.5 text-text-tertiary body-xs-regular">{t(`${prefixSettings}.workflow.showDesc`, { ns: 'appOverview' })}</p>
</div>
{/* more settings switch */}
<Divider className="my-0 h-px" />
{!isShowMore && (
<div className="flex cursor-pointer items-center" onClick={() => setIsShowMore(true)}>
<div className="grow">
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.more.entry`, { ns: 'appOverview' })}</div>
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.entry`, { ns: 'appOverview' })}</div>
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>
{t(`${prefixSettings}.more.copyRightPlaceholder`, { ns: 'appOverview' })}
{' '}
&
@@ -356,7 +356,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
<div className="w-full">
<div className="flex items-center">
<div className="flex grow items-center">
<div className={cn('system-sm-semibold mr-1 py-1 text-text-secondary')}>{t(`${prefixSettings}.more.copyright`, { ns: 'appOverview' })}</div>
<div className={cn('mr-1 py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.copyright`, { ns: 'appOverview' })}</div>
{/* upgrade button */}
{enableBilling && isFreePlan && (
<div className="h-[18px] select-none">
@@ -385,7 +385,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
/>
</Tooltip>
</div>
<p className="body-xs-regular pb-0.5 text-text-tertiary">{t(`${prefixSettings}.more.copyrightTip`, { ns: 'appOverview' })}</p>
<p className="pb-0.5 text-text-tertiary body-xs-regular">{t(`${prefixSettings}.more.copyrightTip`, { ns: 'appOverview' })}</p>
{inputInfo.copyrightSwitchValue && (
<Input
className="mt-2 h-10"
@@ -397,8 +397,8 @@ const SettingsModal: FC<ISettingsModalProps> = ({
</div>
{/* privacy policy */}
<div className="w-full">
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.more.privacyPolicy`, { ns: 'appOverview' })}</div>
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.privacyPolicy`, { ns: 'appOverview' })}</div>
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>
<Trans
i18nKey={`${prefixSettings}.more.privacyPolicyTip`}
ns="appOverview"
@@ -414,8 +414,8 @@ const SettingsModal: FC<ISettingsModalProps> = ({
</div>
{/* custom disclaimer */}
<div className="w-full">
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.more.customDisclaimer`, { ns: 'appOverview' })}</div>
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>{t(`${prefixSettings}.more.customDisclaimerTip`, { ns: 'appOverview' })}</p>
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.customDisclaimer`, { ns: 'appOverview' })}</div>
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>{t(`${prefixSettings}.more.customDisclaimerTip`, { ns: 'appOverview' })}</p>
<Textarea
className="mt-1"
value={inputInfo.customDisclaimer}

View File

@@ -200,14 +200,14 @@ const ChatInputArea = ({
<div className="relative flex w-full grow items-center">
<div
ref={textValueRef}
className="body-lg-regular pointer-events-none invisible absolute h-auto w-auto whitespace-pre p-1 leading-6"
className="pointer-events-none invisible absolute h-auto w-auto whitespace-pre p-1 leading-6 body-lg-regular"
>
{query}
</div>
<Textarea
ref={ref => textareaRef.current = ref as any}
className={cn(
'body-lg-regular w-full resize-none bg-transparent p-1 leading-6 text-text-primary outline-none',
'w-full resize-none bg-transparent p-1 leading-6 text-text-primary outline-none body-lg-regular',
)}
placeholder={decode(t(readonly ? 'chat.inputDisabledPlaceholder' : 'chat.inputPlaceholder', { ns: 'common', botName }) || '')}
autoFocus

View File

@@ -0,0 +1,4 @@
<svg width="10" height="10" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z" fill="#676F83"/>
<path d="M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z" fill="#676F83"/>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -1,5 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="arrow-down-round-fill">
<path id="Vector" d="M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z" fill="#101828"/>
<path id="Vector" d="M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z" fill="currentColor"/>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 380 B

After

Width:  |  Height:  |  Size: 385 B

View File

@@ -1,3 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path id="Solid" fill-rule="evenodd" clip-rule="evenodd" d="M8.00008 0.666016C3.94999 0.666016 0.666748 3.94926 0.666748 7.99935C0.666748 12.0494 3.94999 15.3327 8.00008 15.3327C12.0502 15.3327 15.3334 12.0494 15.3334 7.99935C15.3334 3.94926 12.0502 0.666016 8.00008 0.666016ZM10.4715 5.52794C10.7318 5.78829 10.7318 6.2104 10.4715 6.47075L8.94289 7.99935L10.4715 9.52794C10.7318 9.78829 10.7318 10.2104 10.4715 10.4708C10.2111 10.7311 9.78903 10.7311 9.52868 10.4708L8.00008 8.94216L6.47149 10.4708C6.21114 10.7311 5.78903 10.7311 5.52868 10.4708C5.26833 10.2104 5.26833 9.78829 5.52868 9.52794L7.05727 7.99935L5.52868 6.47075C5.26833 6.2104 5.26833 5.78829 5.52868 5.52794C5.78903 5.26759 6.21114 5.26759 6.47149 5.52794L8.00008 7.05654L9.52868 5.52794C9.78903 5.26759 10.2111 5.26759 10.4715 5.52794Z" fill="#98A2B3"/>
<path id="Solid" fill-rule="evenodd" clip-rule="evenodd" d="M8.00008 0.666016C3.94999 0.666016 0.666748 3.94926 0.666748 7.99935C0.666748 12.0494 3.94999 15.3327 8.00008 15.3327C12.0502 15.3327 15.3334 12.0494 15.3334 7.99935C15.3334 3.94926 12.0502 0.666016 8.00008 0.666016ZM10.4715 5.52794C10.7318 5.78829 10.7318 6.2104 10.4715 6.47075L8.94289 7.99935L10.4715 9.52794C10.7318 9.78829 10.7318 10.2104 10.4715 10.4708C10.2111 10.7311 9.78903 10.7311 9.52868 10.4708L8.00008 8.94216L6.47149 10.4708C6.21114 10.7311 5.78903 10.7311 5.52868 10.4708C5.26833 10.2104 5.26833 9.78829 5.52868 9.52794L7.05727 7.99935L5.52868 6.47075C5.26833 6.2104 5.26833 5.78829 5.52868 5.52794C5.78903 5.26759 6.21114 5.26759 6.47149 5.52794L8.00008 7.05654L9.52868 5.52794C9.78903 5.26759 10.2111 5.26759 10.4715 5.52794Z" fill="currentColor"/>
</svg>

Before

Width:  |  Height:  |  Size: 925 B

After

Width:  |  Height:  |  Size: 930 B

View File

@@ -0,0 +1,35 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "10",
"height": "10",
"viewBox": "0 0 10 10",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z",
"fill": "currentColor"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z",
"fill": "currentColor"
},
"children": []
}
]
},
"name": "CreditsCoin"
}

View File

@@ -0,0 +1,20 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import type { IconData } from '@/app/components/base/icons/IconBase'
import * as React from 'react'
import IconBase from '@/app/components/base/icons/IconBase'
import data from './CreditsCoin.json'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'CreditsCoin'
export default Icon

View File

@@ -1,5 +1,6 @@
export { default as Balance } from './Balance'
export { default as CoinsStacked01 } from './CoinsStacked01'
export { default as CreditsCoin } from './CreditsCoin'
export { default as GoldCoin } from './GoldCoin'
export { default as ReceiptList } from './ReceiptList'
export { default as Tag01 } from './Tag01'

View File

@@ -14,7 +14,7 @@ import {
} from 'react'
import { useTranslation } from 'react-i18next'
import { useReactFlow, useStoreApi } from 'reactflow'
import Tooltip from '@/app/components/base/tooltip'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import { isConversationVar, isENV, isGlobalVar, isRagVariableVar, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils'
import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel'
import {
@@ -28,6 +28,7 @@ import {
UPDATE_WORKFLOW_NODES_MAP,
} from './index'
import { WorkflowVariableBlockNode } from './node'
import { useLlmModelPluginInstalled } from './use-llm-model-plugin-installed'
type WorkflowVariableBlockComponentProps = {
nodeKey: string
@@ -68,6 +69,8 @@ const WorkflowVariableBlockComponent = ({
const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]]
const isException = isExceptionVariable(varName, node?.type)
const sourceNodeId = variables[isRagVar ? 1 : 0]
const isLlmModelInstalled = useLlmModelPluginInstalled(sourceNodeId, localWorkflowNodesMap)
const variableValid = useMemo(() => {
let variableValid = true
const isEnv = isENV(variables)
@@ -144,7 +147,13 @@ const WorkflowVariableBlockComponent = ({
handleVariableJump()
}}
isExceptionVariable={isException}
errorMsg={!variableValid ? t('errorMsg.invalidVariable', { ns: 'workflow' }) : undefined}
errorMsg={
!variableValid
? t('errorMsg.invalidVariable', { ns: 'workflow' })
: !isLlmModelInstalled
? t('errorMsg.modelPluginNotInstalled', { ns: 'workflow' })
: undefined
}
isSelected={isSelected}
ref={ref}
notShowFullPath={isShowAPart}
@@ -155,9 +164,9 @@ const WorkflowVariableBlockComponent = ({
return Item
return (
<Tooltip
noDecoration
popupContent={(
<Tooltip>
<TooltipTrigger disabled={!isShowAPart} render={<div>{Item}</div>} />
<TooltipContent variant="plain">
<VarFullPathPanel
nodeName={node.title}
path={variables.slice(1)}
@@ -169,10 +178,7 @@ const WorkflowVariableBlockComponent = ({
: Type.string}
nodeType={node?.type}
/>
)}
disabled={!isShowAPart}
>
<div>{Item}</div>
</TooltipContent>
</Tooltip>
)
}

View File

@@ -0,0 +1,23 @@
import type { WorkflowNodesMap } from '@/app/components/base/prompt-editor/types'
import { BlockEnum } from '@/app/components/workflow/types'
import { extractPluginId } from '@/app/components/workflow/utils/plugin'
import { useProviderContextSelector } from '@/context/provider-context'
export function useLlmModelPluginInstalled(
nodeId: string,
workflowNodesMap: WorkflowNodesMap | undefined,
): boolean {
const node = workflowNodesMap?.[nodeId]
const modelProvider = node?.type === BlockEnum.LLM
? node.modelProvider
: undefined
const modelPluginId = modelProvider ? extractPluginId(modelProvider) : undefined
return useProviderContextSelector((state) => {
if (!modelPluginId)
return true
return state.modelProviders.some(p =>
extractPluginId(p.provider) === modelPluginId,
)
})
}

View File

@@ -73,7 +73,7 @@ export type GetVarType = (payload: {
export type WorkflowVariableBlockType = {
show?: boolean
variables?: NodeOutPutVar[]
workflowNodesMap?: Record<string, Pick<Node['data'], 'title' | 'type' | 'height' | 'width' | 'position'>>
workflowNodesMap?: WorkflowNodesMap
onInsert?: () => void
onDelete?: () => void
getVarType?: GetVarType
@@ -81,12 +81,14 @@ export type WorkflowVariableBlockType = {
onManageInputField?: () => void
}
export type WorkflowNodesMap = Record<string, Pick<Node['data'], 'title' | 'type' | 'height' | 'width' | 'position'> & { modelProvider?: string }>
export type HITLInputBlockType = {
show?: boolean
nodeId: string
formInputs?: FormInputItem[]
variables?: NodeOutPutVar[]
workflowNodesMap?: Record<string, Pick<Node['data'], 'title' | 'type' | 'height' | 'width' | 'position'>>
workflowNodesMap?: WorkflowNodesMap
getVarType?: GetVarType
onFormInputsChange?: (inputs: FormInputItem[]) => void
onFormInputItemRemove: (varName: string) => void

View File

@@ -1,10 +1,14 @@
import type { ChangeEvent, FC, KeyboardEvent } from 'react'
import { useCallback, useState } from 'react'
import AutosizeInput from 'react-18-input-autosize'
import _AutosizeInput from 'react-18-input-autosize'
import { useTranslation } from 'react-i18next'
import { useToastContext } from '@/app/components/base/toast/context'
import { cn } from '@/utils/classnames'
// CJS/ESM interop: Turbopack may resolve the module namespace object instead of the default export
// eslint-disable-next-line ts/no-explicit-any
const AutosizeInput = ('default' in (_AutosizeInput as any) ? (_AutosizeInput as any).default : _AutosizeInput) as typeof _AutosizeInput
type TagInputProps = {
items: string[]
onChange: (items: string[]) => void

View File

@@ -43,20 +43,24 @@ type DialogContentProps = {
children: React.ReactNode
className?: string
overlayClassName?: string
backdropProps?: React.ComponentPropsWithoutRef<typeof BaseDialog.Backdrop>
}
export function DialogContent({
children,
className,
overlayClassName,
backdropProps,
}: DialogContentProps) {
return (
<DialogPortal>
<BaseDialog.Backdrop
{...backdropProps}
className={cn(
'fixed inset-0 z-50 bg-background-overlay',
'transition-opacity duration-150 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none',
overlayClassName,
backdropProps?.className,
)}
/>
<BaseDialog.Popup

View File

@@ -1,7 +1,7 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import { CategoryEnum } from '..'
import Footer from '../footer'
import { CategoryEnum } from '../types'
vi.mock('next/link', () => ({
default: ({ children, href, className, target }: { children: React.ReactNode, href: string, className?: string, target?: string }) => (

View File

@@ -1,7 +1,16 @@
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { Dialog } from '@/app/components/base/ui/dialog'
import Header from '../header'
function renderHeader(onClose: () => void) {
return render(
<Dialog open>
<Header onClose={onClose} />
</Dialog>,
)
}
describe('Header', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -11,7 +20,7 @@ describe('Header', () => {
it('should render title and description translations', () => {
const handleClose = vi.fn()
render(<Header onClose={handleClose} />)
renderHeader(handleClose)
expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument()
expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument()
@@ -22,7 +31,7 @@ describe('Header', () => {
describe('Props', () => {
it('should invoke onClose when close button is clicked', () => {
const handleClose = vi.fn()
render(<Header onClose={handleClose} />)
renderHeader(handleClose)
fireEvent.click(screen.getByRole('button'))
@@ -32,7 +41,7 @@ describe('Header', () => {
describe('Edge Cases', () => {
it('should render structural elements with translation keys', () => {
const { container } = render(<Header onClose={vi.fn()} />)
const { container } = renderHeader(vi.fn())
expect(container.querySelector('span')).toBeInTheDocument()
expect(container.querySelector('p')).toBeInTheDocument()

View File

@@ -74,15 +74,11 @@ describe('Pricing', () => {
})
describe('Props', () => {
it('should allow switching categories and handle esc key', () => {
const handleCancel = vi.fn()
render(<Pricing onCancel={handleCancel} />)
it('should allow switching categories', () => {
render(<Pricing onCancel={vi.fn()} />)
fireEvent.click(screen.getByText('billing.plansCommon.self'))
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
fireEvent.keyDown(window, { key: 'Escape', keyCode: 27 })
expect(handleCancel).toHaveBeenCalled()
})
})

View File

@@ -1,10 +1,9 @@
import type { Category } from '.'
import { RiArrowRightUpLine } from '@remixicon/react'
import type { Category } from './types'
import Link from 'next/link'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'
import { CategoryEnum } from '.'
import { CategoryEnum } from './types'
type FooterProps = {
pricingPageURL: string
@@ -34,7 +33,7 @@ const Footer = ({
>
{t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })}
</Link>
<RiArrowRightUpLine className="size-4" />
<span aria-hidden="true" className="i-ri-arrow-right-up-line size-4" />
</span>
</div>
</div>

View File

@@ -1,6 +1,6 @@
import { RiCloseLine } from '@remixicon/react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog'
import Button from '../../base/button'
import DifyLogo from '../../base/logo/dify-logo'
@@ -20,19 +20,19 @@ const Header = ({
<div className="py-[5px]">
<DifyLogo className="h-[27px] w-[60px]" />
</div>
<span className="bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
<DialogTitle className="m-0 bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
{t('plansCommon.title.plans', { ns: 'billing' })}
</span>
</DialogTitle>
</div>
<p className="system-sm-regular text-text-tertiary">
<DialogDescription className="m-0 text-text-tertiary system-sm-regular">
{t('plansCommon.title.description', { ns: 'billing' })}
</p>
</DialogDescription>
<Button
variant="secondary"
className="absolute bottom-[40.5px] right-[-18px] z-10 size-9 rounded-full p-2"
onClick={onClose}
>
<RiCloseLine className="size-5" />
<span aria-hidden="true" className="i-ri-close-line size-5" />
</Button>
</div>
</div>

View File

@@ -1,9 +1,9 @@
'use client'
import type { FC } from 'react'
import { useKeyPress } from 'ahooks'
import type { Category } from './types'
import * as React from 'react'
import { useState } from 'react'
import { createPortal } from 'react-dom'
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
import { useAppContext } from '@/context/app-context'
import { useGetPricingPageLanguage } from '@/context/i18n'
import { useProviderContext } from '@/context/provider-context'
@@ -13,13 +13,7 @@ import Header from './header'
import PlanSwitcher from './plan-switcher'
import { PlanRange } from './plan-switcher/plan-range-switcher'
import Plans from './plans'
export enum CategoryEnum {
CLOUD = 'cloud',
SELF = 'self',
}
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF
import { CategoryEnum } from './types'
type PricingProps = {
onCancel: () => void
@@ -33,42 +27,47 @@ const Pricing: FC<PricingProps> = ({
const [planRange, setPlanRange] = React.useState<PlanRange>(PlanRange.monthly)
const [currentCategory, setCurrentCategory] = useState<Category>(CategoryEnum.CLOUD)
const canPay = isCurrentWorkspaceManager
useKeyPress(['esc'], onCancel)
const pricingPageLanguage = useGetPricingPageLanguage()
const pricingPageURL = pricingPageLanguage
? `https://dify.ai/${pricingPageLanguage}/pricing#plans-and-features`
: 'https://dify.ai/pricing#plans-and-features'
return createPortal(
<div
className="fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] overflow-auto bg-saas-background"
onClick={e => e.stopPropagation()}
return (
<Dialog
open
onOpenChange={(open) => {
if (!open)
onCancel()
}}
>
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
<div className="absolute -top-12 left-0 right-0 -z-10">
<NoiseTop />
<DialogContent
className="inset-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-auto rounded-none border-none bg-saas-background p-0 shadow-none"
>
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
<div className="absolute -top-12 left-0 right-0 -z-10">
<NoiseTop />
</div>
<Header onClose={onCancel} />
<PlanSwitcher
currentCategory={currentCategory}
onChangeCategory={setCurrentCategory}
currentPlanRange={planRange}
onChangePlanRange={setPlanRange}
/>
<Plans
plan={plan}
currentPlan={currentCategory}
planRange={planRange}
canPay={canPay}
/>
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
<div className="absolute -bottom-12 left-0 right-0 -z-10">
<NoiseBottom />
</div>
</div>
<Header onClose={onCancel} />
<PlanSwitcher
currentCategory={currentCategory}
onChangeCategory={setCurrentCategory}
currentPlanRange={planRange}
onChangePlanRange={setPlanRange}
/>
<Plans
plan={plan}
currentPlan={currentCategory}
planRange={planRange}
canPay={canPay}
/>
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
<div className="absolute -bottom-12 left-0 right-0 -z-10">
<NoiseBottom />
</div>
</div>
</div>,
document.body,
</DialogContent>
</Dialog>
)
}
export default React.memo(Pricing)

View File

@@ -1,6 +1,6 @@
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { CategoryEnum } from '../../index'
import { CategoryEnum } from '../../types'
import PlanSwitcher from '../index'
import { PlanRange } from '../plan-range-switcher'

View File

@@ -1,5 +1,5 @@
import type { FC } from 'react'
import type { Category } from '../index'
import type { Category } from '../types'
import type { PlanRange } from './plan-range-switcher'
import * as React from 'react'
import { useTranslation } from 'react-i18next'

View File

@@ -0,0 +1,6 @@
export enum CategoryEnum {
CLOUD = 'cloud',
SELF = 'self',
}
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF

View File

@@ -204,7 +204,7 @@ const CSVUploader: FC<Props> = ({
/>
<div ref={dropRef}>
{!file && (
<div className={cn('flex h-20 items-center rounded-xl border border-dashed border-components-panel-border bg-components-panel-bg-blur text-sm font-normal', dragging && 'border border-divider-subtle bg-components-panel-on-panel-item-bg-hover')}>
<div className={cn('flex h-20 items-center rounded-xl border border-dashed border-components-panel-border bg-components-panel-bg-blur text-sm font-normal', dragging && 'border border-divider-subtle bg-components-panel-on-panel-item-bg-hover')}>
<div className="flex w-full items-center justify-center space-x-2">
<CSVIcon className="shrink-0" />
<div className="text-text-secondary">

View File

@@ -58,7 +58,7 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
<button
type="button"
className="system-xs-semibold text-text-accent"
className="text-text-accent system-xs-semibold"
onClick={() => {
clearTimeout(refreshTimer.current)
viewNewlyAddedChildChunk?.()
@@ -120,11 +120,11 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
<div className="flex h-full flex-col">
<div className={cn('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}>
<div className="flex flex-col">
<div className="system-xl-semibold text-text-primary">{t('segment.addChildChunk', { ns: 'datasetDocuments' })}</div>
<div className="text-text-primary system-xl-semibold">{t('segment.addChildChunk', { ns: 'datasetDocuments' })}</div>
<div className="flex items-center gap-x-2">
<SegmentIndexTag label={t('segment.newChildChunk', { ns: 'datasetDocuments' }) as string} />
<Dot />
<span className="system-xs-medium text-text-tertiary">{wordCountText}</span>
<span className="text-text-tertiary system-xs-medium">{wordCountText}</span>
</div>
</div>
<div className="flex items-center">

View File

@@ -61,7 +61,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
<button
type="button"
className="system-xs-semibold text-text-accent"
className="text-text-accent system-xs-semibold"
onClick={() => {
clearTimeout(refreshTimer.current)
viewNewlyAddedChunk()
@@ -158,13 +158,13 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
className={cn('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}
>
<div className="flex flex-col">
<div className="system-xl-semibold text-text-primary">
<div className="text-text-primary system-xl-semibold">
{t('segment.addChunk', { ns: 'datasetDocuments' })}
</div>
<div className="flex items-center gap-x-2">
<SegmentIndexTag label={t('segment.newChunk', { ns: 'datasetDocuments' })!} />
<Dot />
<span className="system-xs-medium text-text-tertiary">{wordCountText}</span>
<span className="text-text-tertiary system-xs-medium">{wordCountText}</span>
</div>
</div>
<div className="flex items-center">

View File

@@ -100,10 +100,10 @@ vi.mock('@/app/components/datasets/create/step-two', () => ({
}))
vi.mock('@/app/components/header/account-setting', () => ({
default: ({ activeTab, onCancel }: { activeTab?: string, onCancel?: () => void }) => (
default: ({ activeTab, onCancelAction }: { activeTab?: string, onCancelAction?: () => void }) => (
<div data-testid="account-setting">
<span data-testid="active-tab">{activeTab}</span>
<button onClick={onCancel} data-testid="close-setting">Close</button>
<button onClick={onCancelAction} data-testid="close-setting">Close</button>
</div>
),
}))

View File

@@ -1,3 +1,4 @@
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
import type { DataSourceProvider, NotionPage } from '@/models/common'
import type {
CrawlOptions,
@@ -19,6 +20,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable'
import Loading from '@/app/components/base/loading'
import StepTwo from '@/app/components/datasets/create/step-two'
import AccountSetting from '@/app/components/header/account-setting'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import DatasetDetailContext from '@/context/dataset-detail'
@@ -33,8 +35,13 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
const { t } = useTranslation()
const router = useRouter()
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
const [accountSettingTab, setAccountSettingTab] = React.useState<AccountSettingTab>(ACCOUNT_SETTING_TAB.PROVIDER)
const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
const handleOpenAccountSetting = React.useCallback(() => {
setAccountSettingTab(ACCOUNT_SETTING_TAB.PROVIDER)
showSetAPIKey()
}, [showSetAPIKey])
const invalidDocumentList = useInvalidDocumentList(datasetId)
const invalidDocumentDetail = useInvalidDocumentDetail()
@@ -135,7 +142,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
{dataset && documentDetail && (
<StepTwo
isAPIKeySet={!!embeddingsDefaultModel}
onSetting={showSetAPIKey}
onSetting={handleOpenAccountSetting}
datasetId={datasetId}
dataSourceType={documentDetail.data_source_type as DataSourceType}
notionPages={currentPage ? [currentPage as unknown as NotionPage] : []}
@@ -155,8 +162,9 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
</div>
{isShowSetAPIKey && (
<AccountSetting
activeTab="provider"
onCancel={async () => {
activeTab={accountSettingTab}
onTabChangeAction={setAccountSettingTab}
onCancelAction={async () => {
hideSetAPIkey()
}}
/>

View File

@@ -120,13 +120,13 @@ const AddExternalAPIModal: FC<AddExternalAPIModalProps> = ({ data, onSave, onCan
<div className="fixed inset-0 flex items-center justify-center bg-black/[.25]">
<div className="shadows-shadow-xl relative flex w-[480px] flex-col items-start rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg">
<div className="flex flex-col items-start gap-2 self-stretch pb-3 pl-6 pr-14 pt-6">
<div className="title-2xl-semi-bold grow self-stretch text-text-primary">
<div className="grow self-stretch text-text-primary title-2xl-semi-bold">
{
isEditMode ? t('editExternalAPIFormTitle', { ns: 'dataset' }) : t('createExternalAPI', { ns: 'dataset' })
}
</div>
{isEditMode && (datasetBindings?.length ?? 0) > 0 && (
<div className="system-xs-regular flex items-center text-text-tertiary">
<div className="flex items-center text-text-tertiary system-xs-regular">
{t('editExternalAPIFormWarning.front', { ns: 'dataset' })}
<span className="flex cursor-pointer items-center text-text-accent">
&nbsp;
@@ -139,12 +139,12 @@ const AddExternalAPIModal: FC<AddExternalAPIModalProps> = ({ data, onSave, onCan
popupContent={(
<div className="p-1">
<div className="flex items-start self-stretch pb-0.5 pl-2 pr-3 pt-1">
<div className="system-xs-medium-uppercase text-text-tertiary">{`${datasetBindings?.length} ${t('editExternalAPITooltipTitle', { ns: 'dataset' })}`}</div>
<div className="text-text-tertiary system-xs-medium-uppercase">{`${datasetBindings?.length} ${t('editExternalAPITooltipTitle', { ns: 'dataset' })}`}</div>
</div>
{datasetBindings?.map(binding => (
<div key={binding.id} className="flex items-center gap-1 self-stretch px-2 py-1">
<RiBook2Line className="h-4 w-4 text-text-secondary" />
<div className="system-sm-medium text-text-secondary">{binding.name}</div>
<div className="text-text-secondary system-sm-medium">{binding.name}</div>
</div>
))}
</div>
@@ -188,8 +188,8 @@ const AddExternalAPIModal: FC<AddExternalAPIModalProps> = ({ data, onSave, onCan
{t('externalAPIForm.save', { ns: 'dataset' })}
</Button>
</div>
<div className="system-xs-regular flex items-center justify-center gap-1 self-stretch rounded-b-2xl border-t-[0.5px]
border-divider-subtle bg-background-soft px-2 py-3 text-text-tertiary"
<div className="flex items-center justify-center gap-1 self-stretch rounded-b-2xl border-t-[0.5px] border-divider-subtle
bg-background-soft px-2 py-3 text-text-tertiary system-xs-regular"
>
<RiLock2Fill className="h-3 w-3 text-text-quaternary" />
{t('externalAPIForm.encrypted.front', { ns: 'dataset' })}

View File

@@ -63,7 +63,7 @@ const SummaryIndexSetting = ({
return (
<div>
<div className="flex h-6 items-center justify-between">
<div className="system-sm-semibold-uppercase flex items-center text-text-secondary">
<div className="flex items-center text-text-secondary system-sm-semibold-uppercase">
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
<Tooltip
triggerClassName="ml-1 h-4 w-4 shrink-0"
@@ -80,7 +80,7 @@ const SummaryIndexSetting = ({
{
summaryIndexSetting?.enable && (
<div>
<div className="system-xs-medium-uppercase mb-1.5 mt-2 flex h-6 items-center text-text-tertiary">
<div className="mb-1.5 mt-2 flex h-6 items-center text-text-tertiary system-xs-medium-uppercase">
{t('form.summaryModel', { ns: 'datasetSettings' })}
</div>
<ModelSelector
@@ -90,7 +90,7 @@ const SummaryIndexSetting = ({
readonly={readonly}
showDeprecatedWarnIcon
/>
<div className="system-xs-medium-uppercase mt-3 flex h-6 items-center text-text-tertiary">
<div className="mt-3 flex h-6 items-center text-text-tertiary system-xs-medium-uppercase">
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
</div>
<Textarea
@@ -111,12 +111,12 @@ const SummaryIndexSetting = ({
<div className="space-y-4">
<div className="flex gap-x-1">
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
<div className="system-sm-semibold text-text-secondary">
<div className="text-text-secondary system-sm-semibold">
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
</div>
</div>
<div className="py-1.5">
<div className="system-sm-semibold flex items-center text-text-secondary">
<div className="flex items-center text-text-secondary system-sm-semibold">
<Switch
className="mr-2"
value={summaryIndexSetting?.enable ?? false}
@@ -127,7 +127,7 @@ const SummaryIndexSetting = ({
summaryIndexSetting?.enable ? t('list.status.enabled', { ns: 'datasetDocuments' }) : t('list.status.disabled', { ns: 'datasetDocuments' })
}
</div>
<div className="system-sm-regular mt-2 text-text-tertiary">
<div className="mt-2 text-text-tertiary system-sm-regular">
{
summaryIndexSetting?.enable && t('form.summaryAutoGenTip', { ns: 'datasetSettings' })
}
@@ -142,7 +142,7 @@ const SummaryIndexSetting = ({
<>
<div className="flex gap-x-1">
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
<div className="system-sm-medium text-text-tertiary">
<div className="text-text-tertiary system-sm-medium">
{t('form.summaryModel', { ns: 'datasetSettings' })}
</div>
</div>
@@ -159,7 +159,7 @@ const SummaryIndexSetting = ({
</div>
<div className="flex">
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
<div className="system-sm-medium text-text-tertiary">
<div className="text-text-tertiary system-sm-medium">
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
</div>
</div>
@@ -188,7 +188,7 @@ const SummaryIndexSetting = ({
onChange={handleSummaryIndexEnableChange}
size="md"
/>
<div className="system-sm-semibold text-text-secondary">
<div className="text-text-secondary system-sm-semibold">
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
</div>
</div>
@@ -196,7 +196,7 @@ const SummaryIndexSetting = ({
summaryIndexSetting?.enable && (
<>
<div>
<div className="system-sm-medium mb-1.5 flex h-6 items-center text-text-secondary">
<div className="mb-1.5 flex h-6 items-center text-text-secondary system-sm-medium">
{t('form.summaryModel', { ns: 'datasetSettings' })}
</div>
<ModelSelector
@@ -209,7 +209,7 @@ const SummaryIndexSetting = ({
/>
</div>
<div>
<div className="system-sm-medium mb-1.5 flex h-6 items-center text-text-secondary">
<div className="mb-1.5 flex h-6 items-center text-text-secondary system-sm-medium">
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
</div>
<Textarea

View File

@@ -46,7 +46,7 @@ const WorkplaceSelector = () => {
<span className="h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90">{currentWorkspace?.name[0]?.toLocaleUpperCase()}</span>
</div>
<div className="flex min-w-0 items-center">
<div className="system-sm-medium min-w-0 max-w-[149px] truncate text-text-secondary max-[800px]:hidden">{currentWorkspace?.name}</div>
<div className="min-w-0 max-w-[149px] truncate text-text-secondary system-sm-medium max-[800px]:hidden">{currentWorkspace?.name}</div>
<RiArrowDownSLine className="h-4 w-4 shrink-0 text-text-secondary" />
</div>
</MenuButton>
@@ -68,9 +68,9 @@ const WorkplaceSelector = () => {
`,
)}
>
<div className="flex w-full flex-col items-start self-stretch rounded-xl border-[0.5px] border-components-panel-border p-1 pb-2 shadow-lg ">
<div className="flex w-full flex-col items-start self-stretch rounded-xl border-[0.5px] border-components-panel-border p-1 pb-2 shadow-lg">
<div className="flex items-start self-stretch px-3 pb-0.5 pt-1">
<span className="system-xs-medium-uppercase flex-1 text-text-tertiary">{t('userProfile.workspace', { ns: 'common' })}</span>
<span className="flex-1 text-text-tertiary system-xs-medium-uppercase">{t('userProfile.workspace', { ns: 'common' })}</span>
</div>
{
workspaces.map(workspace => (
@@ -78,7 +78,7 @@ const WorkplaceSelector = () => {
<div className="flex h-6 w-6 shrink-0 items-center justify-center rounded-md bg-components-icon-bg-blue-solid text-[13px]">
<span className="h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90">{workspace?.name[0]?.toLocaleUpperCase()}</span>
</div>
<div className="system-md-regular line-clamp-1 grow cursor-pointer overflow-hidden text-ellipsis text-text-secondary">{workspace.name}</div>
<div className="line-clamp-1 grow cursor-pointer overflow-hidden text-ellipsis text-text-secondary system-md-regular">{workspace.name}</div>
<PlanBadge plan={workspace.plan as Plan} />
</div>
))

View File

@@ -1,12 +1,16 @@
import type { AccountSettingTab } from './constants'
import type { AppContextValue } from '@/context/app-context'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { fireEvent, render, screen } from '@testing-library/react'
import { useState } from 'react'
import { useAppContext } from '@/context/app-context'
import { baseProviderContextValue, useProviderContext } from '@/context/provider-context'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { ACCOUNT_SETTING_TAB } from './constants'
import AccountSetting from './index'
const mockResetModelProviderListExpanded = vi.fn()
vi.mock('@/context/provider-context', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/context/provider-context')>()
return {
@@ -47,10 +51,15 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', ()
useDefaultModel: vi.fn(() => ({ data: null, isLoading: false })),
useUpdateDefaultModel: vi.fn(() => ({ trigger: vi.fn() })),
useUpdateModelList: vi.fn(() => vi.fn()),
useInvalidateDefaultModel: vi.fn(() => vi.fn()),
useModelList: vi.fn(() => ({ data: [], isLoading: false })),
useSystemDefaultModelAndModelList: vi.fn(() => [null, vi.fn()]),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/atoms', () => ({
useResetModelProviderListExpanded: () => mockResetModelProviderListExpanded,
}))
vi.mock('@/service/use-datasource', () => ({
useGetDataSourceListAuth: vi.fn(() => ({ data: { result: [] } })),
}))
@@ -105,6 +114,38 @@ const baseAppContextValue: AppContextValue = {
describe('AccountSetting', () => {
const mockOnCancel = vi.fn()
const mockOnTabChange = vi.fn()
const renderAccountSetting = (props?: {
initialTab?: AccountSettingTab
onCancel?: () => void
onTabChange?: (tab: AccountSettingTab) => void
}) => {
const {
initialTab = ACCOUNT_SETTING_TAB.MEMBERS,
onCancel = mockOnCancel,
onTabChange = mockOnTabChange,
} = props ?? {}
const StatefulAccountSetting = () => {
const [activeTab, setActiveTab] = useState<AccountSettingTab>(initialTab)
return (
<AccountSetting
onCancelAction={onCancel}
activeTab={activeTab}
onTabChangeAction={(tab) => {
setActiveTab(tab)
onTabChange(tab)
}}
/>
)
}
return render(
<QueryClientProvider client={new QueryClient()}>
<StatefulAccountSetting />
</QueryClientProvider>,
)
}
beforeEach(() => {
vi.clearAllMocks()
@@ -120,11 +161,7 @@ describe('AccountSetting', () => {
describe('Rendering', () => {
it('should render the sidebar with correct menu items', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
expect(screen.getByText('common.userProfile.settings')).toBeInTheDocument()
@@ -137,13 +174,9 @@ describe('AccountSetting', () => {
expect(screen.getAllByText('common.settings.language').length).toBeGreaterThan(0)
})
it('should respect the activeTab prop', () => {
it('should respect the initial tab', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} activeTab={ACCOUNT_SETTING_TAB.DATA_SOURCE} />
</QueryClientProvider>,
)
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.DATA_SOURCE })
// Assert
// Check that the active item title is Data Source
@@ -157,11 +190,7 @@ describe('AccountSetting', () => {
vi.mocked(useBreakpoints).mockReturnValue(MediaType.mobile)
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
// On mobile, the labels should not be rendered as per the implementation
@@ -176,11 +205,7 @@ describe('AccountSetting', () => {
})
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
expect(screen.queryByText('common.settings.provider')).not.toBeInTheDocument()
@@ -197,11 +222,7 @@ describe('AccountSetting', () => {
})
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Assert
expect(screen.queryByText('common.settings.billing')).not.toBeInTheDocument()
@@ -212,11 +233,7 @@ describe('AccountSetting', () => {
describe('Tab Navigation', () => {
it('should change active tab when clicking on menu item', () => {
// Arrange
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} onTabChange={mockOnTabChange} />
</QueryClientProvider>,
)
renderAccountSetting({ onTabChange: mockOnTabChange })
// Act
fireEvent.click(screen.getByText('common.settings.provider'))
@@ -229,11 +246,7 @@ describe('AccountSetting', () => {
it('should navigate through various tabs and show correct details', () => {
// Act & Assert
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
// Billing
fireEvent.click(screen.getByText('common.settings.billing'))
@@ -267,13 +280,11 @@ describe('AccountSetting', () => {
describe('Interactions', () => {
it('should call onCancel when clicking close button', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
const buttons = screen.getAllByRole('button')
fireEvent.click(buttons[0])
renderAccountSetting()
const closeIcon = document.querySelector('.i-ri-close-line')
const closeButton = closeIcon?.closest('button')
expect(closeButton).not.toBeNull()
fireEvent.click(closeButton!)
// Assert
expect(mockOnCancel).toHaveBeenCalled()
@@ -281,11 +292,7 @@ describe('AccountSetting', () => {
it('should call onCancel when pressing Escape key', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
fireEvent.keyDown(document, { key: 'Escape' })
// Assert
@@ -294,12 +301,7 @@ describe('AccountSetting', () => {
it('should update search value in provider tab', () => {
// Arrange
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
fireEvent.click(screen.getByText('common.settings.provider'))
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.PROVIDER })
// Act
const input = screen.getByRole('textbox')
@@ -312,11 +314,7 @@ describe('AccountSetting', () => {
it('should handle scroll event in panel', () => {
// Act
render(
<QueryClientProvider client={new QueryClient()}>
<AccountSetting onCancel={mockOnCancel} />
</QueryClientProvider>,
)
renderAccountSetting()
const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto')
// Assert

View File

@@ -1,6 +1,6 @@
'use client'
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
import { useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import SearchInput from '@/app/components/base/search-input'
import BillingPage from '@/app/components/billing/billing-page'
@@ -20,15 +20,16 @@ import DataSourcePage from './data-source-page-new'
import LanguagePage from './language-page'
import MembersPage from './members-page'
import ModelProviderPage from './model-provider-page'
import { useResetModelProviderListExpanded } from './model-provider-page/atoms'
const iconClassName = `
w-5 h-5 mr-2
`
type IAccountSettingProps = {
onCancel: () => void
activeTab?: AccountSettingTab
onTabChange?: (tab: AccountSettingTab) => void
onCancelAction: () => void
activeTab: AccountSettingTab
onTabChangeAction: (tab: AccountSettingTab) => void
}
type GroupItem = {
@@ -40,14 +41,12 @@ type GroupItem = {
}
export default function AccountSetting({
onCancel,
activeTab = ACCOUNT_SETTING_TAB.MEMBERS,
onTabChange,
onCancelAction,
activeTab,
onTabChangeAction,
}: IAccountSettingProps) {
const [activeMenu, setActiveMenu] = useState<AccountSettingTab>(activeTab)
useEffect(() => {
setActiveMenu(activeTab)
}, [activeTab])
const resetModelProviderListExpanded = useResetModelProviderListExpanded()
const activeMenu = activeTab
const { t } = useTranslation()
const { enableBilling, enableReplaceWebAppLogo } = useProviderContext()
const { isCurrentWorkspaceDatasetOperator } = useAppContext()
@@ -148,10 +147,22 @@ export default function AccountSetting({
const [searchValue, setSearchValue] = useState<string>('')
const handleTabChange = useCallback((tab: AccountSettingTab) => {
if (tab === ACCOUNT_SETTING_TAB.PROVIDER)
resetModelProviderListExpanded()
onTabChangeAction(tab)
}, [onTabChangeAction, resetModelProviderListExpanded])
const handleClose = useCallback(() => {
resetModelProviderListExpanded()
onCancelAction()
}, [onCancelAction, resetModelProviderListExpanded])
return (
<MenuDialog
show
onClose={onCancel}
onClose={handleClose}
>
<div className="mx-auto flex h-[100vh] max-w-[1048px]">
<div className="flex w-[44px] flex-col border-r border-divider-burn pl-4 pr-6 sm:w-[224px]">
@@ -166,21 +177,22 @@ export default function AccountSetting({
<div>
{
menuItem.items.map(item => (
<div
<button
type="button"
key={item.key}
className={cn(
'mb-0.5 flex h-[37px] cursor-pointer items-center rounded-lg p-1 pl-3 text-sm',
'mb-0.5 flex h-[37px] w-full items-center rounded-lg p-1 pl-3 text-left text-sm',
activeMenu === item.key ? 'bg-state-base-active text-components-menu-item-text-active system-sm-semibold' : 'text-components-menu-item-text system-sm-medium',
)}
aria-label={item.name}
title={item.name}
onClick={() => {
setActiveMenu(item.key)
onTabChange?.(item.key)
handleTabChange(item.key)
}}
>
{activeMenu === item.key ? item.activeIcon : item.icon}
{!isMobile && <div className="truncate">{item.name}</div>}
</div>
</button>
))
}
</div>
@@ -195,7 +207,8 @@ export default function AccountSetting({
variant="tertiary"
size="large"
className="px-2"
onClick={onCancel}
aria-label={t('operation.close', { ns: 'common' })}
onClick={handleClose}
>
<span className="i-ri-close-line h-5 w-5" />
</Button>

View File

@@ -97,7 +97,7 @@ const Operation = ({
offset={{ mainAxis: 4 }}
>
<PortalToFollowElemTrigger asChild onClick={() => setOpen(prev => !prev)}>
<div className={cn('system-sm-regular group flex h-full w-full cursor-pointer items-center justify-between px-3 text-text-secondary hover:bg-state-base-hover', open && 'bg-state-base-hover')}>
<div className={cn('group flex h-full w-full cursor-pointer items-center justify-between px-3 text-text-secondary system-sm-regular hover:bg-state-base-hover', open && 'bg-state-base-hover')}>
{RoleMap[member.role] || RoleMap.normal}
<ChevronDownIcon className={cn('h-4 w-4 shrink-0 group-hover:block', open ? 'block' : 'hidden')} />
</div>
@@ -114,8 +114,8 @@ const Operation = ({
: <div className="mr-1 mt-[2px] h-4 w-4 text-text-accent" />
}
<div>
<div className="system-sm-semibold whitespace-nowrap text-text-secondary">{t(roleI18nKeyMap[role].label, { ns: 'common' })}</div>
<div className="system-xs-regular whitespace-nowrap text-text-tertiary">{t(roleI18nKeyMap[role].tip, { ns: 'common' })}</div>
<div className="whitespace-nowrap text-text-secondary system-sm-semibold">{t(roleI18nKeyMap[role].label, { ns: 'common' })}</div>
<div className="whitespace-nowrap text-text-tertiary system-xs-regular">{t(roleI18nKeyMap[role].tip, { ns: 'common' })}</div>
</div>
</div>
))
@@ -125,8 +125,8 @@ const Operation = ({
<div className="flex cursor-pointer rounded-lg px-3 py-2 hover:bg-state-base-hover" onClick={handleDeleteMemberOrCancelInvitation}>
<div className="mr-1 mt-[2px] h-4 w-4 text-text-accent" />
<div>
<div className="system-sm-semibold whitespace-nowrap text-text-secondary">{t('members.removeFromTeam', { ns: 'common' })}</div>
<div className="system-xs-regular whitespace-nowrap text-text-tertiary">{t('members.removeFromTeamTip', { ns: 'common' })}</div>
<div className="whitespace-nowrap text-text-secondary system-sm-semibold">{t('members.removeFromTeam', { ns: 'common' })}</div>
<div className="whitespace-nowrap text-text-tertiary system-xs-regular">{t('members.removeFromTeamTip', { ns: 'common' })}</div>
</div>
</div>
</div>

View File

@@ -40,8 +40,7 @@ describe('MenuDialog', () => {
)
// Assert
const panel = screen.getByRole('dialog').querySelector('.custom-class')
expect(panel).toBeInTheDocument()
expect(screen.getByRole('dialog')).toHaveClass('custom-class')
})
})

View File

@@ -1,7 +1,6 @@
import type { ReactNode } from 'react'
import { Dialog, DialogPanel, Transition, TransitionChild } from '@headlessui/react'
import { noop } from 'es-toolkit/function'
import { Fragment, useCallback, useEffect } from 'react'
import { useCallback } from 'react'
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
import { cn } from '@/utils/classnames'
type DialogProps = {
@@ -19,42 +18,25 @@ const MenuDialog = ({
}: DialogProps) => {
const close = useCallback(() => onClose?.(), [onClose])
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
event.preventDefault()
close()
}
}
document.addEventListener('keydown', handleKeyDown)
return () => {
document.removeEventListener('keydown', handleKeyDown)
}
}, [close])
return (
<Transition appear show={show} as={Fragment}>
<Dialog as="div" className="relative z-[60]" onClose={noop}>
<div className="fixed inset-0">
<div className="flex min-h-full flex-col items-center justify-center">
<TransitionChild>
<DialogPanel className={cn(
'relative h-full w-full grow overflow-hidden bg-background-sidenav-bg p-0 text-left align-middle backdrop-blur-md transition-all',
'duration-300 ease-in data-[closed]:scale-95 data-[closed]:opacity-0',
'data-[enter]:scale-100 data-[enter]:opacity-100',
'data-[enter]:scale-95 data-[leave]:opacity-0',
className,
)}
>
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
{children}
</DialogPanel>
</TransitionChild>
</div>
</div>
</Dialog>
</Transition>
<Dialog
open={show}
onOpenChange={(open) => {
if (!open)
close()
}}
>
<DialogContent
overlayClassName="bg-transparent"
className={cn(
'left-0 top-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-hidden rounded-none border-none bg-background-sidenav-bg p-0 shadow-none backdrop-blur-md',
className,
)}
>
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
{children}
</DialogContent>
</Dialog>
)
}

View File

@@ -0,0 +1,399 @@
import type { ReactNode } from 'react'
import { act, renderHook } from '@testing-library/react'
import { Provider } from 'jotai'
import { beforeEach, describe, expect, it } from 'vitest'
import {
useExpandModelProviderList,
useModelProviderListExpanded,
useResetModelProviderListExpanded,
useSetModelProviderListExpanded,
} from './atoms'
const createWrapper = () => {
return ({ children }: { children: ReactNode }) => (
<Provider>{children}</Provider>
)
}
describe('atoms', () => {
let wrapper: ReturnType<typeof createWrapper>
beforeEach(() => {
wrapper = createWrapper()
})
// Read hook: returns whether a specific provider is expanded
describe('useModelProviderListExpanded', () => {
it('should return false when provider has not been expanded', () => {
const { result } = renderHook(
() => useModelProviderListExpanded('openai'),
{ wrapper },
)
expect(result.current).toBe(false)
})
it('should return false for any unknown provider name', () => {
const { result } = renderHook(
() => useModelProviderListExpanded('nonexistent-provider'),
{ wrapper },
)
expect(result.current).toBe(false)
})
it('should return true when provider has been expanded via setter', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
expect(result.current.expanded).toBe(true)
})
})
// Setter hook: toggles expanded state for a specific provider
describe('useSetModelProviderListExpanded', () => {
it('should expand a provider when called with true', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('anthropic'),
setExpanded: useSetModelProviderListExpanded('anthropic'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
expect(result.current.expanded).toBe(true)
})
it('should collapse a provider when called with false', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('anthropic'),
setExpanded: useSetModelProviderListExpanded('anthropic'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
act(() => {
result.current.setExpanded(false)
})
expect(result.current.expanded).toBe(false)
})
it('should not affect other providers when setting one', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
setOpenai: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.setOpenai(true)
})
expect(result.current.openaiExpanded).toBe(true)
expect(result.current.anthropicExpanded).toBe(false)
})
})
// Expand hook: expands any provider by name
describe('useExpandModelProviderList', () => {
it('should expand the specified provider', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('google'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('google')
})
expect(result.current.expanded).toBe(true)
})
it('should expand multiple providers independently', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.expand('anthropic')
})
expect(result.current.openaiExpanded).toBe(true)
expect(result.current.anthropicExpanded).toBe(true)
})
it('should not collapse already expanded providers when expanding another', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.expand('anthropic')
})
expect(result.current.openaiExpanded).toBe(true)
})
})
// Reset hook: clears all expanded state back to empty
describe('useResetModelProviderListExpanded', () => {
it('should reset all expanded providers to false', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.expand('anthropic')
})
act(() => {
result.current.reset()
})
expect(result.current.openaiExpanded).toBe(false)
expect(result.current.anthropicExpanded).toBe(false)
})
it('should be safe to call when no providers are expanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.reset()
})
expect(result.current.expanded).toBe(false)
})
it('should allow re-expanding providers after reset', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
expand: useExpandModelProviderList(),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.reset()
})
act(() => {
result.current.expand('openai')
})
expect(result.current.expanded).toBe(true)
})
})
// Cross-hook interaction: verify hooks cooperate through the shared atom
describe('Cross-hook interaction', () => {
it('should reflect state set by useSetModelProviderListExpanded in useModelProviderListExpanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
expect(result.current.expanded).toBe(true)
})
it('should reflect state set by useExpandModelProviderList in useModelProviderListExpanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('anthropic')
})
expect(result.current.expanded).toBe(true)
})
it('should allow useSetModelProviderListExpanded to collapse a provider expanded by useExpandModelProviderList', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
expand: useExpandModelProviderList(),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
expect(result.current.expanded).toBe(true)
act(() => {
result.current.setExpanded(false)
})
expect(result.current.expanded).toBe(false)
})
it('should reset state set by useSetModelProviderListExpanded via useResetModelProviderListExpanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
act(() => {
result.current.reset()
})
expect(result.current.expanded).toBe(false)
})
})
// selectAtom granularity: changing one provider should not affect unrelated reads
describe('selectAtom granularity', () => {
it('should not cause unrelated provider reads to change when one provider is toggled', () => {
const { result } = renderHook(
() => ({
openai: useModelProviderListExpanded('openai'),
anthropic: useModelProviderListExpanded('anthropic'),
google: useModelProviderListExpanded('google'),
setOpenai: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
const anthropicBefore = result.current.anthropic
const googleBefore = result.current.google
act(() => {
result.current.setOpenai(true)
})
expect(result.current.openai).toBe(true)
expect(result.current.anthropic).toBe(anthropicBefore)
expect(result.current.google).toBe(googleBefore)
})
it('should keep individual provider states independent across multiple expansions and collapses', () => {
const { result } = renderHook(
() => ({
openai: useModelProviderListExpanded('openai'),
anthropic: useModelProviderListExpanded('anthropic'),
setOpenai: useSetModelProviderListExpanded('openai'),
setAnthropic: useSetModelProviderListExpanded('anthropic'),
}),
{ wrapper },
)
act(() => {
result.current.setOpenai(true)
})
act(() => {
result.current.setAnthropic(true)
})
act(() => {
result.current.setOpenai(false)
})
expect(result.current.openai).toBe(false)
expect(result.current.anthropic).toBe(true)
})
})
// Isolation: separate Provider instances have independent state
describe('Provider isolation', () => {
it('should have independent state across different Provider instances', () => {
const wrapper1 = createWrapper()
const wrapper2 = createWrapper()
const { result: result1 } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper: wrapper1 },
)
const { result: result2 } = renderHook(
() => useModelProviderListExpanded('openai'),
{ wrapper: wrapper2 },
)
act(() => {
result1.current.setExpanded(true)
})
expect(result1.current.expanded).toBe(true)
expect(result2.current).toBe(false)
})
})
})

View File

@@ -0,0 +1,35 @@
import { atom, useAtomValue, useSetAtom } from 'jotai'
import { selectAtom } from 'jotai/utils'
import { useCallback, useMemo } from 'react'
const expandedAtom = atom<Record<string, boolean>>({})
export function useModelProviderListExpanded(providerName: string) {
return useAtomValue(
useMemo(
() => selectAtom(expandedAtom, s => !!s[providerName]),
[providerName],
),
)
}
export function useSetModelProviderListExpanded(providerName: string) {
const set = useSetAtom(expandedAtom)
return useCallback(
(expanded: boolean) => set(prev => ({ ...prev, [providerName]: expanded })),
[providerName, set],
)
}
export function useExpandModelProviderList() {
const set = useSetAtom(expandedAtom)
return useCallback(
(providerName: string) => set(prev => ({ ...prev, [providerName]: true })),
[set],
)
}
export function useResetModelProviderListExpanded() {
const set = useSetAtom(expandedAtom)
return useCallback(() => set({}), [set])
}

View File

@@ -9,6 +9,7 @@ import type {
} from './declarations'
import { act, renderHook, waitFor } from '@testing-library/react'
import { useLocale } from '@/context/i18n'
import { consoleQuery } from '@/service/client'
import { fetchDefaultModal, fetchModelList, fetchModelProviderCredentials } from '@/service/common'
import {
ConfigurationMethodEnum,
@@ -23,6 +24,7 @@ import {
useAnthropicBuyQuota,
useCurrentProviderAndModel,
useDefaultModel,
useInvalidateDefaultModel,
useLanguage,
useMarketplaceAllPlugins,
useModelList,
@@ -36,7 +38,6 @@ import {
useUpdateModelList,
useUpdateModelProviders,
} from './hooks'
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
// Mock dependencies
vi.mock('@tanstack/react-query', () => ({
@@ -78,14 +79,6 @@ vi.mock('@/context/modal-context', () => ({
}),
}))
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: vi.fn(() => ({
eventEmitter: {
emit: vi.fn(),
},
})),
}))
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
useMarketplacePlugins: vi.fn(() => ({
plugins: [],
@@ -99,12 +92,16 @@ vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
})),
}))
vi.mock('./atoms', () => ({
useExpandModelProviderList: vi.fn(() => vi.fn()),
}))
const { useQuery, useQueryClient } = await import('@tanstack/react-query')
const { getPayUrl } = await import('@/service/common')
const { useProviderContext } = await import('@/context/provider-context')
const { useModalContextSelector } = await import('@/context/modal-context')
const { useEventEmitterContextContext } = await import('@/context/event-emitter')
const { useMarketplacePlugins, useMarketplacePluginsByCollectionId } = await import('@/app/components/plugins/marketplace/hooks')
const { useExpandModelProviderList } = await import('./atoms')
describe('hooks', () => {
beforeEach(() => {
@@ -913,6 +910,38 @@ describe('hooks', () => {
})
})
describe('useInvalidateDefaultModel', () => {
it('should invalidate default model queries', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
const { result } = renderHook(() => useInvalidateDefaultModel())
act(() => {
result.current(ModelTypeEnum.textGeneration)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: ['default-model', ModelTypeEnum.textGeneration],
})
})
it('should handle multiple model types', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
const { result } = renderHook(() => useInvalidateDefaultModel())
act(() => {
result.current(ModelTypeEnum.textGeneration)
result.current(ModelTypeEnum.textEmbedding)
result.current(ModelTypeEnum.rerank)
})
expect(invalidateQueries).toHaveBeenCalledTimes(3)
})
})
describe('useAnthropicBuyQuota', () => {
beforeEach(() => {
Object.defineProperty(window, 'location', {
@@ -1275,39 +1304,52 @@ describe('hooks', () => {
it('should refresh providers and model lists', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
const provider = createMockProvider()
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
act(() => {
result.current.handleRefreshModel(provider)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-providers'] })
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textEmbedding] })
})
it('should emit event when refreshModelList is true and custom config is active', () => {
it('should expand target provider list when refreshModelList is true and custom config is active', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
const expandModelProviderList = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
const provider = createMockProvider()
const customFields: CustomConfigurationModelFixedFields = {
__model_name: 'gpt-4',
__model_type: ModelTypeEnum.textGeneration,
}
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
@@ -1315,23 +1357,30 @@ describe('hooks', () => {
result.current.handleRefreshModel(provider, customFields, true)
})
expect(emit).toHaveBeenCalledWith({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: 'openai',
expect(expandModelProviderList).toHaveBeenCalledWith('openai')
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
})
it('should not emit event when custom config is not active', () => {
it('should not expand provider list when custom config is not active', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
const expandModelProviderList = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
const provider = { ...createMockProvider(), custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure } }
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
@@ -1339,17 +1388,43 @@ describe('hooks', () => {
result.current.handleRefreshModel(provider, undefined, true)
})
expect(emit).not.toHaveBeenCalled()
expect(expandModelProviderList).not.toHaveBeenCalled()
expect(invalidateQueries).not.toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
})
it('should emit event and invalidate all supported model types when __model_type is undefined', () => {
it('should refetch active model provider list when custom refresh callback is absent', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
const provider = createMockProvider()
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
act(() => {
result.current.handleRefreshModel(provider, undefined, true)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
})
it('should invalidate all supported model types when __model_type is undefined', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
const provider = createMockProvider()
const customFields = { __model_name: 'my-model', __model_type: undefined } as unknown as CustomConfigurationModelFixedFields
@@ -1360,11 +1435,7 @@ describe('hooks', () => {
result.current.handleRefreshModel(provider, customFields, true)
})
expect(emit).toHaveBeenCalledWith({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: 'openai',
})
// When __model_type is undefined, all supported model types are invalidated
// When __model_type is undefined, all supported model types are invalidated.
const modelListCalls = invalidateQueries.mock.calls.filter(
call => call[0]?.queryKey?.[0] === 'model-list',
)
@@ -1375,9 +1446,6 @@ describe('hooks', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit: vi.fn() },
})
const provider = {
...createMockProvider(),

View File

@@ -21,10 +21,10 @@ import {
useMarketplacePluginsByCollectionId,
} from '@/app/components/plugins/marketplace/hooks'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useLocale } from '@/context/i18n'
import { useModalContextSelector } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
import { consoleQuery } from '@/service/client'
import {
fetchDefaultModal,
fetchModelList,
@@ -32,12 +32,12 @@ import {
getPayUrl,
} from '@/service/common'
import { commonQueryKeys } from '@/service/use-common'
import { useExpandModelProviderList } from './atoms'
import {
ConfigurationMethodEnum,
CustomConfigurationStatusEnum,
ModelStatusEnum,
} from './declarations'
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
type UseDefaultModelAndModelList = (
defaultModel: DefaultModelResponse | undefined,
@@ -57,15 +57,21 @@ export const useSystemDefaultModelAndModelList: UseDefaultModelAndModelList = (
return currentDefaultModel
}, [defaultModel, modelList])
const currentDefaultModelKey = currentDefaultModel
? `${currentDefaultModel.provider}:${currentDefaultModel.model}`
: ''
const [defaultModelState, setDefaultModelState] = useState<DefaultModel | undefined>(currentDefaultModel)
const handleDefaultModelChange = useCallback((model: DefaultModel) => {
setDefaultModelState(model)
}, [])
useEffect(() => {
setDefaultModelState(currentDefaultModel)
}, [currentDefaultModel])
const [defaultModelSourceKey, setDefaultModelSourceKey] = useState(currentDefaultModelKey)
const selectedDefaultModel = defaultModelSourceKey === currentDefaultModelKey
? defaultModelState
: currentDefaultModel
return [defaultModelState, handleDefaultModelChange]
const handleDefaultModelChange = useCallback((model: DefaultModel) => {
setDefaultModelSourceKey(currentDefaultModelKey)
setDefaultModelState(model)
}, [currentDefaultModelKey])
return [selectedDefaultModel, handleDefaultModelChange]
}
export const useLanguage = () => {
@@ -116,7 +122,7 @@ export const useProviderCredentialsAndLoadBalancing = (
predefinedFormSchemasValue?.credentials,
])
const mutate = useMemo(() => () => {
const mutate = useCallback(() => {
if (predefinedEnabled)
queryClient.invalidateQueries({ queryKey: ['model-providers', 'credentials', provider, credentialId] })
if (customEnabled)
@@ -222,6 +228,14 @@ export const useUpdateModelList = () => {
return updateModelList
}
export const useInvalidateDefaultModel = () => {
const queryClient = useQueryClient()
return useCallback((type: ModelTypeEnum) => {
queryClient.invalidateQueries({ queryKey: commonQueryKeys.defaultModel(type) })
}, [queryClient])
}
export const useAnthropicBuyQuota = () => {
const [loading, setLoading] = useState(false)
@@ -314,7 +328,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
}
export const useRefreshModel = () => {
const { eventEmitter } = useEventEmitterContextContext()
const expandModelProviderList = useExpandModelProviderList()
const queryClient = useQueryClient()
const updateModelProviders = useUpdateModelProviders()
const updateModelList = useUpdateModelList()
const handleRefreshModel = useCallback((
@@ -322,6 +337,19 @@ export const useRefreshModel = () => {
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
refreshModelList?: boolean,
) => {
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
queryClient.invalidateQueries({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
updateModelProviders()
provider.supported_model_types.forEach((type) => {
@@ -329,15 +357,17 @@ export const useRefreshModel = () => {
})
if (refreshModelList && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
eventEmitter?.emit({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: provider.provider,
} as any)
expandModelProviderList(provider.provider)
queryClient.invalidateQueries({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
if (CustomConfigurationModelFixedFields?.__model_type)
updateModelList(CustomConfigurationModelFixedFields.__model_type)
}
}, [eventEmitter, updateModelList, updateModelProviders])
}, [expandModelProviderList, queryClient, updateModelList, updateModelProviders])
return {
handleRefreshModel,

View File

@@ -7,16 +7,7 @@ import {
} from './declarations'
import ModelProviderPage from './index'
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
mutateCurrentWorkspace: vi.fn(),
isValidatingCurrentWorkspace: false,
}),
}))
const mockGlobalState = {
systemFeatures: { enable_marketplace: true },
}
let mockEnableMarketplace = true
const mockQuotaConfig = {
quota_type: CurrentSystemQuotaTypeEnum.free,
@@ -28,7 +19,11 @@ const mockQuotaConfig = {
}
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: (selector: (s: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector(mockGlobalState),
useSystemFeaturesQuery: () => ({
data: {
enable_marketplace: mockEnableMarketplace,
},
}),
}))
const mockProviders = [
@@ -60,21 +55,16 @@ vi.mock('@/context/provider-context', () => ({
}),
}))
type MockDefaultModelData = {
model: string
provider?: { provider: string }
} | null
const mockDefaultModelState: {
data: MockDefaultModelData
isLoading: boolean
} = {
data: null,
isLoading: false,
const mockDefaultModels: Record<string, { data: unknown, isLoading: boolean }> = {
'llm': { data: null, isLoading: false },
'text-embedding': { data: null, isLoading: false },
'rerank': { data: null, isLoading: false },
'speech2text': { data: null, isLoading: false },
'tts': { data: null, isLoading: false },
}
vi.mock('./hooks', () => ({
useDefaultModel: () => mockDefaultModelState,
useDefaultModel: (type: string) => mockDefaultModels[type] ?? { data: null, isLoading: false },
}))
vi.mock('./install-from-marketplace', () => ({
@@ -93,13 +83,18 @@ vi.mock('./system-model-selector', () => ({
default: () => <div data-testid="system-model-selector" />,
}))
vi.mock('@/service/use-plugins', () => ({
useCheckInstalled: () => ({ data: undefined }),
}))
describe('ModelProviderPage', () => {
beforeEach(() => {
vi.useFakeTimers()
vi.clearAllMocks()
mockGlobalState.systemFeatures.enable_marketplace = true
mockDefaultModelState.data = null
mockDefaultModelState.isLoading = false
mockEnableMarketplace = true
Object.keys(mockDefaultModels).forEach((key) => {
mockDefaultModels[key] = { data: null, isLoading: false }
})
mockProviders.splice(0, mockProviders.length, {
provider: 'openai',
label: { en_US: 'OpenAI' },
@@ -157,13 +152,76 @@ describe('ModelProviderPage', () => {
})
it('should hide marketplace section when marketplace feature is disabled', () => {
mockGlobalState.systemFeatures.enable_marketplace = false
mockEnableMarketplace = false
render(<ModelProviderPage searchText="" />)
expect(screen.queryByTestId('install-from-marketplace')).not.toBeInTheDocument()
})
describe('system model config status', () => {
it('should not show top warning when no configured providers exist (empty state card handles it)', () => {
mockProviders.splice(0, mockProviders.length, {
provider: 'anthropic',
label: { en_US: 'Anthropic' },
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
system_configuration: {
enabled: false,
current_quota_type: CurrentSystemQuotaTypeEnum.free,
quota_configurations: [mockQuotaConfig],
},
})
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
expect(screen.getByText('common.modelProvider.emptyProviderTitle')).toBeInTheDocument()
})
it('should show none-configured warning when providers exist but no default models set', () => {
render(<ModelProviderPage searchText="" />)
expect(screen.getByText('common.modelProvider.noneConfigured')).toBeInTheDocument()
})
it('should show partially-configured warning when some default models are set', () => {
mockDefaultModels.llm = {
data: { model: 'gpt-4', model_type: 'llm', provider: { provider: 'openai', icon_small: { en_US: '' } } },
isLoading: false,
}
render(<ModelProviderPage searchText="" />)
expect(screen.getByText('common.modelProvider.notConfigured')).toBeInTheDocument()
})
it('should not show warning when all default models are configured', () => {
const makeModel = (model: string, type: string) => ({
data: { model, model_type: type, provider: { provider: 'openai', icon_small: { en_US: '' } } },
isLoading: false,
})
mockDefaultModels.llm = makeModel('gpt-4', 'llm')
mockDefaultModels['text-embedding'] = makeModel('text-embedding-3', 'text-embedding')
mockDefaultModels.rerank = makeModel('rerank-v3', 'rerank')
mockDefaultModels.speech2text = makeModel('whisper-1', 'speech2text')
mockDefaultModels.tts = makeModel('tts-1', 'tts')
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
it('should not show warning while loading', () => {
Object.keys(mockDefaultModels).forEach((key) => {
mockDefaultModels[key] = { data: null, isLoading: true }
})
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
})
it('should prioritize fixed providers in visible order', () => {
mockProviders.splice(0, mockProviders.length, {
provider: 'zeta-provider',
@@ -204,129 +262,4 @@ describe('ModelProviderPage', () => {
])
expect(screen.queryByText('common.modelProvider.toBeConfigured')).not.toBeInTheDocument()
})
it('should show not configured alert when all default models are absent', () => {
mockDefaultModelState.data = null
mockDefaultModelState.isLoading = false
render(<ModelProviderPage searchText="" />)
expect(screen.getByText('common.modelProvider.notConfigured')).toBeInTheDocument()
})
it('should not show not configured alert when default model is loading', () => {
mockDefaultModelState.data = null
mockDefaultModelState.isLoading = true
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
it('should filter providers by label text', () => {
render(<ModelProviderPage searchText="OpenAI" />)
act(() => {
vi.advanceTimersByTime(600)
})
expect(screen.getByText('openai')).toBeInTheDocument()
expect(screen.queryByText('anthropic')).not.toBeInTheDocument()
})
it('should classify system-enabled providers with matching quota as configured', () => {
mockProviders.splice(0, mockProviders.length, {
provider: 'sys-provider',
label: { en_US: 'System Provider' },
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
system_configuration: {
enabled: true,
current_quota_type: CurrentSystemQuotaTypeEnum.free,
quota_configurations: [mockQuotaConfig],
},
})
render(<ModelProviderPage searchText="" />)
expect(screen.getByText('sys-provider')).toBeInTheDocument()
expect(screen.queryByText('common.modelProvider.toBeConfigured')).not.toBeInTheDocument()
})
it('should classify system-enabled provider with no matching quota as not configured', () => {
mockProviders.splice(0, mockProviders.length, {
provider: 'sys-no-quota',
label: { en_US: 'System No Quota' },
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
system_configuration: {
enabled: true,
current_quota_type: CurrentSystemQuotaTypeEnum.free,
quota_configurations: [],
},
})
render(<ModelProviderPage searchText="" />)
expect(screen.getByText('sys-no-quota')).toBeInTheDocument()
expect(screen.getByText('common.modelProvider.toBeConfigured')).toBeInTheDocument()
})
it('should preserve order of two non-fixed providers (sort returns 0)', () => {
mockProviders.splice(0, mockProviders.length, {
provider: 'alpha-provider',
label: { en_US: 'Alpha Provider' },
custom_configuration: { status: CustomConfigurationStatusEnum.active },
system_configuration: {
enabled: false,
current_quota_type: CurrentSystemQuotaTypeEnum.free,
quota_configurations: [mockQuotaConfig],
},
}, {
provider: 'beta-provider',
label: { en_US: 'Beta Provider' },
custom_configuration: { status: CustomConfigurationStatusEnum.active },
system_configuration: {
enabled: false,
current_quota_type: CurrentSystemQuotaTypeEnum.free,
quota_configurations: [mockQuotaConfig],
},
})
render(<ModelProviderPage searchText="" />)
const renderedProviders = screen.getAllByTestId('provider-card').map(item => item.textContent)
expect(renderedProviders).toEqual(['alpha-provider', 'beta-provider'])
})
it('should not show not configured alert when shared default model mock has data', () => {
mockDefaultModelState.data = { model: 'embed-model' }
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
it('should not show not configured alert when rerankDefaultModel has data', () => {
mockDefaultModelState.data = { model: 'rerank-model', provider: { provider: 'cohere' } }
mockDefaultModelState.isLoading = false
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
it('should not show not configured alert when ttsDefaultModel has data', () => {
mockDefaultModelState.data = { model: 'tts-model', provider: { provider: 'openai' } }
mockDefaultModelState.isLoading = false
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
it('should not show not configured alert when speech2textDefaultModel has data', () => {
mockDefaultModelState.data = { model: 'whisper', provider: { provider: 'openai' } }
mockDefaultModelState.isLoading = false
render(<ModelProviderPage searchText="" />)
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
})
})

View File

@@ -1,17 +1,14 @@
import type {
ModelProvider,
} from './declarations'
import {
RiAlertFill,
RiBrainLine,
} from '@remixicon/react'
import type { PluginDetail } from '@/app/components/plugins/types'
import { useDebounce } from 'ahooks'
import { useEffect, useMemo } from 'react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { IS_CLOUD_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useSystemFeaturesQuery } from '@/context/global-public-context'
import { useProviderContext } from '@/context/provider-context'
import { useCheckInstalled } from '@/service/use-plugins'
import { cn } from '@/utils/classnames'
import {
CustomConfigurationStatusEnum,
@@ -24,6 +21,9 @@ import InstallFromMarketplace from './install-from-marketplace'
import ProviderAddedCard from './provider-added-card'
import QuotaPanel from './provider-added-card/quota-panel'
import SystemModelSelector from './system-model-selector'
import { providerToPluginId } from './utils'
type SystemModelConfigStatus = 'no-provider' | 'none-configured' | 'partially-configured' | 'fully-configured'
type Props = {
searchText: string
@@ -34,20 +34,35 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an
const ModelProviderPage = ({ searchText }: Props) => {
const debouncedSearchText = useDebounce(searchText, { wait: 500 })
const { t } = useTranslation()
const { mutateCurrentWorkspace, isValidatingCurrentWorkspace } = useAppContext()
const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration)
const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding)
const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank)
const { data: speech2textDefaultModel, isLoading: isSpeech2textDefaultModelLoading } = useDefaultModel(ModelTypeEnum.speech2text)
const { data: ttsDefaultModel, isLoading: isTTSDefaultModelLoading } = useDefaultModel(ModelTypeEnum.tts)
const { modelProviders: providers } = useProviderContext()
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
const { data: systemFeatures } = useSystemFeaturesQuery()
const allPluginIds = useMemo(() => {
return [...new Set(providers.map(p => providerToPluginId(p.provider)).filter(Boolean))]
}, [providers])
const { data: installedPlugins } = useCheckInstalled({
pluginIds: allPluginIds,
enabled: allPluginIds.length > 0,
})
const pluginDetailMap = useMemo(() => {
const map = new Map<string, PluginDetail>()
if (installedPlugins?.plugins) {
for (const plugin of installedPlugins.plugins)
map.set(plugin.plugin_id, plugin)
}
return map
}, [installedPlugins])
const enableMarketplace = systemFeatures?.enable_marketplace ?? false
const isDefaultModelLoading = isTextGenerationDefaultModelLoading
|| isEmbeddingsDefaultModelLoading
|| isRerankDefaultModelLoading
|| isSpeech2textDefaultModelLoading
|| isTTSDefaultModelLoading
const defaultModelNotConfigured = !isDefaultModelLoading && !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel
const [configuredProviders, notConfiguredProviders] = useMemo(() => {
const configuredProviders: ModelProvider[] = []
const notConfiguredProviders: ModelProvider[] = []
@@ -79,6 +94,26 @@ const ModelProviderPage = ({ searchText }: Props) => {
return [configuredProviders, notConfiguredProviders]
}, [providers])
const systemModelConfigStatus: SystemModelConfigStatus = useMemo(() => {
const defaultModels = [textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel]
const configuredCount = defaultModels.filter(Boolean).length
if (configuredCount === 0 && configuredProviders.length === 0)
return 'no-provider'
if (configuredCount === 0)
return 'none-configured'
if (configuredCount < defaultModels.length)
return 'partially-configured'
return 'fully-configured'
}, [configuredProviders, textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel])
const warningTextKey
= systemModelConfigStatus === 'none-configured'
? 'modelProvider.noneConfigured'
: systemModelConfigStatus === 'partially-configured'
? 'modelProvider.notConfigured'
: null
const showWarning = !isDefaultModelLoading && !!warningTextKey
const [filteredConfiguredProviders, filteredNotConfiguredProviders] = useMemo(() => {
const filteredConfiguredProviders = configuredProviders.filter(
provider => provider.provider.toLowerCase().includes(debouncedSearchText.toLowerCase())
@@ -92,28 +127,24 @@ const ModelProviderPage = ({ searchText }: Props) => {
return [filteredConfiguredProviders, filteredNotConfiguredProviders]
}, [configuredProviders, debouncedSearchText, notConfiguredProviders])
useEffect(() => {
mutateCurrentWorkspace()
}, [mutateCurrentWorkspace])
return (
<div className="relative -mt-2 pt-1">
<div className={cn('mb-2 flex items-center')}>
<div className="grow text-text-primary system-md-semibold">{t('modelProvider.models', { ns: 'common' })}</div>
<div className={cn(
'relative flex shrink-0 items-center justify-end gap-2 rounded-lg border border-transparent p-px',
defaultModelNotConfigured && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
showWarning && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
)}
>
{defaultModelNotConfigured && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
{defaultModelNotConfigured && (
{showWarning && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
{showWarning && (
<div className="flex items-center gap-1 text-text-primary system-xs-medium">
<RiAlertFill className="h-4 w-4 text-text-warning-secondary" />
<span className="max-w-[460px] truncate" title={t('modelProvider.notConfigured', { ns: 'common' })}>{t('modelProvider.notConfigured', { ns: 'common' })}</span>
<span className="i-ri-alert-fill h-4 w-4 text-text-warning-secondary" />
<span className="max-w-[460px] truncate" title={t(warningTextKey, { ns: 'common' })}>{t(warningTextKey, { ns: 'common' })}</span>
</div>
)}
<SystemModelSelector
notConfigured={defaultModelNotConfigured}
notConfigured={showWarning}
textGenerationDefaultModel={textGenerationDefaultModel}
embeddingsDefaultModel={embeddingsDefaultModel}
rerankDefaultModel={rerankDefaultModel}
@@ -123,11 +154,11 @@ const ModelProviderPage = ({ searchText }: Props) => {
/>
</div>
</div>
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} isLoading={isValidatingCurrentWorkspace} />}
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} />}
{!filteredConfiguredProviders?.length && (
<div className="mb-2 rounded-[10px] bg-workflow-process-bg p-4">
<div className="flex h-10 w-10 items-center justify-center rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg backdrop-blur">
<RiBrainLine className="h-5 w-5 text-text-primary" />
<span className="i-ri-brain-line h-5 w-5 text-text-primary" />
</div>
<div className="mt-2 text-text-secondary system-sm-medium">{t('modelProvider.emptyProviderTitle', { ns: 'common' })}</div>
<div className="mt-1 text-text-tertiary system-xs-regular">{t('modelProvider.emptyProviderTip', { ns: 'common' })}</div>
@@ -139,6 +170,7 @@ const ModelProviderPage = ({ searchText }: Props) => {
<ProviderAddedCard
key={provider.provider}
provider={provider}
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
/>
))}
</div>
@@ -152,13 +184,14 @@ const ModelProviderPage = ({ searchText }: Props) => {
notConfigured
key={provider.provider}
provider={provider}
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
/>
))}
</div>
</>
)}
{
enable_marketplace && (
enableMarketplace && (
<InstallFromMarketplace
providers={providers}
searchText={searchText}

View File

@@ -2,10 +2,6 @@ import type {
ModelProvider,
} from './declarations'
import type { Plugin } from '@/app/components/plugins/types'
import {
RiArrowDownSLine,
RiArrowRightUpLine,
} from '@remixicon/react'
import { useTheme } from 'next-themes'
import Link from 'next/link'
import { useCallback, useState } from 'react'
@@ -47,15 +43,15 @@ const InstallFromMarketplace = ({
<div className="mb-2">
<Divider className="!mt-4 h-px" />
<div className="flex items-center justify-between">
<div className="system-md-semibold flex cursor-pointer items-center gap-1 text-text-primary" onClick={() => setCollapse(!collapse)}>
<RiArrowDownSLine className={cn('h-4 w-4', collapse && '-rotate-90')} />
<div className="flex cursor-pointer items-center gap-1 text-text-primary system-md-semibold" onClick={() => setCollapse(!collapse)}>
<span className={cn('i-ri-arrow-down-s-line h-4 w-4', collapse && '-rotate-90')} />
{t('modelProvider.installProvider', { ns: 'common' })}
</div>
<div className="mb-2 flex items-center pt-2">
<span className="system-sm-regular pr-1 text-text-tertiary">{t('modelProvider.discoverMore', { ns: 'common' })}</span>
<Link target="_blank" href={getMarketplaceUrl('', { theme })} className="system-sm-medium inline-flex items-center text-text-accent">
<span className="pr-1 text-text-tertiary system-sm-regular">{t('modelProvider.discoverMore', { ns: 'common' })}</span>
<Link target="_blank" href={getMarketplaceUrl('', { theme })} className="inline-flex items-center text-text-accent system-sm-medium">
{t('marketplace.difyMarketplace', { ns: 'plugin' })}
<RiArrowRightUpLine className="h-4 w-4" />
<span className="i-ri-arrow-right-up-line h-4 w-4" />
</Link>
</div>
</div>

View File

@@ -2,12 +2,6 @@ import type { Credential } from '../../declarations'
import { fireEvent, render, screen } from '@testing-library/react'
import CredentialItem from './credential-item'
vi.mock('@remixicon/react', () => ({
RiCheckLine: () => <div data-testid="check-icon" />,
RiDeleteBinLine: () => <div data-testid="delete-icon" />,
RiEqualizer2Line: () => <div data-testid="edit-icon" />,
}))
vi.mock('@/app/components/header/indicator', () => ({
default: () => <div data-testid="indicator" />,
}))
@@ -61,8 +55,12 @@ describe('CredentialItem', () => {
render(<CredentialItem credential={credential} onEdit={onEdit} onDelete={onDelete} />)
fireEvent.click(screen.getByTestId('edit-icon').closest('button') as HTMLButtonElement)
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
const buttons = screen.getAllByRole('button')
const editButton = buttons.find(b => b.querySelector('.i-ri-equalizer-2-line'))!
const deleteButton = buttons.find(b => b.querySelector('.i-ri-delete-bin-line'))!
fireEvent.click(editButton)
fireEvent.click(deleteButton)
expect(onEdit).toHaveBeenCalledWith(credential)
expect(onDelete).toHaveBeenCalledWith(credential)
@@ -81,7 +79,10 @@ describe('CredentialItem', () => {
/>,
)
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
const deleteButton = screen.getAllByRole('button')
.find(b => b.querySelector('.i-ri-delete-bin-line'))!
fireEvent.click(deleteButton)
expect(onDelete).not.toHaveBeenCalled()
})
@@ -121,14 +122,16 @@ describe('CredentialItem', () => {
render(<CredentialItem credential={credential} disabled onDelete={onDelete} />)
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
const deleteButton = screen.getAllByRole('button')
.find(b => b.querySelector('.i-ri-delete-bin-line'))!
fireEvent.click(deleteButton)
expect(onDelete).not.toHaveBeenCalled()
})
// showSelectedIcon=true: check icon area is always rendered; check icon only appears when IDs match
it('should render check icon area when showSelectedIcon=true and selectedCredentialId matches', () => {
render(
const { container } = render(
<CredentialItem
credential={credential}
showSelectedIcon
@@ -136,7 +139,7 @@ describe('CredentialItem', () => {
/>,
)
expect(screen.getByTestId('check-icon')).toBeInTheDocument()
expect(container.querySelector('.i-ri-check-line')).toBeInTheDocument()
})
it('should not render check icon when showSelectedIcon=true but selectedCredentialId does not match', () => {

View File

@@ -1,9 +1,4 @@
import type { Credential } from '../../declarations'
import {
RiCheckLine,
RiDeleteBinLine,
RiEqualizer2Line,
} from '@remixicon/react'
import {
memo,
useMemo,
@@ -11,7 +6,7 @@ import {
import { useTranslation } from 'react-i18next'
import ActionButton from '@/app/components/base/action-button'
import Badge from '@/app/components/base/badge'
import Tooltip from '@/app/components/base/tooltip'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import Indicator from '@/app/components/header/indicator'
import { cn } from '@/utils/classnames'
@@ -56,7 +51,7 @@ const CredentialItem = ({
key={credential.credential_id}
className={cn(
'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover',
(disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50',
(disabled || credential.not_allowed_to_use) ? 'cursor-not-allowed opacity-50' : onItemClick && 'cursor-pointer',
)}
onClick={() => {
if (disabled || credential.not_allowed_to_use)
@@ -70,7 +65,7 @@ const CredentialItem = ({
<div className="h-4 w-4">
{
selectedCredentialId === credential.credential_id && (
<RiCheckLine className="h-4 w-4 text-text-accent" />
<span className="i-ri-check-line h-4 w-4 text-text-accent" />
)
}
</div>
@@ -78,7 +73,7 @@ const CredentialItem = ({
}
<Indicator className="ml-2 mr-1.5 shrink-0" />
<div
className="system-md-regular truncate text-text-secondary"
className="truncate text-text-secondary system-md-regular"
title={credential.credential_name}
>
{credential.credential_name}
@@ -96,38 +91,50 @@ const CredentialItem = ({
<div className="ml-2 hidden shrink-0 items-center group-hover:flex">
{
!disableEdit && !credential.not_allowed_to_use && (
<Tooltip popupContent={t('operation.edit', { ns: 'common' })}>
<ActionButton
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
onEdit?.(credential)
}}
>
<RiEqualizer2Line className="h-4 w-4 text-text-tertiary" />
</ActionButton>
<Tooltip>
<TooltipTrigger
render={(
<ActionButton
disabled={disabled}
onClick={(e) => {
e.stopPropagation()
onEdit?.(credential)
}}
>
<span className="i-ri-equalizer-2-line h-4 w-4 text-text-tertiary" />
</ActionButton>
)}
/>
<TooltipContent>{t('operation.edit', { ns: 'common' })}</TooltipContent>
</Tooltip>
)
}
{
!disableDelete && (
<Tooltip popupContent={disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}>
<ActionButton
className="hover:bg-transparent"
onClick={(e) => {
if (disabled || disableDeleteWhenSelected)
return
e.stopPropagation()
onDelete?.(credential)
}}
>
<RiDeleteBinLine className={cn(
'h-4 w-4 text-text-tertiary',
!disableDeleteWhenSelected && 'hover:text-text-destructive',
disableDeleteWhenSelected && 'opacity-50',
<Tooltip>
<TooltipTrigger
render={(
<ActionButton
className="hover:bg-transparent"
onClick={(e) => {
if (disabled || disableDeleteWhenSelected)
return
e.stopPropagation()
onDelete?.(credential)
}}
>
<span className={cn(
'i-ri-delete-bin-line h-4 w-4 text-text-tertiary',
!disableDeleteWhenSelected && 'hover:text-text-destructive',
disableDeleteWhenSelected && 'opacity-50',
)}
/>
</ActionButton>
)}
/>
</ActionButton>
/>
<TooltipContent>
{disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}
</TooltipContent>
</Tooltip>
)
}
@@ -139,8 +146,9 @@ const CredentialItem = ({
if (credential.not_allowed_to_use) {
return (
<Tooltip popupContent={t('auth.customCredentialUnavailable', { ns: 'plugin' })}>
{Item}
<Tooltip>
<TooltipTrigger render={Item} />
<TooltipContent>{t('auth.customCredentialUnavailable', { ns: 'plugin' })}</TooltipContent>
</Tooltip>
)
}

View File

@@ -53,4 +53,14 @@ describe('useCredentialStatus', () => {
expect(result.current.hasCredential).toBe(false)
expect(result.current.available_credentials).toBeUndefined()
})
it('handles undefined provider gracefully', () => {
const { result } = renderHook(() => useCredentialStatus(undefined))
expect(result.current.hasCredential).toBe(false)
expect(result.current.authorized).toBeFalsy()
expect(result.current.authRemoved).toBe(false)
expect(result.current.available_credentials).toBeUndefined()
expect(result.current.current_credential_id).toBeUndefined()
expect(result.current.current_credential_name).toBeUndefined()
})
})

View File

@@ -3,12 +3,12 @@ import type {
} from '../../declarations'
import { useMemo } from 'react'
export const useCredentialStatus = (provider: ModelProvider) => {
export const useCredentialStatus = (provider: ModelProvider | undefined) => {
const {
current_credential_id,
current_credential_name,
available_credentials,
} = provider.custom_configuration
} = provider?.custom_configuration ?? {}
const hasCredential = !!available_credentials?.length
const authorized = current_credential_id && current_credential_name
const authRemoved = hasCredential && !current_credential_id && !current_credential_name

View File

@@ -10,7 +10,7 @@ const ModelBadge: FC<ModelBadgeProps> = ({
children,
}) => {
return (
<div className={cn('system-2xs-medium-uppercase flex h-[18px] cursor-default items-center rounded-[5px] border border-divider-deep px-1 text-text-tertiary', className)}>
<div className={cn('inline-flex h-[18px] shrink-0 items-center justify-center whitespace-nowrap rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-[5px] text-text-tertiary system-2xs-medium-uppercase', className)}>
{children}
</div>
)

View File

@@ -1,7 +1,5 @@
import type { ComponentProps } from 'react'
import type { Credential, CredentialFormSchema, CustomModel, ModelProvider } from '../declarations'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import type { Credential, CredentialFormSchema, ModelProvider } from '../declarations'
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react'
import {
ConfigurationMethodEnum,
CurrentSystemQuotaTypeEnum,
@@ -45,6 +43,15 @@ const mockHandlers = vi.hoisted(() => ({
handleActiveCredential: vi.fn(),
}))
type FormResponse = {
isCheckValidated: boolean
values: Record<string, unknown>
}
const mockFormState = vi.hoisted(() => ({
responses: [] as FormResponse[],
setFieldValue: vi.fn(),
}))
vi.mock('../model-auth/hooks', () => ({
useCredentialData: () => ({
isLoading: mockState.isLoading,
@@ -79,6 +86,36 @@ vi.mock('../hooks', () => ({
useLanguage: () => 'en_US',
}))
vi.mock('@/app/components/base/form/form-scenarios/auth', async () => {
const React = await import('react')
const AuthForm = React.forwardRef(({
onChange,
}: {
onChange?: (field: string, value: string) => void
}, ref: React.ForwardedRef<{ getFormValues: () => FormResponse, getForm: () => { setFieldValue: (field: string, value: string) => void } }>) => {
React.useImperativeHandle(ref, () => ({
getFormValues: () => mockFormState.responses.shift() || { isCheckValidated: false, values: {} },
getForm: () => ({ setFieldValue: mockFormState.setFieldValue }),
}))
return (
<div>
<button type="button" onClick={() => onChange?.('__model_name', 'updated-model')}>Model Name Change</button>
</div>
)
})
return { default: AuthForm }
})
vi.mock('../model-auth', () => ({
CredentialSelector: ({ onSelect }: { onSelect: (credential: Credential & { addNewCredential?: boolean }) => void }) => (
<div>
<button type="button" onClick={() => onSelect({ credential_id: 'existing' })}>Choose Existing</button>
<button type="button" onClick={() => onSelect({ credential_id: 'new', addNewCredential: true })}>Add New</button>
</div>
),
}))
const createI18n = (text: string) => ({ en_US: text, zh_Hans: text })
const createProvider = (overrides?: Partial<ModelProvider>): ModelProvider => ({
@@ -121,7 +158,7 @@ const createProvider = (overrides?: Partial<ModelProvider>): ModelProvider => ({
...overrides,
})
const renderModal = (overrides?: Partial<ComponentProps<typeof ModelModal>>) => {
const renderModal = (overrides?: Partial<React.ComponentProps<typeof ModelModal>>) => {
const provider = createProvider()
const props = {
provider,
@@ -131,50 +168,13 @@ const renderModal = (overrides?: Partial<ComponentProps<typeof ModelModal>>) =>
onRemove: vi.fn(),
...overrides,
}
render(<ModelModal {...props} />)
return props
const view = render(<ModelModal {...props} />)
return {
...props,
unmount: view.unmount,
}
}
const mockFormRef1 = {
getFormValues: vi.fn(),
getForm: vi.fn(() => ({ setFieldValue: vi.fn() })),
}
const mockFormRef2 = {
getFormValues: vi.fn(),
getForm: vi.fn(() => ({ setFieldValue: vi.fn() })),
}
vi.mock('@/app/components/base/form/form-scenarios/auth', () => ({
default: React.forwardRef((props: { formSchemas: Record<string, unknown>[], onChange?: (f: string, v: string) => void }, ref: React.ForwardedRef<unknown>) => {
React.useImperativeHandle(ref, () => {
// Return the mock depending on schemas passed (hacky but works for refs)
if (props.formSchemas.length > 0 && props.formSchemas[0].name === '__model_name')
return mockFormRef1
return mockFormRef2
})
return (
<div data-testid="auth-form" onClick={() => props.onChange?.('test-field', 'val')}>
AuthForm Mock (
{props.formSchemas.length}
{' '}
fields)
</div>
)
}),
}))
vi.mock('../model-auth', () => ({
CredentialSelector: ({ onSelect }: { onSelect: (val: unknown) => void }) => (
<button onClick={() => onSelect({ addNewCredential: true })} data-testid="credential-selector">
Select Credential
</button>
),
useAuth: vi.fn(),
useCredentialData: vi.fn(),
useModelFormSchemas: vi.fn(),
}))
describe('ModelModal', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -187,131 +187,168 @@ describe('ModelModal', () => {
mockState.formValues = {}
mockState.modelNameAndTypeFormSchemas = []
mockState.modelNameAndTypeFormValues = {}
// reset form refs
mockFormRef1.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __model_name: 'test', __model_type: ModelTypeEnum.textGeneration } })
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __authorization_name__: 'test_auth', api_key: 'sk-test' } })
mockFormState.responses = []
})
it('should render title and loading state for predefined credential modal', () => {
it('should show title, description, and loading state for predefined models', () => {
mockState.isLoading = true
renderModal()
const predefined = renderModal()
expect(screen.getByText('common.modelProvider.auth.apiKeyModal.title')).toBeInTheDocument()
expect(screen.getByText('common.modelProvider.auth.apiKeyModal.desc')).toBeInTheDocument()
})
expect(screen.getByRole('status')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeDisabled()
it('should render model credential title when mode is configModelCredential', () => {
renderModal({
mode: ModelModalModeEnum.configModelCredential,
model: { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration },
})
predefined.unmount()
const customizable = renderModal({ configurateMethod: ConfigurationMethodEnum.customizableModel })
expect(screen.queryByText('common.modelProvider.auth.apiKeyModal.desc')).not.toBeInTheDocument()
customizable.unmount()
mockState.credentialData = { credentials: {}, available_credentials: [] }
renderModal({ mode: ModelModalModeEnum.configModelCredential, model: { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration } })
expect(screen.getByText('common.modelProvider.auth.addModelCredential')).toBeInTheDocument()
})
it('should render edit credential title when credential exists', () => {
renderModal({
mode: ModelModalModeEnum.configModelCredential,
credential: { credential_id: '1' } as unknown as Credential,
})
expect(screen.getByText('common.modelProvider.auth.editModelCredential')).toBeInTheDocument()
it('should reveal the credential label when adding a new credential', () => {
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList })
expect(screen.queryByText('common.modelProvider.auth.modelCredential')).not.toBeInTheDocument()
fireEvent.click(screen.getByText('Add New'))
expect(screen.getByText('common.modelProvider.auth.modelCredential')).toBeInTheDocument()
})
it('should change title to Add Model when mode is configCustomModel', () => {
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
renderModal({ mode: ModelModalModeEnum.configCustomModel })
expect(screen.getByText('common.modelProvider.auth.addModel')).toBeInTheDocument()
it('should call onCancel when the cancel button is clicked', () => {
const { onCancel } = renderModal()
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(onCancel).toHaveBeenCalledTimes(1)
})
it('should validate and fail save if form is invalid in configCustomModel mode', async () => {
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
mockFormRef1.getFormValues.mockReturnValue({ isCheckValidated: false, values: {} })
renderModal({ mode: ModelModalModeEnum.configCustomModel })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockHandlers.handleSaveCredential).not.toHaveBeenCalled()
})
it('should call onCancel when the escape key is pressed', () => {
const { onCancel } = renderModal()
it('should validate and save new credential and model in configCustomModel mode', async () => {
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
const props = renderModal({ mode: ModelModalModeEnum.configCustomModel })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
credential_id: undefined,
credentials: { api_key: 'sk-test' },
name: 'test_auth',
model: 'test',
model_type: ModelTypeEnum.textGeneration,
})
expect(props.onSave).toHaveBeenCalled()
})
})
it('should save credential only in standard configProviderCredential mode', async () => {
const { onSave } = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
credential_id: undefined,
credentials: { api_key: 'sk-test' },
name: 'test_auth',
})
expect(onSave).toHaveBeenCalled()
})
})
it('should save active credential and cancel when picking existing credential in addCustomModelToModelList mode', async () => {
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList, model: { model: 'm1', model_type: ModelTypeEnum.textGeneration } as unknown as CustomModel })
// By default selected is undefined so button clicks form
// Let's not click credential selector, so it evaluates without it. If selectedCredential is undefined, form validation is checked.
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: false, values: {} })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockHandlers.handleSaveCredential).not.toHaveBeenCalled()
})
it('should save active credential when picking existing credential in addCustomModelToModelList mode', async () => {
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList, model: { model: 'm2', model_type: ModelTypeEnum.textGeneration } as unknown as CustomModel })
// Select existing credential (addNewCredential: true simulates new but we can simulate false if we just hack the mocked state in the component, but it's internal.
// The credential selector sets selectedCredential.
fireEvent.click(screen.getByTestId('credential-selector')) // Sets addNewCredential = true internally, so it proceeds to form save
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __authorization_name__: 'auth', api: 'key' } })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
credential_id: undefined,
credentials: { api: 'key' },
name: 'auth',
model: 'm2',
model_type: ModelTypeEnum.textGeneration,
})
})
})
it('should open and confirm deletion of credential', () => {
mockState.credentialData = { credentials: { api_key: '123' }, available_credentials: [] }
mockState.formValues = { api_key: '123' } // To trigger isEditMode = true
const credential = { credential_id: 'c1' } as unknown as Credential
renderModal({ credential })
// Open Delete Confirm
fireEvent.click(screen.getByRole('button', { name: 'common.operation.remove' }))
expect(mockHandlers.openConfirmDelete).toHaveBeenCalledWith(credential, undefined)
// Simulate the dialog appearing and confirming
mockState.deleteCredentialId = 'c1'
renderModal({ credential }) // Re-render logic mock
fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.confirm' })[0])
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalled()
})
it('should bind escape key to cancel', () => {
const props = renderModal()
fireEvent.keyDown(document, { key: 'Escape' })
expect(props.onCancel).toHaveBeenCalled()
expect(onCancel).toHaveBeenCalledTimes(1)
})
it('should confirm deletion when a delete dialog is shown', () => {
mockState.credentialData = { credentials: { api_key: 'secret' }, available_credentials: [] }
mockState.deleteCredentialId = 'delete-id'
const credential: Credential = { credential_id: 'cred-1' }
const { onCancel } = renderModal({ credential })
const alertDialog = screen.getByRole('alertdialog', { hidden: true })
expect(alertDialog).toHaveTextContent('common.modelProvider.confirmDelete')
fireEvent.click(within(alertDialog).getByRole('button', { hidden: true, name: 'common.operation.confirm' }))
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalledTimes(1)
expect(onCancel).toHaveBeenCalledTimes(1)
})
it('should handle save flows for different modal modes', async () => {
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text-input' } as unknown as CredentialFormSchema]
mockState.formSchemas = [{ variable: 'api_key', type: 'secret-input' } as unknown as CredentialFormSchema]
mockFormState.responses = [
{ isCheckValidated: true, values: { __model_name: 'custom-model', __model_type: ModelTypeEnum.textGeneration } },
{ isCheckValidated: true, values: { __authorization_name__: 'Auth Name', api_key: 'secret' } },
]
const configCustomModel = renderModal({ mode: ModelModalModeEnum.configCustomModel })
fireEvent.click(screen.getAllByText('Model Name Change')[0])
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockFormState.setFieldValue).toHaveBeenCalledWith('__model_name', 'updated-model')
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
credential_id: undefined,
credentials: { api_key: 'secret' },
name: 'Auth Name',
model: 'custom-model',
model_type: ModelTypeEnum.textGeneration,
})
})
expect(configCustomModel.onSave).toHaveBeenCalledWith({ __authorization_name__: 'Auth Name', api_key: 'secret' })
configCustomModel.unmount()
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'Model Auth', api_key: 'abc' } }]
const model = { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration }
const configModelCredential = renderModal({
mode: ModelModalModeEnum.configModelCredential,
model,
credential: { credential_id: 'cred-123' },
})
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
credential_id: 'cred-123',
credentials: { api_key: 'abc' },
name: 'Model Auth',
model: 'gpt-4',
model_type: ModelTypeEnum.textGeneration,
})
})
expect(configModelCredential.onSave).toHaveBeenCalledWith({ __authorization_name__: 'Model Auth', api_key: 'abc' })
configModelCredential.unmount()
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'Provider Auth', api_key: 'provider-key' } }]
const configProviderCredential = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
credential_id: undefined,
credentials: { api_key: 'provider-key' },
name: 'Provider Auth',
})
})
configProviderCredential.unmount()
const addToModelList = renderModal({
mode: ModelModalModeEnum.addCustomModelToModelList,
model,
})
fireEvent.click(screen.getByText('Choose Existing'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockHandlers.handleActiveCredential).toHaveBeenCalledWith({ credential_id: 'existing' }, model)
expect(addToModelList.onCancel).toHaveBeenCalled()
addToModelList.unmount()
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'New Auth', api_key: 'new-key' } }]
const addToModelListWithNew = renderModal({
mode: ModelModalModeEnum.addCustomModelToModelList,
model,
})
fireEvent.click(screen.getByText('Add New'))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
credential_id: undefined,
credentials: { api_key: 'new-key' },
name: 'New Auth',
model: 'gpt-4',
model_type: ModelTypeEnum.textGeneration,
})
})
addToModelListWithNew.unmount()
mockFormState.responses = [{ isCheckValidated: false, values: {} }]
const invalidSave = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
await waitFor(() => {
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledTimes(4)
})
invalidSave.unmount()
mockState.credentialData = { credentials: { api_key: 'value' }, available_credentials: [] }
mockState.formValues = { api_key: 'value' }
const removable = renderModal({ credential: { credential_id: 'remove-1' } })
fireEvent.click(screen.getByRole('button', { name: 'common.operation.remove' }))
expect(mockHandlers.openConfirmDelete).toHaveBeenCalledWith({ credential_id: 'remove-1' }, undefined)
removable.unmount()
})
})

View File

@@ -9,11 +9,9 @@ import type {
FormRefObject,
FormSchema,
} from '@/app/components/base/form/types'
import { RiCloseLine } from '@remixicon/react'
import {
memo,
useCallback,
useEffect,
useMemo,
useRef,
useState,
@@ -21,15 +19,23 @@ import {
import { useTranslation } from 'react-i18next'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general'
import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
import Loading from '@/app/components/base/loading'
import {
PortalToFollowElem,
PortalToFollowElemContent,
} from '@/app/components/base/portal-to-follow-elem'
AlertDialog,
AlertDialogActions,
AlertDialogCancelButton,
AlertDialogConfirmButton,
AlertDialogContent,
AlertDialogTitle,
} from '@/app/components/base/ui/alert-dialog'
import {
Dialog,
DialogCloseButton,
DialogContent,
} from '@/app/components/base/ui/dialog'
import {
useAuth,
useCredentialData,
@@ -197,7 +203,7 @@ const ModelModal: FC<ModelModalProps> = ({
}
return (
<div className="title-2xl-semi-bold text-text-primary">
<div className="text-text-primary title-2xl-semi-bold">
{label}
</div>
)
@@ -206,7 +212,7 @@ const ModelModal: FC<ModelModalProps> = ({
const modalDesc = useMemo(() => {
if (providerFormSchemaPredefined) {
return (
<div className="system-xs-regular mt-1 text-text-tertiary">
<div className="mt-1 text-text-tertiary system-xs-regular">
{t('modelProvider.auth.apiKeyModal.desc', { ns: 'common' })}
</div>
)
@@ -223,7 +229,7 @@ const ModelModal: FC<ModelModalProps> = ({
className="mr-2 h-4 w-4 shrink-0"
provider={provider}
/>
<div className="system-md-regular mr-1 text-text-secondary">{renderI18nObject(provider.label)}</div>
<div className="mr-1 text-text-secondary system-md-regular">{renderI18nObject(provider.label)}</div>
</div>
)
}
@@ -235,7 +241,7 @@ const ModelModal: FC<ModelModalProps> = ({
provider={provider}
modelName={model.model}
/>
<div className="system-md-regular mr-1 text-text-secondary">{model.model}</div>
<div className="mr-1 text-text-secondary system-md-regular">{model.model}</div>
<Badge>{model.model_type}</Badge>
</div>
)
@@ -275,174 +281,171 @@ const ModelModal: FC<ModelModalProps> = ({
}, [])
const notAllowCustomCredential = provider.allow_custom_token === false
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
event.stopPropagation()
onCancel()
}
}
document.addEventListener('keydown', handleKeyDown, true)
return () => {
document.removeEventListener('keydown', handleKeyDown, true)
}
const handleOpenChange = useCallback((open: boolean) => {
if (!open)
onCancel()
}, [onCancel])
const handleConfirmOpenChange = useCallback((open: boolean) => {
if (!open)
closeConfirmDelete()
}, [closeConfirmDelete])
return (
<PortalToFollowElem open>
<PortalToFollowElemContent className="z-[60] h-full w-full">
<div className="fixed inset-0 flex items-center justify-center bg-black/[.25]">
<div className="relative w-[640px] rounded-2xl bg-components-panel-bg shadow-xl">
<div
className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center"
onClick={onCancel}
>
<RiCloseLine className="h-4 w-4 text-text-tertiary" />
</div>
<div className="p-6 pb-3">
{modalTitle}
{modalDesc}
{modalModel}
</div>
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
{
mode === ModelModalModeEnum.configCustomModel && (
<AuthForm
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
}
}) as FormSchema[]}
defaultValues={modelNameAndTypeFormValues}
inputClassName="justify-start"
ref={formRef1}
onChange={handleModelNameAndTypeChange}
/>
)
}
{
mode === ModelModalModeEnum.addCustomModelToModelList && (
<CredentialSelector
credentials={available_credentials || []}
onSelect={setSelectedCredential}
selectedCredential={selectedCredential}
disabled={isLoading}
notAllowAddNewCredential={notAllowCustomCredential}
/>
)
}
{
showCredentialLabel && (
<div className="system-xs-medium-uppercase mb-3 mt-6 flex items-center text-text-tertiary">
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
</div>
)
}
{
isLoading && (
<div className="mt-3 flex items-center justify-center">
<Loading />
</div>
)
}
{
!isLoading
&& showCredentialForm
&& (
<AuthForm
formSchemas={formSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
showRadioUI: formSchema.type === FormTypeEnum.radio,
}
}) as FormSchema[]}
defaultValues={formValues}
inputClassName="justify-start"
ref={formRef2}
/>
)
}
</div>
<div className="flex justify-between p-6 pt-5">
{
(provider.help && (provider.help.title || provider.help.url))
? (
<a
href={provider.help?.url[language] || provider.help?.url.en_US}
target="_blank"
rel="noopener noreferrer"
className="system-xs-regular mt-2 inline-block align-middle text-text-accent"
onClick={e => !provider.help.url && e.preventDefault()}
>
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
</a>
)
: <div />
}
<div className="ml-2 flex items-center justify-end space-x-2">
{
isEditMode && (
<Button
variant="warning"
onClick={() => openConfirmDelete(credential, model)}
>
{t('operation.remove', { ns: 'common' })}
</Button>
)
}
<Button
onClick={onCancel}
>
{t('operation.cancel', { ns: 'common' })}
</Button>
<Button
variant="primary"
onClick={handleSave}
disabled={isLoading || doingAction}
>
{saveButtonText}
</Button>
</div>
</div>
{
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
<div className="border-t-[0.5px] border-t-divider-regular">
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
{t('modelProvider.encrypted.front', { ns: 'common' })}
<a
className="mx-1 text-text-accent"
target="_blank"
rel="noopener noreferrer"
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
>
PKCS1_OAEP
</a>
{t('modelProvider.encrypted.back', { ns: 'common' })}
</div>
</div>
)
}
</div>
<Dialog open onOpenChange={handleOpenChange}>
<DialogContent
backdropProps={{ forceRender: true }}
className="w-[640px] max-w-[640px] overflow-hidden p-0"
>
<DialogCloseButton className="right-5 top-5 h-8 w-8" />
<div className="p-6 pb-3">
{modalTitle}
{modalDesc}
{modalModel}
</div>
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
{
deleteCredentialId && (
<Confirm
isShow
title={t('modelProvider.confirmDelete', { ns: 'common' })}
isDisabled={doingAction}
onCancel={closeConfirmDelete}
onConfirm={handleDeleteCredential}
mode === ModelModalModeEnum.configCustomModel && (
<AuthForm
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
}
}) as FormSchema[]}
defaultValues={modelNameAndTypeFormValues}
inputClassName="justify-start"
ref={formRef1}
onChange={handleModelNameAndTypeChange}
/>
)
}
{
mode === ModelModalModeEnum.addCustomModelToModelList && (
<CredentialSelector
credentials={available_credentials || []}
onSelect={setSelectedCredential}
selectedCredential={selectedCredential}
disabled={isLoading}
notAllowAddNewCredential={notAllowCustomCredential}
/>
)
}
{
showCredentialLabel && (
<div className="mb-3 mt-6 flex items-center text-text-tertiary system-xs-medium-uppercase">
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
</div>
)
}
{
isLoading && (
<div className="mt-3 flex items-center justify-center">
<Loading />
</div>
)
}
{
!isLoading
&& showCredentialForm
&& (
<AuthForm
formSchemas={formSchemas.map((formSchema) => {
return {
...formSchema,
name: formSchema.variable,
showRadioUI: formSchema.type === FormTypeEnum.radio,
}
}) as FormSchema[]}
defaultValues={formValues}
inputClassName="justify-start"
ref={formRef2}
/>
)
}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<div className="flex justify-between p-6 pt-5">
{
(provider.help && (provider.help.title || provider.help.url))
? (
<a
href={provider.help?.url[language] || provider.help?.url.en_US}
target="_blank"
rel="noopener noreferrer"
className="mt-2 inline-block align-middle text-text-accent system-xs-regular"
onClick={e => !provider.help.url && e.preventDefault()}
>
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
</a>
)
: <div />
}
<div className="ml-2 flex items-center justify-end space-x-2">
{
isEditMode && (
<Button
variant="warning"
onClick={() => openConfirmDelete(credential, model)}
>
{t('operation.remove', { ns: 'common' })}
</Button>
)
}
<Button
onClick={onCancel}
>
{t('operation.cancel', { ns: 'common' })}
</Button>
<Button
variant="primary"
onClick={handleSave}
disabled={isLoading || doingAction}
>
{saveButtonText}
</Button>
</div>
</div>
{
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
<div className="border-t-[0.5px] border-t-divider-regular">
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
{t('modelProvider.encrypted.front', { ns: 'common' })}
<a
className="mx-1 text-text-accent"
target="_blank"
rel="noopener noreferrer"
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
>
PKCS1_OAEP
</a>
{t('modelProvider.encrypted.back', { ns: 'common' })}
</div>
</div>
)
}
</DialogContent>
<AlertDialog open={!!deleteCredentialId} onOpenChange={handleConfirmOpenChange}>
<AlertDialogContent backdropProps={{ forceRender: true }}>
<div className="flex flex-col gap-2 p-6 pb-4">
<AlertDialogTitle className="text-text-primary title-2xl-semi-bold">
{t('modelProvider.confirmDelete', { ns: 'common' })}
</AlertDialogTitle>
</div>
<AlertDialogActions>
<AlertDialogCancelButton>{t('operation.cancel', { ns: 'common' })}</AlertDialogCancelButton>
<AlertDialogConfirmButton
disabled={doingAction}
onClick={handleDeleteCredential}
>
{t('operation.confirm', { ns: 'common' })}
</AlertDialogConfirmButton>
</AlertDialogActions>
</AlertDialogContent>
</AlertDialog>
</Dialog>
)
}

View File

@@ -14,10 +14,10 @@ import { useTranslation } from 'react-i18next'
import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows'
import Loading from '@/app/components/base/loading'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
Popover,
PopoverContent,
PopoverTrigger,
} from '@/app/components/base/ui/popover'
import { PROVIDER_WITH_PRESET_TONE, STOP_PARAMETER_RULE, TONE_LIST } from '@/config'
import { useProviderContext } from '@/context/provider-context'
import { useModelParameterRules } from '@/service/use-common'
@@ -129,117 +129,118 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
}
return (
<PortalToFollowElem
<Popover
open={open}
onOpenChange={setOpen}
placement={isInWorkflow ? 'left' : 'bottom-end'}
offset={4}
onOpenChange={(newOpen) => {
if (readonly)
return
setOpen(newOpen)
}}
>
<div className="relative">
<PortalToFollowElemTrigger
onClick={() => {
if (readonly)
return
setOpen(v => !v)
}}
className="block"
>
{
renderTrigger
? renderTrigger({
open,
disabled,
modelDisabled,
hasDeprecated,
currentProvider,
currentModel,
providerName: provider,
modelId,
})
: (
<Trigger
disabled={disabled}
isInWorkflow={isInWorkflow}
modelDisabled={modelDisabled}
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={provider}
modelId={modelId}
/>
)
}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={cn('z-[60]', portalToFollowElemContentClassName)}>
<div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
<div className={cn('max-h-[420px] overflow-y-auto p-4 pt-3')}>
<div className="relative">
<div className={cn('system-sm-semibold mb-1 flex h-6 items-center text-text-secondary')}>
{t('modelProvider.model', { ns: 'common' }).toLocaleUpperCase()}
</div>
<ModelSelector
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
modelList={activeTextGenerationModelList}
onSelect={handleChangeModel}
/>
</div>
{
!!parameterRules.length && (
<div className="my-3 h-px bg-divider-subtle" />
)
}
{
isLoading && (
<div className="mt-5"><Loading /></div>
)
}
{
!isLoading && !!parameterRules.length && (
<div className="mb-2 flex items-center justify-between">
<div className={cn('system-sm-semibold flex h-6 items-center text-text-secondary')}>{t('modelProvider.parameters', { ns: 'common' })}</div>
{
PROVIDER_WITH_PRESET_TONE.includes(provider) && (
<PresetsParameter onSelect={handleSelectPresetParameter} />
)
}
</div>
)
}
{
!isLoading && !!parameterRules.length && (
[
...parameterRules,
...(isAdvancedMode ? [STOP_PARAMETER_RULE] : []),
].map(parameter => (
<ParameterItem
key={`${modelId}-${parameter.name}`}
parameterRule={parameter}
value={completionParams?.[parameter.name]}
onChange={v => handleParamChange(parameter.name, v)}
onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)}
<PopoverTrigger
render={(
<div className="block">
{
renderTrigger
? renderTrigger({
open,
disabled,
modelDisabled,
hasDeprecated,
currentProvider,
currentModel,
providerName: provider,
modelId,
})
: (
<Trigger
disabled={disabled}
isInWorkflow={isInWorkflow}
modelDisabled={modelDisabled}
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={provider}
modelId={modelId}
/>
))
)
}
</div>
{!hideDebugWithMultipleModel && (
<div
className="bg-components-section-burn system-sm-regular flex h-[50px] cursor-pointer items-center justify-between rounded-b-xl border-t border-t-divider-subtle px-4 text-text-accent"
onClick={() => onDebugWithMultipleModelChange?.()}
>
{
debugWithMultipleModel
? t('debugAsSingleModel', { ns: 'appDebug' })
: t('debugAsMultipleModel', { ns: 'appDebug' })
}
<ArrowNarrowLeft className="h-3 w-3 rotate-180" />
</div>
)}
)
}
</div>
</PortalToFollowElemContent>
</div>
</PortalToFollowElem>
)}
/>
<PopoverContent
placement={isInWorkflow ? 'left' : 'bottom-end'}
sideOffset={4}
className={portalToFollowElemContentClassName}
popupClassName={cn(popupClassName, 'w-[389px] rounded-2xl')}
>
<div className="max-h-[420px] overflow-y-auto p-4 pt-3">
<div className="relative">
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-semibold">
{t('modelProvider.model', { ns: 'common' }).toLocaleUpperCase()}
</div>
<ModelSelector
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
modelList={activeTextGenerationModelList}
onSelect={handleChangeModel}
onHide={() => setOpen(false)}
/>
</div>
{
!!parameterRules.length && (
<div className="my-3 h-px bg-divider-subtle" />
)
}
{
isLoading && (
<div className="mt-5"><Loading /></div>
)
}
{
!isLoading && !!parameterRules.length && (
<div className="mb-2 flex items-center justify-between">
<div className="flex h-6 items-center text-text-secondary system-sm-semibold">{t('modelProvider.parameters', { ns: 'common' })}</div>
{
PROVIDER_WITH_PRESET_TONE.includes(provider) && (
<PresetsParameter onSelect={handleSelectPresetParameter} />
)
}
</div>
)
}
{
!isLoading && !!parameterRules.length && (
[
...parameterRules,
...(isAdvancedMode ? [STOP_PARAMETER_RULE] : []),
].map(parameter => (
<ParameterItem
key={`${modelId}-${parameter.name}`}
parameterRule={parameter}
value={completionParams?.[parameter.name]}
onChange={v => handleParamChange(parameter.name, v)}
onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)}
isInWorkflow={isInWorkflow}
/>
))
)
}
</div>
{!hideDebugWithMultipleModel && (
<div
className="flex h-[50px] cursor-pointer items-center justify-between rounded-b-xl border-t border-t-divider-subtle px-4 text-text-accent system-sm-regular"
onClick={() => onDebugWithMultipleModelChange?.()}
>
{
debugWithMultipleModel
? t('debugAsSingleModel', { ns: 'appDebug' })
: t('debugAsMultipleModel', { ns: 'appDebug' })
}
<ArrowNarrowLeft className="h-3 w-3 rotate-180" />
</div>
)}
</PopoverContent>
</Popover>
)
}

View File

@@ -1,61 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import DeprecatedModelTrigger from './deprecated-model-trigger'
vi.mock('../model-icon', () => ({
default: ({ modelName }: { modelName: string }) => <span>{modelName}</span>,
}))
const mockUseProviderContext = vi.hoisted(() => vi.fn())
vi.mock('@/context/provider-context', () => ({
useProviderContext: mockUseProviderContext,
}))
describe('DeprecatedModelTrigger', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseProviderContext.mockReturnValue({
modelProviders: [{ provider: 'someone-else' }, { provider: 'openai' }],
})
})
it('should render model name', () => {
render(<DeprecatedModelTrigger modelName="gpt-deprecated" providerName="openai" />)
expect(screen.getAllByText('gpt-deprecated').length).toBeGreaterThan(0)
})
it('should show deprecated tooltip when warn icon is hovered', async () => {
const { container } = render(
<DeprecatedModelTrigger
modelName="gpt-deprecated"
providerName="openai"
showWarnIcon
/>,
)
const tooltipTrigger = container.querySelector('[data-state]') as HTMLElement
fireEvent.mouseEnter(tooltipTrigger)
expect(await screen.findByText('common.modelProvider.deprecated')).toBeInTheDocument()
})
it('should render when provider is not found', () => {
mockUseProviderContext.mockReturnValue({
modelProviders: [{ provider: 'someone-else' }],
})
render(<DeprecatedModelTrigger modelName="gpt-deprecated" providerName="openai" />)
expect(screen.getAllByText('gpt-deprecated').length).toBeGreaterThan(0)
})
it('should not show deprecated tooltip when warn icon is disabled', async () => {
render(
<DeprecatedModelTrigger
modelName="gpt-deprecated"
providerName="openai"
showWarnIcon={false}
/>,
)
expect(screen.queryByText('common.modelProvider.deprecated')).not.toBeInTheDocument()
})
})

View File

@@ -1,54 +0,0 @@
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import { useProviderContext } from '@/context/provider-context'
import { cn } from '@/utils/classnames'
import ModelIcon from '../model-icon'
type ModelTriggerProps = {
modelName: string
providerName: string
className?: string
showWarnIcon?: boolean
contentClassName?: string
}
const ModelTrigger: FC<ModelTriggerProps> = ({
modelName,
providerName,
className,
showWarnIcon,
contentClassName,
}) => {
const { t } = useTranslation()
const { modelProviders } = useProviderContext()
const currentProvider = modelProviders.find(provider => provider.provider === providerName)
return (
<div
className={cn('group box-content flex h-8 grow cursor-pointer items-center gap-1 rounded-lg bg-components-input-bg-disabled p-[3px] pl-1', className)}
>
<div className={cn('flex w-full items-center', contentClassName)}>
<div className="flex min-w-0 flex-1 items-center gap-1 py-[1px]">
<ModelIcon
className="h-4 w-4"
provider={currentProvider}
modelName={modelName}
/>
<div className="system-sm-regular truncate text-components-input-text-filled">
{modelName}
</div>
</div>
<div className="flex shrink-0 items-center justify-center">
{showWarnIcon && (
<Tooltip popupContent={t('modelProvider.deprecated', { ns: 'common' })}>
<AlertTriangle className="h-4 w-4 text-text-warning-secondary" />
</Tooltip>
)}
</div>
</div>
</div>
)
}
export default ModelTrigger

View File

@@ -1,31 +0,0 @@
import { render, screen } from '@testing-library/react'
import EmptyTrigger from './empty-trigger'
describe('EmptyTrigger', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should render configure model text', () => {
render(<EmptyTrigger open={false} />)
expect(screen.getByText('plugin.detailPanel.configureModel')).toBeInTheDocument()
})
// open=true: hover bg class present
it('should apply hover background class when open is true', () => {
// Act
const { container } = render(<EmptyTrigger open={true} />)
// Assert
expect(container.firstChild).toHaveClass('bg-components-input-bg-hover')
})
// className prop truthy: custom className appears on root
it('should apply custom className when provided', () => {
// Act
const { container } = render(<EmptyTrigger open={false} className="custom-class" />)
// Assert
expect(container.firstChild).toHaveClass('custom-class')
})
})

View File

@@ -1,42 +0,0 @@
import type { FC } from 'react'
import { RiEqualizer2Line } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
import { cn } from '@/utils/classnames'
type ModelTriggerProps = {
open: boolean
className?: string
}
const ModelTrigger: FC<ModelTriggerProps> = ({
open,
className,
}) => {
const { t } = useTranslation()
return (
<div
className={cn(
'flex cursor-pointer items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1 hover:bg-components-input-bg-hover',
open && 'bg-components-input-bg-hover',
className,
)}
>
<div className="flex grow items-center">
<div className="mr-1.5 flex h-4 w-4 items-center justify-center rounded-[5px] border border-dashed border-divider-regular">
<CubeOutline className="h-3 w-3 text-text-quaternary" />
</div>
<div
className="truncate text-[13px] text-text-tertiary"
title="Configure model"
>
{t('detailPanel.configureModel', { ns: 'plugin' })}
</div>
</div>
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
<RiEqualizer2Line className="h-3.5 w-3.5 text-text-tertiary" />
</div>
</div>
)
}
export default ModelTrigger

View File

@@ -7,15 +7,13 @@ import type {
} from '../declarations'
import { useState } from 'react'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
Popover,
PopoverContent,
PopoverTrigger,
} from '@/app/components/base/ui/popover'
import { cn } from '@/utils/classnames'
import { useCurrentProviderAndModel } from '../hooks'
import DeprecatedModelTrigger from './deprecated-model-trigger'
import EmptyTrigger from './empty-trigger'
import ModelTrigger from './model-trigger'
import ModelSelectorTrigger from './model-selector-trigger'
import Popup from './popup'
type ModelSelectorProps = {
@@ -24,6 +22,7 @@ type ModelSelectorProps = {
triggerClassName?: string
popupClassName?: string
onSelect?: (model: DefaultModel) => void
onHide?: () => void
readonly?: boolean
scopeFeatures?: ModelFeatureEnum[]
deprecatedClassName?: string
@@ -35,10 +34,11 @@ const ModelSelector: FC<ModelSelectorProps> = ({
triggerClassName,
popupClassName,
onSelect,
onHide,
readonly,
scopeFeatures = [],
deprecatedClassName,
showDeprecatedWarnIcon = false,
showDeprecatedWarnIcon = true,
}) => {
const [open, setOpen] = useState(false)
const {
@@ -56,67 +56,60 @@ const ModelSelector: FC<ModelSelectorProps> = ({
onSelect({ provider, model: model.model })
}
const handleToggle = () => {
if (readonly)
return
setOpen(v => !v)
}
return (
<PortalToFollowElem
<Popover
open={open}
onOpenChange={setOpen}
placement="bottom-start"
offset={4}
onOpenChange={(newOpen) => {
if (readonly)
return
setOpen(newOpen)
}}
>
<div className={cn('relative')}>
<PortalToFollowElemTrigger
onClick={handleToggle}
className="block"
>
{
currentModel && currentProvider && (
<ModelTrigger
open={open}
provider={currentProvider}
model={currentModel}
className={triggerClassName}
readonly={readonly}
/>
)
}
{
!currentModel && defaultModel && (
<DeprecatedModelTrigger
modelName={defaultModel?.model || ''}
providerName={defaultModel?.provider || ''}
className={triggerClassName}
showWarnIcon={showDeprecatedWarnIcon}
contentClassName={deprecatedClassName}
/>
)
}
{
!defaultModel && (
<EmptyTrigger
open={open}
className={triggerClassName}
/>
)
}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={`z-[1002] ${popupClassName}`}>
<Popup
defaultModel={defaultModel}
modelList={modelList}
onSelect={handleSelect}
scopeFeatures={scopeFeatures}
onHide={() => setOpen(false)}
/>
</PortalToFollowElemContent>
</div>
</PortalToFollowElem>
<PopoverTrigger
render={(
<button
type="button"
className="block w-full border-0 bg-transparent p-0 text-left"
disabled={readonly}
>
<ModelSelectorTrigger
currentProvider={currentProvider}
currentModel={currentModel}
defaultModel={defaultModel}
open={open}
readonly={readonly}
className={triggerClassName}
deprecatedClassName={deprecatedClassName}
showDeprecatedWarnIcon={showDeprecatedWarnIcon}
/>
</button>
)}
/>
{/*
* TODO(overlay-migration): temporary layering hack.
* Some callers still render ModelSelector inside legacy high-z modals
* (e.g. code/automatic generators at z-[1000]). Keep this selector above
* them until those call sites are fully migrated to unified base/ui overlays.
*/}
<PopoverContent
placement="bottom-start"
sideOffset={4}
className={cn('z-[1002]', popupClassName)}
popupClassName="overflow-hidden rounded-lg"
popupProps={{ style: { minWidth: '320px', width: 'var(--anchor-width, auto)' } }}
>
<Popup
defaultModel={defaultModel}
modelList={modelList}
onSelect={handleSelect}
scopeFeatures={scopeFeatures}
onHide={() => {
setOpen(false)
onHide?.()
}}
/>
</PopoverContent>
</Popover>
)
}

View File

@@ -0,0 +1,193 @@
import type { Model, ModelItem } from '../declarations'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import {
ConfigurationMethodEnum,
ModelFeatureEnum,
ModelStatusEnum,
ModelTypeEnum,
} from '../declarations'
import ModelSelectorTrigger from './model-selector-trigger'
const mockUseProviderContext = vi.hoisted(() => vi.fn())
vi.mock('@/context/provider-context', () => ({
useProviderContext: mockUseProviderContext,
}))
const createModelItem = (overrides: Partial<ModelItem> = {}): ModelItem => ({
model: 'gpt-4',
label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' },
model_type: ModelTypeEnum.textGeneration,
features: [ModelFeatureEnum.vision],
fetch_from: ConfigurationMethodEnum.predefinedModel,
status: ModelStatusEnum.active,
model_properties: { mode: 'chat', context_size: 4096 },
load_balancing_enabled: false,
...overrides,
})
const createModel = (overrides: Partial<Model> = {}): Model => ({
provider: 'openai',
icon_small: {
en_US: 'https://example.com/openai-light.png',
zh_Hans: 'https://example.com/openai-light.png',
},
icon_small_dark: {
en_US: 'https://example.com/openai-dark.png',
zh_Hans: 'https://example.com/openai-dark.png',
},
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
models: [createModelItem()],
status: ModelStatusEnum.active,
...overrides,
})
describe('ModelSelectorTrigger', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseProviderContext.mockReturnValue({
modelProviders: [createModel()],
})
})
describe('Rendering', () => {
it('should render empty state when no model is selected', () => {
const { container } = render(<ModelSelectorTrigger />)
expect(screen.getByText('plugin.detailPanel.configureModel')).toBeInTheDocument()
expect(container.querySelector('.i-ri-arrow-down-s-line')).toBeInTheDocument()
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-normal')
})
it('should render selected model details when model is active', () => {
const currentProvider = createModel()
const currentModel = createModelItem()
const { container } = render(
<ModelSelectorTrigger
currentProvider={currentProvider}
currentModel={currentModel}
/>,
)
expect(screen.getByText('GPT-4')).toBeInTheDocument()
expect(screen.getByText('CHAT')).toBeInTheDocument()
expect(container.querySelector('.i-ri-arrow-down-s-line')).toBeInTheDocument()
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-normal')
})
it('should render deprecated default model and disabled style when selection is missing', () => {
const { container } = render(
<ModelSelectorTrigger
defaultModel={{ provider: 'openai', model: 'legacy-model' }}
/>,
)
expect(screen.getByText('legacy-model')).toBeInTheDocument()
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-disabled')
expect(container.querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument()
})
})
describe('Props', () => {
it('should apply custom className to root element', () => {
const { container } = render(<ModelSelectorTrigger className="custom-trigger" />)
expect(container.firstElementChild).toHaveClass('custom-trigger')
})
it('should apply open background style when open is true and model is active', () => {
const { container } = render(
<ModelSelectorTrigger
currentProvider={createModel()}
currentModel={createModelItem()}
open
/>,
)
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-hover')
})
it('should hide the expand arrow when readonly is true', () => {
const { container } = render(
<ModelSelectorTrigger
currentProvider={createModel()}
currentModel={createModelItem()}
readonly
/>,
)
expect(container.querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument()
})
})
describe('Status Handling', () => {
it('should show status badge when selected model is not active and not readonly', () => {
render(
<ModelSelectorTrigger
currentProvider={createModel()}
currentModel={createModelItem({ status: ModelStatusEnum.noConfigure })}
/>,
)
expect(screen.getByText('common.modelProvider.selector.configureRequired')).toBeInTheDocument()
})
it('should not show status badge when selected model is readonly', () => {
render(
<ModelSelectorTrigger
currentProvider={createModel()}
currentModel={createModelItem({ status: ModelStatusEnum.noConfigure })}
readonly
/>,
)
expect(screen.queryByText('common.modelProvider.selector.configureRequired')).not.toBeInTheDocument()
})
it('should show incompatible tooltip when hovering no-permission status badge', async () => {
const user = userEvent.setup()
render(
<ModelSelectorTrigger
currentProvider={createModel()}
currentModel={createModelItem({ status: ModelStatusEnum.noPermission })}
/>,
)
await user.hover(screen.getByText('common.modelProvider.selector.incompatible'))
expect(await screen.findByText('common.modelProvider.selector.incompatibleTip')).toBeInTheDocument()
})
})
describe('Edge Cases', () => {
it('should show deprecated tooltip when hovering warn icon', async () => {
const user = userEvent.setup()
const { container } = render(
<ModelSelectorTrigger
defaultModel={{ provider: 'openai', model: 'legacy-model' }}
/>,
)
const warnIcon = container.querySelector('.i-ri-alert-line')
expect(warnIcon).toBeInTheDocument()
await user.hover(warnIcon as HTMLElement)
expect(await screen.findByText('common.modelProvider.deprecated')).toBeInTheDocument()
})
it('should render fallback icon when deprecated provider is not found', () => {
mockUseProviderContext.mockReturnValue({
modelProviders: [],
})
const { container } = render(
<ModelSelectorTrigger
defaultModel={{ provider: 'unknown-provider', model: 'legacy-model' }}
/>,
)
expect(container.querySelector('img[alt="model-icon"]')).not.toBeInTheDocument()
expect(container.querySelector('svg')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,146 @@
import type { FC } from 'react'
import type {
DefaultModel,
Model,
ModelItem,
} from '../declarations'
import { useTranslation } from 'react-i18next'
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
import { useProviderContext } from '@/context/provider-context'
import { cn } from '@/utils/classnames'
import { ModelStatusEnum } from '../declarations'
import ModelIcon from '../model-icon'
import ModelName from '../model-name'
const STATUS_I18N_KEY: Partial<Record<ModelStatusEnum, string>> = {
[ModelStatusEnum.quotaExceeded]: 'modelProvider.selector.creditsExhausted',
[ModelStatusEnum.noConfigure]: 'modelProvider.selector.configureRequired',
[ModelStatusEnum.noPermission]: 'modelProvider.selector.incompatible',
[ModelStatusEnum.disabled]: 'modelProvider.selector.disabled',
[ModelStatusEnum.credentialRemoved]: 'modelProvider.selector.apiKeyUnavailable',
}
type ModelSelectorTriggerProps = {
currentProvider?: Model
currentModel?: ModelItem
defaultModel?: DefaultModel
open?: boolean
readonly?: boolean
className?: string
deprecatedClassName?: string
showDeprecatedWarnIcon?: boolean
}
const ModelSelectorTrigger: FC<ModelSelectorTriggerProps> = ({
currentProvider,
currentModel,
defaultModel,
open,
readonly,
className,
deprecatedClassName,
showDeprecatedWarnIcon = true,
}) => {
const { t } = useTranslation()
const { modelProviders } = useProviderContext()
const isSelected = !!currentProvider && !!currentModel
const isDeprecated = !isSelected && !!defaultModel
const isEmpty = !isSelected && !defaultModel
const isActive = isSelected && currentModel.status === ModelStatusEnum.active
const isDisabled = isDeprecated || (isSelected && !isActive)
const statusI18nKey = isSelected ? STATUS_I18N_KEY[currentModel.status] : undefined
const deprecatedProvider = isDeprecated
? modelProviders.find(p => p.provider === defaultModel.provider)
: undefined
return (
<div
className={cn(
'group flex h-8 items-center gap-0.5 rounded-lg p-1',
isDisabled
? 'bg-components-input-bg-disabled'
: 'bg-components-input-bg-normal',
!readonly && !isDisabled && 'cursor-pointer hover:bg-components-input-bg-hover',
open && !isDisabled && 'bg-components-input-bg-hover',
className,
)}
>
{isEmpty
? (
<div className="flex h-6 w-6 items-center justify-center">
<div className="flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle">
<span className="i-ri-brain-2-line h-3.5 w-3.5 text-text-quaternary" />
</div>
</div>
)
: (
<ModelIcon
className="p-0.5"
provider={isSelected ? currentProvider : deprecatedProvider}
modelName={isSelected ? currentModel.model : defaultModel?.model}
/>
)}
<div className={cn('flex grow items-center gap-1 truncate px-1 py-[3px]', isDeprecated && deprecatedClassName)}>
{isSelected && (
<ModelName
className="grow"
modelItem={currentModel}
showMode
showFeatures
/>
)}
{isDeprecated && (
<div className="grow truncate text-components-input-text-filled system-sm-regular">
{defaultModel.model}
</div>
)}
{isEmpty && (
<div className="grow truncate text-[13px] text-text-quaternary">
{t('detailPanel.configureModel', { ns: 'plugin' })}
</div>
)}
{isSelected && !readonly && !isActive && statusI18nKey && (
<Tooltip>
<TooltipTrigger
disabled={currentModel.status !== ModelStatusEnum.noPermission}
render={(
<div className="flex shrink-0 items-center gap-[3px] rounded-md border border-text-warning px-[5px] py-0.5">
<span className="i-ri-alert-fill h-3 w-3 text-text-warning" />
<span className="whitespace-nowrap text-text-warning system-xs-medium">
{t(statusI18nKey as 'modelProvider.selector.creditsExhausted', { ns: 'common' })}
</span>
</div>
)}
/>
<TooltipContent placement="top">
{t('modelProvider.selector.incompatibleTip', { ns: 'common' })}
</TooltipContent>
</Tooltip>
)}
{isDeprecated && showDeprecatedWarnIcon && (
<Tooltip>
<TooltipTrigger render={(
<span className="i-ri-alert-line h-4 w-4 shrink-0 text-text-warning-secondary" />
)}
/>
<TooltipContent placement="top">
{t('modelProvider.deprecated', { ns: 'common' })}
</TooltipContent>
</Tooltip>
)}
{!readonly && (isActive || isEmpty) && (
<span className="i-ri-arrow-down-s-line h-3.5 w-3.5 shrink-0 text-text-tertiary" />
)}
</div>
</div>
)
}
export default ModelSelectorTrigger

View File

@@ -1,91 +0,0 @@
import type { Model, ModelItem } from '../declarations'
import { fireEvent, render, screen } from '@testing-library/react'
import {
ConfigurationMethodEnum,
ModelStatusEnum,
ModelTypeEnum,
} from '../declarations'
import ModelTrigger from './model-trigger'
vi.mock('../hooks', async () => {
const actual = await vi.importActual<typeof import('../hooks')>('../hooks')
return {
...actual,
useLanguage: () => 'en_US',
}
})
vi.mock('../model-icon', () => ({
default: ({ modelName }: { modelName: string }) => <span>{modelName}</span>,
}))
vi.mock('../model-name', () => ({
default: ({ modelItem }: { modelItem: ModelItem }) => <span>{modelItem.label.en_US}</span>,
}))
const makeModelItem = (overrides: Partial<ModelItem> = {}): ModelItem => ({
model: 'gpt-4',
label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' },
model_type: ModelTypeEnum.textGeneration,
fetch_from: ConfigurationMethodEnum.predefinedModel,
status: ModelStatusEnum.active,
model_properties: {},
load_balancing_enabled: false,
...overrides,
})
const makeModel = (overrides: Partial<Model> = {}): Model => ({
provider: 'openai',
icon_small: { en_US: '', zh_Hans: '' },
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
models: [makeModelItem()],
status: ModelStatusEnum.active,
...overrides,
})
describe('ModelTrigger', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should show model name', () => {
render(
<ModelTrigger
open
provider={makeModel()}
model={makeModelItem()}
/>,
)
expect(screen.getByText('GPT-4')).toBeInTheDocument()
})
it('should show status tooltip content when model is not active', async () => {
const { container } = render(
<ModelTrigger
open={false}
provider={makeModel()}
model={makeModelItem({ status: ModelStatusEnum.noConfigure })}
/>,
)
const tooltipTrigger = container.querySelector('[data-state]') as HTMLElement
fireEvent.mouseEnter(tooltipTrigger)
expect(await screen.findByText('No Configure')).toBeInTheDocument()
})
it('should not show status icon when readonly', () => {
render(
<ModelTrigger
open={false}
provider={makeModel()}
model={makeModelItem({ status: ModelStatusEnum.noConfigure })}
readonly
/>,
)
expect(screen.getByText('GPT-4')).toBeInTheDocument()
expect(screen.queryByText('No Configure')).not.toBeInTheDocument()
})
})

View File

@@ -1,78 +0,0 @@
import type { FC } from 'react'
import type {
Model,
ModelItem,
} from '../declarations'
import { RiArrowDownSLine } from '@remixicon/react'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import { cn } from '@/utils/classnames'
import {
MODEL_STATUS_TEXT,
ModelStatusEnum,
} from '../declarations'
import { useLanguage } from '../hooks'
import ModelIcon from '../model-icon'
import ModelName from '../model-name'
type ModelTriggerProps = {
open: boolean
provider: Model
model: ModelItem
className?: string
readonly?: boolean
}
const ModelTrigger: FC<ModelTriggerProps> = ({
open,
provider,
model,
className,
readonly,
}) => {
const language = useLanguage()
return (
<div
className={cn(
'group flex h-8 items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1',
!readonly && 'cursor-pointer hover:bg-components-input-bg-hover',
open && 'bg-components-input-bg-hover',
model.status !== ModelStatusEnum.active && 'bg-components-input-bg-disabled hover:bg-components-input-bg-disabled',
className,
)}
>
<ModelIcon
className="p-0.5"
provider={provider}
modelName={model.model}
/>
<div className="flex grow items-center gap-1 truncate px-1 py-[3px]">
<ModelName
className="grow"
modelItem={model}
showMode
showFeatures
/>
{!readonly && (
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
{
model.status !== ModelStatusEnum.active
? (
<Tooltip popupContent={MODEL_STATUS_TEXT[model.status][language]}>
<AlertTriangle className="h-4 w-4 text-text-warning-secondary" />
</Tooltip>
)
: (
<RiArrowDownSLine
className="h-3.5 w-3.5 text-text-tertiary"
/>
)
}
</div>
)}
</div>
</div>
)
}
export default ModelTrigger

Some files were not shown because too many files have changed in this diff Show More