diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 7e80183e0c..1a7ed1e1f5 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -25,6 +25,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter' import { useLocale } from '@/context/i18n' import { useModalContextSelector } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' +import { consoleQuery } from '@/service/client' import { fetchDefaultModal, fetchModelList, @@ -323,6 +324,7 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: export const useRefreshModel = () => { const { eventEmitter } = useEventEmitterContextContext() + const queryClient = useQueryClient() const updateModelProviders = useUpdateModelProviders() const updateModelList = useUpdateModelList() const handleRefreshModel = useCallback(( @@ -330,6 +332,11 @@ export const useRefreshModel = () => { CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, refreshModelList?: boolean, ) => { + queryClient.invalidateQueries({ + queryKey: consoleQuery.modelProviders.models.key(), + refetchType: 'none', + }) + updateModelProviders() provider.supported_model_types.forEach((type) => { @@ -345,7 +352,7 @@ export const useRefreshModel = () => { if (CustomConfigurationModelFixedFields?.__model_type) updateModelList(CustomConfigurationModelFixedFields.__model_type) } - }, [eventEmitter, updateModelList, updateModelProviders]) + }, [eventEmitter, queryClient, updateModelList, updateModelProviders]) return { handleRefreshModel, diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx index 554efc93d2..28f6094ded 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.spec.tsx @@ -1,4 +1,5 @@ import type { ModelProvider } from '../declarations' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { changeModelProviderPriority } from '@/service/common' import { ConfigurationMethodEnum } from '../declarations' @@ -71,6 +72,21 @@ vi.mock('@/app/components/header/indicator', () => ({ default: ({ color }: { color: string }) =>
{color}
, })) +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + }, +}) + +const renderWithQueryClient = (provider: ModelProvider) => { + const queryClient = createTestQueryClient() + return render( + + + , + ) +} + describe('CredentialPanel', () => { const mockProvider: ModelProvider = { provider: 'test-provider', @@ -94,7 +110,7 @@ describe('CredentialPanel', () => { }) it('should show credential name and configuration actions', () => { - render() + renderWithQueryClient(mockProvider) expect(screen.getByText('test-credential')).toBeInTheDocument() expect(screen.getByTestId('config-provider')).toBeInTheDocument() @@ -103,7 +119,7 @@ describe('CredentialPanel', () => { it('should show unauthorized status label when credential is missing', () => { mockCredentialStatus.hasCredential = false - render() + renderWithQueryClient(mockProvider) expect(screen.getByText(/modelProvider\.auth\.unAuthorized/)).toBeInTheDocument() }) @@ -111,7 +127,7 @@ describe('CredentialPanel', () => { it('should show removed credential label and priority tip for custom preference', () => { mockCredentialStatus.authorized = false mockCredentialStatus.authRemoved = true - render() + renderWithQueryClient({ ...mockProvider, preferred_provider_type: 'custom' } as ModelProvider) expect(screen.getByText(/modelProvider\.auth\.authRemoved/)).toBeInTheDocument() expect(screen.getByTestId('priority-use-tip')).toBeInTheDocument() @@ -120,7 +136,7 @@ describe('CredentialPanel', () => { it('should change priority and refresh related data after success', async () => { const mockChangePriority = changeModelProviderPriority as ReturnType mockChangePriority.mockResolvedValue({ result: 'success' }) - render() + renderWithQueryClient(mockProvider) fireEvent.click(screen.getByTestId('priority-selector')) @@ -138,7 +154,7 @@ describe('CredentialPanel', () => { ...mockProvider, provider_credential_schema: null, } as unknown as ModelProvider - render() + renderWithQueryClient(providerNoSchema) expect(screen.getByTestId('priority-selector')).toBeInTheDocument() expect(screen.queryByTestId('config-provider')).not.toBeInTheDocument() }) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx index c46f9d56bd..ba7079ef88 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/credential-panel.tsx @@ -1,6 +1,7 @@ import type { ModelProvider, } from '../declarations' +import { useQueryClient } from '@tanstack/react-query' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useToastContext } from '@/app/components/base/toast' @@ -9,6 +10,7 @@ import { useCredentialStatus } from '@/app/components/header/account-setting/mod import Indicator from '@/app/components/header/indicator' import { IS_CLOUD_EDITION } from '@/config' import { useEventEmitterContextContext } from '@/context/event-emitter' +import { consoleQuery } from '@/service/client' import { changeModelProviderPriority } from '@/service/common' import { cn } from '@/utils/classnames' import { @@ -34,6 +36,7 @@ const CredentialPanel = ({ const { t } = useTranslation() const { notify } = useToastContext() const { eventEmitter } = useEventEmitterContextContext() + const queryClient = useQueryClient() const updateModelList = useUpdateModelList() const updateModelProviders = useUpdateModelProviders() const customConfig = provider.custom_configuration @@ -60,6 +63,10 @@ const CredentialPanel = ({ }) if (res.result === 'success') { notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) + queryClient.invalidateQueries({ + queryKey: consoleQuery.modelProviders.models.key(), + refetchType: 'none', + }) updateModelProviders() configurateMethods.forEach((method) => { @@ -82,7 +89,7 @@ const CredentialPanel = ({ return t('modelProvider.auth.authRemoved', { ns: 'common' }) return '' - }, [authorized, authRemoved, current_credential_name, hasCredential]) + }, [authorized, authRemoved, current_credential_name, hasCredential, t]) const color = useMemo(() => { if (authRemoved || !hasCredential) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx index 28b95c891c..586dd67d44 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx @@ -120,13 +120,13 @@ describe('ProviderAddedCard', () => { // Explicitly re-find and click to re-open fireEvent.click(screen.getByTestId('show-models-button')) expect(await screen.findByTestId('model-list')).toBeInTheDocument() - expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) // Should not fetch again + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2) // Re-open fetches again with default stale/gc behavior // Refresh list from ModelList const refreshBtn = screen.getByRole('button', { name: 'refresh list' }) fireEvent.click(refreshBtn) await waitFor(() => { - expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2) + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(3) }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx index a715e75b05..553b91da1b 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx @@ -83,14 +83,14 @@ const ProviderAddedCard: FC = ({ const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager const refreshModelList = useCallback((targetProviderName: string) => { - if (targetProviderName !== currentProviderName || loading) + if (targetProviderName !== currentProviderName) return if (collapsed) setCollapsed(false) refetchModelList().catch(() => {}) - }, [collapsed, currentProviderName, loading, refetchModelList]) + }, [collapsed, currentProviderName, refetchModelList]) const handleOpenModelList = useCallback(() => { if (loading) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx index 908d2f0e6c..a21ece2384 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/model-list-item.tsx @@ -1,4 +1,5 @@ import type { ModelItem, ModelProvider } from '../declarations' +import { useQueryClient } from '@tanstack/react-query' import { useDebounceFn } from 'ahooks' import { memo, useCallback } from 'react' import { useTranslation } from 'react-i18next' @@ -9,6 +10,7 @@ import Tooltip from '@/app/components/base/tooltip' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' import { useProviderContext, useProviderContextSelector } from '@/context/provider-context' +import { consoleQuery } from '@/service/client' import { disableModel, enableModel } from '@/service/common' import { cn } from '@/utils/classnames' import { ModelStatusEnum } from '../declarations' @@ -30,6 +32,7 @@ const ModelListItem = ({ model, provider, isConfigurable, onChange, onModifyLoad const { plan } = useProviderContext() const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) const { isCurrentWorkspaceManager } = useAppContext() + const queryClient = useQueryClient() const updateModelList = useUpdateModelList() const toggleModelEnablingStatus = useCallback(async (enabled: boolean) => { @@ -37,9 +40,14 @@ const ModelListItem = ({ model, provider, isConfigurable, onChange, onModifyLoad await enableModel(`/workspaces/current/model-providers/${provider.provider}/models/enable`, { model: model.model, model_type: model.model_type }) else await disableModel(`/workspaces/current/model-providers/${provider.provider}/models/disable`, { model: model.model, model_type: model.model_type }) + + queryClient.invalidateQueries({ + queryKey: consoleQuery.modelProviders.models.key(), + refetchType: 'none', + }) updateModelList(model.model_type) onChange?.(provider.provider) - }, [model.model, model.model_type, onChange, provider.provider, updateModelList]) + }, [model.model, model.model_type, onChange, provider.provider, queryClient, updateModelList]) const { run: debouncedToggleModelEnablingStatus } = useDebounceFn(toggleModelEnablingStatus, { wait: 500 })