mirror of
https://github.com/langgenius/dify.git
synced 2026-03-16 04:37:04 +00:00
Compare commits
17 Commits
main
...
fix/use-ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92872d820c | ||
|
|
b5c5cb72f0 | ||
|
|
1863cdbff0 | ||
|
|
cc769b9dd8 | ||
|
|
6f8ff490a8 | ||
|
|
a9ff96113a | ||
|
|
0d549af0a9 | ||
|
|
2a04f59984 | ||
|
|
a9f507cd2e | ||
|
|
96fd3bde1b | ||
|
|
f9ded4960a | ||
|
|
c1e8735f7a | ||
|
|
8f5d65a7fc | ||
|
|
8f8477a19f | ||
|
|
214f24bcd3 | ||
|
|
1677a52900 | ||
|
|
e054d5d8b3 |
@@ -187,53 +187,13 @@ const Template = useMemo(() => {
|
||||
|
||||
**When**: Component directly handles API calls, data transformation, or complex async operations.
|
||||
|
||||
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
|
||||
|
||||
```typescript
|
||||
// ❌ Before: API logic in component
|
||||
const MCPServiceCard = () => {
|
||||
const [basicAppConfig, setBasicAppConfig] = useState({})
|
||||
|
||||
useEffect(() => {
|
||||
if (isBasicApp && appId) {
|
||||
(async () => {
|
||||
const res = await fetchAppDetail({ url: '/apps', id: appId })
|
||||
setBasicAppConfig(res?.model_config || {})
|
||||
})()
|
||||
}
|
||||
}, [appId, isBasicApp])
|
||||
|
||||
// More API-related logic...
|
||||
}
|
||||
|
||||
// ✅ After: Extract to data hook using React Query
|
||||
// use-app-config.ts
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
|
||||
return useQuery({
|
||||
enabled: isBasicApp && !!appId,
|
||||
queryKey: [NAME_SPACE, 'detail', appId],
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || {},
|
||||
})
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const MCPServiceCard = () => {
|
||||
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
|
||||
// UI only
|
||||
}
|
||||
```
|
||||
|
||||
**React Query Best Practices in Dify**:
|
||||
- Define `NAME_SPACE` for query key organization
|
||||
- Use `enabled` option for conditional fetching
|
||||
- Use `select` for data transformation
|
||||
- Export invalidation hooks: `useInvalidXxx`
|
||||
**Dify Convention**:
|
||||
- This skill is for component decomposition, not query/mutation design.
|
||||
- When refactoring data fetching, follow `web/AGENTS.md`.
|
||||
- Use `orpc-contract-first` for contracts, query shape, data-fetching wrappers, and query/mutation call-site patterns.
|
||||
- Use `web/docs/query-mutation.md` for Dify-specific conditional query, invalidation, and mutation error-handling rules.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.
|
||||
|
||||
**Dify Examples**:
|
||||
- `web/service/use-workflow.ts`
|
||||
|
||||
@@ -155,48 +155,15 @@ const Configuration: FC = () => {
|
||||
|
||||
## Common Hook Patterns in Dify
|
||||
|
||||
### 1. Data Fetching Hook (React Query)
|
||||
### 1. Data Fetching / Mutation Hooks
|
||||
|
||||
```typescript
|
||||
// Pattern: Use @tanstack/react-query for data fetching
|
||||
import { useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { get } from '@/service/base'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
|
||||
|
||||
const NAME_SPACE = 'appConfig'
|
||||
|
||||
// Query keys for cache management
|
||||
export const appConfigQueryKeys = {
|
||||
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
|
||||
}
|
||||
|
||||
// Main data hook
|
||||
export const useAppConfig = (appId: string) => {
|
||||
return useQuery({
|
||||
enabled: !!appId,
|
||||
queryKey: appConfigQueryKeys.detail(appId),
|
||||
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
|
||||
select: data => data?.model_config || null,
|
||||
})
|
||||
}
|
||||
|
||||
// Invalidation hook for refreshing data
|
||||
export const useInvalidAppConfig = () => {
|
||||
return useInvalid([NAME_SPACE])
|
||||
}
|
||||
|
||||
// Usage in component
|
||||
const Component = () => {
|
||||
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
|
||||
const invalidAppConfig = useInvalidAppConfig()
|
||||
|
||||
const handleRefresh = () => {
|
||||
invalidAppConfig() // Invalidates cache and triggers refetch
|
||||
}
|
||||
|
||||
return <div>...</div>
|
||||
}
|
||||
```
|
||||
- Follow `web/AGENTS.md` first.
|
||||
- Use `orpc-contract-first` for contracts, query shape, data-fetching wrappers, and query/mutation call-site patterns.
|
||||
- Use `web/docs/query-mutation.md` for conditional query, invalidation, and mutation error-handling rules.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.
|
||||
|
||||
### 2. Form State Hook
|
||||
|
||||
|
||||
44
.agents/skills/frontend-query-mutation/SKILL.md
Normal file
44
.agents/skills/frontend-query-mutation/SKILL.md
Normal file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep invalidation and mutation flow knowledge in the service layer.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind invalidation in the service-layer mutation definition.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
|
||||
- `web/contract/console/*.ts`
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- `web/service/use-*.ts`
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
|
||||
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
|
||||
|
||||
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.
|
||||
@@ -0,0 +1,4 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query and oRPC patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
|
||||
@@ -0,0 +1,98 @@
|
||||
# Contract Patterns
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Intent
|
||||
- Minimal structure
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
|
||||
- Use `base.route({...}).output(type<...>())` as the baseline.
|
||||
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
|
||||
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
2. Register contract in `web/contract/router.ts`.
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utilities.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- Input structure: always use `{ params, query?, body? }`.
|
||||
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
- Path params: use `{paramName}` in the path and match it in the `params` object.
|
||||
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
|
||||
- No barrel files: import directly from specific files.
|
||||
- Types: import from `@/types/` and use the `type<T>()` helper.
|
||||
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
@@ -0,0 +1,133 @@
|
||||
# Runtime Rules
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
- Legacy migration
|
||||
|
||||
## Conditional Queries
|
||||
|
||||
Prefer contract-shaped `queryOptions(...)`.
|
||||
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
|
||||
Use `enabled` only for extra business gating after the input itself is already valid.
|
||||
|
||||
```typescript
|
||||
import { skipToken, useQuery } from '@tanstack/react-query'
|
||||
|
||||
// Disable the query by skipping input construction.
|
||||
function useAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: appId
|
||||
? { params: { appId } }
|
||||
: skipToken,
|
||||
}))
|
||||
}
|
||||
|
||||
// Avoid runtime-only guards that bypass type checking.
|
||||
function useBadAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: { params: { appId: appId! } },
|
||||
enabled: !!appId,
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind invalidation in the service-layer mutation definition.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
|
||||
- `.key()` for namespace or prefix invalidation
|
||||
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
|
||||
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
|
||||
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Service layer owns cache invalidation.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
mutate({ appId, mode }, {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Key API Guide
|
||||
|
||||
- `.key(...)`
|
||||
- Use for partial matching operations.
|
||||
- Prefer it for invalidation, refetch, and cancel patterns.
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`
|
||||
- Use for a specific query's full key.
|
||||
- Prefer it for exact cache addressing and direct reads or writes.
|
||||
- `.mutationKey(...)`
|
||||
- Use for a specific mutation's full key.
|
||||
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
|
||||
|
||||
## `mutate` vs `mutateAsync`
|
||||
|
||||
Prefer `mutate` by default.
|
||||
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
|
||||
|
||||
Rules:
|
||||
|
||||
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
|
||||
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
|
||||
- Do not use `mutateAsync` when callbacks already express the flow clearly.
|
||||
|
||||
```typescript
|
||||
// Default case.
|
||||
mutation.mutate(data, {
|
||||
onSuccess: result => router.push(result.url),
|
||||
})
|
||||
|
||||
// Promise semantics are required.
|
||||
try {
|
||||
const order = await createOrder.mutateAsync(orderData)
|
||||
await confirmPayment.mutateAsync({ orderId: order.id, token })
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
## Legacy Migration
|
||||
|
||||
When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
|
||||
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
@@ -1,103 +0,0 @@
|
||||
---
|
||||
name: orpc-contract-first
|
||||
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
|
||||
---
|
||||
|
||||
# oRPC Contract-First Development
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as single source of truth in `web/contract/*`.
|
||||
- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
|
||||
- Use `base.route({...}).output(type<...>())` as baseline.
|
||||
- Add `.input(type<...>())` only when request has `params/query/body`.
|
||||
- For `GET` without input, omit `.input(...)` (do not use `.input(type<unknown>())`).
|
||||
2. Register contract in `web/contract/router.ts`
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utils.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default: call site directly uses `*.queryOptions(...)`.
|
||||
2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration:
|
||||
- Combine multiple queries/mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic.
|
||||
|
||||
## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`)
|
||||
|
||||
- `.key(...)`:
|
||||
- Use for partial matching operations (recommended for invalidation/refetch/cancel patterns).
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`:
|
||||
- Use for a specific query's full key (exact query identity / direct cache addressing).
|
||||
- `.mutationKey(...)`:
|
||||
- Use for a specific mutation's full key.
|
||||
- Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- **Input structure**: Always use `{ params, query?, body? }` format
|
||||
- **No-input GET**: Omit `.input(...)`; do not use `.input(type<unknown>())`
|
||||
- **Path params**: Use `{paramName}` in path, match in `params` object
|
||||
- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`)
|
||||
- **No barrel files**: Import directly from specific files
|
||||
- **Types**: Import from `@/types/`, use `type<T>()` helper
|
||||
- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
1
.claude/skills/frontend-query-mutation
Symbolic link
1
.claude/skills/frontend-query-mutation
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/frontend-query-mutation
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/orpc-contract-first
|
||||
2
.github/workflows/api-tests.yml
vendored
2
.github/workflows/api-tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@@ -113,7 +113,7 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-${{ matrix.context }}-*
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/main-ci.yml
vendored
2
.github/workflows/main-ci.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
|
||||
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@@ -77,7 +77,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
|
||||
@@ -103,6 +103,7 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
|
||||
@@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN mkdir -p /usr/local/share/nltk_data \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
@@ -1,45 +1,16 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check.
|
||||
# Defined at module level to avoid per-request tuple construction.
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - features: billing/plan features (ProviderContextProvider)
|
||||
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/features",
|
||||
"/console/api/account/profile",
|
||||
"/console/api/workspaces/current",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@@ -60,39 +31,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/")
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status (cached — see EnterpriseService for TTL details)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
if license_status is None:
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@@ -44,22 +44,10 @@ class FetchUserArg(BaseModel):
|
||||
required: bool = False
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token(
|
||||
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def validate_app_token(
|
||||
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token("app")
|
||||
|
||||
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
|
||||
@@ -225,20 +213,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
@overload
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def validate_dataset_token(
|
||||
view: Callable[Concatenate[T, P], R] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
@@ -309,7 +287,7 @@ def validate_dataset_token(
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@@ -305,7 +305,9 @@ class ProviderManager:
|
||||
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
|
||||
|
||||
if available_models:
|
||||
available_model = available_models[0]
|
||||
available_model = next(
|
||||
(model for model in available_models if model.model == "gpt-4"), available_models[0]
|
||||
)
|
||||
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -135,8 +135,8 @@ class PGVectoRS(BaseVector):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>:key = :value")
|
||||
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
|
||||
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
else:
|
||||
@@ -172,9 +172,9 @@ class PGVectoRS(BaseVector):
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self._client) as session:
|
||||
select_statement = sql_text(
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = :doc_id limit 1"
|
||||
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
|
||||
)
|
||||
result = session.execute(select_statement, {"doc_id": id}).fetchall()
|
||||
result = session.execute(select_statement).fetchall()
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
@@ -154,8 +154,10 @@ class RelytVector(BaseVector):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
result = None
|
||||
with Session(self.client) as session:
|
||||
select_statement = sql_text(f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>:key = :value""")
|
||||
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
|
||||
)
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
return [item[0] for item in result]
|
||||
else:
|
||||
@@ -199,10 +201,11 @@ class RelytVector(BaseVector):
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
with Session(self.client) as session:
|
||||
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = ANY(:doc_ids)"""
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
|
||||
)
|
||||
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
|
||||
result = session.execute(select_statement).fetchall()
|
||||
if result:
|
||||
ids = [item[0] for item in result]
|
||||
self.delete_by_uuids(ids)
|
||||
@@ -215,9 +218,9 @@ class RelytVector(BaseVector):
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with Session(self.client) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = :doc_id limit 1"""
|
||||
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
|
||||
)
|
||||
result = session.execute(select_statement, {"doc_id": id}).fetchall()
|
||||
result = session.execute(select_statement).fetchall()
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
|
||||
@@ -294,7 +294,7 @@ class BaseIndexProcessor(ABC):
|
||||
logging.warning("Error downloading image from %s: %s", image_url, str(e))
|
||||
return None
|
||||
except Exception:
|
||||
logging.warning("Unexpected error downloading image from %s", image_url, exc_info=True)
|
||||
logging.exception("Unexpected error downloading image from %s", image_url)
|
||||
return None
|
||||
|
||||
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:
|
||||
|
||||
@@ -45,7 +45,6 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.protocols import TemplateRenderer
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@@ -229,16 +228,6 @@ class DefaultWorkflowCodeExecutor:
|
||||
return isinstance(error, CodeExecutionError)
|
||||
|
||||
|
||||
class DefaultLLMTemplateRenderer(TemplateRenderer):
|
||||
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=template,
|
||||
inputs=inputs,
|
||||
)
|
||||
return str(result.get("result", ""))
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
@@ -265,7 +254,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
|
||||
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||
@@ -403,8 +391,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_instance=model_instance,
|
||||
),
|
||||
}
|
||||
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
node_init_kwargs["template_renderer"] = self._llm_template_renderer
|
||||
if include_http_client:
|
||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||
return node_init_kwargs
|
||||
|
||||
@@ -1,53 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.file import FileType, file_manager
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.model_runtime.entities import (
|
||||
from dify_graph.model_runtime.entities import PromptMessageRole
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables import ArrayFileSegment, FileSegment
|
||||
from dify_graph.variables.segments import ArrayAnySegment, NoneSegment
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
||||
|
||||
from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig
|
||||
from .exc import (
|
||||
InvalidVariableTypeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
)
|
||||
from .protocols import TemplateRenderer
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
model_instance.model_name,
|
||||
dict(model_instance.credentials),
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
return model_schema
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]:
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
@@ -108,366 +89,3 @@ def fetch_memory_text(
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
)
|
||||
|
||||
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
vision_enabled: bool = False,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
context_files: list[File] | None = None,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
prompt_messages.extend(
|
||||
handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
)
|
||||
|
||||
prompt_messages.extend(
|
||||
handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
)
|
||||
|
||||
if sys_query:
|
||||
prompt_messages.extend(
|
||||
handle_list_messages(
|
||||
messages=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
)
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
prompt_messages.extend(
|
||||
handle_completion_template(
|
||||
template=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
)
|
||||
|
||||
memory_text = handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
prompt_content = prompt_messages[0].content
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
content_item.data = memory_text + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
if sys_query:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
_append_file_prompts(
|
||||
prompt_messages=prompt_messages,
|
||||
files=sys_files,
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
_append_file_prompts(
|
||||
prompt_messages=prompt_messages,
|
||||
files=context_files or [],
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
|
||||
filtered_prompt_messages: list[PromptMessage] = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
if not model_schema.features:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
prompt_message_content.append(content_item)
|
||||
continue
|
||||
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_schema.features
|
||||
)
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if prompt_message_content:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if len(filtered_prompt_messages) == 0:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
|
||||
def handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
prompt_messages.append(
|
||||
combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)],
|
||||
role=message.role,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
template = message.text.replace("{#context#}", context) if context else message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
file_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
|
||||
)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
|
||||
)
|
||||
|
||||
if segment_group.text:
|
||||
prompt_messages.append(
|
||||
combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=segment_group.text)],
|
||||
role=message.role,
|
||||
)
|
||||
)
|
||||
if file_contents:
|
||||
prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> str:
|
||||
if not template:
|
||||
return ""
|
||||
if template_renderer is None:
|
||||
raise ValueError("template_renderer is required for jinja2 prompt rendering")
|
||||
|
||||
jinja2_inputs: dict[str, Any] = {}
|
||||
for jinja2_variable in jinja2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs)
|
||||
|
||||
|
||||
def handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
else:
|
||||
template_text = template.text.replace("{#context#}", context) if context else template.text
|
||||
result_text = variable_pool.convert_template(template_text).text
|
||||
return [
|
||||
combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)],
|
||||
role=PromptMessageRole.USER,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def combine_message_content_with_role(
|
||||
*,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None = None,
|
||||
role: PromptMessageRole,
|
||||
) -> PromptMessage:
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=contents)
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def handle_memory_chat_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> Sequence[PromptMessage]:
|
||||
if not memory or not memory_config:
|
||||
return []
|
||||
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
|
||||
return memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
|
||||
|
||||
def handle_memory_completion_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> str:
|
||||
if not memory or not memory_config:
|
||||
return ""
|
||||
|
||||
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
|
||||
return fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
|
||||
|
||||
def _append_file_prompts(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
files: Sequence[File],
|
||||
vision_enabled: bool,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
) -> None:
|
||||
if not vision_enabled or not files:
|
||||
return
|
||||
|
||||
file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files]
|
||||
if (
|
||||
prompt_messages
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
existing_contents = prompt_messages[-1].content
|
||||
assert isinstance(existing_contents, list)
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -27,10 +28,11 @@ from dify_graph.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
|
||||
from dify_graph.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
@@ -41,7 +43,14 @@ from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
@@ -55,12 +64,13 @@ from dify_graph.node_events import (
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
@@ -79,6 +89,9 @@ from .exc import (
|
||||
InvalidContextStructureError,
|
||||
InvalidVariableTypeError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
@@ -105,7 +118,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_model_factory: ModelFactory
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
_template_renderer: TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -118,7 +130,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
template_renderer: TemplateRenderer,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -135,7 +146,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
self._template_renderer = template_renderer
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
@@ -230,7 +240,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
context_files=context_files,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
@@ -764,24 +773,182 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
context_files: list[File] | None = None,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=sys_query,
|
||||
sys_files=sys_files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=stop,
|
||||
memory_config=memory_config,
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=vision_detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=jinja2_variables,
|
||||
context_files=context_files,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory messages for chat mode
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if sys_query:
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
# For completion model
|
||||
prompt_messages.extend(
|
||||
_handle_completion_template(
|
||||
template=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory text for completion model
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
content_item.data = memory_text + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
# The sys_files will be deprecated later
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = []
|
||||
for file in sys_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# The context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = []
|
||||
for file in context_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Remove empty messages and filter unsupported content
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
continue
|
||||
|
||||
# Skip content if corresponding feature is not supported
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_schema.features
|
||||
)
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if len(filtered_prompt_messages) == 0:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@@ -881,16 +1048,59 @@ class LLMNode(Node[LLMNodeData]):
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
return llm_utils.handle_list_messages(
|
||||
messages=messages,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail_config,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@staticmethod
|
||||
def handle_blocking_result(
|
||||
@@ -1029,3 +1239,152 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=contents)
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
):
|
||||
if not template:
|
||||
return ""
|
||||
|
||||
jinja2_inputs = {}
|
||||
for jinja2_variable in jinja2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=template,
|
||||
inputs=jinja2_inputs,
|
||||
)
|
||||
result_text = code_execute_resp["result"]
|
||||
return result_text
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: ModelInstance,
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Handle completion template processing outside of LLMNode class.
|
||||
|
||||
Args:
|
||||
template: The completion model prompt template
|
||||
context: Optional context string
|
||||
jinja2_variables: Variables for jinja2 template rendering
|
||||
variable_pool: Variable pool for template conversion
|
||||
|
||||
Returns:
|
||||
Sequence of prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
else:
|
||||
if context:
|
||||
template_text = template.text.replace("{#context#}", context)
|
||||
else:
|
||||
template_text = template.text
|
||||
result_text = variable_pool.convert_template(template_text).text
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -20,11 +19,3 @@ class ModelFactory(Protocol):
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
|
||||
|
||||
class TemplateRenderer(Protocol):
|
||||
"""Port for rendering prompt templates used by LLM-compatible nodes."""
|
||||
|
||||
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""Render the given Jinja2 template into plain text."""
|
||||
...
|
||||
|
||||
@@ -28,7 +28,7 @@ from dify_graph.nodes.llm import (
|
||||
llm_utils,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
@@ -59,7 +59,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
_template_renderer: TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -72,7 +71,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
template_renderer: TemplateRenderer,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -89,7 +87,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
self._template_renderer = template_renderer
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
@@ -145,7 +142,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||
# two consecutive user prompts will be generated, causing model's error.
|
||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||
prompt_messages, stop = llm_utils.fetch_prompt_messages(
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
@@ -156,7 +153,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
@@ -291,7 +287,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages, _ = llm_utils.fetch_prompt_messages(
|
||||
prompt_messages, _ = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
sys_files=[],
|
||||
@@ -304,7 +300,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""add indexes for human_input_forms query patterns
|
||||
|
||||
Revision ID: 0ec65df55790
|
||||
Revises: e288952f2994
|
||||
Create Date: 2026-03-02 18:05:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0ec65df55790"
|
||||
down_revision = "e288952f2994"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("human_input_forms", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"human_input_forms_workflow_run_id_node_id_idx",
|
||||
["workflow_run_id", "node_id"],
|
||||
unique=False,
|
||||
)
|
||||
batch_op.create_index(
|
||||
"human_input_forms_status_created_at_idx",
|
||||
["status", "created_at"],
|
||||
unique=False,
|
||||
)
|
||||
batch_op.create_index(
|
||||
"human_input_forms_status_expiration_time_idx",
|
||||
["status", "expiration_time"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
batch_op.f("human_input_form_deliveries_form_id_idx"),
|
||||
["form_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
batch_op.f("human_input_form_recipients_delivery_id_idx"),
|
||||
["delivery_id"],
|
||||
unique=False,
|
||||
)
|
||||
batch_op.create_index(
|
||||
batch_op.f("human_input_form_recipients_form_id_idx"),
|
||||
["form_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("human_input_forms", schema=None) as batch_op:
|
||||
batch_op.drop_index("human_input_forms_workflow_run_id_node_id_idx")
|
||||
batch_op.drop_index("human_input_forms_status_expiration_time_idx")
|
||||
batch_op.drop_index("human_input_forms_status_created_at_idx")
|
||||
|
||||
with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f("human_input_form_recipients_form_id_idx"))
|
||||
batch_op.drop_index(batch_op.f("human_input_form_recipients_delivery_id_idx"))
|
||||
|
||||
with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f("human_input_form_deliveries_form_id_idx"))
|
||||
@@ -30,15 +30,6 @@ def _generate_token() -> str:
|
||||
|
||||
class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_forms"
|
||||
__table_args__ = (
|
||||
sa.Index(
|
||||
"human_input_forms_workflow_run_id_node_id_idx",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
),
|
||||
sa.Index("human_input_forms_status_expiration_time_idx", "status", "expiration_time"),
|
||||
sa.Index("human_input_forms_status_created_at_idx", "status", "created_at"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@@ -93,12 +84,6 @@ class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
|
||||
class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_deliveries"
|
||||
__table_args__ = (
|
||||
sa.Index(
|
||||
None,
|
||||
"form_id",
|
||||
),
|
||||
)
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
@@ -196,10 +181,6 @@ RecipientPayload = Annotated[
|
||||
|
||||
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_recipients"
|
||||
__table_args__ = (
|
||||
sa.Index(None, "form_id"),
|
||||
sa.Index(None, "delivery_id"),
|
||||
)
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
|
||||
@@ -6,12 +6,12 @@ requires-python = ">=3.11,<3.13"
|
||||
dependencies = [
|
||||
"aliyun-log-python-sdk~=0.9.37",
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.68",
|
||||
"azure-identity==1.25.2",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"boto3==1.42.65",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
"celery~=5.5.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.24",
|
||||
@@ -35,12 +35,12 @@ dependencies = [
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.7.16",
|
||||
"markdown~=3.10.2",
|
||||
"markdown~=3.8.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.82.1", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
@@ -58,7 +58,7 @@ dependencies = [
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
"opentelemetry-util-http==0.49b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"pandas[excel,output-formatting,performance]~=2.2.2",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.23.0",
|
||||
@@ -66,22 +66,22 @@ dependencies = [
|
||||
"pydantic-extra-types~=2.11.0",
|
||||
"pydantic-settings~=2.13.1",
|
||||
"pyjwt~=2.12.0",
|
||||
"pypdfium2==5.6.0",
|
||||
"pypdfium2==5.2.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.2.2",
|
||||
"python-dotenv==1.0.1",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.52.1",
|
||||
"starlette==0.49.1",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"yarl~=1.23.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"yarl~=1.18.3",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.9.0",
|
||||
"sseclient-py~=1.8.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"flask-restx~=1.3.2",
|
||||
@@ -111,7 +111,7 @@ package = false
|
||||
dev = [
|
||||
"coverage~=7.13.4",
|
||||
"dotenv-linter~=0.7.0",
|
||||
"faker~=40.11.0",
|
||||
"faker~=40.8.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.38.2",
|
||||
"ruff~=0.15.5",
|
||||
@@ -120,7 +120,7 @@ dev = [
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
"testcontainers~=4.13.2",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=6.2.0",
|
||||
|
||||
@@ -6,13 +6,6 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIBadRequestError,
|
||||
EnterpriseAPIError,
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,51 +64,10 @@ class BaseRequest:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
|
||||
# Validate HTTP status and raise domain-specific errors
|
||||
if not response.is_success:
|
||||
cls._handle_error_response(response)
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _handle_error_response(cls, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
||||
|
||||
Attempts to extract error message from JSON response body,
|
||||
falls back to status text if parsing fails.
|
||||
"""
|
||||
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
||||
|
||||
# Try to extract error message from JSON response
|
||||
try:
|
||||
error_data = response.json()
|
||||
if isinstance(error_data, dict):
|
||||
# Common error response formats:
|
||||
# {"error": "...", "message": "..."}
|
||||
# {"message": "..."}
|
||||
# {"detail": "..."}
|
||||
error_message = (
|
||||
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
||||
)
|
||||
except Exception:
|
||||
# If JSON parsing fails, use the default message
|
||||
logger.debug(
|
||||
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
||||
)
|
||||
|
||||
# Raise specific error based on status code
|
||||
if response.status_code == 400:
|
||||
raise EnterpriseAPIBadRequestError(error_message)
|
||||
elif response.status_code == 401:
|
||||
raise EnterpriseAPIUnauthorizedError(error_message)
|
||||
elif response.status_code == 403:
|
||||
raise EnterpriseAPIForbiddenError(error_message)
|
||||
elif response.status_code == 404:
|
||||
raise EnterpriseAPINotFoundError(error_message)
|
||||
else:
|
||||
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
||||
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
# License status cache configuration
|
||||
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
|
||||
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
||||
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
@@ -63,7 +52,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
|
||||
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
|
||||
if self.joined and not self.workspace_id:
|
||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||
return self
|
||||
@@ -126,6 +115,7 @@ class EnterpriseService:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||
raise_for_status=True,
|
||||
)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||
@@ -233,64 +223,3 @@ class EnterpriseService:
|
||||
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
|
||||
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
|
||||
balances prompt license-fix detection against DoS mitigation — without
|
||||
caching, every request on an expired license would hit the enterprise API.
|
||||
|
||||
Returns:
|
||||
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return None
|
||||
|
||||
cached = cls._read_cached_license_status()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
return cls._fetch_and_cache_license_status()
|
||||
|
||||
@classmethod
|
||||
def _read_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Read license status from Redis cache, returning None on miss or failure."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||
if raw:
|
||||
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
return LicenseStatus(value)
|
||||
except Exception:
|
||||
logger.debug("Failed to read license status from cache", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
|
||||
"""Fetch license status from enterprise API and cache the result."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
info = cls.get_info()
|
||||
license_info = info.get("License")
|
||||
if not license_info:
|
||||
return None
|
||||
|
||||
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
ttl = (
|
||||
VALID_LICENSE_CACHE_TTL
|
||||
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
|
||||
else INVALID_LICENSE_CACHE_TTL
|
||||
)
|
||||
try:
|
||||
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
|
||||
except Exception:
|
||||
logger.debug("Failed to cache license status", exc_info=True)
|
||||
return status
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch enterprise license status", exc_info=True)
|
||||
return None
|
||||
|
||||
@@ -70,6 +70,7 @@ class PluginManagerService:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json=body.model_dump(),
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -7,7 +7,6 @@ from . import (
|
||||
conversation,
|
||||
dataset,
|
||||
document,
|
||||
enterprise,
|
||||
file,
|
||||
index,
|
||||
message,
|
||||
@@ -22,7 +21,6 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"enterprise",
|
||||
"file",
|
||||
"index",
|
||||
"message",
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Enterprise service errors."""
|
||||
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class EnterpriseServiceError(BaseServiceError):
|
||||
"""Base exception for enterprise service errors."""
|
||||
|
||||
def __init__(self, description: str | None = None, status_code: int | None = None):
|
||||
super().__init__(description)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class EnterpriseAPIError(EnterpriseServiceError):
|
||||
"""Generic enterprise API error (non-2xx response)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EnterpriseAPINotFoundError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 404 Not Found."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=404)
|
||||
|
||||
|
||||
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 403 Forbidden."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=403)
|
||||
|
||||
|
||||
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 401 Unauthorized."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 400 Bad Request."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=400)
|
||||
@@ -379,19 +379,14 @@ class FeatureService:
|
||||
)
|
||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||
|
||||
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
|
||||
# so the login page can detect an expired/inactive license after force-logout.
|
||||
# All other license details (expiry date, workspace usage) remain auth-gated.
|
||||
# This behavior reflects prior internal review of information-leakage risks.
|
||||
if license_info := enterprise_info.get("License"):
|
||||
if is_authenticated and (license_info := enterprise_info.get("License")):
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
|
||||
if is_authenticated:
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
|
||||
if "PluginInstallationPermission" in enterprise_info:
|
||||
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.model_manager import ModelInstance
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@@ -75,7 +75,6 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
template_renderer=MagicMock(spec=TemplateRenderer),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
)
|
||||
|
||||
@@ -159,7 +158,7 @@ def test_execute_llm():
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(*_args, **_kwargs):
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
|
||||
@@ -358,9 +358,10 @@ class TestFeatureService:
|
||||
assert result is not None
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
|
||||
# Detailed license info (expiry, workspaces) remains auth-gated.
|
||||
assert result.license.status == LicenseStatus.ACTIVE
|
||||
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
|
||||
# Ensure only essential UI flags are returned to unauthenticated clients
|
||||
# to keep the payload lightweight and adhere to architectural boundaries.
|
||||
assert result.license.status == LicenseStatus.NONE
|
||||
assert result.license.expired_at == ""
|
||||
assert result.license.workspaces.enabled is False
|
||||
assert result.license.workspaces.limit == 0
|
||||
|
||||
@@ -1,34 +1,32 @@
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity():
|
||||
mock_entity = Mock()
|
||||
def mock_provider_entity(mocker: MockerFixture):
|
||||
mock_entity = mocker.Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||
provider_credential_schema = Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
|
||||
provider_credential_schema = mocker.Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.provider_credential_schema = provider_credential_schema
|
||||
|
||||
model_credential_schema = Mock()
|
||||
type(model_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
|
||||
model_credential_schema = mocker.Mock()
|
||||
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
mock_entity.model_credential_schema = model_credential_schema
|
||||
|
||||
return mock_entity
|
||||
|
||||
|
||||
def test__to_model_settings(mock_provider_entity):
|
||||
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
@@ -65,18 +63,18 @@ def test__to_model_settings(mock_provider_entity):
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@@ -89,7 +87,7 @@ def test__to_model_settings(mock_provider_entity):
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mock_provider_entity):
|
||||
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
|
||||
ps = ProviderModelSetting(
|
||||
@@ -115,18 +113,18 @@ def test__to_model_settings_only_one_lb(mock_provider_entity):
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@@ -137,7 +135,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity):
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mock_provider_entity):
|
||||
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
@@ -172,18 +170,18 @@ def test__to_model_settings_lb_disabled(mock_provider_entity):
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@@ -192,39 +190,3 @@ def test__to_model_settings_lb_disabled(mock_provider_entity):
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test_get_default_model_uses_first_available_active_model():
|
||||
mock_session = Mock()
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
provider_configurations = Mock()
|
||||
provider_configurations.get_models.return_value = [
|
||||
Mock(model="gpt-3.5-turbo", provider=Mock(provider="openai")),
|
||||
Mock(model="gpt-4", provider=Mock(provider="openai")),
|
||||
]
|
||||
|
||||
manager = ProviderManager()
|
||||
with (
|
||||
patch("core.provider_manager.db.session", mock_session),
|
||||
patch.object(manager, "get_configurations", return_value=provider_configurations),
|
||||
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
|
||||
):
|
||||
mock_factory_cls.return_value.get_provider_schema.return_value = Mock(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
assert result is not None
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
assert result.provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
mock_session.add.assert_called_once()
|
||||
saved_default_model = mock_session.add.call_args.args[0]
|
||||
assert saved_default_model.model_name == "gpt-3.5-turbo"
|
||||
assert saved_default_model.provider_name == "openai"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode
|
||||
from dify_graph.nodes.http_request import HttpRequestNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
@@ -68,8 +68,6 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
||||
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
|
||||
|
||||
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
|
||||
@@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
@@ -107,7 +107,6 @@ def llm_node(
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@@ -122,7 +121,6 @@ def llm_node(
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node
|
||||
@@ -592,33 +590,6 @@ def test_handle_list_messages_basic(llm_node):
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
|
||||
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
|
||||
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
|
||||
messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="",
|
||||
jinja2_text="Hello, {{ name }}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="jinja2",
|
||||
)
|
||||
]
|
||||
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=None,
|
||||
jinja2_variables=[],
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
template_renderer=llm_node._template_renderer,
|
||||
)
|
||||
|
||||
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
|
||||
llm_node._template_renderer.render_jinja2.assert_called_once_with(
|
||||
template="Hello, {{ name }}",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
|
||||
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
memory.get_history_prompt_messages.return_value = [
|
||||
@@ -642,8 +613,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=3),
|
||||
)
|
||||
|
||||
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = llm_utils.handle_memory_completion_mode(
|
||||
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
@@ -659,7 +630,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@@ -674,7 +644,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.question_classifier import (
|
||||
QuestionClassifierNode,
|
||||
QuestionClassifierNodeData,
|
||||
)
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data():
|
||||
@@ -74,52 +65,3 @@ def test_init_question_classifier_node_data_without_vision_config():
|
||||
assert node_data.vision.enabled == False
|
||||
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
|
||||
node_data = QuestionClassifierNodeData.model_validate(
|
||||
{
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
}
|
||||
)
|
||||
template_renderer = MagicMock(spec=TemplateRenderer)
|
||||
node = QuestionClassifierNode(
|
||||
id="node-id",
|
||||
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="workflow-id",
|
||||
graph_config={},
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
),
|
||||
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
llm_file_saver=MagicMock(),
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
fetch_prompt_messages = MagicMock(return_value=([], None))
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
|
||||
fetch_prompt_messages,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
|
||||
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
|
||||
)
|
||||
|
||||
node._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query="hello",
|
||||
model_instance=MagicMock(stop=(), parameters={}),
|
||||
context="",
|
||||
)
|
||||
|
||||
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer
|
||||
|
||||
@@ -140,29 +140,6 @@ class TestDefaultWorkflowCodeExecutor:
|
||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||
|
||||
|
||||
class TestDefaultLLMTemplateRenderer:
|
||||
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
|
||||
renderer = node_factory.DefaultLLMTemplateRenderer()
|
||||
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
|
||||
monkeypatch.setattr(
|
||||
node_factory.CodeExecutor,
|
||||
"execute_workflow_code_template",
|
||||
execute_workflow_code_template,
|
||||
)
|
||||
|
||||
result = renderer.render_jinja2(
|
||||
template="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
assert result == "hello world"
|
||||
execute_workflow_code_template.assert_called_once_with(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
@@ -173,7 +150,6 @@ class TestDifyNodeFactoryInit:
|
||||
http_request_config = sentinel.http_request_config
|
||||
credentials_provider = sentinel.credentials_provider
|
||||
model_factory = sentinel.model_factory
|
||||
llm_template_renderer = sentinel.llm_template_renderer
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@@ -196,11 +172,6 @@ class TestDifyNodeFactoryInit:
|
||||
"build_http_request_config",
|
||||
return_value=http_request_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"DefaultLLMTemplateRenderer",
|
||||
return_value=llm_template_renderer,
|
||||
) as llm_renderer_factory,
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_dify_model_access",
|
||||
@@ -215,14 +186,11 @@ class TestDifyNodeFactoryInit:
|
||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||
renderer_factory.assert_called_once()
|
||||
llm_renderer_factory.assert_called_once()
|
||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||
assert factory.graph_init_params is graph_init_params
|
||||
assert factory.graph_runtime_state is graph_runtime_state
|
||||
assert factory._dify_context is dify_context
|
||||
assert factory._template_renderer is template_renderer
|
||||
|
||||
assert factory._llm_template_renderer is llm_template_renderer
|
||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||
assert factory._http_request_config is http_request_config
|
||||
assert factory._llm_credentials_provider is credentials_provider
|
||||
@@ -274,7 +242,6 @@ class TestDifyNodeFactoryCreateNode:
|
||||
factory._code_executor = sentinel.code_executor
|
||||
factory._code_limits = sentinel.code_limits
|
||||
factory._template_renderer = sentinel.template_renderer
|
||||
factory._llm_template_renderer = sentinel.llm_template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||
@@ -411,22 +378,8 @@ class TestDifyNodeFactoryCreateNode:
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
[
|
||||
(
|
||||
BuiltinNodeTypes.LLM,
|
||||
"LLMNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
"QuestionClassifierNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
||||
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
||||
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Unit tests for enterprise service integrations.
|
||||
|
||||
Covers:
|
||||
- Default workspace auto-join behavior
|
||||
- License status caching (get_cached_license_status)
|
||||
This module covers the enterprise-only default workspace auto-join behavior:
|
||||
- Enterprise mode disabled: no external calls
|
||||
- Successful join / skipped join: no errors
|
||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
@@ -10,9 +11,6 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from services.enterprise.enterprise_service import (
|
||||
INVALID_LICENSE_CACHE_TTL,
|
||||
LICENSE_STATUS_CACHE_KEY,
|
||||
VALID_LICENSE_CACHE_TTL,
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
try_join_default_workspace,
|
||||
@@ -39,6 +37,7 @@ class TestJoinDefaultWorkspace:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=1.0,
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||
@@ -140,134 +139,3 @@ class TestTryJoinDefaultWorkspace:
|
||||
|
||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||
try_join_default_workspace("not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_cached_license_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EE_SVC = "services.enterprise.enterprise_service"
|
||||
|
||||
|
||||
class TestGetCachedLicenseStatus:
|
||||
"""Tests for EnterpriseService.get_cached_license_status."""
|
||||
|
||||
def test_returns_none_when_enterprise_disabled(self):
|
||||
with patch(f"{_EE_SVC}.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_cache_hit_returns_license_status_enum(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = b"active"
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
assert isinstance(result, LicenseStatus)
|
||||
mock_get_info.assert_not_called()
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_valid_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
|
||||
)
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "expired"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRED
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
|
||||
)
|
||||
|
||||
def test_redis_read_failure_falls_through_to_api(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_get_info.assert_called_once()
|
||||
|
||||
def test_redis_write_failure_still_returns_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "expiring"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRING
|
||||
|
||||
def test_api_failure_returns_none(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.side_effect = Exception("network failure")
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_api_returns_no_license_info(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {} # no "License" key
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
@@ -34,6 +34,7 @@ class TestTryPreUninstallPlugin:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -61,6 +62,7 @@ class TestTryPreUninstallPlugin:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
@@ -85,6 +87,7 @@ class TestTryPreUninstallPlugin:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
824
api/uv.lock
generated
824
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,10 @@
|
||||
- In new or modified code, use only overlay primitives from `@/app/components/base/ui/*`.
|
||||
- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them and keep the allowlist shrinking (never expanding).
|
||||
|
||||
## Query & Mutation (Mandatory)
|
||||
|
||||
- `frontend-query-mutation` is the source of truth for Dify frontend contracts, query and mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
|
||||
## Automated Test Generation
|
||||
|
||||
- Use `./docs/test.md` as the canonical instruction set for generating frontend automated tests.
|
||||
|
||||
@@ -1,31 +1,29 @@
|
||||
import type { QueryKey } from '@tanstack/react-query'
|
||||
import {
|
||||
|
||||
useQueryClient,
|
||||
} from '@tanstack/react-query'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { useCallback } from 'react'
|
||||
|
||||
/**
|
||||
* @deprecated Convenience wrapper scheduled for removal.
|
||||
* Prefer binding invalidation in `useMutation` callbacks at the service layer.
|
||||
*/
|
||||
export const useInvalid = (key?: QueryKey) => {
|
||||
const queryClient = useQueryClient()
|
||||
return () => {
|
||||
return useCallback(() => {
|
||||
if (!key)
|
||||
return
|
||||
queryClient.invalidateQueries(
|
||||
{
|
||||
queryKey: key,
|
||||
},
|
||||
)
|
||||
}
|
||||
queryClient.invalidateQueries({ queryKey: key })
|
||||
}, [queryClient, key])
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Convenience wrapper scheduled for removal.
|
||||
* Prefer binding reset in `useMutation` callbacks at the service layer.
|
||||
*/
|
||||
export const useReset = (key?: QueryKey) => {
|
||||
const queryClient = useQueryClient()
|
||||
return () => {
|
||||
return useCallback(() => {
|
||||
if (!key)
|
||||
return
|
||||
queryClient.resetQueries(
|
||||
{
|
||||
queryKey: key,
|
||||
},
|
||||
)
|
||||
}
|
||||
queryClient.resetQueries({ queryKey: key })
|
||||
}, [queryClient, key])
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user