refactor(web): replace model provider emitter refresh with jotai state

- add atom-based provider expansion state with reset/prune helpers
- remove event-emitter dependency from model provider refresh flow
- invalidate exact provider model-list query key on refresh
- reset expansion state on model provider page mount/unmount
- update and extend tests for external expansion and query invalidation
- update eslint suppressions to match current code
This commit is contained in:
yyh
2026-03-05 13:20:11 +08:00
parent 4a1032c628
commit 5f4ed4c6f6
8 changed files with 229 additions and 134 deletions

View File

@@ -0,0 +1,64 @@
import { atom, useAtomValue, useSetAtom } from 'jotai'
import { selectAtom } from 'jotai/utils'
import { useCallback, useMemo } from 'react'
const modelProviderListExpandedAtom = atom<Record<string, boolean>>({})
const setModelProviderListExpandedAtom = atom(
null,
(get, set, params: { providerName: string, expanded: boolean }) => {
const { providerName, expanded } = params
const current = get(modelProviderListExpandedAtom)
if (expanded) {
if (current[providerName])
return
set(modelProviderListExpandedAtom, {
...current,
[providerName]: true,
})
return
}
if (!current[providerName])
return
const next = { ...current }
delete next[providerName]
set(modelProviderListExpandedAtom, next)
},
)
const resetModelProviderListExpandedAtom = atom(
null,
(_get, set) => {
set(modelProviderListExpandedAtom, {})
},
)
export function useModelProviderListExpanded(providerName: string) {
const selectedAtom = useMemo(
() => selectAtom(modelProviderListExpandedAtom, state => state[providerName] ?? false),
[providerName],
)
return useAtomValue(selectedAtom)
}
export function useSetModelProviderListExpanded(providerName: string) {
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
return useCallback((expanded: boolean) => {
setExpanded({ providerName, expanded })
}, [providerName, setExpanded])
}
export function useExpandModelProviderList() {
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
return useCallback((providerName: string) => {
setExpanded({ providerName, expanded: true })
}, [setExpanded])
}
export function useResetModelProviderListExpanded() {
return useSetAtom(resetModelProviderListExpandedAtom)
}

View File

@@ -9,6 +9,7 @@ import type {
} from './declarations'
import { act, renderHook, waitFor } from '@testing-library/react'
import { useLocale } from '@/context/i18n'
import { consoleQuery } from '@/service/client'
import { fetchDefaultModal, fetchModelList, fetchModelProviderCredentials } from '@/service/common'
import {
ConfigurationMethodEnum,
@@ -37,7 +38,6 @@ import {
useUpdateModelList,
useUpdateModelProviders,
} from './hooks'
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
// Mock dependencies
vi.mock('@tanstack/react-query', () => ({
@@ -79,14 +79,6 @@ vi.mock('@/context/modal-context', () => ({
}),
}))
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: vi.fn(() => ({
eventEmitter: {
emit: vi.fn(),
},
})),
}))
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
useMarketplacePlugins: vi.fn(() => ({
plugins: [],
@@ -100,12 +92,16 @@ vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
})),
}))
vi.mock('./atoms', () => ({
useExpandModelProviderList: vi.fn(() => vi.fn()),
}))
const { useQuery, useQueryClient } = await import('@tanstack/react-query')
const { getPayUrl } = await import('@/service/common')
const { useProviderContext } = await import('@/context/provider-context')
const { useModalContextSelector } = await import('@/context/modal-context')
const { useEventEmitterContextContext } = await import('@/context/event-emitter')
const { useMarketplacePlugins, useMarketplacePluginsByCollectionId } = await import('@/app/components/plugins/marketplace/hooks')
const { useExpandModelProviderList } = await import('./atoms')
describe('hooks', () => {
beforeEach(() => {
@@ -1200,39 +1196,52 @@ describe('hooks', () => {
it('should refresh providers and model lists', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
const provider = createMockProvider()
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
act(() => {
result.current.handleRefreshModel(provider)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-providers'] })
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textEmbedding] })
})
it('should emit event when refreshModelList is true and custom config is active', () => {
it('should expand target provider list when refreshModelList is true and custom config is active', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
const expandModelProviderList = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
const provider = createMockProvider()
const customFields: CustomConfigurationModelFixedFields = {
__model_name: 'gpt-4',
__model_type: ModelTypeEnum.textGeneration,
}
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
@@ -1240,23 +1249,30 @@ describe('hooks', () => {
result.current.handleRefreshModel(provider, customFields, true)
})
expect(emit).toHaveBeenCalledWith({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: 'openai',
expect(expandModelProviderList).toHaveBeenCalledWith('openai')
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
})
it('should not emit event when custom config is not active', () => {
it('should not expand provider list when custom config is not active', () => {
const invalidateQueries = vi.fn()
const emit = vi.fn()
const expandModelProviderList = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit },
})
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
const provider = { ...createMockProvider(), custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure } }
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
@@ -1264,16 +1280,43 @@ describe('hooks', () => {
result.current.handleRefreshModel(provider, undefined, true)
})
expect(emit).not.toHaveBeenCalled()
expect(expandModelProviderList).not.toHaveBeenCalled()
expect(invalidateQueries).not.toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
})
it('should refetch active model provider list when custom refresh callback is absent', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
const provider = createMockProvider()
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
const { result } = renderHook(() => useRefreshModel())
act(() => {
result.current.handleRefreshModel(provider, undefined, true)
})
expect(invalidateQueries).toHaveBeenCalledWith({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
})
it('should handle provider with single model type', () => {
const invalidateQueries = vi.fn()
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
; (useEventEmitterContextContext as Mock).mockReturnValue({
eventEmitter: { emit: vi.fn() },
})
const provider = {
...createMockProvider(),

View File

@@ -21,7 +21,6 @@ import {
useMarketplacePluginsByCollectionId,
} from '@/app/components/plugins/marketplace/hooks'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useLocale } from '@/context/i18n'
import { useModalContextSelector } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'
@@ -33,12 +32,12 @@ import {
getPayUrl,
} from '@/service/common'
import { commonQueryKeys } from '@/service/use-common'
import { useExpandModelProviderList } from './atoms'
import {
ConfigurationMethodEnum,
CustomConfigurationStatusEnum,
ModelStatusEnum,
} from './declarations'
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
type UseDefaultModelAndModelList = (
defaultModel: DefaultModelResponse | undefined,
@@ -323,7 +322,7 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
}
export const useRefreshModel = () => {
const { eventEmitter } = useEventEmitterContextContext()
const expandModelProviderList = useExpandModelProviderList()
const queryClient = useQueryClient()
const updateModelProviders = useUpdateModelProviders()
const updateModelList = useUpdateModelList()
@@ -332,8 +331,16 @@ export const useRefreshModel = () => {
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
refreshModelList?: boolean,
) => {
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
input: {
params: {
provider: provider.provider,
},
},
})
queryClient.invalidateQueries({
queryKey: consoleQuery.modelProviders.models.key(),
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'none',
})
@@ -344,15 +351,17 @@ export const useRefreshModel = () => {
})
if (refreshModelList && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
eventEmitter?.emit({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: provider.provider,
} as any)
expandModelProviderList(provider.provider)
queryClient.invalidateQueries({
queryKey: modelProviderModelListQueryKey,
exact: true,
refetchType: 'active',
})
if (CustomConfigurationModelFixedFields?.__model_type)
updateModelList(CustomConfigurationModelFixedFields.__model_type)
}
}, [eventEmitter, queryClient, updateModelList, updateModelProviders])
}, [expandModelProviderList, queryClient, updateModelList, updateModelProviders])
return {
handleRefreshModel,

View File

@@ -8,6 +8,7 @@ import {
import ModelProviderPage from './index'
let mockEnableMarketplace = true
const mockResetModelProviderListExpanded = vi.fn()
const mockQuotaConfig = {
quota_type: CurrentSystemQuotaTypeEnum.free,
@@ -67,6 +68,10 @@ vi.mock('./hooks', () => ({
useDefaultModel: (type: string) => mockDefaultModels[type] ?? { data: null, isLoading: false },
}))
vi.mock('./atoms', () => ({
useResetModelProviderListExpanded: () => mockResetModelProviderListExpanded,
}))
vi.mock('./install-from-marketplace', () => ({
default: () => <div data-testid="install-from-marketplace" />,
}))

View File

@@ -3,13 +3,14 @@ import type {
} from './declarations'
import type { PluginDetail } from '@/app/components/plugins/types'
import { useDebounce } from 'ahooks'
import { useMemo } from 'react'
import { useEffect, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { IS_CLOUD_EDITION } from '@/config'
import { useSystemFeaturesQuery } from '@/context/global-public-context'
import { useProviderContext } from '@/context/provider-context'
import { useCheckInstalled } from '@/service/use-plugins'
import { cn } from '@/utils/classnames'
import { useResetModelProviderListExpanded } from './atoms'
import {
CustomConfigurationStatusEnum,
ModelTypeEnum,
@@ -34,6 +35,7 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an
const ModelProviderPage = ({ searchText }: Props) => {
const debouncedSearchText = useDebounce(searchText, { wait: 500 })
const { t } = useTranslation()
const resetModelProviderListExpanded = useResetModelProviderListExpanded()
const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration)
const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding)
const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank)
@@ -127,6 +129,11 @@ const ModelProviderPage = ({ searchText }: Props) => {
return [filteredConfiguredProviders, filteredNotConfiguredProviders]
}, [configuredProviders, debouncedSearchText, notConfiguredProviders])
useEffect(() => {
resetModelProviderListExpanded()
return resetModelProviderListExpanded
}, [resetModelProviderListExpanded])
return (
<div className="relative -mt-2 pt-1">
<div className={cn('mb-2 flex items-center')}>

View File

@@ -2,6 +2,8 @@ import type { ReactNode } from 'react'
import type { ModelProvider } from '../declarations'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { createStore, Provider as JotaiProvider } from 'jotai'
import { useExpandModelProviderList } from '../atoms'
import { ConfigurationMethodEnum } from '../declarations'
import ProviderAddedCard from './index'
@@ -12,10 +14,6 @@ const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { prov
queryFn: () => mockFetchModelProviderModels(input.params.provider),
...options,
}))
const mockEventEmitter = {
useSubscription: vi.fn(),
emit: vi.fn(),
}
vi.mock('@/service/client', () => ({
consoleQuery: {
@@ -33,12 +31,6 @@ vi.mock('@/context/app-context', () => ({
}),
}))
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: () => ({
eventEmitter: mockEventEmitter,
}),
}))
// Mock internal components to simplify testing of the index file
vi.mock('./credential-panel', () => ({
default: () => <div data-testid="credential-panel" />,
@@ -74,10 +66,27 @@ const createTestQueryClient = () => new QueryClient({
const renderWithQueryClient = (node: ReactNode) => {
const queryClient = createTestQueryClient()
const store = createStore()
return render(
<QueryClientProvider client={queryClient}>
{node}
</QueryClientProvider>,
<JotaiProvider store={store}>
<QueryClientProvider client={queryClient}>
{node}
</QueryClientProvider>
</JotaiProvider>,
)
}
const ExternalExpandControls = () => {
const expandModelProviderList = useExpandModelProviderList()
return (
<>
<button type="button" data-testid="expand-other-provider" onClick={() => expandModelProviderList('langgenius/anthropic/anthropic')}>
expand other
</button>
<button type="button" data-testid="expand-current-provider" onClick={() => expandModelProviderList('langgenius/openai/openai')}>
expand current
</button>
</>
)
}
@@ -157,6 +166,27 @@ describe('ProviderAddedCard', () => {
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
})
it('should only react to external expansion for the matching provider', async () => {
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
renderWithQueryClient(
<>
<ProviderAddedCard provider={mockProvider} />
<ExternalExpandControls />
</>,
)
fireEvent.click(screen.getByTestId('expand-other-provider'))
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(0)
})
fireEvent.click(screen.getByTestId('expand-current-provider'))
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
})
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
it('should render configure tip when provider is not in quota list and not configured', () => {
const providerWithoutQuota = {
...mockProvider,
@@ -166,56 +196,19 @@ describe('ProviderAddedCard', () => {
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
})
it('should refresh model list on event subscription', async () => {
let capturedHandler: (v: { type: string, payload: string } | null) => void = () => { }
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
capturedHandler = handler as (v: { type: string, payload: string } | null) => void
})
mockFetchModelProviderModels.mockResolvedValue({ data: [] })
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
expect(capturedHandler).toBeDefined()
act(() => {
capturedHandler({
type: 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST',
payload: mockProvider.provider,
})
})
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
// Should ignore non-matching events
act(() => {
capturedHandler({ type: 'OTHER', payload: '' })
capturedHandler(null)
})
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
it('should render custom model actions for workspace managers', () => {
const customConfigProvider = {
...mockProvider,
configurate_methods: [ConfigurationMethodEnum.customizableModel],
} as unknown as ModelProvider
const queryClient = createTestQueryClient()
const { rerender } = render(
<QueryClientProvider client={queryClient}>
<ProviderAddedCard provider={customConfigProvider} />
</QueryClientProvider>,
)
const { unmount } = renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
unmount()
mockIsCurrentWorkspaceManager = false
rerender(
<QueryClientProvider client={queryClient}>
<ProviderAddedCard provider={customConfigProvider} />
</QueryClientProvider>,
)
renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
})
})

View File

@@ -4,10 +4,9 @@ import type {
} from '../declarations'
import type { ModelProviderQuotaGetPaid } from '../utils'
import type { PluginDetail } from '@/app/components/plugins/types'
import type { EventEmitterValue } from '@/context/event-emitter'
import { useQuery } from '@tanstack/react-query'
import { useCallback, useState } from 'react'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import {
AddCustomModel,
@@ -15,10 +14,10 @@ import {
} from '@/app/components/header/account-setting/model-provider-page/model-auth'
import { IS_CE_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useProviderContext } from '@/context/provider-context'
import { consoleQuery } from '@/service/client'
import { cn } from '@/utils/classnames'
import { useModelProviderListExpanded, useSetModelProviderListExpanded } from '../atoms'
import { ConfigurationMethodEnum } from '../declarations'
import ModelBadge from '../model-badge'
import ProviderIcon from '../provider-icon'
@@ -30,22 +29,6 @@ import CredentialPanel from './credential-panel'
import ModelList from './model-list'
import ProviderCardActions from './provider-card-actions'
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
const isModelProviderCustomModelListUpdateEvent = (
value: EventEmitterValue,
providerName: string,
): value is {
type: typeof UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
payload: string
} => {
return typeof value === 'object'
&& value !== null
&& value.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
&& typeof value.payload === 'string'
&& value.payload === providerName
}
type ProviderAddedCardProps = {
notConfigured?: boolean
provider: ModelProvider
@@ -57,10 +40,10 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
pluginDetail,
}) => {
const { t } = useTranslation()
const { eventEmitter } = useEventEmitterContextContext()
const { refreshModelProviders } = useProviderContext()
const [collapsed, setCollapsed] = useState(true)
const currentProviderName = provider.provider
const expanded = useModelProviderListExpanded(currentProviderName)
const setExpanded = useSetModelProviderListExpanded(currentProviderName)
const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
const systemConfig = provider.system_configuration
@@ -71,12 +54,12 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
refetch: refetchModelList,
} = useQuery(consoleQuery.modelProviders.models.queryOptions({
input: { params: { provider: currentProviderName } },
enabled: !collapsed,
enabled: expanded,
refetchOnWindowFocus: false,
select: response => response.data,
}))
const hasModelList = hasFetchedModelList && !!modelList.length
const showCollapsedSection = collapsed || !hasFetchedModelList
const showCollapsedSection = !expanded || !hasFetchedModelList
const { isCurrentWorkspaceManager } = useAppContext()
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
@@ -86,32 +69,23 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
if (targetProviderName !== currentProviderName)
return
if (collapsed)
setCollapsed(false)
if (!expanded)
setExpanded(true)
refetchModelList().catch(() => {})
}, [collapsed, currentProviderName, refetchModelList])
}, [currentProviderName, expanded, refetchModelList, setExpanded])
const handleOpenModelList = useCallback(() => {
if (loading)
return
if (collapsed) {
setCollapsed(false)
if (!expanded) {
setExpanded(true)
return
}
refetchModelList().catch(() => {})
}, [collapsed, loading, refetchModelList])
const handleModelProviderCustomModelListUpdate = useCallback((value: EventEmitterValue) => {
if (!isModelProviderCustomModelListUpdateEvent(value, currentProviderName))
return
refreshModelList(currentProviderName)
}, [currentProviderName, refreshModelList])
eventEmitter?.useSubscription(handleModelProviderCustomModelListUpdate)
}, [expanded, loading, refetchModelList, setExpanded])
return (
<div
@@ -198,7 +172,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
<ModelList
provider={provider}
models={modelList}
onCollapse={() => setCollapsed(true)}
onCollapse={() => setExpanded(false)}
onChange={refreshModelList}
/>
)

View File

@@ -4681,7 +4681,7 @@
"count": 1
},
"ts/no-explicit-any": {
"count": 3
"count": 2
}
},
"app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx": {