Compare commits

..

6 Commits

Author SHA1 Message Date
-LAN-
a30a64d51b Scope tool configuration test patches 2026-03-14 19:29:05 +08:00
-LAN-
34ef10c818 Keep direct provider_id consumers unchanged 2026-03-14 19:14:30 +08:00
-LAN-
26fedca865 Keep trigger provider handling in the node 2026-03-14 19:14:30 +08:00
-LAN-
2fd4e9e259 Restore trigger provider metadata on start events 2026-03-14 19:14:30 +08:00
-LAN-
9e8a4c8a71 Keep dify_graph node base generic 2026-03-14 19:14:30 +08:00
-LAN-
238497b7ab Move trigger workflow nodes into core workflow 2026-03-14 19:14:29 +08:00
278 changed files with 2586 additions and 4197 deletions

View File

@@ -187,13 +187,53 @@ const Template = useMemo(() => {
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**:
- This skill is for component decomposition, not query/mutation design.
- When refactoring data fetching, follow `web/AGENTS.md`.
- Use `orpc-contract-first` for contracts, query shape, data-fetching wrappers, and query/mutation call-site patterns.
- Use `web/docs/query-mutation.md` for Dify-specific conditional query, invalidation, and mutation error-handling rules.
- Do not introduce deprecated `useInvalid` / `useReset`.
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Examples**:
- `web/service/use-workflow.ts`

View File

@@ -155,15 +155,48 @@ const Configuration: FC = () => {
## Common Hook Patterns in Dify
### 1. Data Fetching / Mutation Hooks
### 1. Data Fetching Hook (React Query)
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
- Follow `web/AGENTS.md` first.
- Use `orpc-contract-first` for contracts, query shape, data-fetching wrappers, and query/mutation call-site patterns.
- Use `web/docs/query-mutation.md` for conditional query, invalidation, and mutation error-handling rules.
- Do not introduce deprecated `useInvalid` / `useReset`.
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
### 2. Form State Hook

View File

@@ -1,44 +0,0 @@
---
name: frontend-query-mutation
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
---
# Frontend Query & Mutation
## Intent
- Keep contract as the single source of truth in `web/contract/*`.
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
- Keep invalidation and mutation flow knowledge in the service layer.
- Keep abstractions minimal to preserve TypeScript inference.
## Workflow
1. Identify the change surface.
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
- Read both references when a task spans contract shape and runtime behavior.
2. Implement the smallest abstraction that fits the task.
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
- Extract a small shared query helper only when multiple call sites share the same extra options.
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
3. Preserve Dify conventions.
- Keep contract inputs in `{ params, query?, body? }` shape.
- Bind invalidation in the service-layer mutation definition.
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
## Files Commonly Touched
- `web/contract/console/*.ts`
- `web/contract/marketplace.ts`
- `web/contract/router.ts`
- `web/service/client.ts`
- `web/service/use-*.ts`
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
## References
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.

View File

@@ -1,4 +0,0 @@
interface:
display_name: "Frontend Query & Mutation"
short_description: "Dify TanStack Query and oRPC patterns"
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."

View File

@@ -1,98 +0,0 @@
# Contract Patterns
## Table of Contents
- Intent
- Minimal structure
- Core workflow
- Query usage decision rule
- Mutation usage decision rule
- Anti-patterns
- Contract rules
- Type export
## Intent
- Keep contract as the single source of truth in `web/contract/*`.
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
- Keep abstractions minimal and preserve TypeScript inference.
## Minimal Structure
```text
web/contract/
├── base.ts
├── router.ts
├── marketplace.ts
└── console/
├── billing.ts
└── ...other domains
web/service/client.ts
```
## Core Workflow
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
- Use `base.route({...}).output(type<...>())` as the baseline.
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
2. Register contract in `web/contract/router.ts`.
- Import directly from domain files and nest by API prefix.
3. Consume from UI call sites via oRPC query utilities.
```typescript
import { useQuery } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
staleTime: 5 * 60 * 1000,
throwOnError: true,
select: invoice => invoice.url,
}))
```
## Query Usage Decision Rule
1. Default to direct `*.queryOptions(...)` usage at the call site.
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
3. Create `web/service/use-{domain}.ts` only for orchestration.
- Combine multiple queries or mutations.
- Share domain-level derived state or invalidation helpers.
```typescript
const invoicesBaseQueryOptions = () =>
consoleQuery.billing.invoices.queryOptions({ retry: false })
const invoiceQuery = useQuery({
...invoicesBaseQueryOptions(),
throwOnError: true,
})
```
## Mutation Usage Decision Rule
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
## Contract Rules
- Input structure: always use `{ params, query?, body? }`.
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
- Path params: use `{paramName}` in the path and match it in the `params` object.
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
- No barrel files: import directly from specific files.
- Types: import from `@/types/` and use the `type<T>()` helper.
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@@ -1,133 +0,0 @@
# Runtime Rules
## Table of Contents
- Conditional queries
- Cache invalidation
- Key API guide
- `mutate` vs `mutateAsync`
- Legacy migration
## Conditional Queries
Prefer contract-shaped `queryOptions(...)`.
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
Use `enabled` only for extra business gating after the input itself is already valid.
```typescript
import { skipToken, useQuery } from '@tanstack/react-query'
// Disable the query by skipping input construction.
function useAccessMode(appId: string | undefined) {
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
input: appId
? { params: { appId } }
: skipToken,
}))
}
// Avoid runtime-only guards that bypass type checking.
function useBadAccessMode(appId: string | undefined) {
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
input: { params: { appId: appId! } },
enabled: !!appId,
}))
}
```
## Cache Invalidation
Bind invalidation in the service-layer mutation definition.
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
Use:
- `.key()` for namespace or prefix invalidation
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
Do not use deprecated `useInvalid` from `use-base.ts`.
```typescript
// Service layer owns cache invalidation.
export const useUpdateAccessMode = () => {
const queryClient = useQueryClient()
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
})
},
}))
}
// Component only adds UI behavior.
updateAccessMode({ appId, mode }, {
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
})
// Avoid putting invalidation knowledge in the component.
mutate({ appId, mode }, {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
})
},
})
```
## Key API Guide
- `.key(...)`
- Use for partial matching operations.
- Prefer it for invalidation, refetch, and cancel patterns.
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
- `.queryKey(...)`
- Use for a specific query's full key.
- Prefer it for exact cache addressing and direct reads or writes.
- `.mutationKey(...)`
- Use for a specific mutation's full key.
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
## `mutate` vs `mutateAsync`
Prefer `mutate` by default.
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
Rules:
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
- Do not use `mutateAsync` when callbacks already express the flow clearly.
```typescript
// Default case.
mutation.mutate(data, {
onSuccess: result => router.push(result.url),
})
// Promise semantics are required.
try {
const order = await createOrder.mutateAsync(orderData)
await confirmPayment.mutateAsync({ orderId: order.id, token })
router.push(`/orders/${order.id}`)
}
catch (error) {
Toast.notify({
type: 'error',
message: error instanceof Error ? error.message : 'Unknown error',
})
}
```
## Legacy Migration
When touching old code, migrate it toward these rules:
| Old pattern | New pattern |
|---|---|
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |

View File

@@ -0,0 +1,103 @@
---
name: orpc-contract-first
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Trigger when creating or updating contracts in web/contract, wiring router composition, integrating TanStack Query with typed contracts, migrating legacy service calls to oRPC, or deciding whether to call queryOptions directly vs extracting a helper or use-* hook in web/service.
---
# oRPC Contract-First Development
## Intent
- Keep contract as single source of truth in `web/contract/*`.
- Default query usage: call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
- Keep abstractions minimal and preserve TypeScript inference.
## Minimal Structure
```text
web/contract/
├── base.ts
├── router.ts
├── marketplace.ts
└── console/
├── billing.ts
└── ...other domains
web/service/client.ts
```
## Core Workflow
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`
- Use `base.route({...}).output(type<...>())` as baseline.
- Add `.input(type<...>())` only when request has `params/query/body`.
- For `GET` without input, omit `.input(...)` (do not use `.input(type<unknown>())`).
2. Register contract in `web/contract/router.ts`
- Import directly from domain files and nest by API prefix.
3. Consume from UI call sites via oRPC query utils.
```typescript
import { useQuery } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
staleTime: 5 * 60 * 1000,
throwOnError: true,
select: invoice => invoice.url,
}))
```
## Query Usage Decision Rule
1. Default: call site directly uses `*.queryOptions(...)`.
2. If 3+ call sites share the same extra options (for example `retry: false`), extract a small queryOptions helper, not a `use-*` passthrough hook.
3. Create `web/service/use-{domain}.ts` only for orchestration:
- Combine multiple queries/mutations.
- Share domain-level derived state or invalidation helpers.
```typescript
const invoicesBaseQueryOptions = () =>
consoleQuery.billing.invoices.queryOptions({ retry: false })
const invoiceQuery = useQuery({
...invoicesBaseQueryOptions(),
throwOnError: true,
})
```
## Mutation Usage Decision Rule
1. Default: call mutation helpers from `consoleQuery` / `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If mutation flow is heavily custom, use oRPC clients as `mutationFn` (for example `consoleClient.xxx` / `marketplaceClient.xxx`), instead of generic handwritten non-oRPC mutation logic.
## Key API Guide (`.key` vs `.queryKey` vs `.mutationKey`)
- `.key(...)`:
- Use for partial matching operations (recommended for invalidation/refetch/cancel patterns).
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
- `.queryKey(...)`:
- Use for a specific query's full key (exact query identity / direct cache addressing).
- `.mutationKey(...)`:
- Use for a specific mutation's full key.
- Typical use cases: mutation defaults registration, mutation-status filtering (`useIsMutating`, `queryClient.isMutating`), or explicit devtools grouping.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey/queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- Reason: these patterns can degrade inference (`data` may become `unknown`, especially around `throwOnError`/`select`) and add unnecessary indirection.
## Contract Rules
- **Input structure**: Always use `{ params, query?, body? }` format
- **No-input GET**: Omit `.input(...)`; do not use `.input(type<unknown>())`
- **Path params**: Use `{paramName}` in path, match in `params` object
- **Router nesting**: Group by API prefix (e.g., `/billing/*` -> `billing: {}`)
- **No barrel files**: Import directly from specific files
- **Types**: Import from `@/types/`, use `type<T>()` helper
- **Mutations**: Prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults/filtering/devtools
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@@ -1 +0,0 @@
../../.agents/skills/frontend-query-mutation

View File

@@ -0,0 +1 @@
../../.agents/skills/orpc-contract-first

3
.gitignore vendored
View File

@@ -237,6 +237,3 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md
# Code Agent Folder
.qoder/*

View File

@@ -22,10 +22,10 @@ APP_WEB_URL=http://localhost:3000
# Files URL
FILES_URL=http://localhost:5001
# INTERNAL_FILES_URL is used by services running in Docker to reach the API file endpoints.
# For Docker Desktop (Mac/Windows), use http://host.docker.internal:5001 when the API runs on the host.
# For Docker Compose on Linux, use http://api:5001 when the API runs inside the Docker network.
INTERNAL_FILES_URL=http://host.docker.internal:5001
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
# Set this to the internal Docker service URL for proper plugin file access.
# Example: INTERNAL_FILES_URL=http://api:5001
INTERNAL_FILES_URL=http://127.0.0.1:5001
# TRIGGER URL
TRIGGER_URL=http://localhost:5001
@@ -180,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@@ -217,20 +217,6 @@ COUCHBASE_PASSWORD=password
COUCHBASE_BUCKET_NAME=Embeddings
COUCHBASE_SCOPE_NAME=_default
# Hologres configuration
# access_key_id is used as the PG username, access_key_secret is used as the PG password
HOLOGRES_HOST=
HOLOGRES_PORT=80
HOLOGRES_DATABASE=
HOLOGRES_ACCESS_KEY_ID=
HOLOGRES_ACCESS_KEY_SECRET=
HOLOGRES_SCHEMA=public
HOLOGRES_TOKENIZER=jieba
HOLOGRES_DISTANCE_METHOD=Cosine
HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq
HOLOGRES_MAX_DEGREE=64
HOLOGRES_EF_CONSTRUCTION=400
# Milvus configuration
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=

View File

@@ -43,6 +43,7 @@ forbidden_modules =
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
@@ -89,6 +90,9 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
dify_graph.nodes.agent.agent_node -> core.model_manager
dify_graph.nodes.agent.agent_node -> core.provider_manager
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
dify_graph.nodes.llm.llm_utils -> core.model_manager
dify_graph.nodes.llm.protocols -> core.model_manager
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
@@ -96,6 +100,9 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
dify_graph.nodes.agent.agent_node -> core.agent.entities
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
@@ -103,10 +110,12 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager
dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
@@ -115,12 +124,17 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
dify_graph.nodes.llm.node -> models.dataset
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.file_saver -> core.tools.signature
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.agent.agent_node -> services
dify_graph.nodes.tool.tool_node -> services
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis

View File

@@ -160,7 +160,6 @@ def migrate_knowledge_vector_database():
}
lower_collection_vector_types = {
VectorType.ANALYTICDB,
VectorType.HOLOGRES,
VectorType.CHROMA,
VectorType.MYSCALE,
VectorType.PGVECTO_RS,

View File

@@ -26,7 +26,6 @@ from .vdb.chroma_config import ChromaConfig
from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.hologres_config import HologresConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
@@ -348,7 +347,6 @@ class MiddlewareConfig(
AnalyticdbConfig,
ChromaConfig,
ClickzettaConfig,
HologresConfig,
HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig,

View File

@@ -1,68 +0,0 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
class HologresConfig(BaseSettings):
"""
Configuration settings for Hologres vector database.
Hologres is compatible with PostgreSQL protocol.
access_key_id is used as the PostgreSQL username,
and access_key_secret is used as the PostgreSQL password.
"""
HOLOGRES_HOST: str | None = Field(
description="Hostname or IP address of the Hologres instance.",
default=None,
)
HOLOGRES_PORT: int = Field(
description="Port number for connecting to the Hologres instance.",
default=80,
)
HOLOGRES_DATABASE: str | None = Field(
description="Name of the Hologres database to connect to.",
default=None,
)
HOLOGRES_ACCESS_KEY_ID: str | None = Field(
description="Alibaba Cloud AccessKey ID, also used as the PostgreSQL username.",
default=None,
)
HOLOGRES_ACCESS_KEY_SECRET: str | None = Field(
description="Alibaba Cloud AccessKey Secret, also used as the PostgreSQL password.",
default=None,
)
HOLOGRES_SCHEMA: str = Field(
description="Schema name in the Hologres database.",
default="public",
)
HOLOGRES_TOKENIZER: TokenizerType = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)
HOLOGRES_MAX_DEGREE: int = Field(
description="Max degree (M) parameter for HNSW vector index.",
default=64,
)
HOLOGRES_EF_CONSTRUCTION: int = Field(
description="ef_construction parameter for HNSW vector index.",
default=400,
)

View File

@@ -25,8 +25,7 @@ from controllers.console.wraps import (
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.enums import NodeType, WorkflowExecutionStatus
from dify_graph.file import helpers as file_helpers
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
@@ -509,7 +508,11 @@ class AppListApi(Resource):
.scalars()
.all()
)
trigger_node_types = TRIGGER_NODE_TYPES
trigger_node_types = {
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
node_id = None
try:

View File

@@ -1,4 +1,5 @@
import json
from enum import StrEnum
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
@@ -10,7 +11,6 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -19,6 +19,11 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
app_server_model = console_ns.model("AppServer", app_server_fields)
class AppMCPServerStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
@@ -112,10 +117,9 @@ class AppMCPServerController(Resource):
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
try:
server.status = AppMCPServerStatus(payload.status)
except ValueError:
if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
server.status = payload.status
db.session.commit()
return server

View File

@@ -22,7 +22,6 @@ from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.trace_id_helper import get_external_trace_id
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from core.trigger.debug.event_selectors import (
TriggerDebugEvent,
TriggerDebugEventPoller,
@@ -1210,7 +1209,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
event: TriggerDebugEvent | None = None
# for schedule trigger, when run single node, just execute directly
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
if node_type == NodeType.TRIGGER_SCHEDULE:
event = TriggerDebugEvent(
workflow_args={},
node_id=node_id,

View File

@@ -263,7 +263,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
VectorType.HOLOGRES,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}

View File

@@ -43,7 +43,7 @@ from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
from models.account import AccountStatus
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -216,7 +216,7 @@ class AccountInitApi(Resource):
db.session.query(InvitationCode)
.where(
InvitationCode.code == args.invitation_code,
InvitationCode.status == InvitationCodeStatus.UNUSED,
InvitationCode.status == "unused",
)
.first()
)
@@ -224,7 +224,7 @@ class AccountInitApi(Resource):
if not invitation_code:
raise InvalidInvitationCodeError()
invitation_code.status = InvitationCodeStatus.USED
invitation_code.status = "used"
invitation_code.used_at = naive_utc_now()
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id

View File

@@ -5,7 +5,6 @@ from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -170,20 +169,6 @@ register_enum_models(
)
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
"""
Read the uploaded file and validate its actual size before delegating to the plugin service.
FileStorage.content_length is not reliable for multipart test uploads and may be zero even when
content exists, so the controllers validate against the loaded bytes instead.
"""
content = file.read()
if len(content) > max_size:
raise ValueError("File size exceeds the maximum allowed size")
return content
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@@ -299,7 +284,12 @@ class PluginUploadFromPkgApi(Resource):
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
# check file size
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
@@ -338,7 +328,12 @@ class PluginUploadFromBundleApi(Resource):
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
# check file size
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_bundle(tenant_id, content)
except PluginDaemonClientSideError as e:

View File

@@ -6,13 +6,13 @@ from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from dify_graph.variables.input_entities import VariableEntity
from extensions.ext_database import db
from libs import helper
from models.enums import AppMCPServerStatus
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@@ -6,7 +6,6 @@ from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
from core.agent.errors import AgentMaxIterationError
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
@@ -23,6 +22,7 @@ from dify_graph.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from dify_graph.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)

View File

@@ -1,9 +0,0 @@
class AgentMaxIterationError(Exception):
"""Raised when an agent runner exceeds the configured max iteration count."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@@ -5,7 +5,6 @@ from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.errors import AgentMaxIterationError
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
@@ -26,6 +25,7 @@ from dify_graph.model_runtime.entities import (
UserPromptMessage,
)
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from dify_graph.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)

View File

@@ -69,7 +69,7 @@ from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.nodes import NodeType
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
from dify_graph.runtime import GraphRuntimeState
from dify_graph.system_variable import SystemVariable
@@ -357,7 +357,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
# Record files if it's an answer node or end node
if event.node_type in [BuiltinNodeTypes.ANSWER, BuiltinNodeTypes.END, BuiltinNodeTypes.LLM]:
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
self._recorded_files.extend(
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)

View File

@@ -48,13 +48,12 @@ from core.app.entities.task_entities import (
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import (
BuiltinNodeTypes,
NodeType,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
@@ -443,7 +442,7 @@ class WorkflowResponseConverter:
event: QueueNodeStartedEvent,
task_id: str,
) -> NodeStartStreamResponse | None:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._store_snapshot(event)
@@ -465,13 +464,13 @@ class WorkflowResponseConverter:
)
try:
if event.node_type == BuiltinNodeTypes.TOOL:
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == BuiltinNodeTypes.DATASOURCE:
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
@@ -480,7 +479,7 @@ class WorkflowResponseConverter:
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == TRIGGER_PLUGIN_NODE_TYPE:
elif event.node_type == NodeType.TRIGGER_PLUGIN:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
@@ -497,7 +496,7 @@ class WorkflowResponseConverter:
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
) -> NodeFinishStreamResponse | None:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._pop_snapshot(event.node_execution_id)
@@ -555,7 +554,7 @@ class WorkflowResponseConverter:
event: QueueNodeRetryEvent,
task_id: str,
) -> NodeRetryStreamResponse | None:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
@@ -613,7 +612,7 @@ class WorkflowResponseConverter:
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
created_at=int(time.time()),
extras={},
@@ -636,7 +635,7 @@ class WorkflowResponseConverter:
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
index=event.index,
created_at=int(time.time()),
@@ -663,7 +662,7 @@ class WorkflowResponseConverter:
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
@@ -693,7 +692,7 @@ class WorkflowResponseConverter:
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
created_at=int(time.time()),
extras={},
@@ -716,7 +715,7 @@ class WorkflowResponseConverter:
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
index=event.index,
# The `pre_loop_output` field is not utilized by the frontend.
@@ -745,7 +744,7 @@ class WorkflowResponseConverter:
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,

View File

@@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
build_dify_run_context,
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.graph_init_params import GraphInitParams
from dify_graph.enums import WorkflowType
@@ -274,8 +274,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if start_node_id is None:
start_node_id = get_default_root_node_id(graph_config)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
if not graph:

View File

@@ -3,10 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.queue_entities import (
AppQueueEvent,
@@ -32,8 +29,8 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
@@ -140,9 +137,6 @@ class WorkflowBasedAppRunner:
graph_runtime_state=graph_runtime_state,
)
if root_node_id is None:
root_node_id = get_default_root_node_id(graph_config)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
@@ -314,7 +308,7 @@ class WorkflowBasedAppRunner:
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
@@ -342,18 +336,6 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
@staticmethod
def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None:
raw_agent_strategy = event.extras.get("agent_strategy")
if raw_agent_strategy is None:
return None
try:
return AgentStrategyInfo.model_validate(raw_agent_strategy)
except ValidationError:
logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True)
return None
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
@@ -439,7 +421,7 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=self._build_agent_strategy_info(event),
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
@@ -508,9 +490,7 @@ class WorkflowBasedAppRunner:
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=[
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
],
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)

View File

@@ -1,3 +0,0 @@
from .agent_strategy import AgentStrategyInfo
__all__ = ["AgentStrategyInfo"]

View File

@@ -1,8 +0,0 @@
from pydantic import BaseModel, ConfigDict
class AgentStrategyInfo(BaseModel):
name: str
icon: str | None = None
model_config = ConfigDict(extra="forbid")

View File

@@ -5,12 +5,13 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.nodes import NodeType
class QueueEvent(StrEnum):
@@ -313,9 +314,9 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
agent_strategy: AgentStrategyInfo | None = None
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
# Legacy provider fields kept for existing start-event consumers.
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
provider_id: str

View File

@@ -4,8 +4,8 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict[str, object] = Field(default_factory=dict)
iteration_id: str | None = None
loop_id: str | None = None
agent_strategy: AgentStrategyInfo | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str

View File

@@ -2,7 +2,7 @@ import logging
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
@@ -22,7 +22,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER:
if event.node_type != NodeType.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
return

View File

@@ -12,7 +12,7 @@ from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase
@@ -113,11 +113,11 @@ class LLMQuotaLayer(GraphEngineLayer):
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case BuiltinNodeTypes.LLM:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None

View File

@@ -16,7 +16,7 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_
from typing_extensions import override
from configs import dify_config
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.enums import NodeType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphNodeEventBase
from dify_graph.nodes.base.node import Node
@@ -74,13 +74,16 @@ class ObservabilityLayer(GraphEngineLayer):
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
BuiltinNodeTypes.TOOL: ToolNodeOTelParser(),
BuiltinNodeTypes.LLM: LLMNodeOTelParser(),
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
NodeType.TOOL: ToolNodeOTelParser(),
NodeType.LLM: LLMNodeOTelParser(),
NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
return self._parsers.get(node.node_type, self._default_parser)
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
@override
def on_graph_start(self) -> None:

View File

@@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.file import File
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile

View File

@@ -58,7 +58,7 @@ from core.ops.entities.trace_entity import (
)
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import WorkflowNodeExecutionTriggeredFrom
@@ -302,11 +302,11 @@ class AliyunDataTrace(BaseTraceInstance):
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
):
try:
if node_execution.node_type == BuiltinNodeTypes.LLM:
if node_execution.node_type == NodeType.LLM:
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
elif node_execution.node_type == NodeType.TOOL:
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
else:
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)

View File

@@ -155,8 +155,8 @@ def wrap_span_metadata(metadata, **kwargs):
return metadata
# Mapping from built-in node type strings to OpenInference span kinds.
# Node types not listed here default to CHAIN.
# Mapping from NodeType string values to OpenInference span kinds.
# NodeType values not listed here default to CHAIN.
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
"llm": OpenInferenceSpanKindValues.LLM,
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
@@ -168,7 +168,7 @@ _NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
"""Return the OpenInference span kind for a given workflow node type.
Covers every built-in node type string. Nodes that do not have a
Covers every ``NodeType`` enum value. Nodes that do not have a
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
"""

View File

@@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
@@ -141,7 +141,7 @@ class LangFuseDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@@ -163,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}
@@ -197,7 +197,7 @@ class LangSmithDataTrace(BaseTraceInstance):
"ls_model_name": process_data.get("model_name", ""),
}
)
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
run_type = LangSmithRunType.retriever
else:
run_type = LangSmithRunType.tool

View File

@@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
@@ -145,10 +145,10 @@ class MLflowDataTrace(BaseTraceInstance):
"app_name": node.title,
}
if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER):
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
attributes.update(llm_attributes)
elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST:
elif node.node_type == NodeType.HTTP_REQUEST:
inputs = node.process_data # contains request URL
if not inputs:
@@ -180,9 +180,9 @@ class MLflowDataTrace(BaseTraceInstance):
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = json.loads(node.outputs) if node.outputs else {}
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == BuiltinNodeTypes.LLM:
elif node.node_type == NodeType.LLM:
outputs = outputs.get("text", outputs)
node_span.end(
outputs=outputs,
@@ -471,13 +471,13 @@ class MLflowDataTrace(BaseTraceInstance):
def _get_node_span_type(self, node_type: str) -> str:
"""Map Dify node types to MLflow span types"""
node_type_mapping = {
BuiltinNodeTypes.LLM: SpanType.LLM,
BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
BuiltinNodeTypes.TOOL: SpanType.TOOL,
BuiltinNodeTypes.CODE: SpanType.TOOL,
BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL,
BuiltinNodeTypes.AGENT: SpanType.AGENT,
NodeType.LLM: SpanType.LLM,
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
NodeType.TOOL: SpanType.TOOL,
NodeType.CODE: SpanType.TOOL,
NodeType.HTTP_REQUEST: SpanType.TOOL,
NodeType.AGENT: SpanType.AGENT,
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]

View File

@@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@@ -187,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@@ -27,7 +27,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from dify_graph.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.nodes import NodeType
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
@@ -179,7 +179,7 @@ class TencentDataTrace(BaseTraceInstance):
if node_span:
self.trace_client.add_span(node_span)
if node_execution.node_type == BuiltinNodeTypes.LLM:
if node_execution.node_type == NodeType.LLM:
self._record_llm_metrics(node_execution)
except Exception:
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
@@ -192,15 +192,15 @@ class TencentDataTrace(BaseTraceInstance):
) -> SpanData | None:
"""Build span for different node types"""
try:
if node_execution.node_type == BuiltinNodeTypes.LLM:
if node_execution.node_type == NodeType.LLM:
return TencentSpanBuilder.build_workflow_llm_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return TencentSpanBuilder.build_workflow_retrieval_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
elif node_execution.node_type == NodeType.TOOL:
return TencentSpanBuilder.build_workflow_tool_span(
trace_id, workflow_span_id, trace_info, node_execution
)

View File

@@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@@ -175,7 +175,7 @@ class WeaveDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@@ -1,5 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
@@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
node_data_dict["type"] = BuiltinNodeTypes.PARAMETER_EXTRACTOR
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,

View File

@@ -1,361 +0,0 @@
import json
import logging
import time
from typing import Any
import holo_search_sdk as holo # type: ignore
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from psycopg import sql as psql
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class HologresVectorConfig(BaseModel):
"""
Configuration for Hologres vector database connection.
In Hologres, access_key_id is used as the PostgreSQL username,
and access_key_secret is used as the PostgreSQL password.
"""
host: str
port: int = 80
database: str
access_key_id: str
access_key_secret: str
schema_name: str = "public"
tokenizer: TokenizerType = "jieba"
distance_method: DistanceType = "Cosine"
base_quantization_type: BaseQuantizationType = "rabitq"
max_degree: int = 64
ef_construction: int = 400
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config HOLOGRES_HOST is required")
if not values.get("database"):
raise ValueError("config HOLOGRES_DATABASE is required")
if not values.get("access_key_id"):
raise ValueError("config HOLOGRES_ACCESS_KEY_ID is required")
if not values.get("access_key_secret"):
raise ValueError("config HOLOGRES_ACCESS_KEY_SECRET is required")
return values
class HologresVector(BaseVector):
"""
Hologres vector storage implementation using holo-search-sdk.
Supports semantic search (vector), full-text search, and hybrid search.
"""
def __init__(self, collection_name: str, config: HologresVectorConfig):
super().__init__(collection_name)
self._config = config
self._client = self._init_client(config)
self.table_name = f"embedding_{collection_name}".lower()
def _init_client(self, config: HologresVectorConfig):
"""Initialize and return a holo-search-sdk client."""
client = holo.connect(
host=config.host,
port=config.port,
database=config.database,
access_key_id=config.access_key_id,
access_key_secret=config.access_key_secret,
schema=config.schema_name,
)
client.connect()
return client
def get_type(self) -> str:
return VectorType.HOLOGRES
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create collection table with vector and full-text indexes, then add texts."""
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add texts with embeddings to the collection using batch upsert."""
if not documents:
return []
pks: list[str] = []
batch_size = 100
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
values = []
column_names = ["id", "text", "meta", "embedding"]
for j, doc in enumerate(batch_docs):
doc_id = doc.metadata.get("doc_id", "") if doc.metadata else ""
pks.append(doc_id)
values.append(
[
doc_id,
doc.page_content,
json.dumps(doc.metadata or {}),
batch_embeddings[j],
]
)
table = self._client.open_table(self.table_name)
table.upsert_multi(
index_column="id",
values=values,
column_names=column_names,
update=True,
update_columns=["text", "meta", "embedding"],
)
return pks
def text_exists(self, id: str) -> bool:
"""Check if a text with the given doc_id exists in the collection."""
if not self._client.check_table_exist(self.table_name):
return False
result = self._client.execute(
psql.SQL("SELECT 1 FROM {} WHERE id = {} LIMIT 1").format(
psql.Identifier(self.table_name), psql.Literal(id)
),
fetch_result=True,
)
return bool(result)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
"""Get document IDs by metadata field key and value."""
result = self._client.execute(
psql.SQL("SELECT id FROM {} WHERE meta->>{} = {}").format(
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
),
fetch_result=True,
)
if result:
return [row[0] for row in result]
return None
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their doc_id list."""
if not ids:
return
if not self._client.check_table_exist(self.table_name):
return
self._client.execute(
psql.SQL("DELETE FROM {} WHERE id IN ({})").format(
psql.Identifier(self.table_name),
psql.SQL(", ").join(psql.Literal(id) for id in ids),
)
)
def delete_by_metadata_field(self, key: str, value: str):
"""Delete documents by metadata field key and value."""
if not self._client.check_table_exist(self.table_name):
return
self._client.execute(
psql.SQL("DELETE FROM {} WHERE meta->>{} = {}").format(
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
)
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search for documents by vector similarity."""
if not self._client.check_table_exist(self.table_name):
return []
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
table = self._client.open_table(self.table_name)
query = (
table.search_vector(
vector=query_vector,
column="embedding",
distance_method=self._config.distance_method,
output_name="distance",
)
.select(["id", "text", "meta"])
.limit(top_k)
)
# Apply document_ids_filter if provided
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
)
query = query.where(filter_sql)
results = query.fetchall()
return self._process_vector_results(results, score_threshold)
def _process_vector_results(self, results: list, score_threshold: float) -> list[Document]:
"""Process vector search results into Document objects."""
docs = []
for row in results:
# row format: (distance, id, text, meta)
# distance is first because search_vector() adds the computed column before selected columns
distance = row[0]
text = row[2]
meta = row[3]
if isinstance(meta, str):
meta = json.loads(meta)
# Convert distance to similarity score (consistent with pgvector)
score = 1 - distance
meta["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=meta))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search for documents by full-text search."""
if not self._client.check_table_exist(self.table_name):
return []
top_k = kwargs.get("top_k", 4)
table = self._client.open_table(self.table_name)
search_query = table.search_text(
column="text",
expression=query,
return_score=True,
return_score_name="score",
return_all_columns=True,
).limit(top_k)
# Apply document_ids_filter if provided
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
)
search_query = search_query.where(filter_sql)
results = search_query.fetchall()
return self._process_full_text_results(results)
def _process_full_text_results(self, results: list) -> list[Document]:
"""Process full-text search results into Document objects."""
docs = []
for row in results:
# row format: (id, text, meta, embedding, score)
text = row[1]
meta = row[2]
score = row[-1] # score is the last column from return_score
if isinstance(meta, str):
meta = json.loads(meta)
meta["score"] = score
docs.append(Document(page_content=text, metadata=meta))
return docs
def delete(self):
"""Delete the entire collection table."""
if self._client.check_table_exist(self.table_name):
self._client.drop_table(self.table_name)
def _create_collection(self, dimension: int):
"""Create the collection table with vector and full-text indexes."""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
if not self._client.check_table_exist(self.table_name):
# Create table via SQL with CHECK constraint for vector dimension
create_table_sql = psql.SQL("""
CREATE TABLE IF NOT EXISTS {} (
id TEXT PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding float4[] NOT NULL
CHECK (array_ndims(embedding) = 1
AND array_length(embedding, 1) = {})
);
""").format(psql.Identifier(self.table_name), psql.Literal(dimension))
self._client.execute(create_table_sql)
# Wait for table to be fully ready before creating indexes
max_wait_seconds = 30
poll_interval = 2
for _ in range(max_wait_seconds // poll_interval):
if self._client.check_table_exist(self.table_name):
break
time.sleep(poll_interval)
else:
raise RuntimeError(f"Table {self.table_name} was not ready after {max_wait_seconds}s")
# Open table and set vector index
table = self._client.open_table(self.table_name)
table.set_vector_index(
column="embedding",
distance_method=self._config.distance_method,
base_quantization_type=self._config.base_quantization_type,
max_degree=self._config.max_degree,
ef_construction=self._config.ef_construction,
use_reorder=self._config.base_quantization_type == "rabitq",
)
# Create full-text search index
table.create_text_index(
index_name=f"ft_idx_{self._collection_name}",
column="text",
tokenizer=self._config.tokenizer,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class HologresVectorFactory(AbstractVectorFactory):
"""Factory class for creating HologresVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HOLOGRES, collection_name))
return HologresVector(
collection_name=collection_name,
config=HologresVectorConfig(
host=dify_config.HOLOGRES_HOST or "",
port=dify_config.HOLOGRES_PORT,
database=dify_config.HOLOGRES_DATABASE or "",
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
schema_name=dify_config.HOLOGRES_SCHEMA,
tokenizer=dify_config.HOLOGRES_TOKENIZER,
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
),
)

View File

@@ -38,7 +38,7 @@ class AbstractVectorFactory(ABC):
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
@@ -191,10 +191,6 @@ class Vector:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@@ -34,4 +34,3 @@ class VectorType(StrEnum):
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"
IRIS = "iris"
HOLOGRES = "hologres"

View File

@@ -196,7 +196,6 @@ class WeaviateVector(BaseVector):
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_type", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
],
vector_config=wc.Configure.Vectors.self_provided(),
@@ -226,8 +225,6 @@ class WeaviateVector(BaseVector):
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
if "doc_id" not in existing:
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
if "doc_type" not in existing:
to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))

View File

@@ -9,8 +9,8 @@ from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory

View File

@@ -56,18 +56,18 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.nodes.knowledge_retrieval.retrieval import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.knowledge_retrieval import exc
from dify_graph.repositories.rag_retrieval_protocol import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown

View File

@@ -18,7 +18,7 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att
from configs import dify_config
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
@@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
index=db_model.index,
predecessor_node_id=db_model.predecessor_node_id,
node_id=db_model.node_id,
node_type=db_model.node_type,
node_type=NodeType(db_model.node_type),
title=db_model.title,
inputs=inputs,
process_data=process_data,

View File

@@ -116,7 +116,6 @@ class ToolParameterConfigurationManager:
return a deep copy of parameters with decrypted values
"""
parameters = self._deep_copy(parameters)
cache = ToolParameterCache(
tenant_id=self.tenant_id,

View File

@@ -3,7 +3,7 @@ from typing import Any
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import OutputVariableEntity
from dify_graph.variables.input_entities import VariableEntity
@@ -51,7 +51,7 @@ class WorkflowToolConfigurationUtils:
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == BuiltinNodeTypes.HUMAN_INPUT:
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
raise WorkflowToolHumanInputNotSupportedError()
@classmethod

View File

@@ -1,18 +0,0 @@
from typing import Final
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
{
TRIGGER_WEBHOOK_NODE_TYPE,
TRIGGER_SCHEDULE_NODE_TYPE,
TRIGGER_PLUGIN_NODE_TYPE,
}
)
def is_trigger_node_type(node_type: str) -> bool:
return node_type in TRIGGER_NODE_TYPES

View File

@@ -11,11 +11,6 @@ from typing import Any
from pydantic import BaseModel
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.trigger.constants import (
TRIGGER_PLUGIN_NODE_TYPE,
TRIGGER_SCHEDULE_NODE_TYPE,
TRIGGER_WEBHOOK_NODE_TYPE,
)
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import (
PluginTriggerDebugEvent,
@@ -27,6 +22,7 @@ from core.trigger.debug.events import (
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType
from extensions.ext_redis import redis_client
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
from libs.schedule_utils import calculate_next_run_at
@@ -210,19 +206,21 @@ def create_event_poller(
if not node_config:
raise ValueError("Node data not found for node %s", node_id)
node_type = draft_workflow.get_node_type_from_node_config(node_config)
if node_type == TRIGGER_PLUGIN_NODE_TYPE:
return PluginTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
if node_type == TRIGGER_WEBHOOK_NODE_TYPE:
return WebhookTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
return ScheduleTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
raise ValueError("unable to create event poller for node type %s", node_type)
match node_type:
case NodeType.TRIGGER_PLUGIN:
return PluginTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case NodeType.TRIGGER_WEBHOOK:
return WebhookTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case NodeType.TRIGGER_SCHEDULE:
return ScheduleTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case _:
raise ValueError("unable to create event poller for node type %s", node_type)
def select_trigger_debug_events(

View File

@@ -1 +1,4 @@
"""Core workflow package."""
from .node_factory import DifyNodeFactory
from .workflow_entry import WorkflowEntry
__all__ = ["DifyNodeFactory", "WorkflowEntry"]

View File

@@ -1,7 +1,4 @@
import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from functools import lru_cache
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
from sqlalchemy import select
@@ -11,6 +8,7 @@ from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
@@ -19,19 +17,16 @@ from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyPresentationProvider,
PluginAgentStrategyResolver,
)
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey
from dify_graph.enums import NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
from dify_graph.graph.graph import NodeFactory
from dify_graph.model_runtime.entities.model_entities import ModelType
@@ -58,135 +53,6 @@ if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
LATEST_VERSION = "latest"
_START_NODE_TYPES: frozenset[NodeType] = frozenset(
(BuiltinNodeTypes.START, BuiltinNodeTypes.DATASOURCE, *TRIGGER_NODE_TYPES)
)
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
if module_name in excluded_modules:
continue
importlib.import_module(module_name)
@lru_cache(maxsize=1)
def register_nodes() -> None:
"""Import production node modules so they self-register with ``Node``."""
_import_node_package("dify_graph.nodes")
_import_node_package("core.workflow.nodes")
def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return a read-only snapshot of the current production node registry.
The workflow layer owns node bootstrap because it must compose built-in
`dify_graph.nodes.*` implementations with workflow-local nodes under
`core.workflow.nodes.*`. Keeping this import side effect here avoids
reintroducing registry bootstrapping into lower-level graph primitives.
"""
register_nodes()
return Node.get_node_type_classes_mapping()
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class
def is_start_node_type(node_type: NodeType) -> bool:
"""Return True when the node type can serve as a workflow entry point."""
return node_type in _START_NODE_TYPES
def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str:
"""Resolve the default entry node for a persisted top-level workflow graph.
This workflow-layer helper depends on start-node semantics defined by
`is_start_node_type`, so it intentionally lives next to the node registry
instead of in the raw `dify_graph.entities.graph_config` schema module.
"""
nodes = graph_config.get("nodes")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
for node in nodes:
if not isinstance(node, Mapping):
continue
if node.get("type") == "custom-note":
continue
node_id = node.get("id")
data = node.get("data")
if not isinstance(node_id, str) or not isinstance(data, Mapping):
continue
node_type = data.get("type")
if isinstance(node_type, str) and is_start_node_type(node_type):
return node_id
raise ValueError("Unable to determine default root node ID from workflow graph")
class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Node]]]):
"""Mutable dict-like view over the current node registry."""
def __init__(self) -> None:
self._cached_snapshot: dict[NodeType, Mapping[str, type[Node]]] = {}
self._cached_version = -1
self._deleted: set[NodeType] = set()
self._overrides: dict[NodeType, Mapping[str, type[Node]]] = {}
def _snapshot(self) -> dict[NodeType, Mapping[str, type[Node]]]:
current_version = Node.get_registry_version()
if self._cached_version != current_version:
self._cached_snapshot = dict(get_node_type_classes_mapping())
self._cached_version = current_version
if not self._deleted and not self._overrides:
return self._cached_snapshot
snapshot = {key: value for key, value in self._cached_snapshot.items() if key not in self._deleted}
snapshot.update(self._overrides)
return snapshot
def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]:
return self._snapshot()[key]
def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None:
self._deleted.discard(key)
self._overrides[key] = value
def __delitem__(self, key: NodeType) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._cached_snapshot:
self._deleted.add(key)
return
raise KeyError(key)
def __iter__(self) -> Iterator[NodeType]:
return iter(self._snapshot())
def __len__(self) -> int:
return len(self._snapshot())
# Keep the canonical node-class mapping in the workflow layer that also bootstraps
# legacy `core.workflow.nodes.*` registrations.
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
@@ -231,7 +97,10 @@ class DefaultWorkflowCodeExecutor:
@final
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that resolves node classes from the live registry.
Default implementation of NodeFactory that uses the traditional node mapping.
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
and instantiating the appropriate node class.
"""
def __init__(
@@ -258,6 +127,7 @@ class DifyNodeFactory(NodeFactory):
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
self._http_request_file_manager = file_manager
self._rag_retrieval = DatasetRetrieval()
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
@@ -273,10 +143,6 @@ class DifyNodeFactory(NodeFactory):
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
self._agent_strategy_resolver = PluginAgentStrategyResolver()
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
self._agent_runtime_support = AgentRuntimeSupport()
self._agent_message_transformer = AgentMessageTransformer()
@staticmethod
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
@@ -304,51 +170,55 @@ class DifyNodeFactory(NodeFactory):
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
NodeType.CODE: lambda: {
"code_executor": self._code_executor,
"code_limits": self._code_limits,
},
BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: {
NodeType.TEMPLATE_TRANSFORM: lambda: {
"template_renderer": self._template_renderer,
"max_output_length": self._template_transform_max_output_length,
},
BuiltinNodeTypes.HTTP_REQUEST: lambda: {
NodeType.HTTP_REQUEST: lambda: {
"http_request_config": self._http_request_config,
"http_client": self._http_request_http_client,
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
"file_manager": self._http_request_file_manager,
},
BuiltinNodeTypes.HUMAN_INPUT: lambda: {
NodeType.HUMAN_INPUT: lambda: {
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
NodeType.KNOWLEDGE_INDEX: lambda: {
"index_processor": IndexProcessor(),
"summary_index_service": SummaryIndex(),
},
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
NodeType.DATASOURCE: lambda: {
"datasource_manager": DatasourceManager,
},
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
"rag_retrieval": self._rag_retrieval,
},
NodeType.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
"http_client": self._http_request_http_client,
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=False,
),
BuiltinNodeTypes.TOOL: lambda: {
NodeType.TOOL: lambda: {
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
},
BuiltinNodeTypes.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
},
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
@@ -368,7 +238,16 @@ class DifyNodeFactory(NodeFactory):
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class
def _build_llm_compatible_node_init_kwargs(
self,

View File

@@ -1 +1 @@
"""Workflow node implementations that remain under the legacy core.workflow namespace."""
"""Core-owned workflow node packages."""

View File

@@ -1,4 +0,0 @@
from .agent_node import AgentNode
from .entities import AgentNodeData
__all__ = ["AgentNode", "AgentNodeData"]

View File

@@ -1,188 +0,0 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from .entities import AgentNodeData
from .exceptions import (
AgentInvocationError,
AgentMessageTransformError,
)
from .message_transformer import AgentMessageTransformer
from .runtime_support import AgentRuntimeSupport
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
class AgentNode(Node[AgentNodeData]):
node_type = BuiltinNodeTypes.AGENT
_strategy_resolver: AgentStrategyResolver
_presentation_provider: AgentStrategyPresentationProvider
_runtime_support: AgentRuntimeSupport
_message_transformer: AgentMessageTransformer
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._strategy_resolver = strategy_resolver
self._presentation_provider = presentation_provider
self._runtime_support = runtime_support
self._message_transformer = message_transformer
@classmethod
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
dify_ctx = self.require_dify_context()
event.extras["agent_strategy"] = {
"name": self.node_data.agent_strategy_name,
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
}
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = self._strategy_resolver.resolve(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
parameters = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
)
parameters_for_log = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
for_log=True,
)
credentials = self._runtime_support.build_credentials(parameters=parameters)
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._message_transformer.transform(
messages=message_stream,
tool_info={
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result

View File

@@ -1,292 +0,0 @@
from __future__ import annotations
from collections.abc import Generator, Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from dify_graph.variables.segments import ArrayFileSegment
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
class AgentMessageTransformer:
def transform(
self,
*,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == BuiltinNodeTypes.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
for log in agent_logs:
if log.message_id == agent_log.message_id:
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
json_output: list[dict[str, Any] | list[Any]] = []
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -1,40 +0,0 @@
from __future__ import annotations
from factories.agent_factory import get_plugin_agent_strategy
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
class PluginAgentStrategyResolver(AgentStrategyResolver):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy:
return get_plugin_agent_strategy(
tenant_id=tenant_id,
agent_strategy_provider_name=agent_strategy_provider_name,
agent_strategy_name=agent_strategy_name,
)
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
try:
plugins = manager.list_plugins(tenant_id)
except Exception:
return None
try:
current_plugin = next(
plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name
)
except StopIteration:
return None
return current_plugin.declaration.icon

View File

@@ -1,276 +0,0 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from dify_graph.enums import SystemVariableKey
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
class AgentRuntimeSupport:
def build_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
for_log: bool = False,
) -> dict[str, Any]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id,
app_id,
entity,
invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
memory = self.fetch_memory(
variable_pool=variable_pool,
app_id=app_id,
model_instance=model_instance,
)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
return credentials
def fetch_memory(
self,
*,
variable_pool: VariablePool,
app_id: str,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
@staticmethod
def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]:
try:
AgentOldVersionModelFeatures(feature.value)
except ValueError:
model_schema.features.remove(feature)
return model_schema
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]

View File

@@ -1,39 +0,0 @@
from __future__ import annotations
from collections.abc import Generator, Sequence
from typing import Any, Protocol
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
from core.tools.entities.tool_entities import ToolInvokeMessage
class ResolvedAgentStrategy(Protocol):
meta_version: str | None
def get_parameters(self) -> Sequence[AgentStrategyParameter]: ...
def invoke(
self,
*,
params: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: InvokeCredentials | None = None,
) -> Generator[ToolInvokeMessage, None, None]: ...
class AgentStrategyResolver(Protocol):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy: ...
class AgentStrategyPresentationProvider(Protocol):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ...

View File

@@ -1 +0,0 @@
"""Datasource workflow node package."""

View File

@@ -1,5 +0,0 @@
"""Knowledge index workflow node package."""
KNOWLEDGE_INDEX_NODE_TYPE = "knowledge-index"
__all__ = ["KNOWLEDGE_INDEX_NODE_TYPE"]

View File

@@ -1 +0,0 @@
"""Knowledge retrieval workflow node package."""

View File

@@ -0,0 +1,30 @@
"""Node mapping for workflow execution.
`core.workflow` owns the trigger node implementations, while the remaining node
implementations still live under `dify_graph`. This module imports the
core-owned node packages first, then asks the shared `Node` registry to load the
rest of the workflow nodes from `dify_graph`.
"""
import importlib
import pkgutil
from collections.abc import Mapping
from dify_graph.enums import NodeType
from dify_graph.nodes.base.node import Node
LATEST_VERSION = "latest"
def _register_core_workflow_nodes() -> None:
import core.workflow.nodes as workflow_nodes_pkg
for _, modname, _ in pkgutil.walk_packages(workflow_nodes_pkg.__path__, workflow_nodes_pkg.__name__ + "."):
if modname == "core.workflow.nodes.node_mapping":
continue
importlib.import_module(modname)
_register_core_workflow_nodes()
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()

View File

@@ -3,7 +3,6 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.entities.entities import EventParameter
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
@@ -14,7 +13,7 @@ from .exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
type: NodeType = TRIGGER_PLUGIN_NODE_TYPE
type: NodeType = NodeType.TRIGGER_PLUGIN
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]

View File

@@ -1,10 +1,9 @@
from collections.abc import Mapping
from typing import Any, cast
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType
from dify_graph.graph_events import NodeRunStartedEvent
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@@ -12,7 +11,7 @@ from .entities import TriggerEventNodeData
class TriggerEventNode(Node[TriggerEventNodeData]):
node_type = TRIGGER_PLUGIN_NODE_TYPE
node_type = NodeType.TRIGGER_PLUGIN
execution_type = NodeExecutionType.ROOT
@classmethod
@@ -34,8 +33,10 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
event.provider_id = self.node_data.provider_id
def customize_start_event(self, event: NodeRunStartedEvent) -> None:
provider_id = self.node_data.provider_id
event.provider_id = provider_id
event.extras["provider_id"] = provider_id
def _run(self) -> NodeRunResult:
"""
@@ -46,8 +47,8 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
"""
# Get trigger data passed when workflow was triggered
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
metadata = {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self.node_data.provider_id,
"event_name": self.node_data.event_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,

View File

@@ -2,7 +2,6 @@ from typing import Literal, Union
from pydantic import BaseModel, Field
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
@@ -12,7 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
Trigger Schedule Node Data
"""
type: NodeType = TRIGGER_SCHEDULE_NODE_TYPE
type: NodeType = NodeType.TRIGGER_SCHEDULE
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")

View File

@@ -1,9 +1,8 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType
from dify_graph.enums import NodeExecutionType, NodeType
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@@ -11,7 +10,7 @@ from .entities import TriggerScheduleNodeData
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
node_type = TRIGGER_SCHEDULE_NODE_TYPE
node_type = NodeType.TRIGGER_SCHEDULE
execution_type = NodeExecutionType.ROOT
@classmethod
@@ -21,7 +20,7 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": TRIGGER_SCHEDULE_NODE_TYPE,
"type": "trigger-schedule",
"config": {
"mode": "visual",
"frequency": "daily",

View File

@@ -3,7 +3,6 @@ from enum import StrEnum
from pydantic import BaseModel, Field, field_validator
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import NodeType
from dify_graph.variables.types import SegmentType
@@ -94,7 +93,7 @@ class WebhookData(BaseNodeData):
class SyncMode(StrEnum):
SYNC = "async" # only support
type: NodeType = TRIGGER_WEBHOOK_NODE_TYPE
type: NodeType = NodeType.TRIGGER_WEBHOOK
method: Method = Method.GET
content_type: ContentType = Field(default=ContentType.JSON)
headers: Sequence[WebhookParameter] = Field(default_factory=list)

View File

@@ -2,10 +2,9 @@ import logging
from collections.abc import Mapping
from typing import Any
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType
from dify_graph.enums import NodeExecutionType, NodeType
from dify_graph.file import FileTransferMethod
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
@@ -20,7 +19,7 @@ logger = logging.getLogger(__name__)
class TriggerWebhookNode(Node[WebhookData]):
node_type = TRIGGER_WEBHOOK_NODE_TYPE
node_type = NodeType.TRIGGER_WEBHOOK
execution_type = NodeExecutionType.ROOT
@classmethod

View File

@@ -8,7 +8,8 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
@@ -21,7 +22,7 @@ from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLay
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_engine.protocols.command_channel import CommandChannel
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.nodes import NodeType
from dify_graph.nodes.base.node import Node
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@@ -252,7 +253,7 @@ class WorkflowEntry:
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
if node_type != BuiltinNodeTypes.DATASOURCE:
if node_type != NodeType.DATASOURCE:
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
@@ -302,7 +303,7 @@ class WorkflowEntry:
"height": node_height,
"type": "custom",
"data": {
"type": BuiltinNodeTypes.START,
"type": NodeType.START,
"title": "Start",
"desc": "Start",
},
@@ -338,11 +339,11 @@ class WorkflowEntry:
# Create a minimal graph for single node execution
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
node_type = NodeType(node_data.get("type", ""))
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")

View File

@@ -113,7 +113,7 @@ The codebase enforces strict layering via import-linter:
1. Create node class in `nodes/<node_type>/`
1. Inherit from `BaseNode` or appropriate base class
1. Implement `_run()` method
1. Ensure the node module is importable under `nodes/<node_type>/`
1. Register in `nodes/node_mapping.py`
1. Add tests in `tests/unit_tests/dify_graph/nodes/`
### Implementing a Custom Layer

View File

@@ -1,9 +1,11 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"WorkflowExecution",
"WorkflowNodeExecution",

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: str | None = None

View File

@@ -121,8 +121,6 @@ class DefaultValue(BaseModel):
class BaseNodeData(ABC, BaseModel):
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
# `type` therefore accepts downstream string node kinds; unknown node implementations
# are rejected later when the node factory resolves the node registry.
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
# and persisted templates/workflows also carry undeclared compatibility keys such as
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive

View File

@@ -48,7 +48,7 @@ class WorkflowNodeExecution(BaseModel):
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: str | None = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, downstream response node)
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node
# Execution data

View File

@@ -1,5 +1,4 @@
from enum import StrEnum
from typing import ClassVar, TypeAlias
class NodeState(StrEnum):
@@ -34,71 +33,56 @@ class SystemVariableKey(StrEnum):
INVOKE_FROM = "invoke_from"
NodeType: TypeAlias = str
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
TRIGGER_WEBHOOK = "trigger-webhook"
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
@property
def is_trigger_node(self) -> bool:
"""Check if this node type is a trigger node."""
return self in [
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class BuiltinNodeTypes:
"""Built-in node type string constants.
`node_type` values are plain strings throughout the graph runtime. This namespace
only exposes the built-in values shipped by `dify_graph`; downstream packages can
use additional strings without extending this class.
"""
START: ClassVar[NodeType] = "start"
END: ClassVar[NodeType] = "end"
ANSWER: ClassVar[NodeType] = "answer"
LLM: ClassVar[NodeType] = "llm"
KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval"
IF_ELSE: ClassVar[NodeType] = "if-else"
CODE: ClassVar[NodeType] = "code"
TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform"
QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier"
HTTP_REQUEST: ClassVar[NodeType] = "http-request"
TOOL: ClassVar[NodeType] = "tool"
DATASOURCE: ClassVar[NodeType] = "datasource"
VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner"
LOOP: ClassVar[NodeType] = "loop"
LOOP_START: ClassVar[NodeType] = "loop-start"
LOOP_END: ClassVar[NodeType] = "loop-end"
ITERATION: ClassVar[NodeType] = "iteration"
ITERATION_START: ClassVar[NodeType] = "iteration-start"
PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor"
VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner"
DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor"
LIST_OPERATOR: ClassVar[NodeType] = "list-operator"
AGENT: ClassVar[NodeType] = "agent"
HUMAN_INPUT: ClassVar[NodeType] = "human-input"
BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = (
BuiltinNodeTypes.START,
BuiltinNodeTypes.END,
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
BuiltinNodeTypes.IF_ELSE,
BuiltinNodeTypes.CODE,
BuiltinNodeTypes.TEMPLATE_TRANSFORM,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
BuiltinNodeTypes.HTTP_REQUEST,
BuiltinNodeTypes.TOOL,
BuiltinNodeTypes.DATASOURCE,
BuiltinNodeTypes.VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LOOP,
BuiltinNodeTypes.LOOP_START,
BuiltinNodeTypes.LOOP_END,
BuiltinNodeTypes.ITERATION,
BuiltinNodeTypes.ITERATION_START,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.VARIABLE_ASSIGNER,
BuiltinNodeTypes.DOCUMENT_EXTRACTOR,
BuiltinNodeTypes.LIST_OPERATOR,
BuiltinNodeTypes.AGENT,
BuiltinNodeTypes.HUMAN_INPUT,
)
@property
def is_start_node(self) -> bool:
"""Check if this node type can serve as a workflow entry point."""
return self in [
NodeType.START,
NodeType.DATASOURCE,
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class NodeExecutionType(StrEnum):
@@ -252,6 +236,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
TRIGGER_INFO = "trigger_info"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"

View File

@@ -83,6 +83,50 @@ class Graph:
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, NodeConfigDict],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
:param node_configs_map: mapping of node ID to node config
:param edge_configs: list of edge configurations
:param root_node_id: explicitly specified root node ID
:return: determined root node ID
"""
if root_node_id:
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
return root_node_id
# Find nodes with no incoming edges
nodes_with_incoming: set[str] = set()
for edge_config in edge_configs:
target = edge_config.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid]["data"]
if node_data.type.is_start_node:
start_node_id = nid
break
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
if not root_node_id:
raise ValueError("Unable to determine root node ID")
return root_node_id
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, object]]
@@ -257,15 +301,15 @@ class Graph:
*,
graph_config: Mapping[str, object],
node_factory: NodeFactory,
root_node_id: str,
root_node_id: str | None = None,
skip_validation: bool = False,
) -> Graph:
"""
Initialize a graph with an explicit execution entry point.
Initialize graph
:param graph_config: graph config containing nodes and edges
:param node_factory: factory for creating node instances from config data
:param root_node_id: active root node id
:param root_node_id: root node id
:return: graph instance
"""
# Parse configs
@@ -283,8 +327,8 @@ class Graph:
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Find root node
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
# Build edges
edges, in_edges, out_edges = cls._build_edges(edge_configs)

View File

@@ -4,7 +4,7 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType
from dify_graph.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@@ -71,7 +71,7 @@ class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START)
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
@@ -86,7 +86,7 @@ class _RootNodeValidator:
)
return issues
node_type = root_node.node_type
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
@@ -114,9 +114,45 @@ class GraphValidator:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)

View File

@@ -6,6 +6,5 @@ of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
from .session import RESPONSE_SESSION_NODE_TYPES
__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"]
__all__ = ["ResponseStreamCoordinator"]

View File

@@ -3,34 +3,19 @@ Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications
can opt additional response-capable node types into session creation without
patching the coordinator.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, cast
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.answer.answer_node import AnswerNode
from dify_graph.nodes.base.template import Template
from dify_graph.nodes.end.end_node import EndNode
from dify_graph.nodes.knowledge_index import KnowledgeIndexNode
from dify_graph.runtime.graph_runtime_state import NodeProtocol
class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
"""Structural contract required from nodes that can open a response session."""
def get_streaming_template(self) -> Template: ...
RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.END,
]
@dataclass
class ResponseSession:
"""
@@ -48,9 +33,10 @@ class ResponseSession:
"""
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES`
and which implements `get_streaming_template()`.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
- `id: str`
- `get_streaming_template() -> Template`
Args:
node: Node from the materialized workflow graph.
@@ -61,22 +47,11 @@ class ResponseSession:
Raises:
TypeError: If node is not a supported response node type.
"""
if node.node_type not in RESPONSE_SESSION_NODE_TYPES:
supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES)
raise TypeError(
"ResponseSession.from_node only supports node types in "
f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}"
)
response_node = cast(_ResponseSessionNodeProtocol, node)
try:
template = response_node.get_streaming_template()
except AttributeError as exc:
raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
return cls(
node_id=node.id,
template=template,
template=node.get_streaming_template(),
)
def is_complete(self) -> bool:

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@@ -12,10 +13,11 @@ from .base import GraphNodeEventBase
class NodeRunStartedEvent(GraphNodeEventBase):
node_title: str
predecessor_node_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
start_at: datetime = Field(..., description="node start time")
extras: dict[str, object] = Field(default_factory=dict)
# FIXME(-LAN-): only for ToolNode
# Legacy provider fields kept for existing start-event consumers.
provider_type: str = ""
provider_id: str = ""

View File

@@ -1,9 +1,9 @@
from collections.abc import Mapping, Sequence
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.file import File
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
@@ -13,7 +13,7 @@ from .base import NodeEventBase
class RunRetrieverResourceEvent(NodeEventBase):
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
context_files: list[File] | None = Field(default=None, description="context files")

View File

@@ -1,3 +1,3 @@
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
__all__ = ["BuiltinNodeTypes"]
__all__ = ["NodeType"]

View File

@@ -0,0 +1,3 @@
from .agent_node import AgentNode
__all__ = ["AgentNode"]

View File

@@ -0,0 +1,761 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import (
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayFileSegment, StringSegment
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentNodeError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(Node[AgentNodeData]):
"""
Agent Node
"""
node_type = NodeType.AGENT
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = get_plugin_agent_strategy(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
# get parameters
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._transform_message(
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
def _generate_agent_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (AgentNodeData): The data associated with the agent node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
dify_ctx = self.require_dify_context()
tool_runtime = ToolManager.get_agent_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
entity,
dify_ctx.invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config
history_prompt_messages = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
# remove structured output feature to support old version agent plugin
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result
@property
def agent_strategy_icon(self) -> str | None:
"""
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
dify_ctx = self.require_dify_context()
plugins = manager.list_plugins(dify_ctx.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
dify_ctx = self.require_dify_context()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
dify_ctx = self.require_dify_context()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=dify_ctx.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
try:
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
except ValueError:
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
:param tool: tool
:return: filtered tool dict
"""
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
msg_metadata = {}
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -6,14 +6,14 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.enums import NodeType
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
agent_strategy_provider_name: str
type: NodeType = NodeType.AGENT
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version

View File

@@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.answer.entities import AnswerNodeData
from dify_graph.nodes.base.node import Node
@@ -11,7 +11,7 @@ from dify_graph.variables import ArrayFileSegment, FileSegment, Segment
class AnswerNode(Node[AnswerNodeData]):
node_type = BuiltinNodeTypes.ANSWER
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
@classmethod

View File

@@ -4,7 +4,7 @@ from enum import StrEnum, auto
from pydantic import BaseModel, Field
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.enums import NodeType
class AnswerNodeData(BaseNodeData):
@@ -12,7 +12,7 @@ class AnswerNodeData(BaseNodeData):
Answer Node Data.
"""
type: NodeType = BuiltinNodeTypes.ANSWER
type: NodeType = NodeType.ANSWER
answer: str = Field(..., description="answer template string")

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import importlib
import logging
import operator
import pkgutil
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
@@ -9,7 +11,7 @@ from types import MappingProxyType
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from dify_graph.entities import GraphInitParams
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
@@ -159,7 +161,7 @@ class Node(Generic[NodeDataT]):
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
node_type = BuiltinNodeTypes.CODE
node_type = NodeType.CODE
# No need to implement _get_title, _get_error_strategy, etc.
"""
super().__init_subclass__(**kwargs)
@@ -177,8 +179,7 @@ class Node(Generic[NodeDataT]):
# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under the
# canonical workflow namespaces.
# Only register production node implementations defined under dify_graph.nodes.*.
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
@@ -186,7 +187,7 @@ class Node(Generic[NodeDataT]):
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
if module_name.startswith("dify_graph.nodes."):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
@@ -202,7 +203,6 @@ class Node(Generic[NodeDataT]):
else:
latest_key = max(version_keys) if version_keys else version
bucket["latest"] = bucket[latest_key]
Node._registry_version += 1
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
@@ -237,11 +237,6 @@ class Node(Generic[NodeDataT]):
# Global registry populated via __init_subclass__
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
_registry_version: ClassVar[int] = 0
@classmethod
def get_registry_version(cls) -> int:
return cls._registry_version
def __init__(
self,
@@ -274,14 +269,14 @@ class Node(Generic[NodeDataT]):
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None:
"""Hydrate `_node_data` for legacy callers that bypass `__init__`."""
self._node_data = self.validate_node_data(cast(BaseNodeData, data))
def post_init(self) -> None:
"""Optional hook for subclasses requiring extra initialization."""
return
def customize_start_event(self, event: NodeRunStartedEvent) -> None:
"""Optional hook for subclasses to attach start-event metadata or extras."""
return
@property
def graph_init_params(self) -> GraphInitParams:
return self._graph_init_params
@@ -358,10 +353,6 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
"""Allow subclasses to enrich the started event without cross-node imports in the base class."""
_ = event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@@ -375,10 +366,35 @@ class Node(Generic[NodeDataT]):
in_iteration_id=None,
start_at=self._start_at,
)
try:
self.populate_start_event(start_event)
except Exception:
logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True)
# === FIXME(-LAN-): Needs to refactor.
from dify_graph.nodes.tool.tool_node import ToolNode
if isinstance(self, ToolNode):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
if isinstance(self, DatasourceNode):
plugin_id = getattr(self.node_data, "plugin_id", "")
provider_name = getattr(self.node_data, "provider_name", "")
start_event.provider_id = f"{plugin_id}/{provider_name}"
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from dify_graph.nodes.agent.agent_node import AgentNode
from dify_graph.nodes.agent.entities import AgentNodeData
if isinstance(self, AgentNode):
start_event.agent_strategy = AgentNodeStrategyInit(
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
icon=self.agent_strategy_icon,
)
self.customize_start_event(start_event)
# ===
yield start_event
try:
@@ -497,20 +513,31 @@ class Node(Generic[NodeDataT]):
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so
# registry lookups can resolve numeric versions and `latest`.
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/dify_graph/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return a read-only view of the currently registered node classes.
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
This accessor intentionally performs no imports. The embedding layer that
owns bootstrap (for example `core.workflow.node_factory`) must import any
extension node packages before calling it so their subclasses register via
`__init_subclass__`.
Import all modules under dify_graph.nodes so subclasses register themselves on import.
Higher-level packages may register additional nodes before calling this helper.
Then we return a readonly view of the registry to avoid accidental mutation.
"""
return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()}
# Import all node modules to ensure they are loaded (thus registered)
import dify_graph.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
# Avoid importing modules that depend on the registry to prevent circular imports.
if _modname == "dify_graph.nodes.node_mapping":
continue
importlib.import_module(_modname)
# Return a readonly view so callers can't mutate the registry by accident
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
@property
def retry(self) -> bool:
@@ -785,16 +812,11 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
retriever_resources = [
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
]
return NodeRunRetrieverResourceEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=retriever_resources,
retriever_resources=event.retriever_resources,
context=event.context,
node_version=self.version(),
)

View File

@@ -4,7 +4,7 @@ from textwrap import dedent
from typing import TYPE_CHECKING, Any, Protocol, cast
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData
@@ -72,7 +72,7 @@ _DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = {
class CodeNode(Node[CodeNodeData]):
node_type = BuiltinNodeTypes.CODE
node_type = NodeType.CODE
_limits: CodeNodeLimits
def __init__(

View File

@@ -4,7 +4,7 @@ from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.variables.types import SegmentType
@@ -40,7 +40,7 @@ class CodeNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = BuiltinNodeTypes.CODE
type: NodeType = NodeType.CODE
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]

Some files were not shown because too many files have changed in this diff Show More