mirror of
https://github.com/langgenius/dify.git
synced 2026-03-06 15:45:14 +00:00
refactor(web): replace model provider emitter refresh with jotai state
- add atom-based provider expansion state with reset/prune helpers - remove event-emitter dependency from model provider refresh flow - invalidate exact provider model-list query key on refresh - reset expansion state on model provider page mount/unmount - update and extend tests for external expansion and query invalidation - update eslint suppressions to match current code
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { selectAtom } from 'jotai/utils'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
|
||||
const modelProviderListExpandedAtom = atom<Record<string, boolean>>({})
|
||||
|
||||
const setModelProviderListExpandedAtom = atom(
|
||||
null,
|
||||
(get, set, params: { providerName: string, expanded: boolean }) => {
|
||||
const { providerName, expanded } = params
|
||||
const current = get(modelProviderListExpandedAtom)
|
||||
|
||||
if (expanded) {
|
||||
if (current[providerName])
|
||||
return
|
||||
|
||||
set(modelProviderListExpandedAtom, {
|
||||
...current,
|
||||
[providerName]: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if (!current[providerName])
|
||||
return
|
||||
|
||||
const next = { ...current }
|
||||
delete next[providerName]
|
||||
set(modelProviderListExpandedAtom, next)
|
||||
},
|
||||
)
|
||||
|
||||
const resetModelProviderListExpandedAtom = atom(
|
||||
null,
|
||||
(_get, set) => {
|
||||
set(modelProviderListExpandedAtom, {})
|
||||
},
|
||||
)
|
||||
|
||||
export function useModelProviderListExpanded(providerName: string) {
|
||||
const selectedAtom = useMemo(
|
||||
() => selectAtom(modelProviderListExpandedAtom, state => state[providerName] ?? false),
|
||||
[providerName],
|
||||
)
|
||||
return useAtomValue(selectedAtom)
|
||||
}
|
||||
|
||||
export function useSetModelProviderListExpanded(providerName: string) {
|
||||
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
|
||||
return useCallback((expanded: boolean) => {
|
||||
setExpanded({ providerName, expanded })
|
||||
}, [providerName, setExpanded])
|
||||
}
|
||||
|
||||
export function useExpandModelProviderList() {
|
||||
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
|
||||
return useCallback((providerName: string) => {
|
||||
setExpanded({ providerName, expanded: true })
|
||||
}, [setExpanded])
|
||||
}
|
||||
|
||||
export function useResetModelProviderListExpanded() {
|
||||
return useSetAtom(resetModelProviderListExpandedAtom)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import type {
|
||||
} from './declarations'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { fetchDefaultModal, fetchModelList, fetchModelProviderCredentials } from '@/service/common'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
@@ -37,7 +38,6 @@ import {
|
||||
useUpdateModelList,
|
||||
useUpdateModelProviders,
|
||||
} from './hooks'
|
||||
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
@@ -79,14 +79,6 @@ vi.mock('@/context/modal-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: vi.fn(() => ({
|
||||
eventEmitter: {
|
||||
emit: vi.fn(),
|
||||
},
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
useMarketplacePlugins: vi.fn(() => ({
|
||||
plugins: [],
|
||||
@@ -100,12 +92,16 @@ vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('./atoms', () => ({
|
||||
useExpandModelProviderList: vi.fn(() => vi.fn()),
|
||||
}))
|
||||
|
||||
const { useQuery, useQueryClient } = await import('@tanstack/react-query')
|
||||
const { getPayUrl } = await import('@/service/common')
|
||||
const { useProviderContext } = await import('@/context/provider-context')
|
||||
const { useModalContextSelector } = await import('@/context/modal-context')
|
||||
const { useEventEmitterContextContext } = await import('@/context/event-emitter')
|
||||
const { useMarketplacePlugins, useMarketplacePluginsByCollectionId } = await import('@/app/components/plugins/marketplace/hooks')
|
||||
const { useExpandModelProviderList } = await import('./atoms')
|
||||
|
||||
describe('hooks', () => {
|
||||
beforeEach(() => {
|
||||
@@ -1200,39 +1196,52 @@ describe('hooks', () => {
|
||||
|
||||
it('should refresh providers and model lists', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
|
||||
const provider = createMockProvider()
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshModel(provider)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-providers'] })
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textEmbedding] })
|
||||
})
|
||||
|
||||
it('should emit event when refreshModelList is true and custom config is active', () => {
|
||||
it('should expand target provider list when refreshModelList is true and custom config is active', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
const expandModelProviderList = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
|
||||
|
||||
const provider = createMockProvider()
|
||||
const customFields: CustomConfigurationModelFixedFields = {
|
||||
__model_name: 'gpt-4',
|
||||
__model_type: ModelTypeEnum.textGeneration,
|
||||
}
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
@@ -1240,23 +1249,30 @@ describe('hooks', () => {
|
||||
result.current.handleRefreshModel(provider, customFields, true)
|
||||
})
|
||||
|
||||
expect(emit).toHaveBeenCalledWith({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: 'openai',
|
||||
expect(expandModelProviderList).toHaveBeenCalledWith('openai')
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
|
||||
})
|
||||
|
||||
it('should not emit event when custom config is not active', () => {
|
||||
it('should not expand provider list when custom config is not active', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
const expandModelProviderList = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
|
||||
|
||||
const provider = { ...createMockProvider(), custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure } }
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
@@ -1264,16 +1280,43 @@ describe('hooks', () => {
|
||||
result.current.handleRefreshModel(provider, undefined, true)
|
||||
})
|
||||
|
||||
expect(emit).not.toHaveBeenCalled()
|
||||
expect(expandModelProviderList).not.toHaveBeenCalled()
|
||||
expect(invalidateQueries).not.toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
})
|
||||
|
||||
it('should refetch active model provider list when custom refresh callback is absent', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
|
||||
const provider = createMockProvider()
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshModel(provider, undefined, true)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle provider with single model type', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit: vi.fn() },
|
||||
})
|
||||
|
||||
const provider = {
|
||||
...createMockProvider(),
|
||||
|
||||
@@ -21,7 +21,6 @@ import {
|
||||
useMarketplacePluginsByCollectionId,
|
||||
} from '@/app/components/plugins/marketplace/hooks'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
@@ -33,12 +32,12 @@ import {
|
||||
getPayUrl,
|
||||
} from '@/service/common'
|
||||
import { commonQueryKeys } from '@/service/use-common'
|
||||
import { useExpandModelProviderList } from './atoms'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CustomConfigurationStatusEnum,
|
||||
ModelStatusEnum,
|
||||
} from './declarations'
|
||||
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
|
||||
|
||||
type UseDefaultModelAndModelList = (
|
||||
defaultModel: DefaultModelResponse | undefined,
|
||||
@@ -323,7 +322,7 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
|
||||
}
|
||||
|
||||
export const useRefreshModel = () => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const expandModelProviderList = useExpandModelProviderList()
|
||||
const queryClient = useQueryClient()
|
||||
const updateModelProviders = useUpdateModelProviders()
|
||||
const updateModelList = useUpdateModelList()
|
||||
@@ -332,8 +331,16 @@ export const useRefreshModel = () => {
|
||||
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
refreshModelList?: boolean,
|
||||
) => {
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.modelProviders.models.key(),
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
|
||||
@@ -344,15 +351,17 @@ export const useRefreshModel = () => {
|
||||
})
|
||||
|
||||
if (refreshModelList && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
|
||||
eventEmitter?.emit({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: provider.provider,
|
||||
} as any)
|
||||
expandModelProviderList(provider.provider)
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
|
||||
if (CustomConfigurationModelFixedFields?.__model_type)
|
||||
updateModelList(CustomConfigurationModelFixedFields.__model_type)
|
||||
}
|
||||
}, [eventEmitter, queryClient, updateModelList, updateModelProviders])
|
||||
}, [expandModelProviderList, queryClient, updateModelList, updateModelProviders])
|
||||
|
||||
return {
|
||||
handleRefreshModel,
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
import ModelProviderPage from './index'
|
||||
|
||||
let mockEnableMarketplace = true
|
||||
const mockResetModelProviderListExpanded = vi.fn()
|
||||
|
||||
const mockQuotaConfig = {
|
||||
quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
@@ -67,6 +68,10 @@ vi.mock('./hooks', () => ({
|
||||
useDefaultModel: (type: string) => mockDefaultModels[type] ?? { data: null, isLoading: false },
|
||||
}))
|
||||
|
||||
vi.mock('./atoms', () => ({
|
||||
useResetModelProviderListExpanded: () => mockResetModelProviderListExpanded,
|
||||
}))
|
||||
|
||||
vi.mock('./install-from-marketplace', () => ({
|
||||
default: () => <div data-testid="install-from-marketplace" />,
|
||||
}))
|
||||
|
||||
@@ -3,13 +3,14 @@ import type {
|
||||
} from './declarations'
|
||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useMemo } from 'react'
|
||||
import { useEffect, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import { useSystemFeaturesQuery } from '@/context/global-public-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useCheckInstalled } from '@/service/use-plugins'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useResetModelProviderListExpanded } from './atoms'
|
||||
import {
|
||||
CustomConfigurationStatusEnum,
|
||||
ModelTypeEnum,
|
||||
@@ -34,6 +35,7 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an
|
||||
const ModelProviderPage = ({ searchText }: Props) => {
|
||||
const debouncedSearchText = useDebounce(searchText, { wait: 500 })
|
||||
const { t } = useTranslation()
|
||||
const resetModelProviderListExpanded = useResetModelProviderListExpanded()
|
||||
const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration)
|
||||
const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank)
|
||||
@@ -127,6 +129,11 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
return [filteredConfiguredProviders, filteredNotConfiguredProviders]
|
||||
}, [configuredProviders, debouncedSearchText, notConfiguredProviders])
|
||||
|
||||
useEffect(() => {
|
||||
resetModelProviderListExpanded()
|
||||
return resetModelProviderListExpanded
|
||||
}, [resetModelProviderListExpanded])
|
||||
|
||||
return (
|
||||
<div className="relative -mt-2 pt-1">
|
||||
<div className={cn('mb-2 flex items-center')}>
|
||||
|
||||
@@ -2,6 +2,8 @@ import type { ReactNode } from 'react'
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { createStore, Provider as JotaiProvider } from 'jotai'
|
||||
import { useExpandModelProviderList } from '../atoms'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ProviderAddedCard from './index'
|
||||
|
||||
@@ -12,10 +14,6 @@ const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { prov
|
||||
queryFn: () => mockFetchModelProviderModels(input.params.provider),
|
||||
...options,
|
||||
}))
|
||||
const mockEventEmitter = {
|
||||
useSubscription: vi.fn(),
|
||||
emit: vi.fn(),
|
||||
}
|
||||
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleQuery: {
|
||||
@@ -33,12 +31,6 @@ vi.mock('@/context/app-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => ({
|
||||
eventEmitter: mockEventEmitter,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock internal components to simplify testing of the index file
|
||||
vi.mock('./credential-panel', () => ({
|
||||
default: () => <div data-testid="credential-panel" />,
|
||||
@@ -74,10 +66,27 @@ const createTestQueryClient = () => new QueryClient({
|
||||
|
||||
const renderWithQueryClient = (node: ReactNode) => {
|
||||
const queryClient = createTestQueryClient()
|
||||
const store = createStore()
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{node}
|
||||
</QueryClientProvider>,
|
||||
<JotaiProvider store={store}>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{node}
|
||||
</QueryClientProvider>
|
||||
</JotaiProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
const ExternalExpandControls = () => {
|
||||
const expandModelProviderList = useExpandModelProviderList()
|
||||
return (
|
||||
<>
|
||||
<button type="button" data-testid="expand-other-provider" onClick={() => expandModelProviderList('langgenius/anthropic/anthropic')}>
|
||||
expand other
|
||||
</button>
|
||||
<button type="button" data-testid="expand-current-provider" onClick={() => expandModelProviderList('langgenius/openai/openai')}>
|
||||
expand current
|
||||
</button>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -157,6 +166,27 @@ describe('ProviderAddedCard', () => {
|
||||
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should only react to external expansion for the matching provider', async () => {
|
||||
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
|
||||
renderWithQueryClient(
|
||||
<>
|
||||
<ProviderAddedCard provider={mockProvider} />
|
||||
<ExternalExpandControls />
|
||||
</>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('expand-other-provider'))
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(0)
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('expand-current-provider'))
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
|
||||
})
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render configure tip when provider is not in quota list and not configured', () => {
|
||||
const providerWithoutQuota = {
|
||||
...mockProvider,
|
||||
@@ -166,56 +196,19 @@ describe('ProviderAddedCard', () => {
|
||||
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should refresh model list on event subscription', async () => {
|
||||
let capturedHandler: (v: { type: string, payload: string } | null) => void = () => { }
|
||||
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
|
||||
capturedHandler = handler as (v: { type: string, payload: string } | null) => void
|
||||
})
|
||||
mockFetchModelProviderModels.mockResolvedValue({ data: [] })
|
||||
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
|
||||
expect(capturedHandler).toBeDefined()
|
||||
act(() => {
|
||||
capturedHandler({
|
||||
type: 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST',
|
||||
payload: mockProvider.provider,
|
||||
})
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Should ignore non-matching events
|
||||
act(() => {
|
||||
capturedHandler({ type: 'OTHER', payload: '' })
|
||||
capturedHandler(null)
|
||||
})
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render custom model actions for workspace managers', () => {
|
||||
const customConfigProvider = {
|
||||
...mockProvider,
|
||||
configurate_methods: [ConfigurationMethodEnum.customizableModel],
|
||||
} as unknown as ModelProvider
|
||||
const queryClient = createTestQueryClient()
|
||||
const { rerender } = render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<ProviderAddedCard provider={customConfigProvider} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
const { unmount } = renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
|
||||
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
|
||||
|
||||
unmount()
|
||||
mockIsCurrentWorkspaceManager = false
|
||||
rerender(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<ProviderAddedCard provider={customConfigProvider} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,10 +4,9 @@ import type {
|
||||
} from '../declarations'
|
||||
import type { ModelProviderQuotaGetPaid } from '../utils'
|
||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||
import type { EventEmitterValue } from '@/context/event-emitter'
|
||||
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
AddCustomModel,
|
||||
@@ -15,10 +14,10 @@ import {
|
||||
} from '@/app/components/header/account-setting/model-provider-page/model-auth'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useModelProviderListExpanded, useSetModelProviderListExpanded } from '../atoms'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ModelBadge from '../model-badge'
|
||||
import ProviderIcon from '../provider-icon'
|
||||
@@ -30,22 +29,6 @@ import CredentialPanel from './credential-panel'
|
||||
import ModelList from './model-list'
|
||||
import ProviderCardActions from './provider-card-actions'
|
||||
|
||||
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
|
||||
|
||||
const isModelProviderCustomModelListUpdateEvent = (
|
||||
value: EventEmitterValue,
|
||||
providerName: string,
|
||||
): value is {
|
||||
type: typeof UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
|
||||
payload: string
|
||||
} => {
|
||||
return typeof value === 'object'
|
||||
&& value !== null
|
||||
&& value.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
|
||||
&& typeof value.payload === 'string'
|
||||
&& value.payload === providerName
|
||||
}
|
||||
|
||||
type ProviderAddedCardProps = {
|
||||
notConfigured?: boolean
|
||||
provider: ModelProvider
|
||||
@@ -57,10 +40,10 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
pluginDetail,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const [collapsed, setCollapsed] = useState(true)
|
||||
const currentProviderName = provider.provider
|
||||
const expanded = useModelProviderListExpanded(currentProviderName)
|
||||
const setExpanded = useSetModelProviderListExpanded(currentProviderName)
|
||||
const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
|
||||
const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
|
||||
const systemConfig = provider.system_configuration
|
||||
@@ -71,12 +54,12 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
refetch: refetchModelList,
|
||||
} = useQuery(consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider: currentProviderName } },
|
||||
enabled: !collapsed,
|
||||
enabled: expanded,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}))
|
||||
const hasModelList = hasFetchedModelList && !!modelList.length
|
||||
const showCollapsedSection = collapsed || !hasFetchedModelList
|
||||
const showCollapsedSection = !expanded || !hasFetchedModelList
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
|
||||
const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
|
||||
@@ -86,32 +69,23 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
if (targetProviderName !== currentProviderName)
|
||||
return
|
||||
|
||||
if (collapsed)
|
||||
setCollapsed(false)
|
||||
if (!expanded)
|
||||
setExpanded(true)
|
||||
|
||||
refetchModelList().catch(() => {})
|
||||
}, [collapsed, currentProviderName, refetchModelList])
|
||||
}, [currentProviderName, expanded, refetchModelList, setExpanded])
|
||||
|
||||
const handleOpenModelList = useCallback(() => {
|
||||
if (loading)
|
||||
return
|
||||
|
||||
if (collapsed) {
|
||||
setCollapsed(false)
|
||||
if (!expanded) {
|
||||
setExpanded(true)
|
||||
return
|
||||
}
|
||||
|
||||
refetchModelList().catch(() => {})
|
||||
}, [collapsed, loading, refetchModelList])
|
||||
|
||||
const handleModelProviderCustomModelListUpdate = useCallback((value: EventEmitterValue) => {
|
||||
if (!isModelProviderCustomModelListUpdateEvent(value, currentProviderName))
|
||||
return
|
||||
|
||||
refreshModelList(currentProviderName)
|
||||
}, [currentProviderName, refreshModelList])
|
||||
|
||||
eventEmitter?.useSubscription(handleModelProviderCustomModelListUpdate)
|
||||
}, [expanded, loading, refetchModelList, setExpanded])
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -198,7 +172,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
<ModelList
|
||||
provider={provider}
|
||||
models={modelList}
|
||||
onCollapse={() => setCollapsed(true)}
|
||||
onCollapse={() => setExpanded(false)}
|
||||
onChange={refreshModelList}
|
||||
/>
|
||||
)
|
||||
|
||||
@@ -4681,7 +4681,7 @@
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 3
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/header/account-setting/model-provider-page/install-from-marketplace.tsx": {
|
||||
|
||||
Reference in New Issue
Block a user