Compare commits

...

117 Commits

Author SHA1 Message Date
yyh
b3c98e417d Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 23:47:27 +08:00
Stephen Zhou
a59c54b3e7 ci: update actions version, reuse workflow by composite action (#33177) 2026-03-09 23:44:17 +08:00
yyh
dfe389c017 Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 23:42:04 +08:00
yyh
b364b06e51 refactor(model-selector): migrate overlays to Popover/Tooltip and unify trigger component
- Migrate PortalToFollowElem to base-ui Popover in model-selector,
  model-parameter-modal, and plugin-detail-panel model-selector
- Migrate legacy Tooltip to compound Tooltip in popup-item and trigger
- Unify EmptyTrigger, ModelTrigger, DeprecatedModelTrigger into a
  single declarative ModelSelectorTrigger that derives state from props
- Remove showDeprecatedWarnIcon boolean prop anti-pattern; deprecated
  state always renders warn icon as part of component's visual contract
- Remove deprecatedClassName prop; component manages disabled styling
- Replace manual triggerRef width measurement with CSS var(--anchor-width)
- Remove tooltip scroll listener (base-ui auto-tracks anchor position)
- Restore conditional placement for workflow mode in plugin-detail-panel
- Prune stale ESLint suppressions for removed deprecated imports
2026-03-09 23:34:42 +08:00
dependabot[bot]
7737bdc699 chore(deps): bump the npm-dependencies group across 1 directory with 55 updates (#33170)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 23:24:48 +08:00
dependabot[bot]
65637fc6b7 chore(deps): bump the npm-dependencies group across 1 directory with 55 updates (#33170)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 23:24:36 +08:00
dependabot[bot]
be6f7b8712 chore(deps-dev): bump the eslint-group group in /web with 5 updates (#33168)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-03-09 19:59:44 +08:00
dependabot[bot]
b257e8ed44 chore(deps-dev): bump the storybook group in /web with 7 updates (#33163) 2026-03-09 19:36:00 +08:00
Stephen Zhou
176d3c8c3a ci: ignore ky and tailwind-merge in update (#33167) 2026-03-09 19:21:18 +08:00
CodingOnStar
ce0197b107 fix(provider): handle undefined provider in credential status and panel state 2026-03-09 18:20:02 +08:00
yyh
164cefc65c Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 17:41:13 +08:00
Stephen Zhou
c72ac8a434 ci: ignore some major update (#33161) 2026-03-09 17:24:56 +08:00
yyh
f6d80b9fa7 fix(workflow): derive plugin install state in render
Remove useEffect-based sync of _pluginInstallLocked/_dimmed in workflow nodes to avoid render-update loops.\n\nMove plugin-missing checks to pure utilities and use them in checklist.\nOptimize node installation hooks by enabling only relevant queries and narrowing memo dependencies.
2026-03-09 17:18:09 +08:00
rajatagarwal-oss
497feac48e test: unit test case for controllers.console.workspace module (#32181)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-09 17:07:40 +08:00
yyh
e845fa7e6a fix(plugin-install): support bundle marketplace dependency shape 2026-03-09 17:07:27 +08:00
rajatagarwal-oss
8906ab8e52 test: unit test cases for console.datasets module (#32179)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
2026-03-09 17:07:13 +08:00
yyh
bab7bd5ecc Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-09 17:03:54 +08:00
yyh
cfb02bceaf feat(workflow): open install bundle from checklist and strict marketplace parsing 2026-03-09 17:03:43 +08:00
yyh
694ca840e1 feat(web): add warning dot indicator on LLM panel field labels synced with checklist
Store checklist items in zustand WorkflowStore so both the checklist UI
and node panels share a single source of truth. The LLM panel reads from
the store to show a Figma-aligned warning dot (absolute-positioned, no
layout shift) on the MODEL field label when the node has checklist warnings.
2026-03-09 16:38:31 +08:00
非法操作
03dcbeafdf fix: stop responding icon not display (#33154)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-03-09 16:27:45 +08:00
yyh
2d979e2cec fix(web): silence toast for model parameter rules fetch on missing provider
Add silent option to useModelParameterRules API call so uninstalled
provider errors are swallowed instead of surfacing a raw backend toast.
2026-03-09 16:17:09 +08:00
yyh
5cee7cf8ce feat(web): add LLM model plugin check to workflow checklist
Detect uninstalled model plugins for LLM nodes in the checklist and
publish-gate. Migrate ChecklistItem.errorMessage to errorMessages[]
so a single node can surface multiple validation issues at once.

- Extract shared extractPluginId utility for checklist and prompt editor
- Build installed-plugin Set (O(1) lookup) from ProviderContext
- Remove short-circuit between checkValid and variable validation
- Sync the same check into handleCheckBeforePublish
- Adapt node-group, use-last-run, and test assertions
2026-03-09 16:16:16 +08:00
wangxiaolei
bbfa28e8a7 refactor: file saver decouple db engine and ssrf proxy (#33076)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 16:09:44 +08:00
Dev Sharma
6c19e75969 test: improve unit tests for controllers.web (#32150)
Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
2026-03-09 15:58:34 +08:00
wangxiaolei
9970f4449a refactor: reuse redis connection instead of create new one (#32678)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 15:53:21 +08:00
yyh
0c17823c8b fix 2026-03-09 15:38:46 +08:00
yyh
49c6696d08 fix: use css icons 2026-03-09 15:27:08 +08:00
yyh
292c98a8f3 refactor(web): redesign workflow checklist panel with grouped tree view and Popover primitive
Migrate checklist from flat card list using deprecated PortalToFollowElem to
grouped tree view using base-ui Popover. Split into checklist/ directory with
separate components: plugin group with batch install, per-node groups with
sub-items and "Go to fix" hover action, and tree-line SVG indicators.
2026-03-09 15:23:34 +08:00
CodingOnStar
0e0a6ad043 test(web): enhance unit tests for credential and popup components
- Updated tests for CredentialItem to improve delete button interaction and check icon rendering.
- Enhanced PopupItem tests by mocking credential panel state for various scenarios, ensuring accurate rendering based on credit status.
- Adjusted Popup tests to include trial credits mock for better coverage of credit management logic.
- Refactored model list item tests to include wrapper for consistent rendering context.
2026-03-09 14:20:12 +08:00
yyh
456c95adb1 refactor(web): trigger error tooltip on entire variable badge hover 2026-03-09 14:03:52 +08:00
yyh
1abbaf9fd5 feat(web): differentiate invalid variable tooltips by model plugin status
Replace the generic "Invalid variable" message in prompt editor variable
labels with two context-aware messages: one for missing nodes and another
for uninstalled model plugins. Add useLlmModelPluginInstalled hook that
checks LLM node model providers against installed providers via
useProviderContextSelector. Migrate Tooltip usage to base-ui primitives
and replace RiErrorWarningFill with Warning icon in warning color.
2026-03-09 14:02:26 +08:00
CodingOnStar
1a26e1669b refactor(web): streamline PopupItem component for credit management
- Removed unused context and variables related to workspace and custom configuration.
- Simplified credit usage logic by leveraging state management for better clarity and performance.
- Enhanced readability by restructuring the code for determining credit status and API key activity.
2026-03-09 13:10:29 +08:00
CodingOnStar
02444af2e3 feat(web): enhance Popup and CreditsFallbackAlert components for better credit management
- Integrated trial credits check in the Popup component to conditionally display the CreditsExhaustedAlert.
- Updated the CreditsFallbackAlert to show a message only when API keys are unavailable.
- Removed the fallback description from translation files as it is no longer used.
2026-03-09 12:57:41 +08:00
CodingOnStar
56038e3684 feat(web): update credits fallback alert to include new description for no API keys
- Modified the CreditsFallbackAlert component to display a different message based on the presence of API keys.
- Added a new translation key for the fallback description in both English and Chinese JSON files.
2026-03-09 12:34:41 +08:00
CodingOnStar
eb9341e7ec feat(web): integrate CreditsCoin icon in PopupItem for enhanced UI
- Replaced the existing credits coin span with the CreditsCoin component for improved visual consistency.
- Updated imports to include the new CreditsCoin icon component.
2026-03-09 12:28:13 +08:00
CodingOnStar
e40b31b9c4 refactor(web): enhance model selector functionality and improve UI consistency
- Removed unnecessary ESLint suppressions for better code quality.
- Updated the ModelParameterModal and ModelSelector components to ensure consistent class ordering.
- Added onHide prop to ModelSelector for better control over dropdown visibility.
- Introduced useChangeProviderPriority hook to manage provider priority changes more effectively.
- Integrated CreditsExhaustedAlert in the Popup component to handle API key status more gracefully.
2026-03-09 12:24:54 +08:00
yyh
b89ee4807f Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/index.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/model-modal/index.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/model-selector/popup-item.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/model-selector/popup.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/quota-panel.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx
2026-03-09 12:12:27 +08:00
yyh
9907cf9e06 Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-08 22:27:42 +08:00
yyh
208a31719f Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-08 01:10:51 +08:00
yyh
3d1ef1f7f5 Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-06 21:45:37 +08:00
CodingOnStar
24b14e2c1a Merge remote-tracking branch 'origin/main' into feat/model-plugins-implementing 2026-03-06 19:00:17 +08:00
CodingOnStar
53f122f717 Merge branch 'feat/model-provider-refactor' into feat/model-plugins-implementing 2026-03-06 17:33:38 +08:00
CodingOnStar
fced2f9e65 refactor: enhance plugin management UI with error handling, improved rendering, and new components 2026-03-06 16:27:26 +08:00
yyh
0c08c4016d Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-06 14:57:48 +08:00
CodingOnStar
ff4e4a8d64 refactor: enhance model trigger component with internationalization support and improved tooltip handling 2026-03-06 14:50:23 +08:00
yyh
948efa129f Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-06 14:47:56 +08:00
CodingOnStar
e371bfd676 refactor: enhance model provider management with new icons, improved UI elements, and marketplace integration 2026-03-06 14:18:29 +08:00
yyh
6d612c0909 test: improve Jotai atom test quality and add model-provider atoms tests
Replace dynamic imports with static imports in marketplace atom tests.
Convert type-only and not-toThrow assertions into proper state-change
verifications. Add comprehensive test suite for model-provider-page
atoms covering all four hooks, cross-hook interaction, selectAtom
granularity, and Provider isolation.
2026-03-05 22:49:09 +08:00
yyh
56e0dc0ae6 trigger ci
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
2026-03-05 21:22:03 +08:00
yyh
975eca00c3 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 20:25:53 +08:00
yyh
f049bafcc3 refactor: simplify Jotai atoms by removing redundant write-only atoms
Replace 2 write-only derived atoms with primitive atom's built-in
updater functions. The selectAtom on the read side already prevents
unnecessary re-renders, making the manual guard logic redundant.
2026-03-05 20:25:29 +08:00
CodingOnStar
dd9c526447 refactor: update model-selector popup-item to support collapsible items and improve icon color handling 2026-03-05 16:45:37 +08:00
yyh
922dc71e36 fix 2026-03-05 16:17:38 +08:00
yyh
f03ec7f671 Merge branch 'main' into feat/model-provider-refactor 2026-03-05 16:14:36 +08:00
yyh
29f275442d Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx
#	web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx
2026-03-05 16:13:40 +08:00
yyh
c9532ffd43 add stories 2026-03-05 15:55:21 +08:00
yyh
840dc33b8b Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 15:12:32 +08:00
yyh
cae58a0649 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 15:08:13 +08:00
yyh
1752edc047 refactor(web): optimize model provider re-render and remove useEffect state sync
- Replace useEffect state sync with derived state pattern in useSystemDefaultModelAndModelList
- Use useCallback instead of useMemo for function memoization in useProviderCredentialsAndLoadBalancing
- Add memo() to ProviderAddedCard and CredentialPanel to prevent unnecessary re-renders
- Switch to useProviderContextSelector for precise context subscription in ProviderAddedCard
- Stabilize activate callback ref in useActivateCredential via supportedModelTypes ref
- Add usage priority tooltip with i18n support
2026-03-05 15:07:53 +08:00
yyh
7471c32612 Revert "temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing"
This reverts commit ab87ac333a.
2026-03-05 14:33:48 +08:00
yyh
2d333bbbe5 refactor(web): extract credential activation into hook and migrate credential-item overlays
Extract credential switching logic from dropdown-content into a dedicated
useActivateCredential hook with optimistic updates and proper data flow
separation. Credential items now stay visible in the popover after clicking
(no auto-close), show cursor-pointer, and disable during activation.

Migrate credential-item from legacy Tooltip and remixicon imports to
base-ui Tooltip and CSS icon classes, pruning stale ESLint suppressions.
2026-03-05 14:22:39 +08:00
yyh
4af6788ce0 fix(web): wrap Header test in Dialog context for base-ui compatibility 2026-03-05 14:20:35 +08:00
yyh
24b072def9 fix: lint 2026-03-05 14:08:20 +08:00
yyh
909c8c3350 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-05 13:58:51 +08:00
yyh
80e9c8bee0 refactor(web): make account setting fully controlled with action props 2026-03-05 13:39:36 +08:00
yyh
15b7b304d2 refactor(web): migrate model-modal overlays to base-ui Dialog and AlertDialog
Replace legacy PortalToFollowElem and Confirm with Dialog/AlertDialog
primitives. Remove manual ESC handler and backdrop div — now handled
natively by base-ui. Add backdropProps={{ forceRender: true }} for
correct nested overlay rendering.
2026-03-05 13:33:53 +08:00
yyh
61e2672b59 refactor(web): make provider reset event-driven and scope model invalidation
- remove provider-page lifecycle reset effect and handle reset in explicit tab/close actions
- switch account setting tab state to controlled/uncontrolled pattern without sync effect
- use provider-scoped model list queryKey with exact invalidation in credential and model toggle mutations
- update related tests and mocks for new behavior
2026-03-05 13:28:30 +08:00
yyh
5f4ed4c6f6 refactor(web): replace model provider emitter refresh with jotai state
- add atom-based provider expansion state with reset/prune helpers
- remove event-emitter dependency from model provider refresh flow
- invalidate exact provider model-list query key on refresh
- reset expansion state on model provider page mount/unmount
- update and extend tests for external expansion and query invalidation
- update eslint suppressions to match current code
2026-03-05 13:20:58 +08:00
yyh
4a1032c628 fix(web): remove redundant hover text swap on show models button
Merge the two hover-toggling divs into a single always-visible element
and remove the unused showModelsNum i18n key from all locales.
2026-03-05 13:16:04 +08:00
yyh
423c97a47e code style 2026-03-05 13:09:33 +08:00
yyh
a7e3fb2e33 fix(web): use triangle Warning icon instead of circle error icon
Replace i-ri-error-warning-fill (circle exclamation) with the
Warning component (triangle) for api-fallback and credits-fallback
variants to match Figma design.
2026-03-05 13:07:20 +08:00
yyh
ce34937a1c feat(web): add credits-fallback variant for API Key priority with available credits
When API Key is selected but unavailable/unconfigured and credits are
available, the card now shows "AI credits in use" with a warning icon
instead of "API key required". When both credits are exhausted and no
API key exists, it shows "No available usage" (destructive).

New deriveVariant logic for priority=apiKey:
- !exhausted + !authorized → credits-fallback (was api-required-*)
- exhausted + no credential → no-usage (was api-required-add)
- exhausted + named unauthorized → api-unavailable (unchanged)
2026-03-05 13:02:40 +08:00
yyh
ad9ac6978e fix(web): align alert card width with API key section in dropdown
Change mx-1 (4px) to mx-2 (8px) on CreditsFallbackAlert and
CreditsExhaustedAlert to match ApiKeySection's p-2 (8px) padding,
consistent with Figma design where both sections are 8px from the
dropdown edge.
2026-03-05 12:56:55 +08:00
yyh
57c1ba3543 fix(web): hide divider above empty API keys state in dropdown
Move the border from UsagePrioritySection (always visible) to
ApiKeySection's list variant (only when credentials exist). This
removes the unwanted divider line above the "No API Keys" empty
state card when on the AI Credits tab with no keys configured.
2026-03-05 12:25:11 +08:00
yyh
d7a5af2b9a Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor
# Conflicts:
#	web/app/components/header/account-setting/model-provider-page/index.tsx
2026-03-05 10:46:24 +08:00
yyh
d45edffaa3 fix(web): wire upgrade link to pricing modal and add credits-coin icon
Replace broken HTML string interpolation with Trans component and
useModalContextSelector so "upgrade your plan" opens the pricing modal.
Add custom credits-coin SVG icon to replace the generic ri-coin-line.
2026-03-05 10:39:31 +08:00
yyh
530515b6ef fix(web): prevent model list from expanding on priority switch
Remove UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST event emission from
changePriority onSuccess. This event was designed for custom model
add/edit/delete scenarios where the card should expand, but firing
it on priority switch caused ProviderAddedCard to unexpectedly
expand via refreshModelList → setCollapsed(false).
2026-03-05 10:35:03 +08:00
yyh
f13f0d1f9a fix(web): align dropdown alerts with Figma design and fix hardcoded credits total
- Expose totalCredits from useTrialCredits hook instead of hardcoding 10,000
- Align CreditsExhaustedAlert with Figma: dynamic progress bar, correct
  design tokens (components-progress-error-bg/progress), sm-medium/xs-regular
  typography
- Align CreditsFallbackAlert typography to sm-medium/xs-regular
- Fix ApiKeySection empty state: horizontal gradient, sm-medium title,
  Figma-aligned padding (pl-7 for API KEYS label)
- Hoist empty credentials array constant to stabilize memo (rerender-memo-with-default-value)
- Remove redundant useCallback wrapper in ApiKeySection
- Replace nested ternary with Record lookup in TextLabel
- Remove dead || 0 guard in useTrialCredits
- Update all test mocks with totalCredits field
2026-03-05 10:09:51 +08:00
yyh
b597d52c11 refactor(web): remove dialog description from system model selector
Remove the DialogDescription and its i18n key (modelProvider.systemModelSettingsLink)
from the system model settings dialog across all 23 locales.
2026-03-05 10:05:01 +08:00
yyh
34c42fe666 Revert "temp: remove cloud condition"
This reverts commit 29e344ac8b.
2026-03-05 09:44:19 +08:00
yyh
dc109c99f0 test(web): expand credential panel and dropdown test coverage for all 8 card variants
Add comprehensive behavioral tests covering all discriminated union variants,
destructive/default styling, warning icons, CreditsFallbackAlert conditions,
credential CRUD interactions, AlertDialog delete confirmation, and Popover behavior.
2026-03-05 09:41:48 +08:00
yyh
223b9d89c1 refactor(web): migrate priority change to oRPC contract with useMutation
- Add changePreferredProviderType contract in model-providers.ts
- Register in consoleRouterContract
- Replace raw async changeModelProviderPriority with useMutation
- Use Toast.notify (static API) instead of useToastContext hook
- Pass isPending as isChangingPriority to disable buttons during switch
- Add disabled prop to UsagePrioritySection
- Fix pre-existing test assertions for api-unavailable variant
- Update all specs with isChangingPriority prop and oRPC mock pattern
2026-03-05 09:30:38 +08:00
yyh
dd119eb44f fix(web): align UsagePrioritySection with Figma design and fix i18n key ordering
- Single-row layout for icon, label, and option cards
- Icon: arrow-up-double-line matching design spec
- Buttons: flexible width with whitespace-nowrap instead of fixed w-[72px]
- Add min-w-0 + truncate for text overflow, focus-visible ring for a11y
- Sort modelProvider.card.* i18n keys alphabetically
2026-03-05 09:15:16 +08:00
yyh
970493fa85 test(web): update tests for credential panel refactoring and new ModelAuthDropdown components
Rewrite credential-panel.spec.tsx to match the new discriminated union
state model and variant-driven rendering. Add new test files for
useCredentialPanelState hook, SystemQuotaCard Label enhancement,
and all ModelAuthDropdown sub-components.
2026-03-05 08:41:17 +08:00
yyh
ab87ac333a temp: remove IS_CLOUD_EDITION guard from supportsCredits for local testing 2026-03-05 08:34:10 +08:00
yyh
b8b70da9ad refactor(web): rewrite CredentialPanel with declarative variant-driven state and new ModelAuthDropdown
- Extract useCredentialPanelState hook with discriminated union CardVariant type replacing scattered boolean conditions
- Create ModelAuthDropdown compound component (Popover-based) with UsagePrioritySection, CreditsExhaustedAlert, and ApiKeySection
- Enhance SystemQuotaCard.Label to accept className override for flexible styling
- Add i18n keys for new card states and dropdown content (en-US, zh-Hans)
2026-03-05 08:33:04 +08:00
yyh
77d81aebe8 Merge remote-tracking branch 'origin/main' into feat/model-provider-refactor 2026-03-04 23:35:20 +08:00
yyh
deb4cd3ece fix: i18n 2026-03-04 23:35:13 +08:00
yyh
648d9ef1f9 refactor(web): extract SystemQuotaCard compound component and shared useTrialCredits hook
Extract trial credits calculation into a shared useTrialCredits hook to prevent
logic drift between QuotaPanel and CredentialPanel. Add SystemQuotaCard compound
component with explicit default/destructive variants for the system quota UI
state in provider cards, replacing inline conditional styling with composable
Label and Actions slots. Remove unnecessary useMemo for simple derived values.
2026-03-04 23:30:25 +08:00
yyh
5ed4797078 fix 2026-03-04 22:53:29 +08:00
yyh
62631658e9 fix(web): update tests for AlertDialog migration and component API changes
- Replace deprecated Confirm mock with real AlertDialog role-based queries
- Add useInvalidateCheckInstalled mock for QueryClient dependency
- Wrap model-list-item renders in QueryClientProvider
- Migrate PluginVersionPicker from PortalToFollowElem to Popover
- Migrate UpdatePluginModal from Modal to Dialog
- Update version picker offset props (sideOffset/alignOffset)
2026-03-04 22:52:21 +08:00
yyh
22a4100dd7 fix(web): invalidate plugin checkInstalled cache after version updates 2026-03-04 22:33:17 +08:00
yyh
0f7ed6f67e refactor(web): align provider badges with figma and remove dead add-model-button 2026-03-04 22:29:51 +08:00
yyh
4d9fcbec57 refactor(web): migrate remove-plugin dialog to base UI AlertDialog and improve UX
- Replace deprecated Confirm component with AlertDialog primitives
- Add forceRender backdrop for proper overlay rendering
- Add success Toast notification after plugin removal
- Update "View Detail" text to "View on Marketplace" (en/zh-Hans)
- Add i18n keys for delete success message
- Prune stale eslint suppression for header-modals
2026-03-04 22:14:19 +08:00
yyh
4d7a9bc798 fix(web): align model provider cache invalidation with oRPC keys 2026-03-04 22:06:27 +08:00
yyh
d6d04ed657 fix 2026-03-04 22:03:06 +08:00
yyh
f594a71dae fix: icon 2026-03-04 22:02:36 +08:00
yyh
04e0ab7eda refactor(web): migrate provider-added-card model list to oRPC query-driven state 2026-03-04 21:55:34 +08:00
yyh
784bda9c86 refactor(web): migrate operation-dropdown to base UI and align provider card styles with Figma
- Migrate OperationDropdown from legacy portal-to-follow-elem to base UI DropdownMenu primitives
- Add placement, sideOffset, alignOffset, popupClassName props for flexible positioning
- Fix version badge font size: system-2xs-medium-uppercase (10px) → system-xs-medium-uppercase (12px)
- Set provider card dropdown to bottom-start placement with 192px width per Figma spec
- Fix PluginVersionPicker toggle: clicking badge now opens and closes the picker
- Add max-h-[224px] overflow scroll to version list
- Replace Remix icon imports with Tailwind CSS icon classes
- Prune stale eslint suppressions for migrated files
2026-03-04 21:55:23 +08:00
yyh
1af1fb6913 feat(web): add version badge and actions menu to provider cards
Integrate plugin version management into model provider cards by
reusing existing plugin detail panel hooks and components. Batch
query installed plugins at list level to avoid N+1 requests.
2026-03-04 21:29:52 +08:00
yyh
1f0c36e9f7 fix: style 2026-03-04 21:07:42 +08:00
yyh
455ae65025 fix: style 2026-03-04 20:58:14 +08:00
yyh
d44682e957 refactor(web): align quota panel with Figma design and migrate to base UI tooltip
- Rename title from "Quota" to "AI Credits" and update tooltip copy
  (Message Credits → AI Credits, free → Trial)
- Show "Credits exhausted" in destructive text when credits reach zero
  instead of displaying the number "0"
- Migrate from deprecated Tooltip to base UI Tooltip compound component
- Add 4px grid background with radial fade mask via CSS module
- Simplify provider icon tooltip text for uninstalled state
- Update i18n keys for both en-US and zh-Hans
2026-03-04 20:52:30 +08:00
yyh
8c4afc0c18 fix(model-selector): align empty trigger with default trigger style 2026-03-04 20:14:49 +08:00
yyh
539cbcae6a fix(account-settings): render nested system model backdrop via base ui 2026-03-04 19:57:53 +08:00
yyh
8d257fea7c chore(web): commit dialog overlay follow-up changes 2026-03-04 19:37:10 +08:00
yyh
c3364ac350 refactor(web): align account settings dialogs with base UI 2026-03-04 19:31:14 +08:00
yyh
f991644989 refactor(pricing): migrate to base ui dialog and extract category types 2026-03-04 19:26:54 +08:00
yyh
29e344ac8b temp: remove cloud condition 2026-03-04 18:50:38 +08:00
yyh
1ad9305732 fix(web): avoid quota panel flicker on account-setting tab switch
- remove mount-time workspace invalidate in model provider page

- read quota with useCurrentWorkspace and keep loading only for initial empty fetch

- reuse existing useSystemFeaturesQuery for marketplace and trial models

- update model provider and quota panel tests for new query/loading behavior
2026-03-04 18:43:01 +08:00
yyh
17f38f171d lint 2026-03-04 18:21:59 +08:00
yyh
802088c8eb test(web): fix trivial assertion and add useInvalidateDefaultModel tests
Replace the no-provider test assertion from checking a nonexistent i18n
key to verifying actual warning keys are absent. Add unit tests for
useInvalidateDefaultModel following the useUpdateModelList pattern.
2026-03-04 17:51:20 +08:00
yyh
cad6d94491 refactor(web): replace remixicon imports with Tailwind CSS icons in system-model-selector 2026-03-04 17:45:41 +08:00
yyh
621d0fb2c9 fix 2026-03-04 17:42:34 +08:00
yyh
a92fb3244b fix(web): skip top warning for no-provider state and remove unused i18n key
The empty state card below already prompts users to install a provider,
so the top warning bar is redundant for the no-provider case. Remove
the unused noProviderInstalled i18n key and replace the lookup map with
a ternary to preserve i18n literal types without assertions.
2026-03-04 17:39:49 +08:00
yyh
97508f8d7b fix(web): invalidate default model cache after saving system model settings
After saving system models, only the model list cache was invalidated
but not the default model cache, causing stale config status in the UI.
Add useInvalidateDefaultModel hook and call it for all 5 model types
after a successful save.
2026-03-04 17:26:24 +08:00
yyh
70e677a6ac feat(web): refine system model settings to 4 distinct config states
Replace the single `defaultModelNotConfigured` boolean with a derived
`systemModelConfigStatus` that distinguishes between no-provider,
none-configured, partially-configured, and fully-configured states,
each showing a context-appropriate warning message. Also updates the
button label from "System Model Settings" to "Default Model Settings"
and migrates remixicon imports to Tailwind CSS icon classes.
2026-03-04 16:58:46 +08:00
318 changed files with 33653 additions and 7150 deletions

33
.github/actions/setup-web/action.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Setup Web Environment
description: Setup pnpm, Node.js, and install web dependencies.
inputs:
node-version:
description: Node.js version to use
required: false
default: "22"
install-dependencies:
description: Whether to install web dependencies after setting up Node.js
required: false
default: "true"
runs:
using: composite
steps:
- name: Install pnpm
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
with:
node-version: ${{ inputs.node-version }}
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: ${{ inputs.install-dependencies == 'true' }}
shell: bash
run: pnpm --dir web install --frozen-lockfile

View File

@@ -24,6 +24,18 @@ updates:
schedule:
interval: "weekly"
open-pull-requests-limit: 2
ignore:
- dependency-name: "ky"
- dependency-name: "tailwind-merge"
update-types: ["version-update:semver-major"]
- dependency-name: "tailwindcss"
update-types: ["version-update:semver-major"]
- dependency-name: "react-markdown"
update-types: ["version-update:semver-major"]
- dependency-name: "react-syntax-highlighter"
update-types: ["version-update:semver-major"]
- dependency-name: "react-window"
update-types: ["version-update:semver-major"]
groups:
lexical:
patterns:
@@ -33,6 +45,9 @@ updates:
patterns:
- "storybook"
- "@storybook/*"
eslint-group:
patterns:
- "*eslint*"
npm-dependencies:
patterns:
- "*"
@@ -41,3 +56,4 @@ updates:
- "@lexical/*"
- "storybook"
- "@storybook/*"
- "*eslint*"

View File

@@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -51,7 +51,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@@ -12,22 +12,22 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v6
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
@@ -84,4 +84,14 @@ jobs:
run: |
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
- name: Setup web environment
uses: ./.github/actions/setup-web
with:
node-version: "24"
- name: ESLint autofix
run: |
cd web
pnpm eslint --concurrency=2 --prune-suppressions
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3

View File

@@ -53,26 +53,26 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
with:
images: ${{ env[matrix.image_name_env] }}
- name: Build Docker image
id: build
uses: docker/build-push-action@v6
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
context: "{{defaultContext}}:${{ matrix.context }}"
platforms: ${{ matrix.platform }}
@@ -91,7 +91,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
@@ -113,21 +113,21 @@ jobs:
context: "web"
steps:
- name: Download digests
uses: actions/download-artifact@v7
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
with:
images: ${{ env[matrix.image_name_env] }}
tags: |

View File

@@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: "3.12"
@@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: "3.12"
@@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@@ -19,7 +19,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}

View File

@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}

View File

@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'build/feat/hitl'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.HITL_SSH_HOST }}
username: ${{ secrets.SSH_USER }}

View File

@@ -32,13 +32,13 @@ jobs:
context: "web"
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Build Docker Image
uses: docker/build-push-action@v6
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
push: false
context: "{{defaultContext}}:${{ matrix.context }}"

View File

@@ -9,6 +9,6 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v6
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with:
sync-labels: true

View File

@@ -27,8 +27,8 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v6
- uses: dorny/paths-filter@v3
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
id: changes
with:
filters: |
@@ -39,6 +39,7 @@ jobs:
web:
- 'web/**'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'

View File

@@ -21,7 +21,7 @@ jobs:
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
uses: actions/github-script@v8
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -49,7 +49,7 @@ jobs:
run: unzip -o pyrefly_diff.zip
- name: Post comment
uses: actions/github-script@v8
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@@ -17,12 +17,12 @@ jobs:
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
@@ -55,7 +55,7 @@ jobs:
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: pyrefly_diff
path: |
@@ -64,7 +64,7 @@ jobs:
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@v8
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@@ -16,6 +16,6 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v10
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
with:
days-before-issue-stale: 15
days-before-issue-close: 3

View File

@@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
api/**
@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: false
python-version: "3.12"
@@ -67,36 +67,22 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/style.yml
.github/actions/setup-web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v6
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
uses: ./.github/actions/setup-web
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
@@ -134,14 +120,14 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
**.sh
@@ -152,7 +138,7 @@ jobs:
.editorconfig
- name: Super-linter
uses: super-linter/super-linter/slim@v8
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@@ -21,12 +21,12 @@ jobs:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Use Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
with:
node-version: 22
cache: ''

View File

@@ -38,7 +38,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@@ -48,18 +48,10 @@ jobs:
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- name: Install pnpm
uses: pnpm/action-setup@v4
- name: Setup web environment
uses: ./.github/actions/setup-web
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
install-dependencies: "false"
- name: Detect changed files and generate diff
id: detect_changes
@@ -130,7 +122,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@v1
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -21,7 +21,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
@@ -59,7 +59,7 @@ jobs:
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: peter-evans/repository-dispatch@v3
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
with:
token: ${{ secrets.GITHUB_TOKEN }}
event-type: i18n-sync

View File

@@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@v3
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -60,7 +60,7 @@ jobs:
# tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@@ -26,32 +26,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Run tests
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*
@@ -70,28 +57,15 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Download blob reports
uses: actions/download-artifact@v6
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: web/.vitest-reports
pattern: blob-report-*
@@ -419,7 +393,7 @@ jobs:
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: web-coverage-report
path: web/coverage
@@ -435,36 +409,22 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/web-tests.yml
.github/actions/setup-web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v6
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
uses: ./.github/actions/setup-web
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -44,7 +44,6 @@ forbidden_modules =
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
@@ -114,7 +113,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.tool.tool_node -> models
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
@@ -135,7 +133,6 @@ ignore_imports =
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models

View File

@@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource):
console_ns.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
custom="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -57,7 +56,7 @@ class ToolFileApi(Resource):
raise Forbidden("Invalid request.")
try:
tool_file_manager = ToolFileManager(engine=global_db.engine)
tool_file_manager = ToolFileManager()
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
file_id,
)

View File

@@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError()
raise NotChatAppError()
message_id = str(message_id)

View File

@@ -10,28 +10,18 @@ from typing import Union
from uuid import uuid4
import httpx
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from extensions.ext_database import db as global_db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
from sqlalchemy.engine import Engine
class ToolFileManager:
_engine: Engine
def __init__(self, engine: Engine | None = None):
if engine is None:
engine = global_db.engine
self._engine = engine
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
@@ -89,7 +79,7 @@ class ToolFileManager:
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@@ -132,7 +122,7 @@ class ToolFileManager:
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@@ -157,7 +147,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
@@ -181,7 +171,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.where(
@@ -225,7 +215,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(

View File

@@ -250,6 +250,7 @@ class DifyNodeFactory(NodeFactory):
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.DATASOURCE:
@@ -292,6 +293,7 @@ class DifyNodeFactory(NodeFactory):
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:

View File

@@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from dify_graph.variables import (
ArrayFileSegment,
@@ -47,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
@@ -56,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
*,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
@@ -69,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
self._file_outputs = []
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls):
return "1"

View File

@@ -1,14 +1,11 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db as global_db
from dify_graph.nodes.protocols import HttpClientProtocol
class LLMFileSaver(tp.Protocol):
@@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol):
raise NotImplementedError()
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
self._user_id = user_id
self._tenant_id = tenant_id
self._http_client = http_client
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
return ToolFileManager()
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response = self._http_client.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")

View File

@@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
@@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]):
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver

View File

@@ -28,6 +28,7 @@ from dify_graph.nodes.llm import (
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver

View File

@@ -21,6 +21,10 @@ celery_redis = Redis(
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
)
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import math
import time
from collections.abc import Iterable, Sequence
from celery import group
from sqlalchemy import ColumnElement, and_, func, or_, select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
@@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
enqueued: int = 0
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
if not is_locked:
continue
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
enqueued += 1
if not any(acquired):
continue
jobs = [
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
if is_locked
]
result = group(jobs).apply_async()
enqueued = len(jobs)
logger.info(
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
page + 1,
pages,
len(subscriptions),
sum(1 for x in acquired if x),
enqueued,
result,
)
logger.info("Trigger refresh scan done: due=%d", total_due)

View File

@@ -1,6 +1,6 @@
import logging
from celery import group, shared_task
from celery import current_app, group, shared_task
from sqlalchemy import and_, select
from sqlalchemy.orm import Session, sessionmaker
@@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
with session_factory() as session:
total_dispatched = 0
# Process in batches until we've handled all due schedules or hit the limit
while True:
due_schedules = _fetch_due_schedules(session)
if not due_schedules:
break
dispatched_count = _process_schedules(session, due_schedules)
total_dispatched += dispatched_count
with current_app.producer_or_acquire() as producer: # type: ignore
dispatched_count = _process_schedules(session, due_schedules, producer)
total_dispatched += dispatched_count
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if (
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
):
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
if total_dispatched > 0:
logger.info("Total processed: %d dispatched", total_dispatched)
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
@@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
return list(due_schedules)
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
if not schedules:
return 0
@@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
if tasks_to_dispatch:
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
job.apply_async()
job.apply_async(producer=producer)
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))

View File

@@ -1,9 +1,10 @@
import logging
import time
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Any, Protocol
import click
from celery import shared_task
from celery import current_app, shared_task
from configs import dify_config
from core.db.session_factory import session_factory
@@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
class CeleryTaskLike(Protocol):
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
@shared_task(queue="dataset")
def document_indexing_task(dataset_id: str, document_ids: list):
"""
@@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
def _document_indexing_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
) -> None:
try:
_document_indexing(dataset_id, document_ids)
except Exception:
@@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
with current_app.producer_or_acquire() as producer: # type: ignore
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.apply_async(
kwargs={
"tenant_id": document_task.tenant_id,
"dataset_id": document_task.dataset_id,
"document_ids": document_task.document_ids,
},
producer=producer,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@@ -3,12 +3,13 @@ import json
import logging
import time
import uuid
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Any
import click
from celery import shared_task # type: ignore
from celery import group, shared_task
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
@@ -27,6 +28,11 @@ from services.file_service import FileService
logger = logging.getLogger(__name__)
def chunked(iterable: Sequence, size: int):
it = iter(iterable)
return iter(lambda: list(islice(it, size)), [])
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
@@ -83,16 +89,24 @@ def rag_pipeline_run_task(
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
for batch in chunked(next_file_ids, 100):
jobs = []
for next_file_id in batch:
tenant_isolated_task_queue.set_task_waiting_time()
file_id = (
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
)
jobs.append(
rag_pipeline_run_task.s(
rag_pipeline_invoke_entities_file_id=file_id,
tenant_id=tenant_id,
)
)
if jobs:
group(jobs).apply_async()
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
@@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
http_client=MagicMock(spec=HttpClientProtocol),
)
return node

View File

@@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once_with(
tenant_id=next_task["tenant_id"],
dataset_id=next_task["dataset_id"],
document_ids=next_task["document_ids"],
)
# apply_async is used by implementation; assert it was called once with expected kwargs
assert task_dispatch_spy.apply_async.call_count == 1
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
assert call_kwargs == {
"tenant_id": next_task["tenant_id"],
"dataset_id": next_task["dataset_id"],
"document_ids": next_task["document_ids"],
}
set_waiting_spy.assert_called_once()
delete_key_spy.assert_not_called()
@@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_not_called()
task_dispatch_spy.apply_async.assert_not_called()
delete_key_spy.assert_called_once()
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
@@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once()
task_dispatch_spy.apply_async.assert_called_once()
def test_sessions_close_on_successful_indexing(
self,
@@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == concurrency_limit
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
assert set_waiting_spy.call_count == concurrency_limit
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
@@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == 3
assert task_dispatch_spy.apply_async.call_count == 3
for index, expected_task in enumerate(ordered_tasks):
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
"""Skip limit checks when billing feature is disabled."""

View File

@@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
assert mock_task_func.apply_async.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
calls = mock_task_func.apply_async.call_args_list
sent_kwargs = calls[0][1]["kwargs"]
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (tasks were pulled)
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
@@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
mock_task_func.apply_async.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant_id,
"dataset_id": dataset_id,
"document_ids": ["waiting-doc-1"],
}
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only tenant1's waiting task was processed
mock_task_func.delay.assert_called_once()
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
mock_task_func.apply_async.assert_called_once()
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant1_id,
"dataset_id": dataset1_id,
"document_ids": ["tenant1-doc-1"],
}
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)

View File

@@ -1,6 +1,6 @@
import json
import uuid
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
@@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
# Set the task key to indicate there are waiting tasks (legacy behavior)
redis_client.set(legacy_task_key, 1, ex=60 * 60)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the priority task with new code but legacy queue data
rag_pipeline_run_task(file_id, tenant.id)
@@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group, pull 1 task a time by default
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify that new code can process legacy queue entries
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
@@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
@@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group.apply_async
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task (should not raise exception)
rag_pipeline_run_task(file_id, tenant.id)
@@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task for tenant1 only
rag_pipeline_run_task(file_id1, tenant1.id)
@@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert call_kwargs.get("tenant_id") == tenant1.id
# Verify only tenant1's waiting task was processed (via group)
assert mock_group.return_value.apply_async.called
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert first_kwargs.get("tenant_id") == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
@@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act & Assert: Execute the regular task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
rag_pipeline_run_task(file_id, tenant.id)
@@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)

View File

@@ -105,18 +105,26 @@ def app_model(
class MockCeleryGroup:
"""Mock for celery group() function that collects dispatched tasks."""
"""Mock for celery group() function that collects dispatched tasks.
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
(e.g. producer) so production code can pass broker-related options without
breaking tests.
"""
def __init__(self) -> None:
self.collected: list[dict[str, Any]] = []
self._applied = False
self.last_apply_async_kwargs: dict[str, Any] | None = None
def __call__(self, items: Any) -> MockCeleryGroup:
self.collected = list(items)
return self
def apply_async(self) -> None:
def apply_async(self, **kwargs: Any) -> None:
# Accept arbitrary kwargs like producer to be compatible with Celery
self._applied = True
self.last_apply_async_kwargs = kwargs
@property
def applied(self) -> bool:

View File

@@ -0,0 +1,817 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.datasource_auth import (
DatasourceAuth,
DatasourceAuthDefaultApi,
DatasourceAuthDeleteApi,
DatasourceAuthListApi,
DatasourceAuthOauthCustomClient,
DatasourceAuthUpdateApi,
DatasourceHardCodeAuthListApi,
DatasourceOAuthCallback,
DatasourcePluginOAuthAuthorizationUrl,
DatasourceUpdateProviderNameApi,
)
from core.plugin.impl.oauth import OAuthHandler
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDatasourcePluginOAuthAuthorizationUrl:
def test_get_success(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/?credential_id=cred-1"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthProxyService,
"create_proxy_context",
return_value="ctx-1",
),
patch.object(
OAuthHandler,
"get_authorization_url",
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
assert response.status_code == 200
def test_get_no_oauth_config(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value=None,
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_without_credential_id_sets_cookie(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthProxyService,
"create_proxy_context",
return_value="ctx-123",
),
patch.object(
OAuthHandler,
"get_authorization_url",
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
assert response.status_code == 200
assert "context_id" in response.headers.get("Set-Cookie")
class TestDatasourceOAuthCallback:
def test_callback_success_new_credential(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {"name": "test"}
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": None,
}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"add_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
def test_callback_missing_context(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_invalid_context(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
with (
app.test_request_context("/?context_id=bad"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_oauth_config_not_found(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
context = {"user_id": "u", "tenant_id": "t"}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "notion")
def test_callback_reauthorize_existing_credential(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {} # avatar + name missing
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": "cred-1",
}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"reauthorize_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
assert "/oauth-callback" in response.location
def test_callback_context_id_from_cookie(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {}
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": None,
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"add_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
class TestDatasourceAuth:
def test_post_success(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {"credentials": {"key": "val"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_post_invalid_credentials(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {"credentials": {"key": "bad"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
side_effect=CredentialsValidateFailedError("invalid"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_success(self, app):
api = DatasourceAuth()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api, "notion")
assert status == 200
assert response["result"]
def test_post_missing_credentials(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_empty_list(self, app):
api = DatasourceAuth()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[],
),
):
response, status = method(api, "notion")
assert status == 200
assert response["result"] == []
class TestDatasourceAuthDeleteApi:
def test_delete_success(self, app):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
payload = {"credential_id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_delete_missing_credential_id(self, app):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
class TestDatasourceAuthUpdateApi:
def test_update_success(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": {"k": "v"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 201
def test_update_with_credentials_none(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": None}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
response, status = method(api, "notion")
update_mock.assert_called_once()
assert status == 201
def test_update_name_only(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
_, status = method(api, "notion")
assert status == 201
def test_update_with_empty_credentials_dict(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": {}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
_, status = method(api, "notion")
update_mock.assert_called_once()
assert status == 201
class TestDatasourceAuthListApi:
def test_list_success(self, app):
api = DatasourceAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
assert status == 200
def test_auth_list_empty(self, app):
api = DatasourceAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
assert status == 200
assert response["result"] == []
def test_hardcode_list_empty(self, app):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
assert status == 200
assert response["result"] == []
class TestDatasourceHardCodeAuthListApi:
def test_list_success(self, app):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
assert status == 200
class TestDatasourceAuthOauthCustomClient:
def test_post_success(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {"client_params": {}, "enable_oauth_custom_client": True}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_delete_success(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_post_empty_payload(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
_, status = method(api, "notion")
assert status == 200
def test_post_disabled_flag(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {
"client_params": {"a": 1},
"enable_oauth_custom_client": False,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
) as setup_mock,
):
_, status = method(api, "notion")
setup_mock.assert_called_once()
assert status == 200
class TestDatasourceAuthDefaultApi:
def test_set_default_success(self, app):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
payload = {"id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"set_default_datasource_provider",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_default_missing_id(self, app):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
class TestDatasourceUpdateProviderNameApi:
def test_update_name_success(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_provider_name",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_update_name_too_long(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {
"credential_id": "id",
"name": "x" * 101,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_update_name_missing_credential_id(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {"name": "Valid"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")

View File

@@ -0,0 +1,143 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
DataSourceContentPreviewApi,
)
from models import Account
from models.dataset import Pipeline
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDataSourceContentPreviewApi:
def _valid_payload(self):
return {
"inputs": {"query": "hello"},
"datasource_type": "notion",
"credential_id": "cred-1",
}
def test_post_success(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = self._valid_payload()
pipeline = MagicMock(spec=Pipeline)
node_id = "node-1"
account = MagicMock(spec=Account)
preview_result = {"content": "preview data"}
service_instance = MagicMock()
service_instance.run_datasource_node_preview.return_value = preview_result
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
return_value=service_instance,
),
):
response, status = method(api, pipeline, node_id)
service_instance.run_datasource_node_preview.assert_called_once_with(
pipeline=pipeline,
node_id=node_id,
user_inputs=payload["inputs"],
account=account,
datasource_type=payload["datasource_type"],
is_published=True,
credential_id=payload["credential_id"],
)
assert status == 200
assert response == preview_result
def test_post_forbidden_non_account_user(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = self._valid_payload()
pipeline = MagicMock(spec=Pipeline)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
MagicMock(), # NOT Account
),
):
with pytest.raises(Forbidden):
method(api, pipeline, "node-1")
def test_post_invalid_payload(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = {
"inputs": {"query": "hello"},
# datasource_type missing
}
pipeline = MagicMock(spec=Pipeline)
account = MagicMock(spec=Account)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
):
with pytest.raises(ValueError):
method(api, pipeline, "node-1")
def test_post_without_credential_id(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = {
"inputs": {"query": "hello"},
"datasource_type": "notion",
"credential_id": None,
}
pipeline = MagicMock(spec=Pipeline)
account = MagicMock(spec=Account)
service_instance = MagicMock()
service_instance.run_datasource_node_preview.return_value = {"ok": True}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
return_value=service_instance,
),
):
response, status = method(api, pipeline, "node-1")
service_instance.run_datasource_node_preview.assert_called_once()
assert status == 200
assert response == {"ok": True}

View File

@@ -0,0 +1,187 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
CustomizedPipelineTemplateApi,
PipelineTemplateDetailApi,
PipelineTemplateListApi,
PublishCustomizedPipelineTemplateApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestPipelineTemplateListApi:
def test_get_success(self, app):
api = PipelineTemplateListApi()
method = unwrap(api.get)
templates = [{"id": "t1"}]
with (
app.test_request_context("/?type=built-in&language=en-US"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
return_value=templates,
),
):
response, status = method(api)
assert status == 200
assert response == templates
class TestPipelineTemplateDetailApi:
def test_get_success(self, app):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
template = {"id": "tpl-1"}
service = MagicMock()
service.get_pipeline_template_detail.return_value = template
with (
app.test_request_context("/?type=built-in"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
return_value=service,
),
):
response, status = method(api, "tpl-1")
assert status == 200
assert response == template
class TestCustomizedPipelineTemplateApi:
def test_patch_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
payload = {
"name": "Template",
"description": "Desc",
"icon_info": {"icon": "📘"},
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
) as update_mock,
):
response = method(api, "tpl-1")
update_mock.assert_called_once()
assert response == 200
def test_delete_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
) as delete_mock,
):
response = method(api, "tpl-1")
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
template = MagicMock()
template.yaml_content = "yaml-data"
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = template
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
response, status = method(api, "tpl-1")
assert status == 200
assert response == {"data": "yaml-data"}
def test_post_template_not_found(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
with pytest.raises(ValueError):
method(api, "tpl-1")
class TestPublishCustomizedPipelineTemplateApi:
def test_post_success(self, app):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)
payload = {
"name": "Template",
"description": "Desc",
"icon_info": {"icon": "📘"},
}
service = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
return_value=service,
),
):
response = method(api, "pipeline-1")
service.publish_customized_pipeline_template.assert_called_once()
assert response == {"result": "success"}

View File

@@ -0,0 +1,187 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
import services
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
CreateEmptyRagPipelineDatasetApi,
CreateRagPipelineDatasetApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestCreateRagPipelineDatasetApi:
def _valid_payload(self):
return {"yaml_content": "name: test"}
def test_post_success(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=True)
import_info = {"dataset_id": "ds-1"}
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.return_value = import_info
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
response, status = method(api)
assert status == 201
assert response == import_info
def test_post_forbidden_non_editor(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)
def test_post_dataset_name_duplicate(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=True)
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = {}
user = MagicMock(is_dataset_editor=True)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api)
class TestCreateEmptyRagPipelineDatasetApi:
def test_post_success(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
user = MagicMock(is_dataset_editor=True)
dataset = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
return_value={"id": "ds-1"},
),
):
response, status = method(api)
assert status == 201
assert response == {"id": "ds-1"}
def test_post_forbidden_non_editor(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)

View File

@@ -0,0 +1,324 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Response
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
RagPipelineEnvironmentVariableCollectionApi,
RagPipelineNodeVariableCollectionApi,
RagPipelineSystemVariableCollectionApi,
RagPipelineVariableApi,
RagPipelineVariableCollectionApi,
RagPipelineVariableResetApi,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.types import SegmentType
from models.account import Account
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def fake_db():
db = MagicMock()
db.engine = MagicMock()
db.session.return_value = MagicMock()
return db
@pytest.fixture
def editor_user():
user = MagicMock(spec=Account)
user.has_edit_permission = True
return user
@pytest.fixture
def restx_config(app):
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
class TestRagPipelineVariableCollectionApi:
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
rag_srv = MagicMock()
rag_srv.is_workflow_exist.return_value = True
# IMPORTANT: RESTX expects .variables
var_list = MagicMock()
var_list.variables = []
draft_srv = MagicMock()
draft_srv.list_variables_without_values.return_value = var_list
with (
app.test_request_context("/?page=1&limit=10"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=draft_srv,
),
):
result = method(api, pipeline)
assert result["items"] == []
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock()
rag_srv = MagicMock()
rag_srv.is_workflow_exist.return_value = False
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
):
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_delete_variables_success(self, app, fake_db, editor_user):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.delete)
pipeline = MagicMock(id="p1")
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
):
result = method(api, pipeline)
assert isinstance(result, Response)
assert result.status_code == 204
class TestRagPipelineNodeVariableCollectionApi:
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineNodeVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
var_list = MagicMock()
var_list.variables = []
srv = MagicMock()
srv.list_node_variables.return_value = var_list
with (
app.test_request_context("/"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline, "node1")
assert result["items"] == []
def test_get_node_variables_invalid_node(self, app, editor_user):
api = RagPipelineNodeVariableCollectionApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
):
with pytest.raises(InvalidArgumentError):
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
class TestRagPipelineVariableApi:
def test_get_variable_not_found(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.get)
srv = MagicMock()
srv.get_variable.return_value = None
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
with pytest.raises(NotFoundError):
method(api, MagicMock(), "v1")
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.patch)
pipeline = MagicMock(id="p1", tenant_id="t1")
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
srv = MagicMock()
srv.get_variable.return_value = variable
payload = {"value": "invalid"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
with pytest.raises(InvalidArgumentError):
method(api, pipeline, "v1")
def test_delete_variable_success(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.delete)
pipeline = MagicMock(id="p1")
variable = MagicMock(app_id="p1")
srv = MagicMock()
srv.get_variable.return_value = variable
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline, "v1")
assert result.status_code == 204
class TestRagPipelineVariableResetApi:
def test_reset_variable_success(self, app, fake_db, editor_user):
api = RagPipelineVariableResetApi()
method = unwrap(api.put)
pipeline = MagicMock(id="p1")
workflow = MagicMock()
variable = MagicMock(app_id="p1")
srv = MagicMock()
srv.get_variable.return_value = variable
srv.reset_variable.return_value = variable
rag_srv = MagicMock()
rag_srv.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
return_value={"id": "v1"},
),
):
result = method(api, pipeline, "v1")
assert result == {"id": "v1"}
class TestSystemAndEnvironmentVariablesApi:
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineSystemVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
var_list = MagicMock()
var_list.variables = []
srv = MagicMock()
srv.list_system_variables.return_value = var_list
with (
app.test_request_context("/"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline)
assert result["items"] == []
def test_environment_variables_success(self, app, editor_user):
api = RagPipelineEnvironmentVariableCollectionApi()
method = unwrap(api.get)
env_var = MagicMock(
id="e1",
name="ENV",
description="d",
selector="s",
value_type=MagicMock(value="string"),
value="x",
)
workflow = MagicMock(environment_variables=[env_var])
pipeline = MagicMock(id="p1")
rag_srv = MagicMock()
rag_srv.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
):
result = method(api, pipeline)
assert len(result["items"]) == 1

View File

@@ -0,0 +1,329 @@
from unittest.mock import MagicMock, patch
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
RagPipelineExportApi,
RagPipelineImportApi,
RagPipelineImportCheckDependenciesApi,
RagPipelineImportConfirmApi,
)
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestRagPipelineImportApi:
def _payload(self, mode="create"):
return {
"mode": mode,
"yaml_content": "content",
"name": "Test",
}
def test_post_success_200(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = "completed"
result.model_dump.return_value = {"status": "success"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 200
assert response == {"status": "success"}
def test_post_failed_400(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
result.model_dump.return_value = {"status": "failed"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 400
assert response == {"status": "failed"}
def test_post_pending_202(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.PENDING
result.model_dump.return_value = {"status": "pending"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 202
assert response == {"status": "pending"}
class TestRagPipelineImportConfirmApi:
def test_confirm_success(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
user = MagicMock()
result = MagicMock()
result.status = "completed"
result.model_dump.return_value = {"ok": True}
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
assert status == 200
assert response == {"ok": True}
def test_confirm_failed(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
result.model_dump.return_value = {"ok": False}
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
assert status == 400
assert response == {"ok": False}
class TestRagPipelineImportCheckDependenciesApi:
def test_get_success(self, app):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
pipeline = MagicMock(spec=Pipeline)
result = MagicMock()
result.model_dump.return_value = {"deps": []}
service = MagicMock()
service.check_dependencies.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, pipeline)
assert status == 200
assert response == {"deps": []}
class TestRagPipelineExportApi:
def test_get_with_include_secret(self, app):
api = RagPipelineExportApi()
method = unwrap(api.get)
pipeline = MagicMock(spec=Pipeline)
service = MagicMock()
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/?include_secret=true"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, pipeline)
assert status == 200
assert response == {"data": {"yaml": "data"}}

View File

@@ -0,0 +1,688 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
DefaultRagPipelineBlockConfigApi,
DraftRagPipelineApi,
DraftRagPipelineRunApi,
PublishedAllRagPipelineApi,
PublishedRagPipelineApi,
PublishedRagPipelineRunApi,
RagPipelineByIdApi,
RagPipelineDatasourceVariableApi,
RagPipelineDraftNodeRunApi,
RagPipelineDraftRunIterationNodeApi,
RagPipelineDraftRunLoopNodeApi,
RagPipelineRecommendedPluginApi,
RagPipelineTaskStopApi,
RagPipelineTransformApi,
RagPipelineWorkflowLastRunApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDraftWorkflowApi:
def test_get_draft_success(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
workflow = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result == workflow
def test_get_draft_not_exist(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_sync_hash_not_match(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
service = MagicMock()
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
with (
app.test_request_context("/", json={"graph": {}, "features": {}}),
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(DraftWorkflowNotSync):
method(api, pipeline)
def test_sync_invalid_text_plain(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
response, status = method(api, pipeline)
assert status == 400
class TestDraftRunNodes:
def test_iteration_node_success(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
result = method(api, pipeline, "node")
assert result == {"ok": True}
def test_iteration_node_conversation_not_exists(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
side_effect=services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(api, pipeline, "node")
def test_loop_node_success(self, app):
api = RagPipelineDraftRunLoopNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
assert method(api, pipeline, "node") == {"ok": True}
class TestPipelineRunApis:
def test_draft_run_success(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
assert method(api, pipeline) == {"ok": True}
def test_draft_run_rate_limit(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context(
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
),
patch.object(
type(console_ns),
"payload",
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
class TestDraftNodeRun:
def test_execution_not_found(self, app):
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
service = MagicMock()
service.run_draft_workflow_node.return_value = None
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(ValueError):
method(api, pipeline, "node")
class TestPublishedPipelineApis:
def test_publish_success(self, app):
api = PublishedRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="u1")
workflow = MagicMock(
id="w1",
created_at=datetime.utcnow(),
)
session = MagicMock()
session.merge.return_value = pipeline
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
service = MagicMock()
service.publish_workflow.return_value = workflow
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result["result"] == "success"
assert "created_at" in result
class TestMiscApis:
def test_task_stop(self, app):
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="u1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
) as stop_mock,
):
result = method(api, pipeline, "task-1")
stop_mock.assert_called_once()
assert result["result"] == "success"
def test_transform_forbidden(self, app):
api = RagPipelineTransformApi()
method = unwrap(api.post)
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, "ds1")
def test_recommended_plugins(self, app):
api = RagPipelineRecommendedPluginApi()
method = unwrap(api.get)
service = MagicMock()
service.get_recommended_plugins.return_value = [{"id": "p1"}]
with (
app.test_request_context("/?type=all"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api)
assert result == [{"id": "p1"}]
class TestPublishedRagPipelineRunApi:
def test_published_run_success(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
"response_mode": "blocking",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
result = method(api, pipeline)
assert result == {"ok": True}
def test_published_run_rate_limit(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
class TestDefaultBlockConfigApi:
def test_get_block_config_success(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_default_block_config.return_value = {"k": "v"}
with (
app.test_request_context("/?q={}"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "llm")
assert result == {"k": "v"}
def test_get_block_config_invalid_json(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
with app.test_request_context("/?q=bad-json"):
with pytest.raises(ValueError):
method(api, pipeline, "llm")
class TestPublishedAllRagPipelineApi:
def test_get_published_workflows_success(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
service = MagicMock()
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result["items"] == [{"id": "w1"}]
assert result["has_more"] is False
def test_get_published_workflows_forbidden(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
with (
app.test_request_context("/?user_id=u2"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, pipeline)
class TestRagPipelineByIdApi:
def test_patch_success(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock(tenant_id="t1")
user = MagicMock(id="u1")
workflow = MagicMock()
service = MagicMock()
service.update_workflow.return_value = workflow
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
payload = {"marked_name": "test"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "w1")
assert result == workflow
def test_patch_no_fields(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={}),
patch.object(type(console_ns), "payload", {}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
result, status = method(api, pipeline, "w1")
assert status == 400
class TestRagPipelineWorkflowLastRunApi:
def test_last_run_success(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
workflow = MagicMock()
node_exec = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = workflow
service.get_node_last_run.return_value = node_exec
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "node1")
assert result == node_exec
def test_last_run_not_found(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(NotFound):
method(api, pipeline, "node1")
class TestRagPipelineDatasourceVariableApi:
def test_set_datasource_variables_success(self, app):
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"datasource_type": "db",
"datasource_info": {},
"start_node_id": "n1",
"start_node_title": "Node",
}
service = MagicMock()
service.set_datasource_variables.return_value = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result is not None

View File

@@ -0,0 +1,444 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import NotFound
from controllers.console.datasets import data_source
from controllers.console.datasets.data_source import (
DataSourceApi,
DataSourceNotionApi,
DataSourceNotionDatasetSyncApi,
DataSourceNotionDocumentSyncApi,
DataSourceNotionListApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def tenant_ctx():
return (MagicMock(id="u1"), "tenant-1")
@pytest.fixture
def patch_tenant(tenant_ctx):
with patch(
"controllers.console.datasets.data_source.current_account_with_tenant",
return_value=tenant_ctx,
):
yield
@pytest.fixture
def mock_engine():
with patch.object(
type(data_source.db),
"engine",
new_callable=PropertyMock,
return_value=MagicMock(),
):
yield
class TestDataSourceApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
binding = MagicMock(
id="b1",
provider="notion",
created_at="now",
disabled=False,
source_info={},
)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.db.session.scalars",
return_value=MagicMock(all=lambda: [binding]),
),
):
response, status = method(api)
assert status == 200
assert response["data"][0]["is_bound"] is True
def test_get_no_bindings(self, app, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.db.session.scalars",
return_value=MagicMock(all=lambda: []),
),
):
response, status = method(api)
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "enable")
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "disable")
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = None
with pytest.raises(NotFound):
method(api, "b1", "enable")
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "enable")
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "disable")
class TestDataSourceNotionListApi:
def test_get_credential_not_found(self, app, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api)
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
page = MagicMock(
page_id="p1",
page_name="Page 1",
type="page",
parent_id="parent",
page_icon=None,
)
online_document_message = MagicMock(
result=[
MagicMock(
workspace_id="w1",
workspace_name="My Workspace",
workspace_icon="icon",
pages=[page],
)
]
)
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=MagicMock(
get_online_document_pages=lambda **kw: iter([online_document_message]),
datasource_provider_type=lambda: None,
),
),
):
response, status = method(api)
assert status == 200
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
page = MagicMock(
page_id="p1",
page_name="Page 1",
type="page",
parent_id="parent",
page_icon=None,
)
online_document_message = MagicMock(
result=[
MagicMock(
workspace_id="w1",
workspace_name="My Workspace",
workspace_icon="icon",
pages=[page],
)
]
)
dataset = MagicMock(data_source_type="notion_import")
document = MagicMock(data_source_info='{"notion_page_id": "p1"}')
with (
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=MagicMock(
get_online_document_pages=lambda **kw: iter([online_document_message]),
datasource_provider_type=lambda: None,
),
),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = [document]
response, status = method(api)
assert status == 200
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
dataset = MagicMock(data_source_type="other_type")
with (
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session"),
):
with pytest.raises(ValueError):
method(api)
class TestDataSourceNotionApi:
def test_get_preview_success(self, app, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"integration_secret": "t"},
),
patch(
"controllers.console.datasets.data_source.NotionExtractor",
return_value=extractor,
),
):
response, status = method(api, "p1", "page")
assert status == 200
def test_post_indexing_estimate_success(self, app, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.post)
payload = {
"notion_info_list": [
{
"workspace_id": "w1",
"credential_id": "c1",
"pages": [{"page_id": "p1", "type": "page"}],
}
],
"process_rule": {"rules": {}},
"doc_form": "text_model",
"doc_language": "English",
}
with (
app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}),
patch(
"controllers.console.datasets.data_source.DocumentService.estimate_args_validate",
),
patch(
"controllers.console.datasets.data_source.IndexingRunner.indexing_estimate",
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
),
):
response, status = method(api)
assert status == 200
class TestDataSourceNotionDatasetSyncApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id",
return_value=[MagicMock(id="d1")],
),
patch(
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
return_value=None,
),
):
response, status = method(api, "ds-1")
assert status == 200
def test_get_dataset_not_found(self, app, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
class TestDataSourceNotionDocumentSyncApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1")
assert status == 200
def test_get_document_not_found(self, app, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,399 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.external import (
BedrockRetrievalApi,
ExternalApiTemplateApi,
ExternalApiTemplateListApi,
ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi,
)
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_external_dataset")
app.config["TESTING"] = True
return app
@pytest.fixture
def current_user():
user = MagicMock()
user.id = "user-1"
user.is_dataset_editor = True
user.has_edit_permission = True
user.is_dataset_operator = True
return user
@pytest.fixture(autouse=True)
def mock_auth(mocker, current_user):
mocker.patch(
"controllers.console.datasets.external.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
)
class TestExternalApiTemplateListApi:
def test_get_success(self, app):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
api_item = MagicMock()
api_item.to_dict.return_value = {"id": "1"}
with (
app.test_request_context("/?page=1&limit=20"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_apis",
return_value=([api_item], 1),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
def test_post_forbidden(self, app, current_user):
current_user.is_dataset_editor = False
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list"),
):
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list"),
patch.object(
ExternalDatasetService,
"create_external_knowledge_api",
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
class TestExternalApiTemplateApi:
def test_get_not_found(self, app):
api = ExternalApiTemplateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_api",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "api-id")
def test_delete_forbidden(self, app, current_user):
current_user.has_edit_permission = False
current_user.is_dataset_operator = False
api = ExternalApiTemplateApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "api-id")
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
"external_knowledge_id": "kid",
"name": "dataset",
}
dataset = MagicMock()
dataset.embedding_available = False
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.enable_qa = False
dataset.enable_vector_store = False
dataset.vector_store_setting = None
dataset.is_multimodal = False
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetService,
"create_external_dataset",
return_value=dataset,
),
):
_, status = method(api)
assert status == 201
def test_create_forbidden(self, app, current_user):
current_user.is_dataset_editor = False
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
"external_knowledge_id": "kid",
"name": "dataset",
}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
):
with pytest.raises(Forbidden):
method(api)
class TestExternalKnowledgeHitTestingApi:
def test_hit_testing_dataset_not_found(self, app):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-id")
def test_hit_testing_success(self, app):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
payload = {"query": "hello"}
dataset = MagicMock()
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission"),
patch.object(
HitTestingService,
"external_retrieve",
return_value={"ok": True},
),
):
resp = method(api, "dataset-id")
assert resp["ok"] is True
class TestBedrockRetrievalApi:
def test_bedrock_retrieval(self, app):
api = BedrockRetrievalApi()
method = unwrap(api.post)
payload = {
"retrieval_setting": {},
"query": "hello",
"knowledge_id": "kid",
}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetTestService,
"knowledge_retrieval",
return_value={"ok": True},
),
):
resp, status = method()
assert status == 200
assert resp["ok"] is True
class TestExternalApiTemplateListApiAdvanced:
def test_post_duplicate_name_error(self, app, mock_auth, current_user):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"),
patch(
"controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api",
side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_get_with_pagination(self, app, mock_auth, current_user):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
with (
app.test_request_context("/?page=1&limit=20"),
patch(
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
return_value=(templates, 25),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 25
assert len(resp["data"]) == 3
class TestExternalDatasetCreateApiAdvanced:
def test_create_forbidden(self, app, mock_auth, current_user):
"""Test creating external dataset without permission"""
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
current_user.is_dataset_editor = False
payload = {
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "ek-1",
"name": "new_dataset",
"description": "A dataset",
}
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(Forbidden):
method(api)
class TestExternalKnowledgeHitTestingApiAdvanced:
def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user):
"""Test hit testing on non-existent dataset"""
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
payload = {
"query": "test query",
"external_retrieval_model": None,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
dataset = MagicMock()
payload = {
"query": "test query",
"external_retrieval_model": {"type": "bm25"},
"metadata_filtering_conditions": {"status": "active"},
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
patch(
"controllers.console.datasets.external.HitTestingService.external_retrieve",
return_value={"results": []},
),
):
resp = method(api, "ds-1")
assert resp["results"] == []
class TestBedrockRetrievalApiAdvanced:
def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user):
api = BedrockRetrievalApi()
method = unwrap(api.post)
payload = {
"retrieval_setting": {},
"query": "test",
"knowledge_id": "k-1",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval",
side_effect=ValueError("Invalid settings"),
),
):
with pytest.raises(ValueError):
method()

View File

@@ -0,0 +1,160 @@
import uuid
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.hit_testing import HitTestingApi
from controllers.console.datasets.hit_testing_base import HitTestingPayload
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_hit_testing")
app.config["TESTING"] = True
return app
@pytest.fixture
def dataset_id():
return uuid.uuid4()
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
"""Bypass all decorators on the API method."""
mocker.patch(
"controllers.console.datasets.hit_testing.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.login_required",
return_value=lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.account_initialization_required",
return_value=lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
return_value=lambda *_: (lambda f: f),
)
class TestHitTestingApi:
def test_hit_testing_success(self, app, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "what is vector search",
"top_k": 3,
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
return_value=dataset,
),
patch.object(
HitTestingApi,
"hit_testing_args_check",
),
patch.object(
HitTestingApi,
"perform_hit_testing",
return_value={"query": "what is vector search", "records": []},
),
):
result = method(api, dataset_id)
assert "query" in result
assert "records" in result
assert result["records"] == []
def test_hit_testing_dataset_not_found(self, app, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "test",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
side_effect=NotFound("Dataset not found"),
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
return_value=dataset,
),
patch.object(
HitTestingApi,
"hit_testing_args_check",
side_effect=ValueError("Invalid parameters"),
),
):
with pytest.raises(ValueError, match="Invalid parameters"):
method(api, dataset_id)

View File

@@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.datasets.hit_testing_base import (
DatasetsHitTestingBase,
)
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@pytest.fixture
def account():
acc = MagicMock(spec=Account)
return acc
@pytest.fixture(autouse=True)
def patch_current_user(mocker, account):
"""Patch current_user to a valid Account."""
mocker.patch(
"controllers.console.datasets.hit_testing_base.current_user",
account,
)
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
class TestGetAndValidateDataset:
def test_success(self, dataset):
with (
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
):
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
assert result == dataset
def test_dataset_not_found(self):
with patch.object(
DatasetService,
"get_dataset",
return_value=None,
):
with pytest.raises(NotFound, match="Dataset not found"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
def test_permission_denied(self, dataset):
with (
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no access"),
),
):
with pytest.raises(Forbidden, match="no access"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
class TestHitTestingArgsCheck:
def test_args_check_called(self):
args = {"query": "test"}
with patch.object(
HitTestingService,
"hit_testing_args_check",
) as check_mock:
DatasetsHitTestingBase.hit_testing_args_check(args)
check_mock.assert_called_once_with(args)
class TestParseArgs:
def test_parse_args_success(self):
payload = {"query": "hello"}
result = DatasetsHitTestingBase.parse_args(payload)
assert result["query"] == "hello"
def test_parse_args_invalid(self):
payload = {"query": "x" * 300}
with pytest.raises(ValueError):
DatasetsHitTestingBase.parse_args(payload)
class TestPerformHitTesting:
def test_success(self, dataset):
response = {
"query": "hello",
"records": [],
}
with patch.object(
HitTestingService,
"retrieve",
return_value=response,
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
assert result["query"] == "hello"
assert result["records"] == []
def test_index_not_initialized(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=services.errors.index.IndexNotInitializedError(),
):
with pytest.raises(DatasetNotInitializedError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_provider_token_not_init(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ProviderTokenNotInitError("token missing"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_quota_exceeded(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=QuotaExceededError(),
):
with pytest.raises(ProviderQuotaExceededError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_model_not_supported(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ModelCurrentlyNotSupportError(),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_llm_bad_request(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=LLMBadRequestError("bad request"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_invoke_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=InvokeError("invoke failed"),
):
with pytest.raises(CompletionRequestError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_value_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ValueError("bad args"),
):
with pytest.raises(ValueError, match="bad args"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_unexpected_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=Exception("boom"),
):
with pytest.raises(InternalServerError, match="boom"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})

View File

@@ -0,0 +1,362 @@
import uuid
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.metadata import (
DatasetMetadataApi,
DatasetMetadataBuiltInFieldActionApi,
DatasetMetadataBuiltInFieldApi,
DatasetMetadataCreateApi,
DocumentMetadataEditApi,
)
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from services.metadata_service import MetadataService
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_dataset_metadata")
app.config["TESTING"] = True
return app
@pytest.fixture
def current_user():
user = MagicMock()
user.id = "user-1"
return user
@pytest.fixture
def dataset():
ds = MagicMock()
ds.id = "dataset-1"
return ds
@pytest.fixture
def dataset_id():
return uuid.uuid4()
@pytest.fixture
def metadata_id():
return uuid.uuid4()
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
"""Bypass setup/login/license decorators."""
mocker.patch(
"controllers.console.datasets.metadata.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.login_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.account_initialization_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.enterprise_license_required",
lambda f: f,
)
class TestDatasetMetadataCreateApi:
def test_create_metadata_success(self, app, current_user, dataset, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.post)
payload = {"name": "author"}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
return_value=MagicMock(),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"create_metadata",
return_value={"id": "m1", "name": "author"},
),
):
result, status = method(api, dataset_id)
assert status == 201
assert result["name"] == "author"
def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.post)
valid_payload = {
"type": "string",
"name": "author",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=valid_payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
return_value=MagicMock(),
),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
class TestDatasetMetadataGetApi:
def test_get_metadata_success(self, app, dataset, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
MetadataService,
"get_dataset_metadatas",
return_value=[{"id": "m1"}],
),
):
result, status = method(api, dataset_id)
assert status == 200
assert isinstance(result, list)
def test_get_metadata_dataset_not_found(self, app, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, dataset_id)
class TestDatasetMetadataApi:
def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
api = DatasetMetadataApi()
method = unwrap(api.patch)
payload = {"name": "updated-name"}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"update_metadata_name",
return_value={"id": "m1", "name": "updated-name"},
),
):
result, status = method(api, dataset_id, metadata_id)
assert status == 200
assert result["name"] == "updated-name"
def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
api = DatasetMetadataApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"delete_metadata",
),
):
result, status = method(api, dataset_id, metadata_id)
assert status == 204
assert result["result"] == "success"
class TestDatasetMetadataBuiltInFieldApi:
def test_get_built_in_fields(self, app):
api = DatasetMetadataBuiltInFieldApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
MetadataService,
"get_built_in_fields",
return_value=["title", "source"],
),
):
result, status = method(api)
assert status == 200
assert result["fields"] == ["title", "source"]
class TestDatasetMetadataBuiltInFieldActionApi:
def test_enable_built_in_field(self, app, current_user, dataset, dataset_id):
api = DatasetMetadataBuiltInFieldActionApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"enable_built_in_field",
),
):
result, status = method(api, dataset_id, "enable")
assert status == 200
assert result["result"] == "success"
class TestDocumentMetadataEditApi:
def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id):
api = DocumentMetadataEditApi()
method = unwrap(api.post)
payload = {"operation": "add", "metadata": {}}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataOperationData,
"model_validate",
return_value=MagicMock(),
),
patch.object(
MetadataService,
"update_documents_metadata",
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result["result"] == "success"

View File

@@ -0,0 +1,233 @@
from unittest.mock import Mock, PropertyMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.datasets.website import (
WebsiteCrawlApi,
WebsiteCrawlStatusApi,
)
from services.website_service import (
WebsiteCrawlApiRequest,
WebsiteCrawlStatusApiRequest,
WebsiteService,
)
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_website_crawl")
app.config["TESTING"] = True
return app
@pytest.fixture(autouse=True)
def bypass_auth_and_setup(mocker):
"""Bypass setup/login/account decorators."""
mocker.patch(
"controllers.console.datasets.website.login_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.website.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.website.account_initialization_required",
lambda f: f,
)
class TestWebsiteCrawlApi:
def test_crawl_success(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "https://example.com",
"options": {"depth": 1},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mock_request = Mock(spec=WebsiteCrawlApiRequest)
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"crawl_url",
return_value={"job_id": "job-1"},
)
result, status = method(api)
assert status == 200
assert result["job_id"] == "job-1"
def test_crawl_invalid_payload(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "bad-url",
"options": {},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
side_effect=ValueError("invalid payload"),
)
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
method(api)
def test_crawl_service_error(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "https://example.com",
"options": {},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mock_request = Mock(spec=WebsiteCrawlApiRequest)
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"crawl_url",
side_effect=Exception("crawl failed"),
)
with pytest.raises(WebsiteCrawlError, match="crawl failed"):
method(api)
class TestWebsiteCrawlStatusApi:
def test_get_status_success(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"get_crawl_status_typed",
return_value={"status": "completed"},
)
result, status = method(api, job_id)
assert status == 200
assert result["status"] == "completed"
def test_get_status_invalid_provider(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
side_effect=ValueError("invalid provider"),
)
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
method(api, job_id)
def test_get_status_service_error(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"get_crawl_status_typed",
side_effect=Exception("status lookup failed"),
)
with pytest.raises(WebsiteCrawlError, match="status lookup failed"):
method(api, job_id)

View File

@@ -0,0 +1,117 @@
from unittest.mock import Mock
import pytest
from controllers.console.datasets.error import PipelineNotFoundError
from controllers.console.datasets.wraps import get_rag_pipeline
from models.dataset import Pipeline
class TestGetRagPipeline:
def test_missing_pipeline_id(self):
@get_rag_pipeline
def dummy_view(**kwargs):
return "ok"
with pytest.raises(ValueError, match="missing pipeline_id"):
dummy_view()
def test_pipeline_not_found(self, mocker):
@get_rag_pipeline
def dummy_view(**kwargs):
return "ok"
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = None
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
with pytest.raises(PipelineNotFoundError):
dummy_view(pipeline_id="pipeline-1")
def test_pipeline_found_and_injected(self, mocker):
pipeline = Mock(spec=Pipeline)
pipeline.id = "pipeline-1"
pipeline.tenant_id = "tenant-1"
@get_rag_pipeline
def dummy_view(**kwargs):
return kwargs["pipeline"]
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id="pipeline-1")
assert result is pipeline
def test_pipeline_id_removed_from_kwargs(self, mocker):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline
def dummy_view(**kwargs):
assert "pipeline_id" not in kwargs
return "ok"
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id="pipeline-1")
assert result == "ok"
def test_pipeline_id_cast_to_string(self, mocker):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline
def dummy_view(**kwargs):
return kwargs["pipeline"]
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
def where_side_effect(*args, **kwargs):
assert args[0].right.value == "123"
return Mock(first=lambda: pipeline)
mock_query = Mock()
mock_query.where.side_effect = where_side_effect
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id=123)
assert result is pipeline

View File

@@ -0,0 +1,341 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailCodeError,
)
from controllers.console.error import AccountInFreezeError
from controllers.console.workspace.account import (
AccountAvatarApi,
AccountDeleteApi,
AccountDeleteVerifyApi,
AccountInitApi,
AccountIntegrateApi,
AccountInterfaceLanguageApi,
AccountInterfaceThemeApi,
AccountNameApi,
AccountPasswordApi,
AccountProfileApi,
AccountTimezoneApi,
ChangeEmailCheckApi,
ChangeEmailResetApi,
CheckEmailUnique,
)
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
InvalidAccountDeletionCodeError,
)
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAccountInitApi:
def test_init_success(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="inactive")
payload = {
"interface_language": "en-US",
"timezone": "UTC",
"invitation_code": "code123",
}
with (
app.test_request_context("/account/init", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.query") as query_mock,
):
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
resp = method(api)
assert resp["result"] == "success"
def test_init_already_initialized(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="active")
with (
app.test_request_context("/account/init"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
):
with pytest.raises(AccountAlreadyInitedError):
method(api)
class TestAccountProfileApi:
def test_get_profile_success(self, app):
api = AccountProfileApi()
method = unwrap(api.get)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/account/profile"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountUpdateApis:
@pytest.mark.parametrize(
("api_cls", "payload"),
[
(AccountNameApi, {"name": "test"}),
(AccountAvatarApi, {"avatar": "img.png"}),
(AccountInterfaceLanguageApi, {"interface_language": "en-US"}),
(AccountInterfaceThemeApi, {"interface_theme": "dark"}),
(AccountTimezoneApi, {"timezone": "UTC"}),
],
)
def test_update_success(self, app, api_cls, payload):
api = api_cls()
method = unwrap(api.post)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountPasswordApi:
def test_password_success(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "old",
"new_password": "new123",
"repeat_new_password": "new123",
}
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
):
result = method(api)
assert result["id"] == "u1"
def test_password_wrong_current(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "bad",
"new_password": "new123",
"repeat_new_password": "new123",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.update_account_password",
side_effect=ServicePwdError(),
),
):
with pytest.raises(CurrentPasswordIncorrectError):
method(api)
class TestAccountIntegrateApi:
def test_get_integrates(self, app):
api = AccountIntegrateApi()
method = unwrap(api.get)
account = MagicMock(id="acc1")
with (
app.test_request_context("/"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
):
scalars_mock.return_value.all.return_value = []
result = method(api)
assert "data" in result
assert len(result["data"]) == 2
class TestAccountDeleteApi:
def test_delete_verify_success(self, app):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
return_value=("token", "1234"),
),
patch(
"controllers.console.workspace.account.AccountService.send_account_deletion_verification_email",
return_value=None,
),
):
result = method(api)
assert result["result"] == "success"
def test_delete_invalid_code(self, app):
api = AccountDeleteApi()
method = unwrap(api.post)
payload = {"token": "t", "code": "x"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
return_value=False,
),
):
with pytest.raises(InvalidAccountDeletionCodeError):
method(api)
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
payload = {"email": "a@test.com", "code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.account.AccountService.get_change_email_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_reset_email_already_used(self, app):
api = ChangeEmailResetApi()
method = unwrap(api.post)
payload = {"new_email": "x@test.com", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
):
with pytest.raises(EmailAlreadyInUseError):
method(api)
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "ok@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True),
):
result = method(api)
assert result["result"] == "success"
def test_email_in_freeze(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "x@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True),
):
with pytest.raises(AccountInFreezeError):
method(api)

View File

@@ -0,0 +1,139 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
AgentProviderApi,
AgentProviderListApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAgentProviderListApi:
def test_get_success(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
providers = [{"name": "openai"}, {"name": "anthropic"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=providers,
),
):
result = method(api)
assert result == providers
def test_get_empty_list(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=[],
),
):
result = method(api)
assert result == []
def test_get_account_not_found(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api)
class TestAgentProviderApi:
def test_get_success(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "openai"
provider_data = {"name": "openai", "models": ["gpt-4"]}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=provider_data,
),
):
result = method(api, provider_name)
assert result == provider_data
def test_get_provider_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "unknown"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=None,
),
):
result = method(api, provider_name)
assert result is None
def test_get_account_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api, "openai")

View File

@@ -0,0 +1,305 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.workspace.endpoint import (
EndpointCreateApi,
EndpointDeleteApi,
EndpointDisableApi,
EndpointEnableApi,
EndpointListApi,
EndpointListForSinglePluginApi,
EndpointUpdateApi,
)
from core.plugin.impl.exc import PluginPermissionDeniedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user_and_tenant():
return MagicMock(id="u1"), "t1"
@pytest.fixture
def patch_current_account(user_and_tenant):
with patch(
"controllers.console.workspace.endpoint.current_account_with_tenant",
return_value=user_and_tenant,
):
yield
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCreateApi:
def test_create_success(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {"a": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_create_permission_denied(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.endpoint.EndpointService.create_endpoint",
side_effect=PluginPermissionDeniedError("denied"),
),
):
with pytest.raises(ValueError):
method(api)
def test_create_validation_error(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "p1",
"name": "",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
):
result = method(api)
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=0&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
patch(
"controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin",
return_value=[{"id": "e1"}],
),
):
result = method(api)
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDeleteApi:
def test_delete_success(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_delete_invalid_payload(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_delete_service_failure(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointUpdateApi:
def test_update_success(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "new-name",
"settings": {"x": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_update_validation_error(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1", "settings": {}}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
def test_update_service_failure(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "n",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_enable_invalid_payload(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_enable_service_failure(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_disable_invalid_payload(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)

View File

@@ -0,0 +1,607 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
MemberNotInTenantError,
NotOwnerError,
OwnerTransferLimitError,
)
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.workspace.members import (
DatasetOperatorMemberListApi,
MemberCancelInviteApi,
MemberInviteEmailApi,
MemberListApi,
MemberUpdateRoleApi,
OwnerTransfer,
OwnerTransferCheckApi,
SendOwnerTransferEmailApi,
)
from services.errors.account import AccountAlreadyInTenantError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestMemberListApi:
def test_get_success(self, app):
api = MemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "m1"
member.name = "Member"
member.email = "member@test.com"
member.avatar = "avatar.png"
member.role = "admin"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = MemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestMemberInviteEmailApi:
def test_invite_success(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
"language": "en-US",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert status == 201
assert result["result"] == "success"
def test_invite_limit_exceeded(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = False
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_already_member(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=AccountAlreadyInTenantError(),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert result["invitation_results"][0]["status"] == "success"
def test_invite_invalid_role(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
payload = {
"emails": ["a@test.com"],
"role": "owner",
}
with app.test_request_context("/", json=payload):
result, status = method(api)
assert status == 400
assert result["code"] == "invalid-role"
def test_invite_generic_exception(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=Exception("boom"),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, _ = method(api)
assert result["invitation_results"][0]["status"] == "failed"
class TestMemberCancelInviteApi:
def test_cancel_success(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 400
def test_cancel_no_permission(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 403
def test_cancel_member_not_in_tenant(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 404
class TestMemberUpdateRoleApi:
def test_update_success(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.update_member_role"),
):
result = method(api, "id")
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
def test_update_invalid_role(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "invalid-role"}
with app.test_request_context("/", json=payload):
result, status = method(api, "id")
assert status == 400
def test_update_member_not_found(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.members.current_account_with_tenant",
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
),
patch("controllers.console.workspace.members.db.session.get", return_value=None),
):
with pytest.raises(HTTPException):
method(api, "id")
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "op1"
member.name = "Operator"
member.email = "operator@test.com"
member.avatar = "avatar.png"
member.role = "operator"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestSendOwnerTransferEmailApi:
def test_send_success(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock(name="ws")
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
),
):
result = method(api)
assert result["result"] == "success"
def test_send_ip_limit(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
):
with pytest.raises(EmailSendIpLimitError):
method(api)
def test_send_not_owner(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/", json={}),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
):
with pytest.raises(NotOwnerError):
method(api)
class TestOwnerTransferCheckApi:
def test_check_invalid_code(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_rate_limited(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=True,
),
):
with pytest.raises(OwnerTransferLimitError):
method(api)
def test_invalid_token(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api)
def test_invalid_email(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "b@test.com", "code": "x"},
),
):
with pytest.raises(InvalidEmailError):
method(api)
class TestOwnerTransferApi:
def test_transfer_self(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
):
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
def test_invalid_token(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
member = MagicMock()
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com"},
),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
):
with pytest.raises(MemberNotInTenantError):
method(api, "2")

View File

@@ -0,0 +1,388 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic_core import ValidationError
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.model_providers import (
ModelProviderCredentialApi,
ModelProviderCredentialSwitchApi,
ModelProviderIconApi,
ModelProviderListApi,
ModelProviderPaymentCheckoutUrlApi,
ModelProviderValidateApi,
PreferredProviderTypeUpdateApi,
)
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
INVALID_UUID = "123"
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestModelProviderListApi:
def test_get_success(self, app):
api = ModelProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?model_type=llm"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
return_value=[{"name": "openai"}],
),
):
result = method(api)
assert "data" in result
class TestModelProviderCredentialApi:
def test_get_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(f"/?credential_id={VALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
return_value={"key": "value"},
),
):
result = method(api, provider="openai")
assert "credentials" in result
def test_get_invalid_uuid(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_post_create_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}, "name": "test"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
assert result["result"] == "success"
assert status == 201
def test_post_create_validation_error(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
side_effect=CredentialsValidateFailedError("bad"),
),
):
with pytest.raises(ValueError):
method(api, provider="openai")
def test_put_update_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_put_invalid_uuid(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_delete_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.delete)
payload = {"credential_id": VALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
assert result["result"] == "success"
assert status == 204
class TestModelProviderCredentialSwitchApi:
def test_switch_success(self, app):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
payload = {"credential_id": VALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_switch_invalid_uuid(self, app):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
payload = {"credential_id": INVALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
class TestModelProviderValidateApi:
def test_validate_success(self, app):
api = ModelProviderValidateApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_validate_failure(self, app):
api = ModelProviderValidateApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
side_effect=CredentialsValidateFailedError("bad"),
),
):
result = method(api, provider="openai")
assert result["result"] == "error"
class TestModelProviderIconApi:
def test_icon_success(self, app):
api = ModelProviderIconApi()
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
return_value=(b"123", "image/png"),
),
):
response = api.get("t1", "openai", "logo", "en")
assert response.mimetype == "image/png"
def test_icon_not_found(self, app):
api = ModelProviderIconApi()
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
return_value=(None, None),
),
):
with pytest.raises(ValueError):
api.get("t1", "openai", "logo", "en")
class TestPreferredProviderTypeUpdateApi:
def test_update_success(self, app):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
payload = {"preferred_provider_type": "custom"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_invalid_enum(self, app):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
payload = {"preferred_provider_type": "invalid"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
class TestModelProviderPaymentCheckoutUrlApi:
def test_checkout_success(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
user = MagicMock(id="u1", email="x@test.com")
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
return_value=None,
),
patch(
"controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link",
return_value={"url": "x"},
),
):
result = method(api, provider="anthropic")
assert "url" in result
def test_invalid_provider(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(ValueError):
method(api, provider="openai")
def test_permission_denied(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
user = MagicMock(id="u1", email="x@test.com")
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
side_effect=Forbidden(),
),
):
with pytest.raises(Forbidden):
method(api, provider="anthropic")

View File

@@ -0,0 +1,447 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.workspace.models import (
DefaultModelApi,
ModelProviderAvailableModelApi,
ModelProviderModelApi,
ModelProviderModelCredentialApi,
ModelProviderModelCredentialSwitchApi,
ModelProviderModelDisableApi,
ModelProviderModelEnableApi,
ModelProviderModelParameterRuleApi,
ModelProviderModelValidateApi,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDefaultModelApi:
def test_get_success(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.get)
with (
app.test_request_context(
"/",
query_string={"model_type": ModelType.LLM.value},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
result = method(api)
assert "data" in result
def test_post_success(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.post)
payload = {
"model_settings": [
{
"model_type": ModelType.LLM.value,
"provider": "openai",
"model": "gpt-4",
}
]
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api)
assert result["result"] == "success"
def test_get_returns_empty_when_no_default(self, app):
api = DefaultModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_default_model_of_model_type.return_value = None
result = method(api)
assert "data" in result
class TestModelProviderModelApi:
def test_get_models_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_provider.return_value = []
result = method(api, "openai")
assert "data" in result
def test_post_models_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"load_balancing": {
"configs": [{"weight": 1}],
"enabled": True,
},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
):
result, status = method(api, "openai")
assert status == 200
def test_delete_model_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.delete)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 204
def test_get_models_returns_empty(self, app):
api = ModelProviderModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_provider.return_value = []
result = method(api, "openai")
assert "data" in result
class TestModelProviderModelCredentialApi:
def test_get_credentials_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(
"/",
query_string={
"model": "gpt-4",
"model_type": ModelType.LLM.value,
},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
):
provider_service.return_value.get_model_credential.return_value = {
"credentials": {},
"current_credential_id": None,
"current_credential_name": None,
}
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "openai")
assert "credentials" in result
def test_create_credential_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credentials": {"key": "val"},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 201
def test_get_empty_credentials(self, app):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
):
service.return_value.get_model_credential.return_value = None
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "openai")
assert result["credentials"] == {}
def test_delete_success(self, app):
api = ModelProviderModelCredentialApi()
method = unwrap(api.delete)
payload = {
"model": "gpt",
"model_type": ModelType.LLM.value,
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 204
class TestModelProviderModelCredentialSwitchApi:
def test_switch_success(self, app: Flask):
api = ModelProviderModelCredentialSwitchApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credential_id": "abc",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
class TestModelEnableDisableApis:
def test_enable_model(self, app: Flask):
api = ModelProviderModelEnableApi()
method = unwrap(api.patch)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
def test_disable_model(self, app: Flask):
api = ModelProviderModelDisableApi()
method = unwrap(api.patch)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
class TestModelProviderModelValidateApi:
def test_validate_success(self, app: Flask):
api = ModelProviderModelValidateApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credentials": {"key": "val"},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
@pytest.mark.parametrize("model_name", ["gpt-4", "gpt"])
def test_validate_failure(self, app: Flask, model_name: str):
api = ModelProviderModelValidateApi()
method = unwrap(api.post)
payload = {
"model": model_name,
"model_type": ModelType.LLM.value,
"credentials": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
result = method(api, "openai")
assert result["result"] == "error"
class TestParameterAndAvailableModels:
def test_parameter_rules(self, app: Flask):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt-4"}),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_model_parameter_rules.return_value = []
result = method(api, "openai")
assert "data" in result
def test_available_models(self, app: Flask):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
assert "data" in result
def test_empty_rules(self, app):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt"}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_model_parameter_rules.return_value = []
result = method(api, "openai")
assert result["data"] == []
def test_no_models(self, app):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
assert result["data"] == []

File diff suppressed because it is too large Load Diff

View File

@@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from controllers.console.workspace.tool_providers import (
ToolApiListApi,
ToolApiProviderAddApi,
ToolApiProviderDeleteApi,
ToolApiProviderGetApi,
ToolApiProviderGetRemoteSchemaApi,
ToolApiProviderListToolsApi,
ToolApiProviderUpdateApi,
ToolBuiltinListApi,
ToolBuiltinProviderAddApi,
ToolBuiltinProviderCredentialsSchemaApi,
ToolBuiltinProviderDeleteApi,
ToolBuiltinProviderGetCredentialInfoApi,
ToolBuiltinProviderGetCredentialsApi,
ToolBuiltinProviderGetOauthClientSchemaApi,
ToolBuiltinProviderIconApi,
ToolBuiltinProviderInfoApi,
ToolBuiltinProviderListToolsApi,
ToolBuiltinProviderSetDefaultApi,
ToolBuiltinProviderUpdateApi,
ToolLabelsApi,
ToolOAuthCallback,
ToolOAuthCustomClient,
ToolPluginOAuthApi,
ToolProviderListApi,
ToolProviderMCPApi,
ToolWorkflowListApi,
ToolWorkflowProviderCreateApi,
ToolWorkflowProviderDeleteApi,
ToolWorkflowProviderGetApi,
ToolWorkflowProviderUpdateApi,
is_valid_url,
)
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def _mock_cache():
return
@@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]
class TestUtils:
def test_is_valid_url(self):
assert is_valid_url("https://example.com")
assert is_valid_url("http://example.com")
assert not is_valid_url("")
assert not is_valid_url("ftp://example.com")
assert not is_valid_url("not-a-url")
assert not is_valid_url(None)
class TestToolProviderListApi:
def test_get_success(self, app):
api = ToolProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
),
patch(
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
return_value=["p1"],
),
):
assert method(api) == ["p1"]
class TestBuiltinProviderApis:
def test_list_tools(self, app):
api = ToolBuiltinProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
return_value=[{"a": 1}],
),
):
assert method(api, "provider") == [{"a": 1}]
def test_info(self, app):
api = ToolBuiltinProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
return_value={"x": 1},
),
):
assert method(api, "provider") == {"x": 1}
def test_delete(self, app):
api = ToolBuiltinProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_id": "cid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
return_value={"result": "success"},
),
):
assert method(api, "provider")["result"] == "success"
def test_add_invalid_type(self, app):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
):
with pytest.raises(ValueError):
method(api, "provider")
def test_add_success(self, app):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
return_value={"id": 1},
),
):
assert method(api, "provider")["id"] == 1
def test_update(self, app):
api = ToolBuiltinProviderUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_credentials(self, app):
api = ToolBuiltinProviderGetCredentialsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
return_value={"k": "v"},
),
):
assert method(api, "provider") == {"k": "v"}
def test_icon(self, app):
api = ToolBuiltinProviderIconApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon",
return_value=(b"x", "image/png"),
),
):
response = method(api, "provider")
assert response.mimetype == "image/png"
def test_credentials_schema(self, app):
api = ToolBuiltinProviderCredentialsSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider", "oauth2") == {"schema": {}}
def test_set_default_credential(self, app):
api = ToolBuiltinProviderSetDefaultApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"id": "c1"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_credential_info(self, app):
api = ToolBuiltinProviderGetCredentialInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
return_value={"info": "x"},
),
):
assert method(api, "provider") == {"info": "x"}
def test_get_oauth_client_schema(self, app):
api = ToolBuiltinProviderGetOauthClientSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider") == {"schema": {}}
class TestApiProviderApis:
def test_add(self, app):
api = ToolApiProviderAddApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"icon": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
def test_remote_schema(self, app):
api = ToolApiProviderGetRemoteSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/?url=http://x.com"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
return_value={"schema": "x"},
),
):
assert method(api)["schema"] == "x"
def test_list_tools(self, app):
api = ToolApiProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
return_value=[{"tool": 1}],
),
):
assert method(api) == [{"tool": 1}]
def test_update(self, app):
api = ToolApiProviderUpdateApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"original_provider": "o",
"icon": {},
"privacy_policy": "",
"custom_disclaimer": "",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
return_value={"ok": True},
),
):
assert method(api)["ok"]
def test_delete(self, app):
api = ToolApiProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"provider": "p"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
return_value={"result": "success"},
),
):
assert method(api)["result"] == "success"
def test_get(self, app):
api = ToolApiProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
return_value={"x": 1},
),
):
assert method(api) == {"x": 1}
class TestWorkflowApis:
def test_create(self, app):
api = ToolWorkflowProviderCreateApi()
method = unwrap(api.post)
payload = {
"workflow_app_id": "123e4567-e89b-12d3-a456-426614174000",
"name": "n",
"label": "l",
"description": "d",
"icon": {},
"parameters": [],
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
def test_update_invalid(self, app):
api = ToolWorkflowProviderUpdateApi()
method = unwrap(api.post)
payload = {
"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000",
"name": "Tool",
"label": "Tool Label",
"description": "A tool",
"icon": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
return_value={"ok": True},
),
):
result = method(api)
assert result["ok"]
def test_delete(self, app):
api = ToolWorkflowProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
return_value={"ok": True},
),
):
assert method(api)["ok"]
def test_get_error(self, app):
api = ToolWorkflowProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
):
with pytest.raises(ValueError):
method(api)
class TestLists:
def test_builtin_list(self, app):
api = ToolBuiltinListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
def test_api_list(self, app):
api = ToolApiListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
def test_workflow_list(self, app):
api = ToolWorkflowListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
class TestLabels:
def test_labels(self, app):
api = ToolLabelsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
return_value=["l1"],
),
):
assert method(api) == ["l1"]
class TestOAuth:
def test_oauth_no_client(self, app):
api = ToolPluginOAuthApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "provider")
def test_oauth_callback_no_cookie(self, app):
api = ToolOAuthCallback()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "provider")
class TestOAuthCustomClient:
def test_save_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"client_params": {"a": 1}}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
return_value={"client_id": "x"},
),
):
assert method(api, "provider") == {"client_id": "x"}
def test_delete_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]

View File

@@ -0,0 +1,558 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden
from controllers.console.workspace.trigger_providers import (
TriggerOAuthAuthorizeApi,
TriggerOAuthCallbackApi,
TriggerOAuthClientManageApi,
TriggerProviderIconApi,
TriggerProviderInfoApi,
TriggerProviderListApi,
TriggerSubscriptionBuilderBuildApi,
TriggerSubscriptionBuilderCreateApi,
TriggerSubscriptionBuilderGetApi,
TriggerSubscriptionBuilderLogsApi,
TriggerSubscriptionBuilderUpdateApi,
TriggerSubscriptionBuilderVerifyApi,
TriggerSubscriptionDeleteApi,
TriggerSubscriptionListApi,
TriggerSubscriptionUpdateApi,
TriggerSubscriptionVerifyApi,
)
from controllers.web.error import NotFoundError
from core.plugin.entities.plugin_daemon import CredentialType
from models.account import Account
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def mock_user():
user = MagicMock(spec=Account)
user.id = "u1"
user.current_tenant_id = "t1"
return user
class TestTriggerProviderApis:
def test_icon_success(self, app):
api = TriggerProviderIconApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
return_value="icon",
),
):
assert method(api, "github") == "icon"
def test_list_providers(self, app):
api = TriggerProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
return_value=[],
),
):
assert method(api) == []
def test_provider_info(self, app):
api = TriggerProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
return_value={"id": "p1"},
),
):
assert method(api, "github") == {"id": "p1"}
class TestTriggerSubscriptionListApi:
def test_list_success(self, app):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
return_value=[],
),
):
assert method(api, "github") == []
def test_list_invalid_provider(self, app):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
side_effect=ValueError("bad"),
),
):
result, status = method(api, "bad")
assert status == 404
class TestTriggerSubscriptionBuilderApis:
def test_create_builder(self, app):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
result = method(api, "github")
assert "subscription_builder" in result
def test_get_builder(self, app):
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
return_value={"id": "b1"},
),
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {"a": 1}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
return_value={"ok": True},
),
):
assert method(api, "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
side_effect=Exception("err"),
),
):
with pytest.raises(ValueError):
method(api, "github", "b1")
def test_update_builder(self, app):
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "n"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_logs(self, app):
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
log = MagicMock()
log.model_dump.return_value = {"a": 1}
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
return_value=[log],
),
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app):
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
return_value=None,
),
):
assert method(api, "github", "b1") == 200
class TestTriggerSubscriptionCrud:
def test_update_rename_only(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
sub = MagicMock()
sub.provider_id = "github"
sub.credential_type = CredentialType.UNAUTHORIZED
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
),
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
):
assert method(api, "s1") == 200
def test_update_not_found(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "x")
def test_update_rebuild(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
sub = MagicMock()
sub.provider_id = "github"
sub.credential_type = CredentialType.OAUTH2
sub.credentials = {}
sub.parameters = {}
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
),
):
assert method(api, "s1") == 200
def test_delete_subscription(self, app):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
mock_session = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls,
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
),
):
mock_db.engine = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
result = method(api, "sub1")
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as session_cls,
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
side_effect=ValueError("bad"),
),
):
mock_db.engine = MagicMock()
session_cls.return_value.__enter__.return_value = MagicMock()
with pytest.raises(BadRequest):
method(api, "sub1")
class TestTriggerOAuthApis:
def test_oauth_authorize_success(self, app):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value=MagicMock(id="b1"),
),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
return_value="ctx",
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url",
return_value=MagicMock(authorization_url="url"),
),
):
resp = method(api, "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "github")
def test_oauth_callback_forbidden(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
return_value=MagicMock(credentials={"a": 1}, expires_at=1),
),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder"
),
):
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
return_value=ctx,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
return_value=ctx,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
return_value=MagicMock(credentials=None, expires_at=None),
),
):
with pytest.raises(ValueError):
method(api, "github")
class TestTriggerOAuthClientManageApi:
def test_get_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
return_value={},
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled",
return_value=False,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists",
return_value=True,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
),
):
result = method(api, "github")
assert "configured" in result
def test_post_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
def test_delete_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
side_effect=ValueError("bad"),
),
):
with pytest.raises(BadRequest):
method(api, "github")
class TestTriggerSubscriptionVerifyApi:
def test_verify_success(self, app):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
return_value={"ok": True},
),
):
assert method(api, "github", "s1") == {"ok": True}
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
def test_verify_errors(self, app, raised_exception):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
side_effect=raised_exception,
),
):
with pytest.raises(BadRequest):
method(api, "github", "s1")

View File

@@ -0,0 +1,605 @@
from datetime import datetime
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Unauthorized
import services
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.workspace.workspace import (
CustomConfigWorkspaceApi,
SwitchWorkspaceApi,
TenantApi,
TenantListApi,
WebappLogoWorkspaceApi,
WorkspaceInfoApi,
WorkspaceListApi,
WorkspacePermissionApi,
)
from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestTenantListApi:
def test_get_success(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
features = MagicMock()
features.billing.enabled = True
features.billing.subscription.plan = CloudPlan.SANDBOX
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
def test_get_billing_disabled(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant = MagicMock(
id="t1",
name="Tenant",
status="active",
created_at=datetime.utcnow(),
)
features = MagicMock()
features.billing.enabled = False
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant],
),
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
),
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
class TestWorkspaceListApi:
def test_get_success(self, app):
api = WorkspaceListApi()
method = unwrap(api.get)
tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow())
paginate_result = MagicMock(
items=[tenant],
has_next=False,
total=1,
)
with (
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}),
patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result),
):
result, status = method(api)
assert status == 200
assert result["total"] == 1
assert result["has_more"] is False
def test_get_has_next_true(self, app):
api = WorkspaceListApi()
method = unwrap(api.get)
tenant = MagicMock(
id="t1",
name="T",
status="active",
created_at=datetime.utcnow(),
)
paginate_result = MagicMock(
items=[tenant],
has_next=True,
total=10,
)
with (
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}),
patch(
"controllers.console.workspace.workspace.db.paginate",
return_value=paginate_result,
),
):
result, status = method(api)
assert status == 200
assert result["has_more"] is True
class TestTenantApi:
def test_post_active_tenant(self, app):
api = TenantApi()
method = unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result, status = method(api)
assert status == 200
assert result["id"] == "t1"
def test_post_archived_with_switch(self, app):
api = TenantApi()
method = unwrap(api.post)
archived = MagicMock(status=TenantStatus.ARCHIVE)
new_tenant = MagicMock(status="active")
user = MagicMock(current_tenant=archived)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
),
):
result, status = method(api)
assert result["id"] == "new"
def test_post_archived_no_tenant(self, app):
api = TenantApi()
method = unwrap(api.post)
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
):
with pytest.raises(Unauthorized):
method(api)
def test_post_info_path(self, app):
api = TenantApi()
method = unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/info"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(user, "t1"),
),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"id": "t1"},
),
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
):
result, status = method(api)
warn_mock.assert_called_once()
assert status == 200
class TestSwitchWorkspaceApi:
def test_switch_success(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "t2"}
tenant = MagicMock(id="t2")
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
),
):
query_mock.return_value.get.return_value = tenant
result = method(api)
assert result["result"] == "success"
def test_switch_not_linked(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "bad"}
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
):
with pytest.raises(AccountNotLinkTenantError):
method(api)
def test_switch_tenant_not_found(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "missing"}
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
):
query_mock.return_value.get.return_value = None
with pytest.raises(ValueError):
method(api)
class TestCustomConfigWorkspaceApi:
def test_post_success(self, app):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
tenant = MagicMock(custom_config_dict={})
payload = {"remove_webapp_brand": True}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result = method(api)
assert result["result"] == "success"
def test_logo_fallback(self, app):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
payload = {"remove_webapp_brand": False}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.db.get_or_404",
return_value=tenant,
),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"id": "t1"},
),
):
result = method(api)
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
assert result["result"] == "success"
class TestWebappLogoWorkspaceApi:
def test_no_file(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
with (
app.test_request_context("/upload", data={}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with pytest.raises(NoFileUploadedError):
method(api)
def test_too_many_files(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
data = {
"file": MagicMock(),
"extra": MagicMock(),
}
with (
app.test_request_context("/upload", data=data),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
):
with pytest.raises(TooManyFilesError):
method(api)
def test_invalid_extension(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = MagicMock(filename="test.txt")
with (
app.test_request_context("/upload", data={"file": file}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with pytest.raises(UnsupportedFileTypeError):
method(api)
def test_upload_success(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
filename="logo.png",
content_type="image/png",
)
upload = MagicMock(id="file1")
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.return_value = upload
result, status = method(api)
assert status == 201
assert result["id"] == "file1"
def test_filename_missing(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
filename="",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
):
with pytest.raises(FilenameNotExistsError):
method(api)
def test_file_too_large(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
with pytest.raises(FileTooLargeError):
method(api)
def test_service_unsupported_file(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
with pytest.raises(UnsupportedFileTypeError):
method(api)
class TestWorkspaceInfoApi:
def test_post_success(self, app):
api = WorkspaceInfoApi()
method = unwrap(api.post)
tenant = MagicMock()
payload = {"name": "New Name"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"name": "New Name"},
),
):
result = method(api)
assert result["result"] == "success"
def test_no_current_tenant(self, app):
api = WorkspaceInfoApi()
method = unwrap(api.post)
payload = {"name": "X"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(ValueError):
method(api)
class TestWorkspacePermissionApi:
def test_get_success(self, app):
api = WorkspacePermissionApi()
method = unwrap(api.get)
permission = MagicMock(
workspace_id="t1",
allow_member_invite=True,
allow_owner_transfer=False,
)
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
return_value=permission,
),
):
result, status = method(api)
assert status == 200
assert result["workspace_id"] == "t1"
def test_no_current_tenant(self, app):
api = WorkspacePermissionApi()
method = unwrap(api.get)
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(ValueError):
method(api)

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
import importlib
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.workspace import plugin_permission_required
from models.account import TenantPluginPermission
class _SessionStub:
def __init__(self, permission):
self._permission = permission
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def query(self, *_args, **_kwargs):
return self
def where(self, *_args, **_kwargs):
return self
def first(self):
return self._permission
def _workspace_module():
return importlib.import_module(plugin_permission_required.__module__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
module = _workspace_module()
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, None)
@plugin_permission_required()
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()

View File

@@ -0,0 +1,85 @@
"""Shared fixtures for controllers.web unit tests."""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
import pytest
from flask import Flask
@pytest.fixture
def app() -> Flask:
"""Minimal Flask app for request contexts."""
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
return flask_app
class FakeSession:
"""Stand-in for db.session that returns pre-seeded objects by model class name."""
def __init__(self, mapping: dict[str, Any] | None = None):
self._mapping: dict[str, Any] = mapping or {}
self._model_name: str | None = None
def query(self, model: type) -> FakeSession:
self._model_name = model.__name__
return self
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
return self
def first(self) -> Any:
assert self._model_name is not None
return self._mapping.get(self._model_name)
class FakeDB:
"""Minimal db stub exposing engine and session."""
def __init__(self, session: FakeSession | None = None):
self.session = session or FakeSession()
self.engine = object()
def make_app_model(
*,
app_id: str = "app-1",
tenant_id: str = "tenant-1",
mode: str = "chat",
enable_site: bool = True,
status: str = "normal",
) -> SimpleNamespace:
"""Build a fake App model with common defaults."""
tenant = SimpleNamespace(
id=tenant_id,
status="normal",
plan="basic",
custom_config_dict={},
)
return SimpleNamespace(
id=app_id,
tenant_id=tenant_id,
tenant=tenant,
mode=mode,
enable_site=enable_site,
status=status,
workflow=None,
app_model_config=None,
)
def make_end_user(
*,
user_id: str = "end-user-1",
session_id: str = "session-1",
external_user_id: str = "ext-user-1",
) -> SimpleNamespace:
"""Build a fake EndUser model with common defaults."""
return SimpleNamespace(
id=user_id,
session_id=session_id,
external_user_id=external_user_id,
)

View File

@@ -0,0 +1,165 @@
"""Unit tests for controllers.web.app endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission
from controllers.web.error import AppUnavailableError
# ---------------------------------------------------------------------------
# AppParameterApi
# ---------------------------------------------------------------------------
class TestAppParameterApi:
def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None:
features_dict = {"opening_statement": "Hello"}
workflow = SimpleNamespace(
features_dict=features_dict,
user_input_form=lambda to_old_structure=False: [],
)
app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow)
with (
app.test_request_context("/parameters"),
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
patch("controllers.web.app.fields.Parameters") as mock_fields,
):
mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"}
result = AppParameterApi().get(app_model, SimpleNamespace())
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[])
assert result == {"result": "ok"}
def test_workflow_mode_uses_workflow(self, app: Flask) -> None:
features_dict = {}
workflow = SimpleNamespace(
features_dict=features_dict,
user_input_form=lambda to_old_structure=False: [{"var": "x"}],
)
app_model = SimpleNamespace(mode="workflow", workflow=workflow)
with (
app.test_request_context("/parameters"),
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
patch("controllers.web.app.fields.Parameters") as mock_fields,
):
mock_fields.model_validate.return_value.model_dump.return_value = {}
AppParameterApi().get(app_model, SimpleNamespace())
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}])
def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None:
app_model = SimpleNamespace(mode="advanced-chat", workflow=None)
with app.test_request_context("/parameters"):
with pytest.raises(AppUnavailableError):
AppParameterApi().get(app_model, SimpleNamespace())
def test_standard_mode_uses_app_model_config(self, app: Flask) -> None:
config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"})
app_model = SimpleNamespace(mode="chat", app_model_config=config)
with (
app.test_request_context("/parameters"),
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
patch("controllers.web.app.fields.Parameters") as mock_fields,
):
mock_fields.model_validate.return_value.model_dump.return_value = {}
AppParameterApi().get(app_model, SimpleNamespace())
call_kwargs = mock_params.call_args
assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}]
def test_standard_mode_no_config_raises(self, app: Flask) -> None:
app_model = SimpleNamespace(mode="chat", app_model_config=None)
with app.test_request_context("/parameters"):
with pytest.raises(AppUnavailableError):
AppParameterApi().get(app_model, SimpleNamespace())
# ---------------------------------------------------------------------------
# AppMeta
# ---------------------------------------------------------------------------
class TestAppMeta:
@patch("controllers.web.app.AppService")
def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None:
mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}}
app_model = SimpleNamespace(id="app-1")
with app.test_request_context("/meta"):
result = AppMeta().get(app_model, SimpleNamespace())
assert result == {"tool_icons": {}}
# ---------------------------------------------------------------------------
# AppAccessMode
# ---------------------------------------------------------------------------
class TestAppAccessMode:
@patch("controllers.web.app.FeatureService.get_system_features")
def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
with app.test_request_context("/webapp/access-mode?appId=app-1"):
result = AppAccessMode().get()
assert result == {"accessMode": "public"}
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
@patch("controllers.web.app.FeatureService.get_system_features")
def test_returns_access_mode_with_app_id(
self, mock_features: MagicMock, mock_access: MagicMock, app: Flask
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
mock_access.return_value = SimpleNamespace(access_mode="internal")
with app.test_request_context("/webapp/access-mode?appId=app-1"):
result = AppAccessMode().get()
assert result == {"accessMode": "internal"}
mock_access.assert_called_once_with("app-1")
@patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id")
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
@patch("controllers.web.app.FeatureService.get_system_features")
def test_resolves_app_code_to_id(
self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
mock_access.return_value = SimpleNamespace(access_mode="external")
with app.test_request_context("/webapp/access-mode?appCode=code1"):
result = AppAccessMode().get()
mock_resolve.assert_called_once_with("code1")
mock_access.assert_called_once_with("resolved-id")
assert result == {"accessMode": "external"}
@patch("controllers.web.app.FeatureService.get_system_features")
def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
with app.test_request_context("/webapp/access-mode"):
with pytest.raises(ValueError, match="appId or appCode"):
AppAccessMode().get()
# ---------------------------------------------------------------------------
# AppWebAuthPermission
# ---------------------------------------------------------------------------
class TestAppWebAuthPermission:
@patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None:
with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}):
result = AppWebAuthPermission().get()
assert result == {"result": True}
def test_raises_when_missing_app_id(self, app: Flask) -> None:
with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}):
with pytest.raises(ValueError, match="appId"):
AppWebAuthPermission().get()

View File

@@ -0,0 +1,135 @@
"""Unit tests for controllers.web.audio endpoints."""
from __future__ import annotations
from io import BytesIO
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.audio import AudioApi, TextApi
from controllers.web.error import (
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
def _app_model() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1", external_user_id="ext-1")
# ---------------------------------------------------------------------------
# AudioApi (audio-to-text)
# ---------------------------------------------------------------------------
class TestAudioApi:
@patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"})
def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
data = {"file": (BytesIO(b"fake-audio"), "test.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
result = AudioApi().post(_app_model(), _end_user())
assert result == {"text": "hello"}
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError())
def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b""), "empty.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(NoAudioUploadedError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big"))
def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"big"), "big.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(AudioTooLargeError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError())
def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"bad"), "bad.xyz")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(UnsupportedAudioTypeError):
AudioApi().post(_app_model(), _end_user())
@patch(
"controllers.web.audio.AudioService.transcript_asr",
side_effect=ProviderNotSupportSpeechToTextServiceError(),
)
def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderNotSupportSpeechToTextError):
AudioApi().post(_app_model(), _end_user())
@patch(
"controllers.web.audio.AudioService.transcript_asr",
side_effect=ProviderTokenNotInitError(description="no token"),
)
def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderNotInitializeError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError())
def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderQuotaExceededError):
AudioApi().post(_app_model(), _end_user())
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError())
def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
data = {"file": (BytesIO(b"x"), "x.mp3")}
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
AudioApi().post(_app_model(), _end_user())
# ---------------------------------------------------------------------------
# TextApi (text-to-audio)
# ---------------------------------------------------------------------------
class TestTextApi:
@patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes")
@patch("controllers.web.audio.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
mock_ns.payload = {"text": "hello", "voice": "alloy"}
with app.test_request_context("/text-to-audio", method="POST"):
result = TextApi().post(_app_model(), _end_user())
assert result == "audio-bytes"
mock_tts.assert_called_once()
@patch(
"controllers.web.audio.AudioService.transcript_tts",
side_effect=InvokeError(description="invoke failed"),
)
@patch("controllers.web.audio.web_ns")
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
mock_ns.payload = {"text": "hello"}
with app.test_request_context("/text-to-audio", method="POST"):
with pytest.raises(CompletionRequestError):
TextApi().post(_app_model(), _end_user())

View File

@@ -0,0 +1,161 @@
"""Unit tests for controllers.web.completion endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from controllers.web.error import (
CompletionRequestError,
NotChatAppError,
NotCompletionAppError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# CompletionApi
# ---------------------------------------------------------------------------
class TestCompletionApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(NotCompletionAppError):
CompletionApi().post(_chat_app(), _end_user())
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"})
@patch("controllers.web.completion.AppGenerateService.generate")
@patch("controllers.web.completion.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}, "query": "test"}
mock_gen.return_value = "response-obj"
with app.test_request_context("/completion-messages", method="POST"):
result = CompletionApi().post(_completion_app(), _end_user())
assert result == {"answer": "hi"}
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=ProviderTokenNotInitError(description="not init"),
)
@patch("controllers.web.completion.web_ns")
def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(ProviderNotInitializeError):
CompletionApi().post(_completion_app(), _end_user())
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=QuotaExceededError(),
)
@patch("controllers.web.completion.web_ns")
def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(ProviderQuotaExceededError):
CompletionApi().post(_completion_app(), _end_user())
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=ModelCurrentlyNotSupportError(),
)
@patch("controllers.web.completion.web_ns")
def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/completion-messages", method="POST"):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
CompletionApi().post(_completion_app(), _end_user())
# ---------------------------------------------------------------------------
# CompletionStopApi
# ---------------------------------------------------------------------------
class TestCompletionStopApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
with pytest.raises(NotCompletionAppError):
CompletionStopApi().post(_chat_app(), _end_user(), "task-1")
@patch("controllers.web.completion.AppTaskService.stop_task")
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1")
assert status == 200
assert result == {"result": "success"}
# ---------------------------------------------------------------------------
# ChatApi
# ---------------------------------------------------------------------------
class TestChatApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/chat-messages", method="POST"):
with pytest.raises(NotChatAppError):
ChatApi().post(_completion_app(), _end_user())
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"})
@patch("controllers.web.completion.AppGenerateService.generate")
@patch("controllers.web.completion.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}, "query": "hi"}
mock_gen.return_value = "response"
with app.test_request_context("/chat-messages", method="POST"):
result = ChatApi().post(_chat_app(), _end_user())
assert result == {"answer": "reply"}
@patch(
"controllers.web.completion.AppGenerateService.generate",
side_effect=InvokeError(description="rate limit"),
)
@patch("controllers.web.completion.web_ns")
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}, "query": "x"}
with app.test_request_context("/chat-messages", method="POST"):
with pytest.raises(CompletionRequestError):
ChatApi().post(_chat_app(), _end_user())
# ---------------------------------------------------------------------------
# ChatStopApi
# ---------------------------------------------------------------------------
class TestChatStopApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
with pytest.raises(NotChatAppError):
ChatStopApi().post(_completion_app(), _end_user(), "task-1")
@patch("controllers.web.completion.AppTaskService.stop_task")
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1")
assert status == 200
assert result == {"result": "success"}

View File

@@ -0,0 +1,183 @@
"""Unit tests for controllers.web.conversation endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.conversation import (
ConversationApi,
ConversationListApi,
ConversationPinApi,
ConversationRenameApi,
ConversationUnPinApi,
)
from controllers.web.error import NotChatAppError
from services.errors.conversation import ConversationNotExistsError
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# ConversationListApi
# ---------------------------------------------------------------------------
class TestConversationListApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/conversations"):
with pytest.raises(NotChatAppError):
ConversationListApi().get(_completion_app(), _end_user())
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
@patch("controllers.web.conversation.db")
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
conv_id = str(uuid4())
conv = SimpleNamespace(
id=conv_id,
name="Test",
inputs={},
status="normal",
introduction="",
created_at=1700000000,
updated_at=1700000000,
)
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
mock_db.engine = "engine"
session_mock = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
with (
app.test_request_context("/conversations?limit=20"),
patch("controllers.web.conversation.Session", return_value=session_ctx),
):
result = ConversationListApi().get(_chat_app(), _end_user())
assert result["limit"] == 20
assert result["has_more"] is False
# ---------------------------------------------------------------------------
# ConversationApi (delete)
# ---------------------------------------------------------------------------
class TestConversationApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}"):
with pytest.raises(NotChatAppError):
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.delete")
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
assert status == 204
assert result["result"] == "success"
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
with pytest.raises(NotFound, match="Conversation Not Exists"):
ConversationApi().delete(_chat_app(), _end_user(), c_id)
# ---------------------------------------------------------------------------
# ConversationRenameApi
# ---------------------------------------------------------------------------
class TestConversationRenameApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
with pytest.raises(NotChatAppError):
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.rename")
@patch("controllers.web.conversation.web_ns")
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "New Name", "auto_generate": False}
conv = SimpleNamespace(
id=str(c_id),
name="New Name",
inputs={},
status="normal",
introduction="",
created_at=1700000000,
updated_at=1700000000,
)
mock_rename.return_value = conv
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}):
result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
assert result["name"] == "New Name"
@patch(
"controllers.web.conversation.ConversationService.rename",
side_effect=ConversationNotExistsError(),
)
@patch("controllers.web.conversation.web_ns")
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "X", "auto_generate": False}
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}):
with pytest.raises(NotFound, match="Conversation Not Exists"):
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
# ---------------------------------------------------------------------------
# ConversationPinApi / ConversationUnPinApi
# ---------------------------------------------------------------------------
class TestConversationPinApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.pin")
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
assert result["result"] == "success"
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
with pytest.raises(NotFound):
ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
class TestConversationUnPinApi:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.unpin")
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
assert result["result"] == "success"

View File

@@ -0,0 +1,75 @@
"""Unit tests for controllers.web.error HTTP exception classes."""
from __future__ import annotations
import pytest
from controllers.web.error import (
AppMoreLikeThisDisabledError,
AppSuggestedQuestionsAfterAnswerDisabledError,
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
InvalidArgumentError,
InvokeRateLimitError,
NoAudioUploadedError,
NotChatAppError,
NotCompletionAppError,
NotFoundError,
NotWorkflowAppError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
WebAppAuthAccessDeniedError,
WebAppAuthRequiredError,
WebFormRateLimitExceededError,
)
_ERROR_SPECS: list[tuple[type, str, int]] = [
(AppUnavailableError, "app_unavailable", 400),
(NotCompletionAppError, "not_completion_app", 400),
(NotChatAppError, "not_chat_app", 400),
(NotWorkflowAppError, "not_workflow_app", 400),
(ConversationCompletedError, "conversation_completed", 400),
(ProviderNotInitializeError, "provider_not_initialize", 400),
(ProviderQuotaExceededError, "provider_quota_exceeded", 400),
(ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400),
(CompletionRequestError, "completion_request_error", 400),
(AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403),
(AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403),
(NoAudioUploadedError, "no_audio_uploaded", 400),
(AudioTooLargeError, "audio_too_large", 413),
(UnsupportedAudioTypeError, "unsupported_audio_type", 415),
(ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400),
(WebAppAuthRequiredError, "web_sso_auth_required", 401),
(WebAppAuthAccessDeniedError, "web_app_access_denied", 401),
(InvokeRateLimitError, "rate_limit_error", 429),
(WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429),
(NotFoundError, "not_found", 404),
(InvalidArgumentError, "invalid_param", 400),
]
@pytest.mark.parametrize(
("cls", "expected_code", "expected_status"),
_ERROR_SPECS,
ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS],
)
def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None:
"""Each error class exposes the correct error_code and HTTP status code."""
assert cls.error_code == expected_code
assert cls.code == expected_status
def test_error_classes_have_description() -> None:
"""Every error class has a description (string or None for generic errors)."""
# NotFoundError and InvalidArgumentError use None description by design
_NO_DESCRIPTION = {NotFoundError, InvalidArgumentError}
for cls, _, _ in _ERROR_SPECS:
if cls in _NO_DESCRIPTION:
continue
assert isinstance(cls.description, str), f"{cls.__name__} missing description"
assert len(cls.description) > 0, f"{cls.__name__} has empty description"

View File

@@ -0,0 +1,38 @@
"""Unit tests for controllers.web.feature endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from flask import Flask
from controllers.web.feature import SystemFeatureApi
class TestSystemFeatureApi:
@patch("controllers.web.feature.FeatureService.get_system_features")
def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None:
mock_model = MagicMock()
mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
mock_features.return_value = mock_model
with app.test_request_context("/system-features"):
result = SystemFeatureApi().get()
assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
mock_features.assert_called_once()
@patch("controllers.web.feature.FeatureService.get_system_features")
def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None:
"""SystemFeatureApi is unauthenticated by design — no WebApiResource decorator."""
mock_model = MagicMock()
mock_model.model_dump.return_value = {}
mock_features.return_value = mock_model
# Verify it's a bare Resource, not WebApiResource
from flask_restx import Resource
from controllers.web.wraps import WebApiResource
assert issubclass(SystemFeatureApi, Resource)
assert not issubclass(SystemFeatureApi, WebApiResource)

View File

@@ -0,0 +1,89 @@
"""Unit tests for controllers.web.files endpoints."""
from __future__ import annotations
from io import BytesIO
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.web.files import FileApi
def _app_model() -> SimpleNamespace:
return SimpleNamespace(id="app-1")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
class TestFileApi:
def test_no_file_uploaded(self, app: Flask) -> None:
with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"):
with pytest.raises(NoFileUploadedError):
FileApi().post(_app_model(), _end_user())
def test_too_many_files(self, app: Flask) -> None:
data = {
"file": (BytesIO(b"a"), "a.txt"),
"file2": (BytesIO(b"b"), "b.txt"),
}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
# Now has "file" key but len(request.files) > 1
with pytest.raises(TooManyFilesError):
FileApi().post(_app_model(), _end_user())
def test_filename_missing(self, app: Flask) -> None:
data = {"file": (BytesIO(b"content"), "")}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(FilenameNotExistsError):
FileApi().post(_app_model(), _end_user())
@patch("controllers.web.files.FileService")
@patch("controllers.web.files.db")
def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
from datetime import datetime
upload_file = SimpleNamespace(
id="file-1",
name="test.txt",
size=100,
extension="txt",
mime_type="text/plain",
created_by="eu-1",
created_at=datetime(2024, 1, 1),
)
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
data = {"file": (BytesIO(b"content"), "test.txt")}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
result, status = FileApi().post(_app_model(), _end_user())
assert status == 201
assert result["id"] == "file-1"
assert result["name"] == "test.txt"
@patch("controllers.web.files.FileService")
@patch("controllers.web.files.db")
def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
import services.errors.file
mock_db.engine = "engine"
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
description="max 10MB"
)
data = {"file": (BytesIO(b"big"), "big.txt")}
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
with pytest.raises(FileTooLargeError):
FileApi().post(_app_model(), _end_user())

View File

@@ -0,0 +1,156 @@
"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.error import (
AppMoreLikeThisDisabledError,
NotChatAppError,
NotCompletionAppError,
)
from controllers.web.message import (
MessageFeedbackApi,
MessageMoreLikeThisApi,
MessageSuggestedQuestionApi,
)
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# MessageFeedbackApi
# ---------------------------------------------------------------------------
class TestMessageFeedbackApi:
@patch("controllers.web.message.MessageService.create_feedback")
@patch("controllers.web.message.web_ns")
def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
mock_ns.payload = {"rating": "like", "content": "great"}
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
assert result == {"result": "success"}
mock_create.assert_called_once()
@patch("controllers.web.message.MessageService.create_feedback")
@patch("controllers.web.message.web_ns")
def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
mock_ns.payload = {"rating": None}
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
assert result == {"result": "success"}
@patch(
"controllers.web.message.MessageService.create_feedback",
side_effect=MessageNotExistsError(),
)
@patch("controllers.web.message.web_ns")
def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
mock_ns.payload = {"rating": "dislike"}
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
with pytest.raises(NotFound, match="Message Not Exists"):
MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
# ---------------------------------------------------------------------------
# MessageMoreLikeThisApi
# ---------------------------------------------------------------------------
class TestMessageMoreLikeThisApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
with pytest.raises(NotCompletionAppError):
MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id)
@patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"})
@patch("controllers.web.message.AppGenerateService.generate_more_like_this")
def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
msg_id = uuid4()
mock_gen.return_value = "response"
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
assert result == {"answer": "similar"}
@patch(
"controllers.web.message.AppGenerateService.generate_more_like_this",
side_effect=MessageNotExistsError(),
)
def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
with pytest.raises(NotFound, match="Message Not Exists"):
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
@patch(
"controllers.web.message.AppGenerateService.generate_more_like_this",
side_effect=MoreLikeThisDisabledError(),
)
def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
with pytest.raises(AppMoreLikeThisDisabledError):
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
# ---------------------------------------------------------------------------
# MessageSuggestedQuestionApi
# ---------------------------------------------------------------------------
class TestMessageSuggestedQuestionApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
def test_wrong_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotChatAppError):
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
msg_id = uuid4()
mock_suggest.return_value = ["What about X?", "Tell me more about Y."]
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
assert result["data"] == ["What about X?", "Tell me more about Y."]
@patch(
"controllers.web.message.MessageService.get_suggested_questions_after_answer",
side_effect=MessageNotExistsError(),
)
def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
with pytest.raises(NotFound, match="Message not found"):
MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web.error import WebAppAuthRequiredError
from controllers.web.passport import (
PassportService,
decode_enterprise_webapp_user_id,
exchange_token_for_existing_web_user,
generate_session_id,
)
from services.webapp_auth_service import WebAppAuthType
def test_decode_enterprise_webapp_user_id_none() -> None:
assert decode_enterprise_webapp_user_id(None) is None
def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"})
with pytest.raises(Unauthorized):
decode_enterprise_webapp_user_id("token")
def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None:
decoded = {"token_source": "webapp_login_token", "user_id": "u1"}
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded)
assert decode_enterprise_webapp_user_id("token") == decoded
def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
def _scalar_side_effect(*_args, **_kwargs):
if not hasattr(_scalar_side_effect, "calls"):
_scalar_side_effect.calls = 0
_scalar_side_effect.calls += 1
return site if _scalar_side_effect.calls == 1 else app_model
db_session = SimpleNamespace(scalar=_scalar_side_effect)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp")
decoded = {"auth_type": "public"}
result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC)
assert result == "resp"
def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None:
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
def _scalar_side_effect(*_args, **_kwargs):
if not hasattr(_scalar_side_effect, "calls"):
_scalar_side_effect.calls = 0
_scalar_side_effect.calls += 1
return site if _scalar_side_effect.calls == 1 else app_model
db_session = SimpleNamespace(scalar=_scalar_side_effect)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
decoded = {"auth_type": "internal"}
with pytest.raises(WebAppAuthRequiredError):
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL)
def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1")
def _scalar_side_effect(*_args, **_kwargs):
if not hasattr(_scalar_side_effect, "calls"):
_scalar_side_effect.calls = 0
_scalar_side_effect.calls += 1
if _scalar_side_effect.calls == 1:
return site
if _scalar_side_effect.calls == 2:
return app_model
return None
db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
decoded = {"auth_type": "internal"}
with pytest.raises(NotFound):
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL)
def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
counts = [1, 0]
def _scalar(*_args, **_kwargs):
return counts.pop(0)
db_session = SimpleNamespace(scalar=_scalar)
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
session_id = generate_session_id()
assert session_id

View File

@@ -0,0 +1,423 @@
"""Unit tests for Pydantic models defined in controllers.web modules.
Covers validation logic, field defaults, constraints, and custom validators
for all ~15 Pydantic models across the web controller layer.
"""
from __future__ import annotations
from uuid import uuid4
import pytest
from pydantic import ValidationError
# ---------------------------------------------------------------------------
# app.py models
# ---------------------------------------------------------------------------
from controllers.web.app import AppAccessModeQuery
class TestAppAccessModeQuery:
def test_alias_resolution(self) -> None:
q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"})
assert q.app_id == "abc"
assert q.app_code == "xyz"
def test_defaults_to_none(self) -> None:
q = AppAccessModeQuery.model_validate({})
assert q.app_id is None
assert q.app_code is None
def test_accepts_snake_case(self) -> None:
q = AppAccessModeQuery(app_id="id1", app_code="code1")
assert q.app_id == "id1"
assert q.app_code == "code1"
# ---------------------------------------------------------------------------
# audio.py models
# ---------------------------------------------------------------------------
from controllers.web.audio import TextToAudioPayload
class TestTextToAudioPayload:
def test_defaults(self) -> None:
p = TextToAudioPayload.model_validate({})
assert p.message_id is None
assert p.voice is None
assert p.text is None
assert p.streaming is None
def test_valid_uuid_message_id(self) -> None:
uid = str(uuid4())
p = TextToAudioPayload(message_id=uid)
assert p.message_id == uid
def test_none_message_id_passthrough(self) -> None:
p = TextToAudioPayload(message_id=None)
assert p.message_id is None
def test_invalid_uuid_message_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
TextToAudioPayload(message_id="not-a-uuid")
# ---------------------------------------------------------------------------
# completion.py models
# ---------------------------------------------------------------------------
from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload
class TestCompletionMessagePayload:
def test_defaults(self) -> None:
p = CompletionMessagePayload(inputs={})
assert p.query == ""
assert p.files is None
assert p.response_mode is None
assert p.retriever_from == "web_app"
def test_accepts_full_payload(self) -> None:
p = CompletionMessagePayload(
inputs={"key": "val"},
query="test",
files=[{"id": "f1"}],
response_mode="streaming",
)
assert p.response_mode == "streaming"
assert p.files == [{"id": "f1"}]
def test_invalid_response_mode(self) -> None:
with pytest.raises(ValidationError):
CompletionMessagePayload(inputs={}, response_mode="invalid")
class TestChatMessagePayload:
def test_valid_uuid_fields(self) -> None:
cid = str(uuid4())
pid = str(uuid4())
p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid)
assert p.conversation_id == cid
assert p.parent_message_id == pid
def test_none_uuid_fields(self) -> None:
p = ChatMessagePayload(inputs={}, query="hi")
assert p.conversation_id is None
assert p.parent_message_id is None
def test_invalid_conversation_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
ChatMessagePayload(inputs={}, query="hi", conversation_id="bad")
def test_invalid_parent_message_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad")
def test_query_required(self) -> None:
with pytest.raises(ValidationError):
ChatMessagePayload(inputs={})
# ---------------------------------------------------------------------------
# conversation.py models
# ---------------------------------------------------------------------------
from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload
class TestConversationListQuery:
def test_defaults(self) -> None:
q = ConversationListQuery()
assert q.last_id is None
assert q.limit == 20
assert q.pinned is None
assert q.sort_by == "-updated_at"
def test_limit_lower_bound(self) -> None:
with pytest.raises(ValidationError):
ConversationListQuery(limit=0)
def test_limit_upper_bound(self) -> None:
with pytest.raises(ValidationError):
ConversationListQuery(limit=101)
def test_limit_boundaries_valid(self) -> None:
assert ConversationListQuery(limit=1).limit == 1
assert ConversationListQuery(limit=100).limit == 100
def test_valid_sort_by_options(self) -> None:
for opt in ("created_at", "-created_at", "updated_at", "-updated_at"):
assert ConversationListQuery(sort_by=opt).sort_by == opt
def test_invalid_sort_by(self) -> None:
with pytest.raises(ValidationError):
ConversationListQuery(sort_by="invalid")
def test_valid_last_id(self) -> None:
uid = str(uuid4())
assert ConversationListQuery(last_id=uid).last_id == uid
def test_invalid_last_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
ConversationListQuery(last_id="not-uuid")
class TestConversationRenamePayload:
def test_auto_generate_true_no_name_required(self) -> None:
p = ConversationRenamePayload(auto_generate=True)
assert p.name is None
def test_auto_generate_false_requires_name(self) -> None:
with pytest.raises(ValidationError, match="name is required"):
ConversationRenamePayload(auto_generate=False)
def test_auto_generate_false_blank_name_rejected(self) -> None:
with pytest.raises(ValidationError, match="name is required"):
ConversationRenamePayload(auto_generate=False, name=" ")
def test_auto_generate_false_with_valid_name(self) -> None:
p = ConversationRenamePayload(auto_generate=False, name="My Chat")
assert p.name == "My Chat"
def test_defaults(self) -> None:
p = ConversationRenamePayload(name="test")
assert p.auto_generate is False
assert p.name == "test"
# ---------------------------------------------------------------------------
# message.py models
# ---------------------------------------------------------------------------
from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery
class TestMessageListQuery:
def test_valid_query(self) -> None:
cid = str(uuid4())
q = MessageListQuery(conversation_id=cid)
assert q.conversation_id == cid
assert q.first_id is None
assert q.limit == 20
def test_invalid_conversation_id(self) -> None:
with pytest.raises(ValidationError, match="not a valid uuid"):
MessageListQuery(conversation_id="bad")
def test_limit_bounds(self) -> None:
cid = str(uuid4())
with pytest.raises(ValidationError):
MessageListQuery(conversation_id=cid, limit=0)
with pytest.raises(ValidationError):
MessageListQuery(conversation_id=cid, limit=101)
def test_valid_first_id(self) -> None:
cid = str(uuid4())
fid = str(uuid4())
q = MessageListQuery(conversation_id=cid, first_id=fid)
assert q.first_id == fid
def test_invalid_first_id(self) -> None:
cid = str(uuid4())
with pytest.raises(ValidationError, match="not a valid uuid"):
MessageListQuery(conversation_id=cid, first_id="invalid")
class TestMessageFeedbackPayload:
def test_defaults(self) -> None:
p = MessageFeedbackPayload()
assert p.rating is None
assert p.content is None
def test_valid_ratings(self) -> None:
assert MessageFeedbackPayload(rating="like").rating == "like"
assert MessageFeedbackPayload(rating="dislike").rating == "dislike"
def test_invalid_rating(self) -> None:
with pytest.raises(ValidationError):
MessageFeedbackPayload(rating="neutral")
class TestMessageMoreLikeThisQuery:
def test_valid_modes(self) -> None:
assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking"
assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming"
def test_invalid_mode(self) -> None:
with pytest.raises(ValidationError):
MessageMoreLikeThisQuery(response_mode="invalid")
def test_required(self) -> None:
with pytest.raises(ValidationError):
MessageMoreLikeThisQuery()
# ---------------------------------------------------------------------------
# remote_files.py models
# ---------------------------------------------------------------------------
from controllers.web.remote_files import RemoteFileUploadPayload
class TestRemoteFileUploadPayload:
def test_valid_url(self) -> None:
p = RemoteFileUploadPayload(url="https://example.com/file.pdf")
assert str(p.url) == "https://example.com/file.pdf"
def test_invalid_url(self) -> None:
with pytest.raises(ValidationError):
RemoteFileUploadPayload(url="not-a-url")
def test_url_required(self) -> None:
with pytest.raises(ValidationError):
RemoteFileUploadPayload()
# ---------------------------------------------------------------------------
# saved_message.py models
# ---------------------------------------------------------------------------
from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery
class TestSavedMessageListQuery:
def test_defaults(self) -> None:
q = SavedMessageListQuery()
assert q.last_id is None
assert q.limit == 20
def test_limit_bounds(self) -> None:
with pytest.raises(ValidationError):
SavedMessageListQuery(limit=0)
with pytest.raises(ValidationError):
SavedMessageListQuery(limit=101)
def test_valid_last_id(self) -> None:
uid = str(uuid4())
q = SavedMessageListQuery(last_id=uid)
assert q.last_id == uid
def test_empty_last_id(self) -> None:
q = SavedMessageListQuery(last_id="")
assert q.last_id == ""
class TestSavedMessageCreatePayload:
def test_valid_message_id(self) -> None:
uid = str(uuid4())
p = SavedMessageCreatePayload(message_id=uid)
assert p.message_id == uid
def test_required(self) -> None:
with pytest.raises(ValidationError):
SavedMessageCreatePayload()
# ---------------------------------------------------------------------------
# workflow.py models
# ---------------------------------------------------------------------------
from controllers.web.workflow import WorkflowRunPayload
class TestWorkflowRunPayload:
def test_defaults(self) -> None:
p = WorkflowRunPayload(inputs={})
assert p.inputs == {}
assert p.files is None
def test_with_files(self) -> None:
p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}])
assert p.files == [{"id": "f1"}]
def test_inputs_required(self) -> None:
with pytest.raises(ValidationError):
WorkflowRunPayload()
# ---------------------------------------------------------------------------
# forgot_password.py models
# ---------------------------------------------------------------------------
from controllers.web.forgot_password import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
class TestForgotPasswordSendPayload:
def test_valid_email(self) -> None:
p = ForgotPasswordSendPayload(email="user@example.com")
assert p.email == "user@example.com"
def test_invalid_email(self) -> None:
with pytest.raises(ValidationError, match="not a valid email"):
ForgotPasswordSendPayload(email="not-an-email")
def test_language_optional(self) -> None:
p = ForgotPasswordSendPayload(email="a@b.com")
assert p.language is None
class TestForgotPasswordCheckPayload:
def test_valid(self) -> None:
p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok")
assert p.email == "a@b.com"
assert p.code == "1234"
assert p.token == "tok"
def test_empty_token_rejected(self) -> None:
with pytest.raises(ValidationError):
ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="")
class TestForgotPasswordResetPayload:
def test_valid_passwords(self) -> None:
p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234")
assert p.new_password == "Valid1234"
def test_weak_password_rejected(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short")
def test_letters_only_password_rejected(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi")
def test_digits_only_password_rejected(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789")
# ---------------------------------------------------------------------------
# login.py models
# ---------------------------------------------------------------------------
from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload
class TestLoginPayload:
def test_valid(self) -> None:
p = LoginPayload(email="a@b.com", password="Valid1234")
assert p.email == "a@b.com"
def test_invalid_email(self) -> None:
with pytest.raises(ValidationError, match="not a valid email"):
LoginPayload(email="bad", password="Valid1234")
def test_weak_password(self) -> None:
with pytest.raises(ValidationError, match="Password must contain"):
LoginPayload(email="a@b.com", password="weak")
class TestEmailCodeLoginSendPayload:
def test_valid(self) -> None:
p = EmailCodeLoginSendPayload(email="a@b.com")
assert p.language is None
def test_with_language(self) -> None:
p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans")
assert p.language == "zh-Hans"
class TestEmailCodeLoginVerifyPayload:
def test_valid(self) -> None:
p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok")
assert p.code == "1234"
def test_empty_token_rejected(self) -> None:
with pytest.raises(ValidationError):
EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="")

View File

@@ -0,0 +1,147 @@
"""Unit tests for controllers.web.remote_files endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError
from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi
def _app_model() -> SimpleNamespace:
return SimpleNamespace(id="app-1")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# RemoteFileInfoApi
# ---------------------------------------------------------------------------
class TestRemoteFileInfoApi:
@patch("controllers.web.remote_files.ssrf_proxy")
def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None:
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"}
mock_proxy.head.return_value = mock_resp
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"):
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf")
assert result["file_type"] == "application/pdf"
assert result["file_length"] == 1024
@patch("controllers.web.remote_files.ssrf_proxy")
def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None:
head_resp = MagicMock()
head_resp.status_code = 405 # Method not allowed
get_resp = MagicMock()
get_resp.status_code = 200
get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"}
get_resp.raise_for_status = MagicMock()
mock_proxy.head.return_value = head_resp
mock_proxy.get.return_value = get_resp
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"):
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt")
assert result["file_type"] == "text/plain"
mock_proxy.get.assert_called_once()
# ---------------------------------------------------------------------------
# RemoteFileUploadApi
# ---------------------------------------------------------------------------
class TestRemoteFileUploadApi:
@patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url")
@patch("controllers.web.remote_files.FileService")
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.web_ns")
@patch("controllers.web.remote_files.db")
def test_upload_success(
self,
mock_db: MagicMock,
mock_ns: MagicMock,
mock_proxy: MagicMock,
mock_guess: MagicMock,
mock_file_svc_cls: MagicMock,
mock_signed: MagicMock,
app: Flask,
) -> None:
mock_db.engine = "engine"
mock_ns.payload = {"url": "https://example.com/file.pdf"}
head_resp = MagicMock()
head_resp.status_code = 200
head_resp.content = b"pdf-content"
head_resp.request.method = "HEAD"
mock_proxy.head.return_value = head_resp
get_resp = MagicMock()
get_resp.content = b"pdf-content"
mock_proxy.get.return_value = get_resp
mock_guess.return_value = SimpleNamespace(
filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100
)
mock_file_svc_cls.is_file_size_within_limit.return_value = True
from datetime import datetime
upload_file = SimpleNamespace(
id="f-1",
name="file.pdf",
size=100,
extension="pdf",
mime_type="application/pdf",
created_by="eu-1",
created_at=datetime(2024, 1, 1),
)
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context("/remote-files/upload", method="POST"):
result, status = RemoteFileUploadApi().post(_app_model(), _end_user())
assert status == 201
assert result["id"] == "f-1"
@patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False)
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.web_ns")
def test_file_too_large(
self,
mock_ns: MagicMock,
mock_proxy: MagicMock,
mock_guess: MagicMock,
mock_size_check: MagicMock,
app: Flask,
) -> None:
mock_ns.payload = {"url": "https://example.com/big.zip"}
head_resp = MagicMock()
head_resp.status_code = 200
mock_proxy.head.return_value = head_resp
mock_guess.return_value = SimpleNamespace(
filename="big.zip", extension="zip", mimetype="application/zip", size=999999999
)
with app.test_request_context("/remote-files/upload", method="POST"):
with pytest.raises(FileTooLargeError):
RemoteFileUploadApi().post(_app_model(), _end_user())
@patch("controllers.web.remote_files.ssrf_proxy")
@patch("controllers.web.remote_files.web_ns")
def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None:
import httpx
mock_ns.payload = {"url": "https://example.com/bad"}
mock_proxy.head.side_effect = httpx.RequestError("connection failed")
with app.test_request_context("/remote-files/upload", method="POST"):
with pytest.raises(RemoteFileUploadError):
RemoteFileUploadApi().post(_app_model(), _end_user())

View File

@@ -0,0 +1,97 @@
"""Unit tests for controllers.web.saved_message endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.error import NotCompletionAppError
from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi
from services.errors.message import MessageNotExistsError
def _completion_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="completion")
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# SavedMessageListApi (GET)
# ---------------------------------------------------------------------------
class TestSavedMessageListApiGet:
def test_non_completion_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/saved-messages"):
with pytest.raises(NotCompletionAppError):
SavedMessageListApi().get(_chat_app(), _end_user())
@patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id")
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[])
with app.test_request_context("/saved-messages?limit=20"):
result = SavedMessageListApi().get(_completion_app(), _end_user())
assert result["limit"] == 20
assert result["has_more"] is False
# ---------------------------------------------------------------------------
# SavedMessageListApi (POST)
# ---------------------------------------------------------------------------
class TestSavedMessageListApiPost:
def test_non_completion_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/saved-messages", method="POST"):
with pytest.raises(NotCompletionAppError):
SavedMessageListApi().post(_chat_app(), _end_user())
@patch("controllers.web.saved_message.SavedMessageService.save")
@patch("controllers.web.saved_message.web_ns")
def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
msg_id = str(uuid4())
mock_ns.payload = {"message_id": msg_id}
with app.test_request_context("/saved-messages", method="POST"):
result = SavedMessageListApi().post(_completion_app(), _end_user())
assert result["result"] == "success"
@patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError())
@patch("controllers.web.saved_message.web_ns")
def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
mock_ns.payload = {"message_id": str(uuid4())}
with app.test_request_context("/saved-messages", method="POST"):
with pytest.raises(NotFound, match="Message Not Exists"):
SavedMessageListApi().post(_completion_app(), _end_user())
# ---------------------------------------------------------------------------
# SavedMessageApi (DELETE)
# ---------------------------------------------------------------------------
class TestSavedMessageApi:
def test_non_completion_mode_raises(self, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
with pytest.raises(NotCompletionAppError):
SavedMessageApi().delete(_chat_app(), _end_user(), msg_id)
@patch("controllers.web.saved_message.SavedMessageService.delete")
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
msg_id = uuid4()
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id)
assert status == 204
assert result["result"] == "success"

View File

@@ -0,0 +1,126 @@
"""Unit tests for controllers.web.site endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.web.site import AppSiteApi, AppSiteInfo
def _tenant(*, status: str = "normal") -> SimpleNamespace:
return SimpleNamespace(
id="tenant-1",
status=status,
plan="basic",
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
)
def _site() -> SimpleNamespace:
return SimpleNamespace(
title="Site",
icon_type="emoji",
icon="robot",
icon_background="#fff",
description="desc",
default_language="en",
chat_color_theme="light",
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
# ---------------------------------------------------------------------------
# AppSiteApi
# ---------------------------------------------------------------------------
class TestAppSiteApi:
@patch("controllers.web.site.FeatureService.get_features")
@patch("controllers.web.site.db")
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
site_obj = _site()
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
result = AppSiteApi().get(app_model, end_user)
# marshal_with serializes AppSiteInfo to a dict
assert result["app_id"] == "app-1"
assert result["plan"] == "basic"
assert result["enable_site"] is True
@patch("controllers.web.site.db")
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_db.session.query.return_value.where.return_value.first.return_value = None
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
@patch("controllers.web.site.db")
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
from models.account import TenantStatus
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.ARCHIVE,
plan="basic",
custom_config_dict={},
)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
# ---------------------------------------------------------------------------
# AppSiteInfo
# ---------------------------------------------------------------------------
class TestAppSiteInfo:
def test_basic_fields(self) -> None:
tenant = _tenant()
site_obj = _site()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
assert info.app_id == "app-1"
assert info.end_user_id == "eu-1"
assert info.enable_site is True
assert info.plan == "basic"
assert info.can_replace_logo is False
assert info.model_config is None
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
def test_can_replace_logo_sets_custom_config(self) -> None:
tenant = SimpleNamespace(
id="tenant-1",
plan="pro",
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
)
site_obj = _site()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
assert info.can_replace_logo is True
assert info.custom_config["remove_webapp_brand"] is True
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]

View File

@@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
import services.errors.account
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
def encode_code(code: str) -> str:
@@ -89,3 +90,114 @@ class TestEmailCodeLoginApi:
mock_revoke_token.assert_called_once_with("token-123")
mock_login.assert_called_once()
mock_reset_login_rate.assert_called_once_with("user@example.com")
class TestLoginApi:
@patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok")
@patch("controllers.web.login.WebAppAuthService.authenticate")
def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None:
mock_auth.return_value = MagicMock()
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
response = LoginApi().post()
assert response.get_json()["data"]["access_token"] == "access-tok"
mock_auth.assert_called_once()
@patch(
"controllers.web.login.WebAppAuthService.authenticate",
side_effect=services.errors.account.AccountLoginError(),
)
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
from controllers.console.error import AccountBannedError
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AccountBannedError):
LoginApi().post()
@patch(
"controllers.web.login.WebAppAuthService.authenticate",
side_effect=services.errors.account.AccountPasswordError(),
)
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
from controllers.console.auth.error import AuthenticationFailedError
with app.test_request_context(
"/web/login",
method="POST",
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
):
with pytest.raises(AuthenticationFailedError):
LoginApi().post()
class TestLoginStatusApi:
@patch("controllers.web.login.extract_webapp_access_token", return_value=None)
def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None:
with app.test_request_context("/web/login/status"):
result = LoginStatusApi().get()
assert result["logged_in"] is False
assert result["app_logged_in"] is False
@patch("controllers.web.login.decode_jwt_token")
@patch("controllers.web.login.PassportService")
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False)
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
def test_public_app_user_logged_in(
self,
mock_extract: MagicMock,
mock_app_id: MagicMock,
mock_perm: MagicMock,
mock_passport: MagicMock,
mock_decode: MagicMock,
app: Flask,
) -> None:
mock_decode.return_value = (MagicMock(), MagicMock())
with app.test_request_context("/web/login/status?app_code=code1"):
result = LoginStatusApi().get()
assert result["logged_in"] is True
assert result["app_logged_in"] is True
@patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad"))
@patch("controllers.web.login.PassportService")
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True)
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
def test_private_app_passport_fails(
self,
mock_extract: MagicMock,
mock_app_id: MagicMock,
mock_perm: MagicMock,
mock_passport_cls: MagicMock,
mock_decode: MagicMock,
app: Flask,
) -> None:
mock_passport_cls.return_value.verify.side_effect = Exception("bad")
with app.test_request_context("/web/login/status?app_code=code1"):
result = LoginStatusApi().get()
assert result["logged_in"] is False
assert result["app_logged_in"] is False
class TestLogoutApi:
@patch("controllers.web.login.clear_webapp_access_token_from_cookie")
def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None:
with app.test_request_context("/web/logout", method="POST"):
response = LogoutApi().post()
assert response.get_json() == {"result": "success"}
mock_clear.assert_called_once()

View File

@@ -0,0 +1,192 @@
"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web.error import WebAppAuthRequiredError
from controllers.web.passport import (
PassportResource,
decode_enterprise_webapp_user_id,
exchange_token_for_existing_web_user,
generate_session_id,
)
from services.webapp_auth_service import WebAppAuthType
# ---------------------------------------------------------------------------
# decode_enterprise_webapp_user_id
# ---------------------------------------------------------------------------
class TestDecodeEnterpriseWebappUserId:
def test_none_token_returns_none(self) -> None:
assert decode_enterprise_webapp_user_id(None) is None
@patch("controllers.web.passport.PassportService")
def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None:
mock_passport_cls.return_value.verify.return_value = {
"token_source": "webapp_login_token",
"user_id": "u1",
}
result = decode_enterprise_webapp_user_id("valid-jwt")
assert result["user_id"] == "u1"
@patch("controllers.web.passport.PassportService")
def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
mock_passport_cls.return_value.verify.return_value = {
"token_source": "other_source",
}
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
decode_enterprise_webapp_user_id("bad-jwt")
@patch("controllers.web.passport.PassportService")
def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
mock_passport_cls.return_value.verify.return_value = {}
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
decode_enterprise_webapp_user_id("no-source-jwt")
# ---------------------------------------------------------------------------
# generate_session_id
# ---------------------------------------------------------------------------
class TestGenerateSessionId:
@patch("controllers.web.passport.db")
def test_returns_unique_session_id(self, mock_db: MagicMock) -> None:
mock_db.session.scalar.return_value = 0
sid = generate_session_id()
assert isinstance(sid, str)
assert len(sid) == 36 # UUID format
@patch("controllers.web.passport.db")
def test_retries_on_collision(self, mock_db: MagicMock) -> None:
# First call returns count=1 (collision), second returns 0
mock_db.session.scalar.side_effect = [1, 0]
sid = generate_session_id()
assert isinstance(sid, str)
assert mock_db.session.scalar.call_count == 2
# ---------------------------------------------------------------------------
# exchange_token_for_existing_web_user
# ---------------------------------------------------------------------------
class TestExchangeTokenForExistingWebUser:
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
site = SimpleNamespace(code="code1", app_id="app-1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
mock_db.session.scalar.side_effect = [site, app_model]
decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external"
with pytest.raises(WebAppAuthRequiredError, match="external"):
exchange_token_for_existing_web_user(
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
)
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
site = SimpleNamespace(code="code1", app_id="app-1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
mock_db.session.scalar.side_effect = [site, app_model]
decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal"
with pytest.raises(WebAppAuthRequiredError, match="internal"):
exchange_token_for_existing_web_user(
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL
)
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
mock_db.session.scalar.return_value = None
decoded = {"user_id": "u1", "auth_type": "external"}
with pytest.raises(NotFound):
exchange_token_for_existing_web_user(
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
)
# ---------------------------------------------------------------------------
# PassportResource.get
# ---------------------------------------------------------------------------
class TestPassportResource:
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
with app.test_request_context("/passport"):
with pytest.raises(Unauthorized, match="X-App-Code"):
PassportResource().get()
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.generate_session_id", return_value="new-sess-id")
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_creates_new_end_user_when_no_user_id(
self,
mock_features: MagicMock,
mock_db: MagicMock,
mock_gen_session: MagicMock,
mock_passport_cls: MagicMock,
app: Flask,
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
site = SimpleNamespace(app_id="app-1", code="code1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
mock_db.session.scalar.side_effect = [site, app_model]
mock_passport_cls.return_value.issue.return_value = "issued-token"
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
response = PassportResource().get()
assert response.get_json()["access_token"] == "issued-token"
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
@patch("controllers.web.passport.PassportService")
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_reuses_existing_end_user_when_user_id_provided(
self,
mock_features: MagicMock,
mock_db: MagicMock,
mock_passport_cls: MagicMock,
app: Flask,
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
site = SimpleNamespace(app_id="app-1", code="code1")
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing")
mock_db.session.scalar.side_effect = [site, app_model, existing_user]
mock_passport_cls.return_value.issue.return_value = "reused-token"
with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}):
response = PassportResource().get()
assert response.get_json()["access_token"] == "reused-token"
# Should not create a new end user
mock_db.session.add.assert_not_called()
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
mock_db.session.scalar.return_value = None
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
PassportResource().get()
@patch("controllers.web.passport.db")
@patch("controllers.web.passport.FeatureService.get_system_features")
def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
site = SimpleNamespace(app_id="app-1", code="code1")
disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False)
mock_db.session.scalar.side_effect = [site, disabled_app]
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
PassportResource().get()

View File

@@ -0,0 +1,95 @@
"""Unit tests for controllers.web.workflow endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.error import (
NotWorkflowAppError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError
def _workflow_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="workflow")
def _chat_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", mode="chat")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# WorkflowRunApi
# ---------------------------------------------------------------------------
class TestWorkflowRunApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/workflows/run", method="POST"):
with pytest.raises(NotWorkflowAppError):
WorkflowRunApi().post(_chat_app(), _end_user())
@patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"})
@patch("controllers.web.workflow.AppGenerateService.generate")
@patch("controllers.web.workflow.web_ns")
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {"key": "val"}}
mock_gen.return_value = "response"
with app.test_request_context("/workflows/run", method="POST"):
result = WorkflowRunApi().post(_workflow_app(), _end_user())
assert result == {"result": "ok"}
@patch(
"controllers.web.workflow.AppGenerateService.generate",
side_effect=ProviderTokenNotInitError(description="not init"),
)
@patch("controllers.web.workflow.web_ns")
def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/workflows/run", method="POST"):
with pytest.raises(ProviderNotInitializeError):
WorkflowRunApi().post(_workflow_app(), _end_user())
@patch(
"controllers.web.workflow.AppGenerateService.generate",
side_effect=QuotaExceededError(),
)
@patch("controllers.web.workflow.web_ns")
def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
mock_ns.payload = {"inputs": {}}
with app.test_request_context("/workflows/run", method="POST"):
with pytest.raises(ProviderQuotaExceededError):
WorkflowRunApi().post(_workflow_app(), _end_user())
# ---------------------------------------------------------------------------
# WorkflowTaskStopApi
# ---------------------------------------------------------------------------
class TestWorkflowTaskStopApi:
def test_wrong_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
with pytest.raises(NotWorkflowAppError):
WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1")
@patch("controllers.web.workflow.GraphEngineManager.send_stop_command")
@patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check")
def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None:
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1")
assert result == {"result": "success"}
mock_legacy.assert_called_once_with("task-1")
mock_graph.assert_called_once_with("task-1")

View File

@@ -0,0 +1,127 @@
"""Unit tests for controllers.web.workflow_events endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.error import NotFoundError
from controllers.web.workflow_events import WorkflowEventsApi
from models.enums import CreatorUserRole
def _workflow_app() -> SimpleNamespace:
return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow")
def _end_user() -> SimpleNamespace:
return SimpleNamespace(id="eu-1")
# ---------------------------------------------------------------------------
# WorkflowEventsApi
# ---------------------------------------------------------------------------
class TestWorkflowEventsApi:
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="other-app",
created_by_role=CreatorUserRole.END_USER,
created_by="eu-1",
finished_at=None,
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_not_created_by_end_user(
self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask
) -> None:
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="eu-1",
finished_at=None,
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="other-user",
finished_at=None,
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
with app.test_request_context("/workflow/run-1/events"):
with pytest.raises(NotFoundError):
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
@patch("controllers.web.workflow_events.WorkflowResponseConverter")
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
@patch("controllers.web.workflow_events.db")
def test_finished_run_returns_sse_response(
self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask
) -> None:
from datetime import datetime
mock_db.engine = "engine"
run = SimpleNamespace(
id="run-1",
app_id="app-1",
created_by_role=CreatorUserRole.END_USER,
created_by="eu-1",
finished_at=datetime(2024, 1, 1),
)
mock_repo = MagicMock()
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
finish_response = MagicMock()
finish_response.model_dump.return_value = {"task_id": "run-1"}
finish_response.event.value = "workflow_finished"
mock_converter.workflow_run_result_to_finish_response.return_value = finish_response
with app.test_request_context("/workflow/run-1/events"):
response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
assert response.mimetype == "text/event-stream"

View File

@@ -0,0 +1,393 @@
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
from controllers.web.wraps import (
_validate_user_accessibility,
_validate_webapp_token,
decode_jwt_token,
)
# ---------------------------------------------------------------------------
# _validate_webapp_token
# ---------------------------------------------------------------------------
class TestValidateWebappToken:
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
"""When both flags are true, a non-webapp source must raise."""
decoded = {"token_source": "other"}
with pytest.raises(WebAppAuthRequiredError):
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None:
decoded = {"token_source": "webapp"}
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None:
decoded = {}
with pytest.raises(WebAppAuthRequiredError):
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
def test_public_app_rejects_webapp_source(self) -> None:
"""When auth is not required, a webapp-sourced token must be rejected."""
decoded = {"token_source": "webapp"}
with pytest.raises(Unauthorized):
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
def test_public_app_accepts_non_webapp_source(self) -> None:
decoded = {"token_source": "other"}
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
def test_public_app_accepts_no_source(self) -> None:
decoded = {}
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
def test_system_enabled_but_app_public(self) -> None:
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
decoded = {"token_source": "webapp"}
with pytest.raises(Unauthorized):
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
# ---------------------------------------------------------------------------
# _validate_user_accessibility
# ---------------------------------------------------------------------------
class TestValidateUserAccessibility:
def test_skips_when_auth_disabled(self) -> None:
"""No checks when system or app auth is disabled."""
_validate_user_accessibility(
decoded={},
app_code="code",
app_web_auth_enabled=False,
system_webapp_auth_enabled=False,
webapp_settings=None,
)
def test_missing_user_id_raises(self) -> None:
decoded = {}
with pytest.raises(WebAppAuthRequiredError):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=SimpleNamespace(access_mode="internal"),
)
def test_missing_webapp_settings_raises(self) -> None:
decoded = {"user_id": "u1"}
with pytest.raises(WebAppAuthRequiredError, match="settings not found"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=None,
)
def test_missing_auth_type_raises(self) -> None:
decoded = {"user_id": "u1", "granted_at": 1}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
def test_missing_granted_at_raises(self) -> None:
decoded = {"user_id": "u1", "auth_type": "external"}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_external_auth_type_checks_sso_update_time(
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
) -> None:
# granted_at is before SSO update time → denied
mock_sso_time.return_value = datetime.now(UTC)
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_internal_auth_type_checks_workspace_sso_update_time(
self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock
) -> None:
mock_workspace_sso.return_value = datetime.now(UTC)
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted}
settings = SimpleNamespace(access_mode="public")
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
def test_external_auth_passes_when_granted_after_sso_update(
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
) -> None:
mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2)
recent_granted = int(datetime.now(UTC).timestamp())
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
settings = SimpleNamespace(access_mode="public")
# Should not raise
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False)
@patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1")
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True)
def test_permission_check_denies_unauthorized_user(
self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock
) -> None:
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())}
settings = SimpleNamespace(access_mode="internal")
with pytest.raises(WebAppAuthAccessDeniedError):
_validate_user_accessibility(
decoded=decoded,
app_code="code",
app_web_auth_enabled=True,
system_webapp_auth_enabled=True,
webapp_settings=settings,
)
# ---------------------------------------------------------------------------
# decode_jwt_token
# ---------------------------------------------------------------------------
class TestDecodeJwtToken:
@patch("controllers.web.wraps._validate_user_accessibility")
@patch("controllers.web.wraps._validate_webapp_token")
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
@patch("controllers.web.wraps.AppService.get_app_id_by_code")
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_happy_path(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
mock_app_id: MagicMock,
mock_access_mode: MagicMock,
mock_validate_token: MagicMock,
mock_validate_user: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
# Configure session mock to return correct objects via scalar()
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, end_user]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
result_app, result_user = decode_jwt_token()
assert result_app.id == "app-1"
assert result_user.id == "eu-1"
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.extract_webapp_passport")
def test_missing_token_raises_unauthorized(
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
) -> None:
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
mock_extract.return_value = None
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(Unauthorized):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_missing_app_raises_not_found(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
session_mock = MagicMock()
session_mock.scalar.return_value = None # No app found
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_disabled_site_raises_bad_request(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=False)
session_mock = MagicMock()
# scalar calls: app_model, site (code found), then end_user
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(BadRequest, match="Site is disabled"):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_missing_end_user_raises_not_found(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(NotFound):
decode_jwt_token()
@patch("controllers.web.wraps.FeatureService.get_system_features")
@patch("controllers.web.wraps.PassportService")
@patch("controllers.web.wraps.extract_webapp_passport")
@patch("controllers.web.wraps.db")
def test_user_id_mismatch_raises_unauthorized(
self,
mock_db: MagicMock,
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app: Flask,
) -> None:
mock_extract.return_value = "jwt-token"
mock_passport_cls.return_value.verify.return_value = {
"app_code": "code1",
"app_id": "app-1",
"end_user_id": "eu-1",
}
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
app_model = SimpleNamespace(id="app-1", enable_site=True)
site = SimpleNamespace(code="code1")
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
session_mock = MagicMock()
session_mock.scalar.side_effect = [app_model, site, end_user]
session_ctx = MagicMock()
session_ctx.__enter__ = MagicMock(return_value=session_mock)
session_ctx.__exit__ = MagicMock(return_value=False)
mock_db.engine = "engine"
with patch("controllers.web.wraps.Session", return_value=session_ctx):
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
with pytest.raises(Unauthorized, match="expired"):
decode_jwt_token(user_id="different-user")

View File

@@ -22,6 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
from dify_graph.nodes.template_transform import TemplateTransformNode
from dify_graph.nodes.template_transform.template_renderer import (
@@ -65,6 +66,8 @@ class MockNodeMixin:
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
# LLM-like nodes now require an http_client; provide a mock by default for tests.
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
# Ensure TemplateTransformNode receives a renderer now required by constructor
if isinstance(self, TemplateTransformNode):

View File

@@ -112,7 +112,6 @@ class TestKnowledgeRetrievalNode:
# Assert
assert node.id == node_id
assert node._rag_retrieval == mock_rag_retrieval
assert node._llm_file_saver is not None
def test_run_with_no_query_or_attachment(
self,

View File

@@ -1,10 +1,10 @@
import uuid
from typing import NamedTuple
from unittest import mock
from unittest.mock import MagicMock
import httpx
import pytest
from sqlalchemy import Engine
from core.helper import ssrf_proxy
from core.tools import signature
@@ -44,7 +44,6 @@ class TestFileSaverImpl:
)
mock_tool_file.id = _gen_id()
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
@@ -53,11 +52,12 @@ class TestFileSaverImpl:
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url
http_client = MagicMock()
storage_file_manager = FileSaverImpl(
user_id=user_id,
tenant_id=tenant_id,
engine_factory=mocked_engine,
http_client=http_client,
)
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
@@ -87,16 +87,18 @@ class TestFileSaverImpl:
status_code=401,
request=mock_request,
)
http_client = MagicMock()
http_client.get.return_value = mock_response
file_saver = FileSaverImpl(
user_id=_gen_id(),
tenant_id=_gen_id(),
http_client=http_client,
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
with pytest.raises(httpx.HTTPStatusError) as exc:
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_get.assert_called_once_with(_TEST_URL)
http_client.get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
@@ -112,8 +114,10 @@ class TestFileSaverImpl:
headers={"Content-Type": mime_type},
request=mock_request,
)
http_client = MagicMock()
http_client.get.return_value = mock_response
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client)
mock_tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,

View File

@@ -111,6 +111,7 @@ def llm_node(
"id": "1",
"data": llm_node_data.model_dump(),
}
http_client = mock.MagicMock()
node = LLMNode(
id="1",
config=node_config,
@@ -120,6 +121,7 @@ def llm_node(
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
http_client=http_client,
)
return node
@@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
"id": "1",
"data": llm_node_data.model_dump(),
}
http_client = mock.MagicMock()
node = LLMNode(
id="1",
config=node_config,
@@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
http_client=http_client,
)
return node, mock_file_saver

File diff suppressed because it is too large Load Diff

View File

@@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => {
})
})
// ─── 6. Close Handling ───────────────────────────────────────────────────
describe('Close handling', () => {
it('should call onCancel when pressing ESC key', () => {
render(<Pricing onCancel={onCancel} />)
// ahooks useKeyPress listens on document for keydown events
document.dispatchEvent(new KeyboardEvent('keydown', {
key: 'Escape',
code: 'Escape',
keyCode: 27,
bubbles: true,
}))
expect(onCancel).toHaveBeenCalledTimes(1)
})
})
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
describe('Pricing page URL', () => {
it('should render pricing link with correct URL', () => {
render(<Pricing onCancel={onCancel} />)

View File

@@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
isShow={isShowDeleteConfirm}
onClose={() => setIsShowDeleteConfirm(false)}
>
<div className="title-2xl-semi-bold mb-3 text-text-primary">{t('avatar.deleteTitle', { ns: 'common' })}</div>
<div className="mb-3 text-text-primary title-2xl-semi-bold">{t('avatar.deleteTitle', { ns: 'common' })}</div>
<p className="mb-8 text-text-secondary">{t('avatar.deleteDescription', { ns: 'common' })}</p>
<div className="flex w-full items-center justify-center gap-2">

View File

@@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
</div>
{step === STEP.start && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.title', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.title', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-medium text-text-warning">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
<div className="body-md-regular text-text-secondary">
<div className="text-text-warning body-md-medium">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
<div className="text-text-secondary body-md-regular">
<Trans
i18nKey="account.changeEmail.content1"
ns="common"
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
values={{ email }}
/>
</div>
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
)}
{step === STEP.verifyOrigin && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-regular text-text-secondary">
<div className="text-text-secondary body-md-regular">
<Trans
i18nKey="account.changeEmail.content2"
ns="common"
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
values={{ email }}
/>
</div>
</div>
<div className="pt-3">
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<Input
className="!w-full"
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
@@ -278,25 +278,25 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{t('operation.cancel', { ns: 'common' })}
</Button>
</div>
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
{time > 0 && (
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
)}
{!time && (
<span onClick={sendCodeToOriginEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
<span onClick={sendCodeToOriginEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
)}
</div>
</>
)}
{step === STEP.newEmail && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-regular text-text-secondary">{t('account.changeEmail.content3', { ns: 'common' })}</div>
<div className="text-text-secondary body-md-regular">{t('account.changeEmail.content3', { ns: 'common' })}</div>
</div>
<div className="pt-3">
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
<Input
className="!w-full"
placeholder={t('account.changeEmail.emailPlaceholder', { ns: 'common' })}
@@ -305,10 +305,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
destructive={newEmailExited || unAvailableEmail}
/>
{newEmailExited && (
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
)}
{unAvailableEmail && (
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
)}
</div>
<div className="mt-3 space-y-2">
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
)}
{step === STEP.verifyNew && (
<>
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
<div className="space-y-0.5 pb-2 pt-1">
<div className="body-md-regular text-text-secondary">
<div className="text-text-secondary body-md-regular">
<Trans
i18nKey="account.changeEmail.content4"
ns="common"
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
values={{ email: mail }}
/>
</div>
</div>
<div className="pt-3">
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
<Input
className="!w-full"
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
@@ -368,13 +368,13 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{t('operation.cancel', { ns: 'common' })}
</Button>
</div>
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
{time > 0 && (
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
)}
{!time && (
<span onClick={sendCodeToNewEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
<span onClick={sendCodeToNewEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
)}
</div>
</>

View File

@@ -138,7 +138,7 @@ export default function AccountPage() {
imageUrl={icon_url}
/>
</div>
<div className="system-sm-medium mt-[3px] text-text-secondary">{item.name}</div>
<div className="mt-[3px] text-text-secondary system-sm-medium">{item.name}</div>
</div>
)
}
@@ -146,12 +146,12 @@ export default function AccountPage() {
return (
<>
<div className="pb-3 pt-2">
<h4 className="title-2xl-semi-bold text-text-primary">{t('account.myAccount', { ns: 'common' })}</h4>
<h4 className="text-text-primary title-2xl-semi-bold">{t('account.myAccount', { ns: 'common' })}</h4>
</div>
<div className="mb-8 flex items-center rounded-xl bg-gradient-to-r from-background-gradient-bg-fill-chat-bg-2 to-background-gradient-bg-fill-chat-bg-1 p-6">
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size={64} />
<div className="ml-4">
<p className="system-xl-semibold text-text-primary">
<p className="text-text-primary system-xl-semibold">
{userProfile.name}
{isEducationAccount && (
<PremiumBadge size="s" color="blue" className="ml-1 !px-2">
@@ -160,16 +160,16 @@ export default function AccountPage() {
</PremiumBadge>
)}
</p>
<p className="system-xs-regular text-text-tertiary">{userProfile.email}</p>
<p className="text-text-tertiary system-xs-regular">{userProfile.email}</p>
</div>
</div>
<div className="mb-8">
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
<div className="mt-2 flex w-full items-center justify-between gap-2">
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
<span className="pl-1">{userProfile.name}</span>
</div>
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={handleEditName}>
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={handleEditName}>
{t('operation.edit', { ns: 'common' })}
</div>
</div>
@@ -177,11 +177,11 @@ export default function AccountPage() {
<div className="mb-8">
<div className={titleClassName}>{t('account.email', { ns: 'common' })}</div>
<div className="mt-2 flex w-full items-center justify-between gap-2">
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
<span className="pl-1">{userProfile.email}</span>
</div>
{systemFeatures.enable_change_email && (
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={() => setShowUpdateEmail(true)}>
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={() => setShowUpdateEmail(true)}>
{t('operation.change', { ns: 'common' })}
</div>
)}
@@ -191,8 +191,8 @@ export default function AccountPage() {
systemFeatures.enable_email_password_login && (
<div className="mb-8 flex justify-between gap-2">
<div>
<div className="system-sm-semibold mb-1 text-text-secondary">{t('account.password', { ns: 'common' })}</div>
<div className="body-xs-regular mb-2 text-text-tertiary">{t('account.passwordTip', { ns: 'common' })}</div>
<div className="mb-1 text-text-secondary system-sm-semibold">{t('account.password', { ns: 'common' })}</div>
<div className="mb-2 text-text-tertiary body-xs-regular">{t('account.passwordTip', { ns: 'common' })}</div>
</div>
<Button onClick={() => setEditPasswordModalVisible(true)}>{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</Button>
</div>
@@ -219,7 +219,7 @@ export default function AccountPage() {
onClose={() => setEditNameModalVisible(false)}
className="!w-[420px] !p-6"
>
<div className="title-2xl-semi-bold mb-6 text-text-primary">{t('account.editName', { ns: 'common' })}</div>
<div className="mb-6 text-text-primary title-2xl-semi-bold">{t('account.editName', { ns: 'common' })}</div>
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
<Input
className="mt-2"
@@ -249,7 +249,7 @@ export default function AccountPage() {
}}
className="!w-[420px] !p-6"
>
<div className="title-2xl-semi-bold mb-6 text-text-primary">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
<div className="mb-6 text-text-primary title-2xl-semi-bold">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
{userProfile.is_password_set && (
<>
<div className={titleClassName}>{t('account.currentPassword', { ns: 'common' })}</div>
@@ -272,7 +272,7 @@ export default function AccountPage() {
</div>
</>
)}
<div className="system-sm-semibold mt-8 text-text-secondary">
<div className="mt-8 text-text-secondary system-sm-semibold">
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
</div>
<div className="relative mt-2">
@@ -291,7 +291,7 @@ export default function AccountPage() {
</Button>
</div>
</div>
<div className="system-sm-semibold mt-8 text-text-secondary">{t('account.confirmPassword', { ns: 'common' })}</div>
<div className="mt-8 text-text-secondary system-sm-semibold">{t('account.confirmPassword', { ns: 'common' })}</div>
<div className="relative mt-2">
<Input
type={showConfirmPassword ? 'text' : 'password'}

Some files were not shown because too many files have changed in this diff Show More