mirror of
https://github.com/langgenius/dify.git
synced 2026-02-10 10:19:00 +00:00
Compare commits
163 Commits
deploy/tri
...
feat/trigg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
037cdb3d7d | ||
|
|
7b9d01bfca | ||
|
|
bd1fcd3525 | ||
|
|
0cb0cea167 | ||
|
|
ee68a685a7 | ||
|
|
c78bd492af | ||
|
|
6857bb4406 | ||
|
|
dcf3ee6982 | ||
|
|
76850749e4 | ||
|
|
91e5e33440 | ||
|
|
11e55088c9 | ||
|
|
57c0bc9fb6 | ||
|
|
c3ebb22a4b | ||
|
|
1562d00037 | ||
|
|
e9e843b27d | ||
|
|
ec33b9908e | ||
|
|
67004368d9 | ||
|
|
50bff270b6 | ||
|
|
bd5cf1c272 | ||
|
|
d22404994a | ||
|
|
9898730cc5 | ||
|
|
b0f1e55a87 | ||
|
|
6566824807 | ||
|
|
9249a2af0d | ||
|
|
112fc3b1d1 | ||
|
|
37299b3bd7 | ||
|
|
8f65ce995a | ||
|
|
4a743e6dc1 | ||
|
|
07dda61929 | ||
|
|
0d8438ef40 | ||
|
|
96bb638969 | ||
|
|
e74962272e | ||
|
|
5a15419baf | ||
|
|
e8403977b9 | ||
|
|
add2ca85f2 | ||
|
|
fbb7b02e90 | ||
|
|
249b62c9de | ||
|
|
b433322e8d | ||
|
|
1c8850fc95 | ||
|
|
dc16f1b65a | ||
|
|
ff30395dc1 | ||
|
|
8e600f3302 | ||
|
|
5a1e0a8379 | ||
|
|
2a3ce6baa9 | ||
|
|
01b2f9cff6 | ||
|
|
ac38614171 | ||
|
|
eb95c5cd07 | ||
|
|
a799b54b9e | ||
|
|
98ba0236e6 | ||
|
|
b6c552df07 | ||
|
|
e2827e475d | ||
|
|
58cbd337b5 | ||
|
|
a91e59d544 | ||
|
|
814787677a | ||
|
|
85caa5bd0c | ||
|
|
e04083fc0e | ||
|
|
cf532e5e0d | ||
|
|
c097fc2c48 | ||
|
|
0371d71409 | ||
|
|
81ef7343d4 | ||
|
|
8e4b59c90c | ||
|
|
68f73410fc | ||
|
|
88af8ed374 | ||
|
|
015f82878e | ||
|
|
3874e58dc2 | ||
|
|
9f8c159583 | ||
|
|
d8f6f9ce19 | ||
|
|
eab03e63d4 | ||
|
|
461829274a | ||
|
|
e751c0c535 | ||
|
|
1fffc79c32 | ||
|
|
83fab4bc19 | ||
|
|
f60e28d2f5 | ||
|
|
a62d7aa3ee | ||
|
|
cc84a45244 | ||
|
|
5cf3d24018 | ||
|
|
4bdbe617fe | ||
|
|
33c867fd8c | ||
|
|
2013ceb9d2 | ||
|
|
7120c6414c | ||
|
|
5ce7b2d98d | ||
|
|
cb82198271 | ||
|
|
5e5ffaa416 | ||
|
|
4b253e1f73 | ||
|
|
dd929dbf0e | ||
|
|
97a9d34e96 | ||
|
|
602070ec9c | ||
|
|
afd8989150 | ||
|
|
694197a701 | ||
|
|
2f08306695 | ||
|
|
6acc77d86d | ||
|
|
5ddd5e49ee | ||
|
|
72f9e77368 | ||
|
|
a46c9238fa | ||
|
|
87120ad4ac | ||
|
|
7544b5ec9a | ||
|
|
ff4a62d1e7 | ||
|
|
41daa51988 | ||
|
|
d522350c99 | ||
|
|
1d1bb9451e | ||
|
|
1fce1a61d4 | ||
|
|
883a6caf96 | ||
|
|
a239c39f09 | ||
|
|
e925a8ab99 | ||
|
|
bccaf939e6 | ||
|
|
676648e0b3 | ||
|
|
4ae19e6dde | ||
|
|
4d0ff5c281 | ||
|
|
327b354cc2 | ||
|
|
6d307cc9fc | ||
|
|
adc7134af5 | ||
|
|
10f19cd0c2 | ||
|
|
9ed45594c6 | ||
|
|
c138f4c3a6 | ||
|
|
a35be05790 | ||
|
|
60b5ed8e5d | ||
|
|
d8ddbc4d87 | ||
|
|
19c0fc85e2 | ||
|
|
a58df35ead | ||
|
|
9789bd02d8 | ||
|
|
d94e54923f | ||
|
|
64c7be59b7 | ||
|
|
89ad6ad902 | ||
|
|
4f73bc9693 | ||
|
|
add6b79231 | ||
|
|
c90dad566f | ||
|
|
5cbe6bf8f8 | ||
|
|
4ef6ff217e | ||
|
|
87abfbf515 | ||
|
|
73e65fd838 | ||
|
|
e53edb0fc2 | ||
|
|
17908fbf6b | ||
|
|
3dae108f84 | ||
|
|
5bbf685035 | ||
|
|
a63d1e87b1 | ||
|
|
7129de98cd | ||
|
|
2984dbc0df | ||
|
|
392db7f611 | ||
|
|
5a427b8daa | ||
|
|
18f2e6f166 | ||
|
|
e78903302f | ||
|
|
4084ade86c | ||
|
|
6b0d919dbd | ||
|
|
a7b558b38b | ||
|
|
6aed7e3ff4 | ||
|
|
8e93a8a2e2 | ||
|
|
e38a86e37b | ||
|
|
392e3530bf | ||
|
|
833c902b2b | ||
|
|
6eaea64b3f | ||
|
|
5303b50737 | ||
|
|
6acbcfe679 | ||
|
|
16ef5ebb97 | ||
|
|
acfb95f9c2 | ||
|
|
aacea166d7 | ||
|
|
f7bb3b852a | ||
|
|
d4ff1e031a | ||
|
|
6a3d135d49 | ||
|
|
5c4bf7aabd | ||
|
|
e9c7dc7464 | ||
|
|
74ad21b145 | ||
|
|
f214eeb7b1 | ||
|
|
ae25f90f34 |
@@ -1,5 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
npm add -g pnpm@10.15.0
|
||||||
corepack enable
|
corepack enable
|
||||||
cd web && pnpm install
|
cd web && pnpm install
|
||||||
pipx install uv
|
pipx install uv
|
||||||
|
|||||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -2,6 +2,8 @@ name: autofix.ci
|
|||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: ["main"]
|
branches: ["main"]
|
||||||
|
push:
|
||||||
|
branches: ["main"]
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -218,3 +218,6 @@ mise.toml
|
|||||||
.roo/
|
.roo/
|
||||||
api/.env.backup
|
api/.env.backup
|
||||||
/clickzetta
|
/clickzetta
|
||||||
|
|
||||||
|
# mcp
|
||||||
|
.serena
|
||||||
@@ -59,6 +59,7 @@ pnpm test # Run Jest tests
|
|||||||
- Use type hints for all functions and class attributes
|
- Use type hints for all functions and class attributes
|
||||||
- No `Any` types unless absolutely necessary
|
- No `Any` types unless absolutely necessary
|
||||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||||
|
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
|
||||||
|
|
||||||
### TypeScript/JavaScript
|
### TypeScript/JavaScript
|
||||||
|
|
||||||
|
|||||||
@@ -434,6 +434,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
|||||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||||
|
|
||||||
|
# Webhook request configuration
|
||||||
|
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||||
|
|
||||||
# Respect X-* headers to redirect clients
|
# Respect X-* headers to redirect clients
|
||||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||||
|
|
||||||
@@ -502,6 +505,12 @@ ENABLE_CLEAN_MESSAGES=false
|
|||||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||||
|
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||||
|
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||||
|
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||||
|
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||||
|
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||||
|
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||||
|
|
||||||
# Position configuration
|
# Position configuration
|
||||||
POSITION_TOOL_PINS=
|
POSITION_TOOL_PINS=
|
||||||
|
|||||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@@ -54,7 +54,7 @@
|
|||||||
"--loglevel",
|
"--loglevel",
|
||||||
"DEBUG",
|
"DEBUG",
|
||||||
"-Q",
|
"-Q",
|
||||||
"dataset,generation,mail,ops_trace,app_deletion"
|
"dataset,generation,mail,ops_trace,app_deletion,workflow"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1207,6 +1207,55 @@ def setup_system_tool_oauth_client(provider, client_params):
|
|||||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
|
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||||
|
@click.option("--provider", prompt=True, help="Provider name")
|
||||||
|
@click.option("--client-params", prompt=True, help="Client Params")
|
||||||
|
def setup_system_trigger_oauth_client(provider, client_params):
|
||||||
|
"""
|
||||||
|
Setup system trigger oauth client
|
||||||
|
"""
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from models.trigger import TriggerOAuthSystemClient
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID(provider)
|
||||||
|
provider_name = provider_id.provider_name
|
||||||
|
plugin_id = provider_id.plugin_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# json validate
|
||||||
|
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||||
|
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||||
|
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||||
|
|
||||||
|
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||||
|
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||||
|
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||||
|
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
|
deleted_count = (
|
||||||
|
db.session.query(TriggerOAuthSystemClient)
|
||||||
|
.filter_by(
|
||||||
|
provider=provider_name,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
)
|
||||||
|
.delete()
|
||||||
|
)
|
||||||
|
if deleted_count > 0:
|
||||||
|
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||||
|
|
||||||
|
oauth_client = TriggerOAuthSystemClient(
|
||||||
|
provider=provider_name,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
encrypted_oauth_params=oauth_client_params,
|
||||||
|
)
|
||||||
|
db.session.add(oauth_client)
|
||||||
|
db.session.commit()
|
||||||
|
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Find draft variables that reference non-existent apps.
|
Find draft variables that reference non-existent apps.
|
||||||
|
|||||||
@@ -147,6 +147,17 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for trigger
|
||||||
|
"""
|
||||||
|
|
||||||
|
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||||
|
description="Maximum allowed size for webhook request bodies in bytes",
|
||||||
|
default=10485760,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PluginConfig(BaseSettings):
|
class PluginConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Plugin configs
|
Plugin configs
|
||||||
@@ -871,6 +882,22 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
|||||||
description="Enable check upgradable plugin task",
|
description="Enable check upgradable plugin task",
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||||
|
description="Enable workflow schedule poller task",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||||
|
description="Workflow schedule poller interval in minutes",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||||
|
description="Maximum number of schedules to process in each poll batch",
|
||||||
|
default=100,
|
||||||
|
)
|
||||||
|
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||||
|
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PositionConfig(BaseSettings):
|
class PositionConfig(BaseSettings):
|
||||||
@@ -994,6 +1021,7 @@ class FeatureConfig(
|
|||||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||||
BillingConfig,
|
BillingConfig,
|
||||||
CodeExecutionSandboxConfig,
|
CodeExecutionSandboxConfig,
|
||||||
|
TriggerConfig,
|
||||||
PluginConfig,
|
PluginConfig,
|
||||||
MarketplaceConfig,
|
MarketplaceConfig,
|
||||||
DataSetConfig,
|
DataSetConfig,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ if TYPE_CHECKING:
|
|||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
|
from core.trigger.provider import PluginTriggerProviderController
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
@@ -33,3 +34,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
|||||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||||
ContextVar("plugin_model_schemas")
|
ContextVar("plugin_model_schemas")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_trigger_providers")
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_trigger_providers_lock")
|
||||||
|
)
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ from .app import (
|
|||||||
workflow_draft_variable,
|
workflow_draft_variable,
|
||||||
workflow_run,
|
workflow_run,
|
||||||
workflow_statistic,
|
workflow_statistic,
|
||||||
|
workflow_trigger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import auth controllers
|
# Import auth controllers
|
||||||
@@ -180,5 +181,6 @@ from .workspace import (
|
|||||||
models,
|
models,
|
||||||
plugin,
|
plugin,
|
||||||
tool_providers,
|
tool_providers,
|
||||||
|
trigger_providers,
|
||||||
workspace,
|
workspace,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from controllers.console.app.error import (
|
|||||||
)
|
)
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
|
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
@@ -125,13 +126,11 @@ class InstructionGenerateApi(Resource):
|
|||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
code_template = (
|
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||||
Python3CodeProvider.get_default_code()
|
code_provider: type[CodeNodeProvider] | None = next(
|
||||||
if args["language"] == "python"
|
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||||
else (JavascriptCodeProvider.get_default_code())
|
|
||||||
if args["language"] == "javascript"
|
|
||||||
else ""
|
|
||||||
)
|
)
|
||||||
|
code_template = code_provider.get_default_code() if code_provider else ""
|
||||||
try:
|
try:
|
||||||
# Generate from nothing for a workflow node
|
# Generate from nothing for a workflow node
|
||||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
|||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
from core.helper.trace_id_helper import get_external_trace_id
|
from core.helper.trace_id_helper import get_external_trace_id
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory, variable_factory
|
from factories import file_factory, variable_factory
|
||||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||||
@@ -38,6 +39,7 @@ from models.workflow import Workflow
|
|||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
from services.trigger_debug_service import TriggerDebugService
|
||||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -806,6 +808,132 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
|||||||
return node_exec
|
return node_exec
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowTriggerNodeApi(Resource):
|
||||||
|
"""
|
||||||
|
Single node debug - Polling API for trigger events
|
||||||
|
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger
|
||||||
|
"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
def post(self, app_model: App, node_id: str):
|
||||||
|
"""
|
||||||
|
Poll for trigger events and execute single node when event arrives
|
||||||
|
"""
|
||||||
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("trigger_name", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("subscription_id", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
trigger_name = args["trigger_name"]
|
||||||
|
subscription_id = args["subscription_id"]
|
||||||
|
event = TriggerDebugService.poll_event(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
node_id=node_id,
|
||||||
|
trigger_name=trigger_name,
|
||||||
|
)
|
||||||
|
if not event:
|
||||||
|
return jsonable_encoder({"status": "waiting"})
|
||||||
|
|
||||||
|
try:
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||||
|
if not draft_workflow:
|
||||||
|
raise ValueError("Workflow not found")
|
||||||
|
|
||||||
|
user_inputs = event.model_dump()
|
||||||
|
node_execution = workflow_service.run_draft_workflow_node(
|
||||||
|
app_model=app_model,
|
||||||
|
draft_workflow=draft_workflow,
|
||||||
|
node_id=node_id,
|
||||||
|
user_inputs=user_inputs,
|
||||||
|
account=current_user,
|
||||||
|
query="",
|
||||||
|
files=[],
|
||||||
|
)
|
||||||
|
return jsonable_encoder(node_execution)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error running draft workflow trigger node")
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
}
|
||||||
|
), 500
|
||||||
|
|
||||||
|
|
||||||
|
class DraftWorkflowTriggerRunApi(Resource):
|
||||||
|
"""
|
||||||
|
Full workflow debug - Polling API for trigger events
|
||||||
|
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||||
|
"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
def post(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Poll for trigger events and execute full workflow when event arrives
|
||||||
|
"""
|
||||||
|
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||||
|
parser.add_argument("trigger_name", type=str, required=True, location="json", nullable=False)
|
||||||
|
parser.add_argument("subscription_id", type=str, required=True, location="json", nullable=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
node_id = args["node_id"]
|
||||||
|
trigger_name = args["trigger_name"]
|
||||||
|
subscription_id = args["subscription_id"]
|
||||||
|
|
||||||
|
event = TriggerDebugService.poll_event(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
node_id=node_id,
|
||||||
|
trigger_name=trigger_name,
|
||||||
|
)
|
||||||
|
if not event:
|
||||||
|
return jsonable_encoder({"status": "waiting"})
|
||||||
|
|
||||||
|
workflow_args = {
|
||||||
|
"inputs": event.model_dump(),
|
||||||
|
"query": "",
|
||||||
|
"files": [],
|
||||||
|
}
|
||||||
|
external_trace_id = get_external_trace_id(request)
|
||||||
|
if external_trace_id:
|
||||||
|
workflow_args["external_trace_id"] = external_trace_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = AppGenerateService.generate(
|
||||||
|
app_model=app_model,
|
||||||
|
user=current_user,
|
||||||
|
args=workflow_args,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
return helper.compact_generate_response(response)
|
||||||
|
except InvokeRateLimitError as ex:
|
||||||
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error running draft workflow trigger run")
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
}
|
||||||
|
), 500
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
DraftWorkflowApi,
|
DraftWorkflowApi,
|
||||||
"/apps/<uuid:app_id>/workflows/draft",
|
"/apps/<uuid:app_id>/workflows/draft",
|
||||||
@@ -830,6 +958,14 @@ api.add_resource(
|
|||||||
DraftWorkflowNodeRunApi,
|
DraftWorkflowNodeRunApi,
|
||||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||||
)
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DraftWorkflowTriggerNodeApi,
|
||||||
|
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DraftWorkflowTriggerRunApi,
|
||||||
|
"/apps/<uuid:app_id>/workflows/draft/trigger/run",
|
||||||
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
AdvancedChatDraftRunIterationNodeApi,
|
AdvancedChatDraftRunIterationNodeApi,
|
||||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ class WorkflowAppLogApi(Resource):
|
|||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
parser.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
parser.add_argument(
|
||||||
|
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||||
)
|
)
|
||||||
|
|||||||
249
api/controllers/console/app/workflow_trigger.py
Normal file
249
api/controllers/console/app/workflow_trigger.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models.model import Account, AppMode
|
||||||
|
from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||||
|
|
||||||
|
|
||||||
|
class PluginTriggerApi(Resource):
|
||||||
|
"""Workflow Plugin Trigger API"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=AppMode.WORKFLOW)
|
||||||
|
def post(self, app_model):
|
||||||
|
"""Create plugin trigger"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=False, location="json")
|
||||||
|
parser.add_argument("provider_id", type=str, required=False, location="json")
|
||||||
|
parser.add_argument("trigger_name", type=str, required=False, location="json")
|
||||||
|
parser.add_argument("subscription_id", type=str, required=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
|
||||||
|
app_id=app_model.id,
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
node_id=args["node_id"],
|
||||||
|
provider_id=args["provider_id"],
|
||||||
|
trigger_name=args["trigger_name"],
|
||||||
|
subscription_id=args["subscription_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonable_encoder(plugin_trigger)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=AppMode.WORKFLOW)
|
||||||
|
def get(self, app_model):
|
||||||
|
"""Get plugin trigger"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
|
||||||
|
app_id=app_model.id,
|
||||||
|
node_id=args["node_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonable_encoder(plugin_trigger)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=AppMode.WORKFLOW)
|
||||||
|
def put(self, app_model):
|
||||||
|
"""Update plugin trigger"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||||
|
parser.add_argument("subscription_id", type=str, required=True, location="json", help="Subscription ID")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
|
||||||
|
app_id=app_model.id,
|
||||||
|
node_id=args["node_id"],
|
||||||
|
subscription_id=args["subscription_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonable_encoder(plugin_trigger)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=AppMode.WORKFLOW)
|
||||||
|
def delete(self, app_model):
|
||||||
|
"""Delete plugin trigger"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
WorkflowPluginTriggerService.delete_plugin_trigger(
|
||||||
|
app_id=app_model.id,
|
||||||
|
node_id=args["node_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookTriggerApi(Resource):
|
||||||
|
"""Webhook Trigger API"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=AppMode.WORKFLOW)
|
||||||
|
@marshal_with(webhook_trigger_fields)
|
||||||
|
def get(self, app_model):
|
||||||
|
"""Get webhook trigger for a node"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
node_id = args["node_id"]
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# Get webhook trigger for this app and node
|
||||||
|
webhook_trigger = (
|
||||||
|
session.query(WorkflowWebhookTrigger)
|
||||||
|
.filter(
|
||||||
|
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||||
|
WorkflowWebhookTrigger.node_id == node_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not webhook_trigger:
|
||||||
|
raise NotFound("Webhook trigger not found for this node")
|
||||||
|
|
||||||
|
# Add computed fields for marshal_with
|
||||||
|
base_url = dify_config.SERVICE_API_URL
|
||||||
|
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}" # type: ignore
|
||||||
|
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}" # type: ignore
|
||||||
|
|
||||||
|
return webhook_trigger
|
||||||
|
|
||||||
|
|
||||||
|
class AppTriggersApi(Resource):
|
||||||
|
"""App Triggers list API"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=AppMode.WORKFLOW)
|
||||||
|
@marshal_with(triggers_list_fields)
|
||||||
|
def get(self, app_model):
|
||||||
|
"""Get app triggers list"""
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# Get all triggers for this app using select API
|
||||||
|
triggers = (
|
||||||
|
session.execute(
|
||||||
|
select(AppTrigger)
|
||||||
|
.where(
|
||||||
|
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||||
|
AppTrigger.app_id == app_model.id,
|
||||||
|
)
|
||||||
|
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add computed icon field for each trigger
|
||||||
|
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||||
|
for trigger in triggers:
|
||||||
|
if trigger.trigger_type == "trigger-plugin":
|
||||||
|
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||||
|
else:
|
||||||
|
trigger.icon = "" # type: ignore
|
||||||
|
|
||||||
|
return {"data": triggers}
|
||||||
|
|
||||||
|
|
||||||
|
class AppTriggerEnableApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=AppMode.WORKFLOW)
|
||||||
|
@marshal_with(trigger_fields)
|
||||||
|
def post(self, app_model):
|
||||||
|
"""Update app trigger (enable/disable)"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
|
assert current_user.current_tenant_id is not None
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
trigger_id = args["trigger_id"]
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# Find the trigger using select
|
||||||
|
trigger = session.execute(
|
||||||
|
select(AppTrigger).where(
|
||||||
|
AppTrigger.id == trigger_id,
|
||||||
|
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||||
|
AppTrigger.app_id == app_model.id,
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if not trigger:
|
||||||
|
raise NotFound("Trigger not found")
|
||||||
|
|
||||||
|
# Update status based on enable_trigger boolean
|
||||||
|
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
session.refresh(trigger)
|
||||||
|
|
||||||
|
# Add computed icon field
|
||||||
|
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||||
|
if trigger.trigger_type == "trigger-plugin":
|
||||||
|
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||||
|
else:
|
||||||
|
trigger.icon = "" # type: ignore
|
||||||
|
|
||||||
|
return trigger
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||||
|
api.add_resource(PluginTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/plugin")
|
||||||
|
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||||
|
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||||
@@ -516,18 +516,20 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
|||||||
parser.add_argument("provider", type=str, required=True, location="args")
|
parser.add_argument("provider", type=str, required=True, location="args")
|
||||||
parser.add_argument("action", type=str, required=True, location="args")
|
parser.add_argument("action", type=str, required=True, location="args")
|
||||||
parser.add_argument("parameter", type=str, required=True, location="args")
|
parser.add_argument("parameter", type=str, required=True, location="args")
|
||||||
|
parser.add_argument("credential_id", type=str, required=False, location="args")
|
||||||
parser.add_argument("provider_type", type=str, required=True, location="args")
|
parser.add_argument("provider_type", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
options = PluginParameterService.get_dynamic_select_options(
|
options = PluginParameterService.get_dynamic_select_options(
|
||||||
tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id,
|
user_id=user_id,
|
||||||
args["plugin_id"],
|
plugin_id=args["plugin_id"],
|
||||||
args["provider"],
|
provider=args["provider"],
|
||||||
args["action"],
|
action=args["action"],
|
||||||
args["parameter"],
|
parameter=args["parameter"],
|
||||||
args["provider_type"],
|
credential_id=args["credential_id"],
|
||||||
|
provider_type=args["provider_type"],
|
||||||
)
|
)
|
||||||
except PluginDaemonClientSideError as e:
|
except PluginDaemonClientSideError as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError
|
|||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.tools.entities.tool_entities import CredentialType
|
|
||||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
|
|||||||
589
api/controllers/console/workspace/trigger_providers.py
Normal file
589
api/controllers/console/workspace/trigger_providers.py
Normal file
@@ -0,0 +1,589 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import make_response, redirect, request
|
||||||
|
from flask_restx import Resource, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import BadRequest, Forbidden
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||||
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models.account import Account
|
||||||
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
|
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||||
|
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||||
|
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
"""List all trigger providers for the current tenant"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderInfoApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
"""Get info for a trigger provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
return jsonable_encoder(
|
||||||
|
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
"""List all trigger subscriptions for the current tenant's provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||||
|
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error listing trigger providers", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider):
|
||||||
|
"""Add a new subscription instance for a trigger provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||||
|
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
credential_type=credential_type,
|
||||||
|
)
|
||||||
|
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||||
|
except ValueError as e:
|
||||||
|
raise BadRequest(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error adding provider credential", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider, subscription_builder_id):
|
||||||
|
"""Get a subscription instance for a trigger provider"""
|
||||||
|
return jsonable_encoder(
|
||||||
|
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider, subscription_builder_id):
|
||||||
|
"""Verify a subscription instance for a trigger provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
# The credentials of the subscription builder
|
||||||
|
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
subscription_builder_id=subscription_builder_id,
|
||||||
|
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||||
|
credentials=args.get("credentials", None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
subscription_builder_id=subscription_builder_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error verifying provider credential", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider, subscription_builder_id):
|
||||||
|
"""Update a subscription instance for a trigger provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
# The name of the subscription builder
|
||||||
|
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||||
|
# The parameters of the subscription builder
|
||||||
|
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||||
|
# The properties of the subscription builder
|
||||||
|
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||||
|
# The credentials of the subscription builder
|
||||||
|
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
subscription_builder_id=subscription_builder_id,
|
||||||
|
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||||
|
name=args.get("name", None),
|
||||||
|
parameters=args.get("parameters", None),
|
||||||
|
properties=args.get("properties", None),
|
||||||
|
credentials=args.get("credentials", None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error updating provider credential", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider, subscription_builder_id):
|
||||||
|
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||||
|
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider, subscription_builder_id):
|
||||||
|
"""Build a subscription instance for a trigger provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
# The name of the subscription builder
|
||||||
|
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||||
|
# The parameters of the subscription builder
|
||||||
|
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||||
|
# The properties of the subscription builder
|
||||||
|
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||||
|
# The credentials of the subscription builder
|
||||||
|
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
try:
|
||||||
|
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
subscription_builder_id=subscription_builder_id,
|
||||||
|
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||||
|
name=args.get("name", None),
|
||||||
|
parameters=args.get("parameters", None),
|
||||||
|
properties=args.get("properties", None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
provider_id=TriggerProviderID(provider),
|
||||||
|
subscription_builder_id=subscription_builder_id,
|
||||||
|
)
|
||||||
|
return 200
|
||||||
|
except ValueError as e:
|
||||||
|
raise BadRequest(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error building provider credential", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionDeleteApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, subscription_id):
|
||||||
|
"""Delete a subscription instance"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# Delete trigger provider subscription
|
||||||
|
TriggerProviderService.delete_trigger_provider(
|
||||||
|
session=session,
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
)
|
||||||
|
# Delete plugin triggers
|
||||||
|
WorkflowPluginTriggerService.delete_plugin_trigger_by_subscription(
|
||||||
|
session=session,
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return {"result": "success"}
|
||||||
|
except ValueError as e:
|
||||||
|
raise BadRequest(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error deleting provider credential", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerOAuthAuthorizeApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_id = TriggerProviderID(provider)
|
||||||
|
plugin_id = provider_id.plugin_id
|
||||||
|
provider_name = provider_id.provider_name
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
|
# Get OAuth client configuration
|
||||||
|
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if oauth_client_params is None:
|
||||||
|
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||||
|
|
||||||
|
# Create subscription builder
|
||||||
|
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
credential_type=CredentialType.OAUTH2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create OAuth handler and proxy context
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
context_id = OAuthProxyService.create_proxy_context(
|
||||||
|
user_id=user.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
extra_data={
|
||||||
|
"subscription_builder_id": subscription_builder.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build redirect URI for callback
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=oauth_client_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create response with cookie
|
||||||
|
response = make_response(
|
||||||
|
jsonable_encoder(
|
||||||
|
{
|
||||||
|
"authorization_url": authorization_url_response.authorization_url,
|
||||||
|
"subscription_builder_id": subscription_builder.id,
|
||||||
|
"subscription_builder": subscription_builder,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
"context_id",
|
||||||
|
context_id,
|
||||||
|
httponly=True,
|
||||||
|
samesite="Lax",
|
||||||
|
max_age=OAuthProxyService.__MAX_AGE__,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerOAuthCallbackApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def get(self, provider):
|
||||||
|
"""Handle OAuth callback for trigger provider"""
|
||||||
|
context_id = request.cookies.get("context_id")
|
||||||
|
if not context_id:
|
||||||
|
raise Forbidden("context_id not found")
|
||||||
|
|
||||||
|
# Use and validate proxy context
|
||||||
|
context = OAuthProxyService.use_proxy_context(context_id)
|
||||||
|
if context is None:
|
||||||
|
raise Forbidden("Invalid context_id")
|
||||||
|
|
||||||
|
# Parse provider ID
|
||||||
|
provider_id = TriggerProviderID(provider)
|
||||||
|
plugin_id = provider_id.plugin_id
|
||||||
|
provider_name = provider_id.provider_name
|
||||||
|
user_id = context.get("user_id")
|
||||||
|
tenant_id = context.get("tenant_id")
|
||||||
|
subscription_builder_id = context.get("subscription_builder_id")
|
||||||
|
|
||||||
|
# Get OAuth client configuration
|
||||||
|
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if oauth_client_params is None:
|
||||||
|
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||||
|
|
||||||
|
# Get OAuth credentials from callback
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||||
|
|
||||||
|
credentials_response = oauth_handler.get_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
provider=provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=oauth_client_params,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials = credentials_response.credentials
|
||||||
|
expires_at = credentials_response.expires_at
|
||||||
|
|
||||||
|
if not credentials:
|
||||||
|
raise Exception("Failed to get OAuth credentials")
|
||||||
|
|
||||||
|
# Update subscription builder
|
||||||
|
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
subscription_builder_id=subscription_builder_id,
|
||||||
|
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||||
|
credentials=credentials,
|
||||||
|
credential_expires_at=expires_at,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Redirect to OAuth callback page
|
||||||
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerOAuthClientManageApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider):
|
||||||
|
"""Get OAuth client configuration for a provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_id = TriggerProviderID(provider)
|
||||||
|
|
||||||
|
# Get custom OAuth client params if exists
|
||||||
|
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if custom client is enabled
|
||||||
|
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if there's a system OAuth client
|
||||||
|
system_client = TriggerProviderService.get_oauth_client(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||||
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"configured": bool(custom_params or system_client),
|
||||||
|
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||||
|
"custom_configured": bool(custom_params),
|
||||||
|
"custom_enabled": is_custom_enabled,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"params": custom_params if custom_params else {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error getting OAuth client", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self, provider):
|
||||||
|
"""Configure custom OAuth client for a provider"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_id = TriggerProviderID(provider)
|
||||||
|
return TriggerProviderService.save_custom_oauth_client_params(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
client_params=args.get("client_params"),
|
||||||
|
enabled=args.get("enabled"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise BadRequest(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, provider):
|
||||||
|
"""Remove custom OAuth client configuration"""
|
||||||
|
user = current_user
|
||||||
|
assert isinstance(user, Account)
|
||||||
|
assert user.current_tenant_id is not None
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider_id = TriggerProviderID(provider)
|
||||||
|
|
||||||
|
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise BadRequest(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error removing OAuth client", exc_info=e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Trigger Subscription
|
||||||
|
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||||
|
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||||
|
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||||
|
api.add_resource(
|
||||||
|
TriggerSubscriptionDeleteApi,
|
||||||
|
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger Subscription Builder
|
||||||
|
api.add_resource(
|
||||||
|
TriggerSubscriptionBuilderCreateApi,
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
TriggerSubscriptionBuilderGetApi,
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
TriggerSubscriptionBuilderUpdateApi,
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
TriggerSubscriptionBuilderVerifyApi,
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
TriggerSubscriptionBuilderBuildApi,
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
TriggerSubscriptionBuilderLogsApi,
|
||||||
|
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
api.add_resource(
|
||||||
|
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||||
|
)
|
||||||
|
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||||
|
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||||
@@ -9,10 +9,9 @@ from controllers.console.app.mcp_server import AppMCPServerStatus
|
|||||||
from controllers.mcp import mcp_ns
|
from controllers.mcp import mcp_ns
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity
|
||||||
from core.mcp import types as mcp_types
|
from core.mcp import types as mcp_types
|
||||||
from core.mcp.server.streamable_http import handle_mcp_request
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
from models.model import App, AppMCPServer, AppMode
|
||||||
|
|
||||||
|
|
||||||
class MCPRequestError(Exception):
|
class MCPRequestError(Exception):
|
||||||
@@ -195,50 +194,6 @@ class MCPAppApi(Resource):
|
|||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||||
|
|
||||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str, session: Session) -> EndUser | None:
|
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||||
"""Get end user from existing session - optimized query"""
|
response = mcp_server_handler.handle()
|
||||||
return (
|
return helper.compact_generate_response(response)
|
||||||
session.query(EndUser)
|
|
||||||
.where(EndUser.tenant_id == tenant_id)
|
|
||||||
.where(EndUser.session_id == mcp_server_id)
|
|
||||||
.where(EndUser.type == "mcp")
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_end_user(
|
|
||||||
self, client_name: str, tenant_id: str, app_id: str, mcp_server_id: str, session: Session
|
|
||||||
) -> EndUser:
|
|
||||||
"""Create end user in existing session"""
|
|
||||||
end_user = EndUser(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
app_id=app_id,
|
|
||||||
type="mcp",
|
|
||||||
name=client_name,
|
|
||||||
session_id=mcp_server_id,
|
|
||||||
)
|
|
||||||
session.add(end_user)
|
|
||||||
session.flush() # Use flush instead of commit to keep transaction open
|
|
||||||
session.refresh(end_user)
|
|
||||||
return end_user
|
|
||||||
|
|
||||||
def _handle_mcp_request(
|
|
||||||
self,
|
|
||||||
app: App,
|
|
||||||
mcp_server: AppMCPServer,
|
|
||||||
mcp_request: mcp_types.ClientRequest,
|
|
||||||
user_input_form: list[VariableEntity],
|
|
||||||
session: Session,
|
|
||||||
request_id: Union[int, str],
|
|
||||||
) -> mcp_types.JSONRPCResponse | mcp_types.JSONRPCError | None:
|
|
||||||
"""Handle MCP request and return response"""
|
|
||||||
end_user = self._retrieve_end_user(mcp_server.tenant_id, mcp_server.id, session)
|
|
||||||
|
|
||||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
|
||||||
client_info = mcp_request.root.params.clientInfo
|
|
||||||
client_name = f"{client_info.name}@{client_info.version}"
|
|
||||||
# Commit the session before creating end user to avoid transaction conflicts
|
|
||||||
session.commit()
|
|
||||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
|
||||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
|
||||||
|
|
||||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
|
||||||
|
|||||||
7
api/controllers/trigger/__init__.py
Normal file
7
api/controllers/trigger/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from flask import Blueprint
|
||||||
|
|
||||||
|
# Create trigger blueprint
|
||||||
|
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||||
|
|
||||||
|
# Import routes after blueprint creation to avoid circular imports
|
||||||
|
from . import trigger, webhook
|
||||||
41
api/controllers/trigger/trigger.py
Normal file
41
api/controllers/trigger/trigger.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from flask import jsonify, request
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.trigger import bp
|
||||||
|
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||||
|
from services.trigger_service import TriggerService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||||
|
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||||
|
def trigger_endpoint(endpoint_id: str):
|
||||||
|
"""
|
||||||
|
Handle endpoint trigger calls.
|
||||||
|
"""
|
||||||
|
# endpoint_id must be UUID
|
||||||
|
if not UUID_MATCHER.match(endpoint_id):
|
||||||
|
raise NotFound("Invalid endpoint ID")
|
||||||
|
handling_chain = [
|
||||||
|
TriggerService.process_endpoint,
|
||||||
|
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
for handler in handling_chain:
|
||||||
|
response = handler(endpoint_id, request)
|
||||||
|
if response:
|
||||||
|
break
|
||||||
|
if not response:
|
||||||
|
raise NotFound("Endpoint not found")
|
||||||
|
return response
|
||||||
|
except ValueError as e:
|
||||||
|
raise NotFound(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||||
|
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||||
46
api/controllers/trigger/webhook.py
Normal file
46
api/controllers/trigger/webhook.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import jsonify
|
||||||
|
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||||
|
|
||||||
|
from controllers.trigger import bp
|
||||||
|
from services.webhook_service import WebhookService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||||
|
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||||
|
def handle_webhook(webhook_id: str):
|
||||||
|
"""
|
||||||
|
Handle webhook trigger calls.
|
||||||
|
|
||||||
|
This endpoint receives webhook calls and processes them according to the
|
||||||
|
configured webhook trigger settings.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get webhook trigger, workflow, and node configuration
|
||||||
|
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(webhook_id)
|
||||||
|
|
||||||
|
# Extract request data
|
||||||
|
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||||
|
|
||||||
|
# Validate request against node configuration
|
||||||
|
validation_result = WebhookService.validate_webhook_request(webhook_data, node_config)
|
||||||
|
if not validation_result["valid"]:
|
||||||
|
return jsonify({"error": "Bad Request", "message": validation_result["error"]}), 400
|
||||||
|
|
||||||
|
# Process webhook call (send to Celery)
|
||||||
|
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||||
|
|
||||||
|
# Return configured response
|
||||||
|
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||||
|
return jsonify(response_data), status_code
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise NotFound(str(e))
|
||||||
|
except RequestEntityTooLarge:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||||
|
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||||
@@ -54,6 +54,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: Optional[str],
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
) -> Generator[Mapping | str, None, None]: ...
|
) -> Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -68,6 +70,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: Optional[str],
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
) -> Mapping[str, Any]: ...
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -82,6 +86,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
streaming: bool,
|
streaming: bool,
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_thread_pool_id: Optional[str],
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@@ -95,6 +101,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
triggered_from: Optional[WorkflowRunTriggeredFrom] = None,
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
|
|
||||||
@@ -130,17 +138,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
**extract_external_trace_id_from_args(args),
|
**extract_external_trace_id_from_args(args),
|
||||||
}
|
}
|
||||||
workflow_run_id = str(uuid.uuid4())
|
workflow_run_id = str(uuid.uuid4())
|
||||||
|
if triggered_from in (WorkflowRunTriggeredFrom.DEBUGGING, WorkflowRunTriggeredFrom.APP_RUN):
|
||||||
|
# start node get inputs
|
||||||
|
inputs = self._prepare_user_inputs(
|
||||||
|
user_inputs=inputs,
|
||||||
|
variables=app_config.variables,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
|
)
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = WorkflowAppGenerateEntity(
|
application_generate_entity = WorkflowAppGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
file_upload_config=file_extra_config,
|
file_upload_config=file_extra_config,
|
||||||
inputs=self._prepare_user_inputs(
|
inputs=inputs,
|
||||||
user_inputs=inputs,
|
|
||||||
variables=app_config.variables,
|
|
||||||
tenant_id=app_model.tenant_id,
|
|
||||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
|
||||||
),
|
|
||||||
files=list(system_files),
|
files=list(system_files),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
@@ -159,7 +170,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
# Create session factory
|
# Create session factory
|
||||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
# Create workflow execution(aka workflow run) repository
|
# Create workflow execution(aka workflow run) repository
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if triggered_from is not None:
|
||||||
|
# Use explicitly provided triggered_from (for async triggers)
|
||||||
|
workflow_triggered_from = triggered_from
|
||||||
|
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
else:
|
else:
|
||||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||||
@@ -187,6 +201,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
root_node_id=root_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@@ -202,6 +217,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
@@ -239,6 +255,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
"context": context,
|
"context": context,
|
||||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
"variable_loader": variable_loader,
|
"variable_loader": variable_loader,
|
||||||
|
"root_node_id": root_node_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -435,6 +452,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
context: contextvars.Context,
|
context: contextvars.Context,
|
||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
@@ -478,6 +496,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
variable_loader=variable_loader,
|
variable_loader=variable_loader,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
system_user_id=system_user_id,
|
system_user_id=system_user_id,
|
||||||
|
root_node_id=root_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
system_user_id: str,
|
system_user_id: str,
|
||||||
|
root_node_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
@@ -44,6 +45,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||||
self._workflow = workflow
|
self._workflow = workflow
|
||||||
self._sys_user_id = system_user_id
|
self._sys_user_id = system_user_id
|
||||||
|
self._root_node_id = root_node_id
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -93,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = self._init_graph(graph_config=self._workflow.graph_dict)
|
graph = self._init_graph(graph_config=self._workflow.graph_dict, root_node_id=self._root_node_id)
|
||||||
|
|
||||||
# RUN WORKFLOW
|
# RUN WORKFLOW
|
||||||
workflow_entry = WorkflowEntry(
|
workflow_entry = WorkflowEntry(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
@@ -79,7 +79,7 @@ class WorkflowBasedAppRunner:
|
|||||||
self._variable_loader = variable_loader
|
self._variable_loader = variable_loader
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|
||||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
def _init_graph(self, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> Graph:
|
||||||
"""
|
"""
|
||||||
Init graph
|
Init graph
|
||||||
"""
|
"""
|
||||||
@@ -92,7 +92,7 @@ class WorkflowBasedAppRunner:
|
|||||||
if not isinstance(graph_config.get("edges"), list):
|
if not isinstance(graph_config.get("edges"), list):
|
||||||
raise ValueError("edges in workflow graph must be a list")
|
raise ValueError("edges in workflow graph must be a list")
|
||||||
# init graph
|
# init graph
|
||||||
graph = Graph.init(graph_config=graph_config)
|
graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise ValueError("graph not found in workflow")
|
raise ValueError("graph not found in workflow")
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class ProviderConfig(BasicProviderConfig):
|
|||||||
|
|
||||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||||
required: bool = False
|
required: bool = False
|
||||||
default: Optional[Union[int, str, float, bool]] = None
|
default: Optional[Union[int, str, float, bool, list]] = None
|
||||||
options: Optional[list[Option]] = None
|
options: Optional[list[Option]] = None
|
||||||
label: Optional[I18nObject] = None
|
label: Optional[I18nObject] = None
|
||||||
help: Optional[I18nObject] = None
|
help: Optional[I18nObject] = None
|
||||||
|
|||||||
128
api/core/helper/provider_encryption.py
Normal file
128
api/core/helper/provider_encryption.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import contextlib
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.helper import encrypter
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigCache(Protocol):
|
||||||
|
"""
|
||||||
|
Interface for provider configuration cache operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigEncrypter:
|
||||||
|
tenant_id: str
|
||||||
|
config: list[BasicProviderConfig]
|
||||||
|
provider_config_cache: ProviderConfigCache
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
config: list[BasicProviderConfig],
|
||||||
|
provider_config_cache: ProviderConfigCache,
|
||||||
|
):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.config = config
|
||||||
|
self.provider_config_cache = provider_config_cache
|
||||||
|
|
||||||
|
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
deep copy data
|
||||||
|
"""
|
||||||
|
return deepcopy(data)
|
||||||
|
|
||||||
|
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
encrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with encrypted values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||||
|
data[field_name] = encrypted
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def mask_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
mask credentials
|
||||||
|
|
||||||
|
return a deep copy of credentials with masked values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
if len(data[field_name]) > 6:
|
||||||
|
data[field_name] = (
|
||||||
|
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data[field_name] = "*" * len(data[field_name])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return self.mask_credentials(data)
|
||||||
|
|
||||||
|
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
decrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with decrypted values
|
||||||
|
"""
|
||||||
|
cached_credentials = self.provider_config_cache.get()
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
|
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
# if the value is None or empty string, skip decrypt
|
||||||
|
if not data[field_name]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||||
|
|
||||||
|
self.provider_config_cache.set(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||||
|
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||||
@@ -13,6 +13,7 @@ from core.plugin.entities.base import BasePluginEntity
|
|||||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolProviderEntity
|
from core.tools.entities.tool_entities import ToolProviderEntity
|
||||||
|
from core.trigger.entities.entities import TriggerProviderEntity
|
||||||
|
|
||||||
|
|
||||||
class PluginInstallationSource(enum.StrEnum):
|
class PluginInstallationSource(enum.StrEnum):
|
||||||
@@ -62,6 +63,7 @@ class PluginCategory(enum.StrEnum):
|
|||||||
Model = "model"
|
Model = "model"
|
||||||
Extension = "extension"
|
Extension = "extension"
|
||||||
AgentStrategy = "agent-strategy"
|
AgentStrategy = "agent-strategy"
|
||||||
|
Trigger = "trigger"
|
||||||
|
|
||||||
|
|
||||||
class PluginDeclaration(BaseModel):
|
class PluginDeclaration(BaseModel):
|
||||||
@@ -69,6 +71,7 @@ class PluginDeclaration(BaseModel):
|
|||||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
|
triggers: Optional[list[str]] = Field(default_factory=list[str])
|
||||||
|
|
||||||
class Meta(BaseModel):
|
class Meta(BaseModel):
|
||||||
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
|
||||||
@@ -89,6 +92,7 @@ class PluginDeclaration(BaseModel):
|
|||||||
repo: Optional[str] = Field(default=None)
|
repo: Optional[str] = Field(default=None)
|
||||||
verified: bool = Field(default=False)
|
verified: bool = Field(default=False)
|
||||||
tool: Optional[ToolProviderEntity] = None
|
tool: Optional[ToolProviderEntity] = None
|
||||||
|
trigger: Optional[TriggerProviderEntity] = None
|
||||||
model: Optional[ProviderEntity] = None
|
model: Optional[ProviderEntity] = None
|
||||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||||
@@ -104,6 +108,8 @@ class PluginDeclaration(BaseModel):
|
|||||||
values["category"] = PluginCategory.Model
|
values["category"] = PluginCategory.Model
|
||||||
elif values.get("agent_strategy"):
|
elif values.get("agent_strategy"):
|
||||||
values["category"] = PluginCategory.AgentStrategy
|
values["category"] = PluginCategory.AgentStrategy
|
||||||
|
elif values.get("trigger"):
|
||||||
|
values["category"] = PluginCategory.Trigger
|
||||||
else:
|
else:
|
||||||
values["category"] = PluginCategory.Extension
|
values["category"] = PluginCategory.Extension
|
||||||
return values
|
return values
|
||||||
@@ -184,6 +190,10 @@ class ToolProviderID(GenericProviderID):
|
|||||||
self.plugin_name = f"{self.provider_name}_tool"
|
self.plugin_name = f"{self.provider_name}_tool"
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderID(GenericProviderID):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PluginDependency(BaseModel):
|
class PluginDependency(BaseModel):
|
||||||
class Type(enum.StrEnum):
|
class Type(enum.StrEnum):
|
||||||
Github = PluginInstallationSource.Github.value
|
Github = PluginInstallationSource.Github.value
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import enum
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
@@ -13,6 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption
|
|||||||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||||
|
from core.trigger.entities.entities import TriggerProviderEntity
|
||||||
|
|
||||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||||
|
|
||||||
@@ -196,3 +198,48 @@ class PluginListResponse(BaseModel):
|
|||||||
|
|
||||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginTriggerProviderEntity(BaseModel):
|
||||||
|
provider: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
plugin_id: str
|
||||||
|
declaration: TriggerProviderEntity
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialType(enum.StrEnum):
|
||||||
|
API_KEY = "api-key"
|
||||||
|
OAUTH2 = "oauth2"
|
||||||
|
UNAUTHORIZED = "unauthorized"
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
if self == CredentialType.API_KEY:
|
||||||
|
return "API KEY"
|
||||||
|
elif self == CredentialType.OAUTH2:
|
||||||
|
return "AUTH"
|
||||||
|
elif self == CredentialType.UNAUTHORIZED:
|
||||||
|
return "UNAUTHORIZED"
|
||||||
|
else:
|
||||||
|
return self.value.replace("-", " ").upper()
|
||||||
|
|
||||||
|
def is_editable(self):
|
||||||
|
return self == CredentialType.API_KEY
|
||||||
|
|
||||||
|
def is_validate_allowed(self):
|
||||||
|
return self == CredentialType.API_KEY
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def values(cls):
|
||||||
|
return [item.value for item in cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def of(cls, credential_type: str) -> "CredentialType":
|
||||||
|
type_name = credential_type.lower()
|
||||||
|
if type_name == "api-key":
|
||||||
|
return cls.API_KEY
|
||||||
|
elif type_name == "oauth2":
|
||||||
|
return cls.OAUTH2
|
||||||
|
elif type_name == "unauthorized":
|
||||||
|
return cls.UNAUTHORIZED
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
@@ -237,3 +239,33 @@ class RequestFetchAppInfo(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
app_id: str
|
app_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class Event(BaseModel):
|
||||||
|
variables: Mapping[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerInvokeResponse(BaseModel):
|
||||||
|
event: Event
|
||||||
|
|
||||||
|
|
||||||
|
class PluginTriggerDispatchResponse(BaseModel):
|
||||||
|
triggers: list[str]
|
||||||
|
raw_http_response: str
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionResponse(BaseModel):
|
||||||
|
subscription: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||||
|
result: bool
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerDispatchResponse:
|
||||||
|
triggers: list[str]
|
||||||
|
response: Response
|
||||||
|
|
||||||
|
def __init__(self, triggers: list[str], response: Response):
|
||||||
|
self.triggers = triggers
|
||||||
|
self.response = response
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class DynamicSelectClient(BasePluginClient):
|
|||||||
provider: str,
|
provider: str,
|
||||||
action: str,
|
action: str,
|
||||||
credentials: Mapping[str, Any],
|
credentials: Mapping[str, Any],
|
||||||
|
credential_type: str,
|
||||||
parameter: str,
|
parameter: str,
|
||||||
) -> PluginDynamicSelectOptionsResponse:
|
) -> PluginDynamicSelectOptionsResponse:
|
||||||
"""
|
"""
|
||||||
@@ -29,6 +30,7 @@ class DynamicSelectClient(BasePluginClient):
|
|||||||
"data": {
|
"data": {
|
||||||
"provider": GenericProviderID(provider).provider_name,
|
"provider": GenericProviderID(provider).provider_name,
|
||||||
"credentials": credentials,
|
"credentials": credentials,
|
||||||
|
"credential_type": credential_type,
|
||||||
"provider_action": action,
|
"provider_action": action,
|
||||||
"parameter": parameter,
|
"parameter": parameter,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ from typing import Any, Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
|
|
||||||
|
|
||||||
class PluginToolManager(BasePluginClient):
|
class PluginToolManager(BasePluginClient):
|
||||||
|
|||||||
301
api/core/plugin/impl/trigger.py
Normal file
301
api/core/plugin/impl/trigger.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
import binascii
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import Request
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin import GenericProviderID, TriggerProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||||
|
from core.plugin.entities.request import (
|
||||||
|
PluginTriggerDispatchResponse,
|
||||||
|
TriggerDispatchResponse,
|
||||||
|
TriggerInvokeResponse,
|
||||||
|
TriggerSubscriptionResponse,
|
||||||
|
TriggerValidateProviderCredentialsResponse,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
from core.plugin.utils.http_parser import deserialize_response, serialize_request
|
||||||
|
from core.trigger.entities.entities import Subscription
|
||||||
|
|
||||||
|
|
||||||
|
class PluginTriggerManager(BasePluginClient):
|
||||||
|
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||||
|
"""
|
||||||
|
Fetch trigger providers for the given tenant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
|
for provider in json_response.get("data", []):
|
||||||
|
declaration = provider.get("declaration", {}) or {}
|
||||||
|
provider_id = provider.get("plugin_id") + "/" + provider.get("provider")
|
||||||
|
for trigger in declaration.get("triggers", []):
|
||||||
|
trigger["identity"]["provider"] = provider_id
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response(
|
||||||
|
"GET",
|
||||||
|
f"plugin/{tenant_id}/management/triggers",
|
||||||
|
list[PluginTriggerProviderEntity],
|
||||||
|
params={"page": 1, "page_size": 256},
|
||||||
|
transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
for provider in response:
|
||||||
|
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||||
|
|
||||||
|
# override the provider name for each trigger to plugin_id/provider_name
|
||||||
|
for trigger in provider.declaration.triggers:
|
||||||
|
trigger.identity.provider = provider.declaration.identity.name
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||||
|
"""
|
||||||
|
Fetch trigger provider for the given tenant and plugin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
|
data = json_response.get("data")
|
||||||
|
if data:
|
||||||
|
for trigger in data.get("declaration", {}).get("triggers", []):
|
||||||
|
trigger["identity"]["provider"] = str(provider_id)
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response(
|
||||||
|
"GET",
|
||||||
|
f"plugin/{tenant_id}/management/trigger",
|
||||||
|
PluginTriggerProviderEntity,
|
||||||
|
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
|
||||||
|
transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.declaration.identity.name = str(provider_id)
|
||||||
|
|
||||||
|
# override the provider name for each trigger to plugin_id/provider_name
|
||||||
|
for trigger in response.declaration.triggers:
|
||||||
|
trigger.identity.provider = str(provider_id)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def invoke_trigger(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
trigger: str,
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
credential_type: CredentialType,
|
||||||
|
request: Request,
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
) -> TriggerInvokeResponse:
|
||||||
|
"""
|
||||||
|
Invoke a trigger with the given parameters.
|
||||||
|
"""
|
||||||
|
trigger_provider_id = GenericProviderID(provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/trigger/invoke",
|
||||||
|
TriggerInvokeResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": trigger_provider_id.provider_name,
|
||||||
|
"trigger": trigger,
|
||||||
|
"credentials": credentials,
|
||||||
|
"credential_type": credential_type,
|
||||||
|
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||||
|
"parameters": parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return TriggerInvokeResponse(event=resp.event)
|
||||||
|
|
||||||
|
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||||
|
|
||||||
|
def validate_provider_credentials(
|
||||||
|
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the credentials of the trigger provider.
|
||||||
|
"""
|
||||||
|
trigger_provider_id = GenericProviderID(provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
|
||||||
|
TriggerValidateProviderCredentialsResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": trigger_provider_id.provider_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp.result
|
||||||
|
|
||||||
|
raise ValueError("No response received from plugin daemon for validate provider credentials")
|
||||||
|
|
||||||
|
def dispatch_event(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
subscription: Mapping[str, Any],
|
||||||
|
request: Request,
|
||||||
|
) -> TriggerDispatchResponse:
|
||||||
|
"""
|
||||||
|
Dispatch an event to triggers.
|
||||||
|
"""
|
||||||
|
trigger_provider_id = GenericProviderID(provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
|
||||||
|
PluginTriggerDispatchResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": trigger_provider_id.provider_name,
|
||||||
|
"subscription": subscription,
|
||||||
|
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return TriggerDispatchResponse(
|
||||||
|
triggers=resp.triggers,
|
||||||
|
response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError("No response received from plugin daemon for dispatch event")
|
||||||
|
|
||||||
|
def subscribe(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
endpoint: str,
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
) -> TriggerSubscriptionResponse:
|
||||||
|
"""
|
||||||
|
Subscribe to a trigger.
|
||||||
|
"""
|
||||||
|
trigger_provider_id = GenericProviderID(provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/trigger/subscribe",
|
||||||
|
TriggerSubscriptionResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": trigger_provider_id.provider_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"parameters": parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
raise ValueError("No response received from plugin daemon for subscribe")
|
||||||
|
|
||||||
|
def unsubscribe(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
subscription: Subscription,
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
) -> TriggerSubscriptionResponse:
|
||||||
|
"""
|
||||||
|
Unsubscribe from a trigger.
|
||||||
|
"""
|
||||||
|
trigger_provider_id = GenericProviderID(provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
|
||||||
|
TriggerSubscriptionResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": trigger_provider_id.provider_name,
|
||||||
|
"subscription": subscription.model_dump(),
|
||||||
|
"credentials": credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
raise ValueError("No response received from plugin daemon for unsubscribe")
|
||||||
|
|
||||||
|
def refresh(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
subscription: Subscription,
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
) -> TriggerSubscriptionResponse:
|
||||||
|
"""
|
||||||
|
Refresh a trigger subscription.
|
||||||
|
"""
|
||||||
|
trigger_provider_id = GenericProviderID(provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/trigger/refresh",
|
||||||
|
TriggerSubscriptionResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": trigger_provider_id.provider_name,
|
||||||
|
"subscription": subscription.model_dump(),
|
||||||
|
"credentials": credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
raise ValueError("No response received from plugin daemon for refresh")
|
||||||
159
api/core/plugin/utils/http_parser.py
Normal file
159
api/core/plugin/utils/http_parser.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from flask import Request, Response
|
||||||
|
from werkzeug.datastructures import Headers
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_request(request: Request) -> bytes:
|
||||||
|
method = request.method
|
||||||
|
path = request.full_path.rstrip("?")
|
||||||
|
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||||
|
|
||||||
|
for name, value in request.headers.items():
|
||||||
|
raw += f"{name}: {value}\r\n".encode()
|
||||||
|
|
||||||
|
raw += b"\r\n"
|
||||||
|
|
||||||
|
body = request.get_data(as_text=False)
|
||||||
|
if body:
|
||||||
|
raw += body
|
||||||
|
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_request(raw_data: bytes) -> Request:
|
||||||
|
header_end = raw_data.find(b"\r\n\r\n")
|
||||||
|
if header_end == -1:
|
||||||
|
header_end = raw_data.find(b"\n\n")
|
||||||
|
if header_end == -1:
|
||||||
|
header_data = raw_data
|
||||||
|
body = b""
|
||||||
|
else:
|
||||||
|
header_data = raw_data[:header_end]
|
||||||
|
body = raw_data[header_end + 2 :]
|
||||||
|
else:
|
||||||
|
header_data = raw_data[:header_end]
|
||||||
|
body = raw_data[header_end + 4 :]
|
||||||
|
|
||||||
|
lines = header_data.split(b"\r\n")
|
||||||
|
if len(lines) == 1 and b"\n" in lines[0]:
|
||||||
|
lines = header_data.split(b"\n")
|
||||||
|
|
||||||
|
if not lines or not lines[0]:
|
||||||
|
raise ValueError("Empty HTTP request")
|
||||||
|
|
||||||
|
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||||
|
parts = request_line.split(" ", 2)
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError(f"Invalid request line: {request_line}")
|
||||||
|
|
||||||
|
method = parts[0]
|
||||||
|
full_path = parts[1]
|
||||||
|
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||||
|
|
||||||
|
if "?" in full_path:
|
||||||
|
path, query_string = full_path.split("?", 1)
|
||||||
|
else:
|
||||||
|
path = full_path
|
||||||
|
query_string = ""
|
||||||
|
|
||||||
|
headers = Headers()
|
||||||
|
for line in lines[1:]:
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
line_str = line.decode("utf-8", errors="ignore")
|
||||||
|
if ":" not in line_str:
|
||||||
|
continue
|
||||||
|
name, value = line_str.split(":", 1)
|
||||||
|
headers.add(name, value.strip())
|
||||||
|
|
||||||
|
host = headers.get("Host", "localhost")
|
||||||
|
if ":" in host:
|
||||||
|
server_name, server_port = host.rsplit(":", 1)
|
||||||
|
else:
|
||||||
|
server_name = host
|
||||||
|
server_port = "80"
|
||||||
|
|
||||||
|
environ = {
|
||||||
|
"REQUEST_METHOD": method,
|
||||||
|
"PATH_INFO": path,
|
||||||
|
"QUERY_STRING": query_string,
|
||||||
|
"SERVER_NAME": server_name,
|
||||||
|
"SERVER_PORT": server_port,
|
||||||
|
"SERVER_PROTOCOL": protocol,
|
||||||
|
"wsgi.input": BytesIO(body),
|
||||||
|
"wsgi.url_scheme": "http",
|
||||||
|
}
|
||||||
|
|
||||||
|
if "Content-Type" in headers:
|
||||||
|
environ["CONTENT_TYPE"] = headers.get("Content-Type")
|
||||||
|
|
||||||
|
if "Content-Length" in headers:
|
||||||
|
environ["CONTENT_LENGTH"] = headers.get("Content-Length")
|
||||||
|
elif body:
|
||||||
|
environ["CONTENT_LENGTH"] = str(len(body))
|
||||||
|
|
||||||
|
for name, value in headers.items():
|
||||||
|
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||||
|
continue
|
||||||
|
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||||
|
environ[env_name] = value
|
||||||
|
|
||||||
|
return Request(environ)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_response(response: Response) -> bytes:
|
||||||
|
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||||
|
|
||||||
|
for name, value in response.headers.items():
|
||||||
|
raw += f"{name}: {value}\r\n".encode()
|
||||||
|
|
||||||
|
raw += b"\r\n"
|
||||||
|
|
||||||
|
body = response.get_data(as_text=False)
|
||||||
|
if body:
|
||||||
|
raw += body
|
||||||
|
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_response(raw_data: bytes) -> Response:
|
||||||
|
header_end = raw_data.find(b"\r\n\r\n")
|
||||||
|
if header_end == -1:
|
||||||
|
header_end = raw_data.find(b"\n\n")
|
||||||
|
if header_end == -1:
|
||||||
|
header_data = raw_data
|
||||||
|
body = b""
|
||||||
|
else:
|
||||||
|
header_data = raw_data[:header_end]
|
||||||
|
body = raw_data[header_end + 2 :]
|
||||||
|
else:
|
||||||
|
header_data = raw_data[:header_end]
|
||||||
|
body = raw_data[header_end + 4 :]
|
||||||
|
|
||||||
|
lines = header_data.split(b"\r\n")
|
||||||
|
if len(lines) == 1 and b"\n" in lines[0]:
|
||||||
|
lines = header_data.split(b"\n")
|
||||||
|
|
||||||
|
if not lines or not lines[0]:
|
||||||
|
raise ValueError("Empty HTTP response")
|
||||||
|
|
||||||
|
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||||
|
parts = status_line.split(" ", 2)
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError(f"Invalid status line: {status_line}")
|
||||||
|
|
||||||
|
status_code = int(parts[1])
|
||||||
|
|
||||||
|
response = Response(response=body, status=status_code)
|
||||||
|
|
||||||
|
for line in lines[1:]:
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
line_str = line.decode("utf-8", errors="ignore")
|
||||||
|
if ":" not in line_str:
|
||||||
|
continue
|
||||||
|
name, value = line_str.split(":", 1)
|
||||||
|
response.headers[name] = value.strip()
|
||||||
|
|
||||||
|
return response
|
||||||
@@ -4,7 +4,8 @@ from openai import BaseModel
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntime(BaseModel):
|
class ToolRuntime(BaseModel):
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ from typing import Any
|
|||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.tool import BuiltinTool
|
from core.tools.builtin_tool.tool import BuiltinTool
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
CredentialType,
|
|
||||||
OAuthSchema,
|
OAuthSchema,
|
||||||
ToolEntity,
|
ToolEntity,
|
||||||
ToolProviderEntity,
|
ToolProviderEntity,
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ from typing import Any, Literal, Optional
|
|||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.tools.__base.tool import ToolParameter
|
from core.tools.__base.tool import ToolParameter
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
|
||||||
|
|
||||||
class ToolApiEntity(BaseModel):
|
class ToolApiEntity(BaseModel):
|
||||||
|
|||||||
@@ -476,36 +476,3 @@ class ToolSelector(BaseModel):
|
|||||||
|
|
||||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||||
return self.model_dump()
|
return self.model_dump()
|
||||||
|
|
||||||
|
|
||||||
class CredentialType(enum.StrEnum):
|
|
||||||
API_KEY = "api-key"
|
|
||||||
OAUTH2 = "oauth2"
|
|
||||||
|
|
||||||
def get_name(self):
|
|
||||||
if self == CredentialType.API_KEY:
|
|
||||||
return "API KEY"
|
|
||||||
elif self == CredentialType.OAUTH2:
|
|
||||||
return "AUTH"
|
|
||||||
else:
|
|
||||||
return self.value.replace("-", " ").upper()
|
|
||||||
|
|
||||||
def is_editable(self):
|
|
||||||
return self == CredentialType.API_KEY
|
|
||||||
|
|
||||||
def is_validate_allowed(self):
|
|
||||||
return self == CredentialType.API_KEY
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def values(cls):
|
|
||||||
return [item.value for item in cls]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def of(cls, credential_type: str) -> "CredentialType":
|
|
||||||
type_name = credential_type.lower()
|
|
||||||
if type_name == "api-key":
|
|
||||||
return cls.API_KEY
|
|
||||||
elif type_name == "oauth2":
|
|
||||||
return cls.OAUTH2
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
@@ -47,7 +48,6 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
|||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
CredentialType,
|
|
||||||
ToolInvokeFrom,
|
ToolInvokeFrom,
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
|
|||||||
@@ -1,132 +1,23 @@
|
|||||||
import contextlib
|
# Import generic components from provider_encryption module
|
||||||
from copy import deepcopy
|
from core.helper.provider_encryption import (
|
||||||
from typing import Any, Optional, Protocol
|
ProviderConfigCache,
|
||||||
|
ProviderConfigEncrypter,
|
||||||
|
create_provider_encrypter,
|
||||||
|
)
|
||||||
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
# Re-export for backward compatibility
|
||||||
from core.helper import encrypter
|
__all__ = [
|
||||||
|
"ProviderConfigCache",
|
||||||
|
"ProviderConfigEncrypter",
|
||||||
|
"create_provider_encrypter",
|
||||||
|
"create_tool_provider_encrypter",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Tool-specific imports
|
||||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
from core.tools.__base.tool_provider import ToolProviderController
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfigCache(Protocol):
|
|
||||||
"""
|
|
||||||
Interface for provider configuration cache operations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get(self) -> Optional[dict]:
|
|
||||||
"""Get cached provider configuration"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def set(self, config: dict[str, Any]) -> None:
|
|
||||||
"""Cache provider configuration"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
"""Delete cached provider configuration"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfigEncrypter:
|
|
||||||
tenant_id: str
|
|
||||||
config: list[BasicProviderConfig]
|
|
||||||
provider_config_cache: ProviderConfigCache
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tenant_id: str,
|
|
||||||
config: list[BasicProviderConfig],
|
|
||||||
provider_config_cache: ProviderConfigCache,
|
|
||||||
):
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.config = config
|
|
||||||
self.provider_config_cache = provider_config_cache
|
|
||||||
|
|
||||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
deep copy data
|
|
||||||
"""
|
|
||||||
return deepcopy(data)
|
|
||||||
|
|
||||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
encrypt tool credentials with tenant id
|
|
||||||
|
|
||||||
return a deep copy of credentials with encrypted values
|
|
||||||
"""
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
|
||||||
data[field_name] = encrypted
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
mask tool credentials
|
|
||||||
|
|
||||||
return a deep copy of credentials with masked values
|
|
||||||
"""
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
if len(data[field_name]) > 6:
|
|
||||||
data[field_name] = (
|
|
||||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data[field_name] = "*" * len(data[field_name])
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
decrypt tool credentials with tenant id
|
|
||||||
|
|
||||||
return a deep copy of credentials with decrypted values
|
|
||||||
"""
|
|
||||||
cached_credentials = self.provider_config_cache.get()
|
|
||||||
if cached_credentials:
|
|
||||||
return cached_credentials
|
|
||||||
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
# if the value is None or empty string, skip decrypt
|
|
||||||
if not data[field_name]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
|
||||||
|
|
||||||
self.provider_config_cache.set(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
|
||||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
|
||||||
|
|
||||||
|
|
||||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||||
cache = SingletonProviderCredentialsCache(
|
cache = SingletonProviderCredentialsCache(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
|||||||
1
api/core/trigger/__init__.py
Normal file
1
api/core/trigger/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Core trigger module initialization
|
||||||
76
api/core/trigger/entities/api_entities.py
Normal file
76
api/core/trigger/entities/api_entities.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.trigger.entities.entities import (
|
||||||
|
SubscriptionSchema,
|
||||||
|
TriggerCreationMethod,
|
||||||
|
TriggerDescription,
|
||||||
|
TriggerIdentity,
|
||||||
|
TriggerParameter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderSubscriptionApiEntity(BaseModel):
|
||||||
|
id: str = Field(description="The unique id of the subscription")
|
||||||
|
name: str = Field(description="The name of the subscription")
|
||||||
|
provider: str = Field(description="The provider id of the subscription")
|
||||||
|
credential_type: CredentialType = Field(description="The type of the credential")
|
||||||
|
credentials: dict = Field(description="The credentials of the subscription")
|
||||||
|
endpoint: str = Field(description="The endpoint of the subscription")
|
||||||
|
parameters: dict = Field(description="The parameters of the subscription")
|
||||||
|
properties: dict = Field(description="The properties of the subscription")
|
||||||
|
workflows_in_use: int = Field(description="The number of workflows using this subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerApiEntity(BaseModel):
|
||||||
|
name: str = Field(description="The name of the trigger")
|
||||||
|
identity: TriggerIdentity = Field(description="The identity of the trigger")
|
||||||
|
description: TriggerDescription = Field(description="The description of the trigger")
|
||||||
|
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
|
||||||
|
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderApiEntity(BaseModel):
|
||||||
|
author: str = Field(..., description="The author of the trigger provider")
|
||||||
|
name: str = Field(..., description="The name of the trigger provider")
|
||||||
|
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||||
|
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||||
|
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||||
|
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||||
|
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||||
|
|
||||||
|
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||||
|
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||||
|
|
||||||
|
supported_creation_methods: list[TriggerCreationMethod] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Supported creation methods for the trigger provider. Possible values: 'OAUTH', 'APIKEY', 'MANUAL'."
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
|
||||||
|
oauth_client_schema: list[ProviderConfig] = Field(
|
||||||
|
default_factory=list, description="The schema of the OAuth client"
|
||||||
|
)
|
||||||
|
subscription_schema: Optional[SubscriptionSchema] = Field(
|
||||||
|
description="The subscription schema of the trigger provider"
|
||||||
|
)
|
||||||
|
triggers: list[TriggerApiEntity] = Field(description="The triggers of the trigger provider")
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionBuilderApiEntity(BaseModel):
|
||||||
|
id: str = Field(description="The id of the subscription builder")
|
||||||
|
name: str = Field(description="The name of the subscription builder")
|
||||||
|
provider: str = Field(description="The provider id of the subscription builder")
|
||||||
|
endpoint: str = Field(description="The endpoint id of the subscription builder")
|
||||||
|
parameters: Mapping[str, Any] = Field(description="The parameters of the subscription builder")
|
||||||
|
properties: Mapping[str, Any] = Field(description="The properties of the subscription builder")
|
||||||
|
credentials: Mapping[str, str] = Field(description="The credentials of the subscription builder")
|
||||||
|
credential_type: CredentialType = Field(description="The credential type of the subscription builder")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]
|
||||||
307
api/core/trigger/entities/entities.py
Normal file
307
api/core/trigger/entities/entities.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.entities.parameters import PluginParameterAutoGenerate, PluginParameterOption, PluginParameterTemplate
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerParameterType(StrEnum):
|
||||||
|
"""The type of the parameter"""
|
||||||
|
|
||||||
|
STRING = "string"
|
||||||
|
NUMBER = "number"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
SELECT = "select"
|
||||||
|
FILE = "file"
|
||||||
|
FILES = "files"
|
||||||
|
MODEL_SELECTOR = "model-selector"
|
||||||
|
APP_SELECTOR = "app-selector"
|
||||||
|
OBJECT = "object"
|
||||||
|
ARRAY = "array"
|
||||||
|
DYNAMIC_SELECT = "dynamic-select"
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerParameter(BaseModel):
|
||||||
|
"""
|
||||||
|
The parameter of the trigger
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(..., description="The name of the parameter")
|
||||||
|
label: I18nObject = Field(..., description="The label presented to the user")
|
||||||
|
type: TriggerParameterType = Field(..., description="The type of the parameter")
|
||||||
|
auto_generate: Optional[PluginParameterAutoGenerate] = Field(
|
||||||
|
default=None, description="The auto generate of the parameter"
|
||||||
|
)
|
||||||
|
template: Optional[PluginParameterTemplate] = Field(default=None, description="The template of the parameter")
|
||||||
|
scope: Optional[str] = None
|
||||||
|
required: Optional[bool] = False
|
||||||
|
multiple: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether the parameter is multiple select, only valid for select or dynamic-select type",
|
||||||
|
)
|
||||||
|
default: Union[int, float, str, list, None] = None
|
||||||
|
min: Union[float, int, None] = None
|
||||||
|
max: Union[float, int, None] = None
|
||||||
|
precision: Optional[int] = None
|
||||||
|
options: Optional[list[PluginParameterOption]] = None
|
||||||
|
description: Optional[I18nObject] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderIdentity(BaseModel):
|
||||||
|
"""
|
||||||
|
The identity of the trigger provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
author: str = Field(..., description="The author of the trigger provider")
|
||||||
|
name: str = Field(..., description="The name of the trigger provider")
|
||||||
|
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||||
|
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||||
|
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||||
|
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||||
|
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerIdentity(BaseModel):
|
||||||
|
"""
|
||||||
|
The identity of the trigger
|
||||||
|
"""
|
||||||
|
|
||||||
|
author: str = Field(..., description="The author of the trigger")
|
||||||
|
name: str = Field(..., description="The name of the trigger")
|
||||||
|
label: I18nObject = Field(..., description="The label of the trigger")
|
||||||
|
provider: Optional[str] = Field(default=None, description="The provider of the trigger")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerDescription(BaseModel):
|
||||||
|
"""
|
||||||
|
The description of the trigger
|
||||||
|
"""
|
||||||
|
|
||||||
|
human: I18nObject = Field(..., description="Human readable description")
|
||||||
|
llm: I18nObject = Field(..., description="LLM readable description")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
The configuration of a trigger
|
||||||
|
"""
|
||||||
|
|
||||||
|
identity: TriggerIdentity = Field(..., description="The identity of the trigger")
|
||||||
|
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
|
||||||
|
description: TriggerDescription = Field(..., description="The description of the trigger")
|
||||||
|
output_schema: Optional[Mapping[str, Any]] = Field(
|
||||||
|
default=None, description="The output schema that this trigger produces"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSchema(BaseModel):
|
||||||
|
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||||
|
credentials_schema: list[ProviderConfig] = Field(
|
||||||
|
default_factory=list, description="The schema of the OAuth credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
The subscription schema of the trigger provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
parameters_schema: list[TriggerParameter] | None = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The parameters schema required to create a subscription",
|
||||||
|
)
|
||||||
|
|
||||||
|
properties_schema: list[ProviderConfig] | None = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The configuration schema stored in the subscription entity",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_parameters(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the default parameters from the parameters schema"""
|
||||||
|
if not self.parameters_schema:
|
||||||
|
return {}
|
||||||
|
return {param.name: param.default for param in self.parameters_schema if param.default}
|
||||||
|
|
||||||
|
def get_default_properties(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the default properties from the properties schema"""
|
||||||
|
if not self.properties_schema:
|
||||||
|
return {}
|
||||||
|
return {prop.name: prop.default for prop in self.properties_schema if prop.default}
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
The configuration of a trigger provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider")
|
||||||
|
credentials_schema: list[ProviderConfig] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The credentials schema of the trigger provider",
|
||||||
|
)
|
||||||
|
oauth_schema: Optional[OAuthSchema] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The OAuth schema of the trigger provider if OAuth is supported",
|
||||||
|
)
|
||||||
|
subscription_schema: SubscriptionSchema = Field(
|
||||||
|
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
|
||||||
|
)
|
||||||
|
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(BaseModel):
|
||||||
|
"""
|
||||||
|
Result of a successful trigger subscription operation.
|
||||||
|
|
||||||
|
Contains all information needed to manage the subscription lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expires_at: int = Field(
|
||||||
|
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
|
||||||
|
properties: Mapping[str, Any] = Field(
|
||||||
|
..., description="Subscription data containing all properties and provider-specific information"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Unsubscription(BaseModel):
|
||||||
|
"""
|
||||||
|
Result of a trigger unsubscription operation.
|
||||||
|
|
||||||
|
Provides detailed information about the unsubscription attempt,
|
||||||
|
including success status and error details if failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether the unsubscription was successful")
|
||||||
|
|
||||||
|
message: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Human-readable message about the operation result. "
|
||||||
|
"Success message for successful operations, "
|
||||||
|
"detailed error information for failures.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestLog(BaseModel):
|
||||||
|
id: str = Field(..., description="The id of the request log")
|
||||||
|
endpoint: str = Field(..., description="The endpoint of the request log")
|
||||||
|
request: dict = Field(..., description="The request of the request log")
|
||||||
|
response: dict = Field(..., description="The response of the request log")
|
||||||
|
created_at: datetime = Field(..., description="The created at of the request log")
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionBuilder(BaseModel):
|
||||||
|
id: str = Field(..., description="The id of the subscription builder")
|
||||||
|
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||||
|
tenant_id: str = Field(..., description="The tenant id of the subscription builder")
|
||||||
|
user_id: str = Field(..., description="The user id of the subscription builder")
|
||||||
|
provider_id: str = Field(..., description="The provider id of the subscription builder")
|
||||||
|
endpoint_id: str = Field(..., description="The endpoint id of the subscription builder")
|
||||||
|
parameters: Mapping[str, Any] = Field(..., description="The parameters of the subscription builder")
|
||||||
|
properties: Mapping[str, Any] = Field(..., description="The properties of the subscription builder")
|
||||||
|
credentials: Mapping[str, str] = Field(..., description="The credentials of the subscription builder")
|
||||||
|
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||||
|
credential_expires_at: int | None = Field(
|
||||||
|
default=None, description="The credential expires at of the subscription builder"
|
||||||
|
)
|
||||||
|
expires_at: int = Field(..., description="The expires at of the subscription builder")
|
||||||
|
|
||||||
|
def to_subscription(self) -> Subscription:
|
||||||
|
return Subscription(
|
||||||
|
expires_at=self.expires_at,
|
||||||
|
endpoint=self.endpoint_id,
|
||||||
|
properties=self.properties,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionBuilderUpdater(BaseModel):
|
||||||
|
name: str | None = Field(default=None, description="The name of the subscription builder")
|
||||||
|
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters of the subscription builder")
|
||||||
|
properties: Mapping[str, Any] | None = Field(default=None, description="The properties of the subscription builder")
|
||||||
|
credentials: Mapping[str, str] | None = Field(
|
||||||
|
default=None, description="The credentials of the subscription builder"
|
||||||
|
)
|
||||||
|
credential_type: str | None = Field(default=None, description="The credential type of the subscription builder")
|
||||||
|
credential_expires_at: int | None = Field(
|
||||||
|
default=None, description="The credential expires at of the subscription builder"
|
||||||
|
)
|
||||||
|
expires_at: int | None = Field(default=None, description="The expires at of the subscription builder")
|
||||||
|
|
||||||
|
def update(self, subscription_builder: SubscriptionBuilder) -> None:
|
||||||
|
if self.name:
|
||||||
|
subscription_builder.name = self.name
|
||||||
|
if self.parameters:
|
||||||
|
subscription_builder.parameters = self.parameters
|
||||||
|
if self.properties:
|
||||||
|
subscription_builder.properties = self.properties
|
||||||
|
if self.credentials:
|
||||||
|
subscription_builder.credentials = self.credentials
|
||||||
|
if self.credential_type:
|
||||||
|
subscription_builder.credential_type = self.credential_type
|
||||||
|
if self.credential_expires_at:
|
||||||
|
subscription_builder.credential_expires_at = self.credential_expires_at
|
||||||
|
if self.expires_at:
|
||||||
|
subscription_builder.expires_at = self.expires_at
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerEventData(BaseModel):
|
||||||
|
"""Event data dispatched to trigger sessions."""
|
||||||
|
|
||||||
|
subscription_id: str
|
||||||
|
triggers: list[str]
|
||||||
|
request_id: str
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerInputs(BaseModel):
|
||||||
|
"""Standard inputs for trigger nodes."""
|
||||||
|
|
||||||
|
request_id: str
|
||||||
|
trigger_name: str
|
||||||
|
subscription_id: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_trigger_entity(cls, request_id: str, subscription_id: str, trigger: TriggerEntity) -> "TriggerInputs":
|
||||||
|
"""Create from trigger entity (for production)."""
|
||||||
|
return cls(request_id=request_id, trigger_name=trigger.identity.name, subscription_id=subscription_id)
|
||||||
|
|
||||||
|
def to_workflow_args(self) -> dict[str, Any]:
|
||||||
|
"""Convert to workflow arguments format."""
|
||||||
|
return {"inputs": self.model_dump(), "files": []}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dict (alias for model_dump)."""
|
||||||
|
return self.model_dump()
|
||||||
|
|
||||||
|
class TriggerCreationMethod(StrEnum):
|
||||||
|
OAUTH = "OAUTH"
|
||||||
|
APIKEY = "APIKEY"
|
||||||
|
MANUAL = "MANUAL"
|
||||||
|
|
||||||
|
# Export all entities
|
||||||
|
__all__ = [
|
||||||
|
"OAuthSchema",
|
||||||
|
"RequestLog",
|
||||||
|
"Subscription",
|
||||||
|
"SubscriptionBuilder",
|
||||||
|
"TriggerCreationMethod",
|
||||||
|
"TriggerDescription",
|
||||||
|
"TriggerEntity",
|
||||||
|
"TriggerEventData",
|
||||||
|
"TriggerIdentity",
|
||||||
|
"TriggerInputs",
|
||||||
|
"TriggerParameter",
|
||||||
|
"TriggerParameterType",
|
||||||
|
"TriggerProviderEntity",
|
||||||
|
"TriggerProviderIdentity",
|
||||||
|
"Unsubscription",
|
||||||
|
]
|
||||||
2
api/core/trigger/errors.py
Normal file
2
api/core/trigger/errors.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
class TriggerProviderCredentialValidationError(ValueError):
|
||||||
|
pass
|
||||||
358
api/core/trigger/provider.py
Normal file
358
api/core/trigger/provider.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""
|
||||||
|
Trigger Provider Controller for managing trigger providers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from flask import Request
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.plugin.entities.request import (
|
||||||
|
TriggerDispatchResponse,
|
||||||
|
TriggerInvokeResponse,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.trigger import PluginTriggerManager
|
||||||
|
from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProviderApiEntity
|
||||||
|
from core.trigger.entities.entities import (
|
||||||
|
ProviderConfig,
|
||||||
|
Subscription,
|
||||||
|
SubscriptionSchema,
|
||||||
|
TriggerCreationMethod,
|
||||||
|
TriggerEntity,
|
||||||
|
TriggerProviderEntity,
|
||||||
|
TriggerProviderIdentity,
|
||||||
|
Unsubscription,
|
||||||
|
)
|
||||||
|
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginTriggerProviderController:
|
||||||
|
"""
|
||||||
|
Controller for plugin trigger providers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: TriggerProviderEntity,
|
||||||
|
plugin_id: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
tenant_id: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize plugin trigger provider controller
|
||||||
|
|
||||||
|
:param entity: Trigger provider entity
|
||||||
|
:param plugin_id: Plugin ID
|
||||||
|
:param plugin_unique_identifier: Plugin unique identifier
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
"""
|
||||||
|
self.entity = entity
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.provider_id = provider_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def get_provider_id(self) -> TriggerProviderID:
|
||||||
|
"""
|
||||||
|
Get provider ID
|
||||||
|
"""
|
||||||
|
return self.provider_id
|
||||||
|
|
||||||
|
def to_api_entity(self) -> TriggerProviderApiEntity:
|
||||||
|
"""
|
||||||
|
Convert to API entity
|
||||||
|
"""
|
||||||
|
icon = (
|
||||||
|
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon)
|
||||||
|
if self.entity.identity.icon
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
icon_dark = (
|
||||||
|
PluginService.get_plugin_icon_url(self.tenant_id, self.entity.identity.icon_dark)
|
||||||
|
if self.entity.identity.icon_dark
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
supported_creation_methods = []
|
||||||
|
if self.entity.oauth_schema:
|
||||||
|
supported_creation_methods.append(TriggerCreationMethod.OAUTH)
|
||||||
|
if self.entity.credentials_schema:
|
||||||
|
supported_creation_methods.append(TriggerCreationMethod.APIKEY)
|
||||||
|
if self.entity.subscription_schema:
|
||||||
|
supported_creation_methods.append(TriggerCreationMethod.MANUAL)
|
||||||
|
return TriggerProviderApiEntity(
|
||||||
|
author=self.entity.identity.author,
|
||||||
|
name=self.entity.identity.name,
|
||||||
|
label=self.entity.identity.label,
|
||||||
|
description=self.entity.identity.description,
|
||||||
|
icon=icon,
|
||||||
|
icon_dark=icon_dark,
|
||||||
|
tags=self.entity.identity.tags,
|
||||||
|
plugin_id=self.plugin_id,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
credentials_schema=self.entity.credentials_schema,
|
||||||
|
oauth_client_schema=self.entity.oauth_schema.client_schema if self.entity.oauth_schema else [],
|
||||||
|
subscription_schema=self.entity.subscription_schema,
|
||||||
|
supported_creation_methods=supported_creation_methods,
|
||||||
|
triggers=[
|
||||||
|
TriggerApiEntity(
|
||||||
|
name=trigger.identity.name,
|
||||||
|
identity=trigger.identity,
|
||||||
|
description=trigger.description,
|
||||||
|
parameters=trigger.parameters,
|
||||||
|
output_schema=trigger.output_schema,
|
||||||
|
)
|
||||||
|
for trigger in self.entity.triggers
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identity(self) -> TriggerProviderIdentity:
|
||||||
|
"""Get provider identity"""
|
||||||
|
return self.entity.identity
|
||||||
|
|
||||||
|
def get_triggers(self) -> list[TriggerEntity]:
|
||||||
|
"""
|
||||||
|
Get all triggers for this provider
|
||||||
|
|
||||||
|
:return: List of trigger entities
|
||||||
|
"""
|
||||||
|
return self.entity.triggers
|
||||||
|
|
||||||
|
def get_trigger(self, trigger_name: str) -> Optional[TriggerEntity]:
|
||||||
|
"""
|
||||||
|
Get a specific trigger by name
|
||||||
|
|
||||||
|
:param trigger_name: Trigger name
|
||||||
|
:return: Trigger entity or None
|
||||||
|
"""
|
||||||
|
for trigger in self.entity.triggers:
|
||||||
|
if trigger.identity.name == trigger_name:
|
||||||
|
return trigger
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_subscription_schema(self) -> SubscriptionSchema:
|
||||||
|
"""
|
||||||
|
Get subscription schema for this provider
|
||||||
|
|
||||||
|
:return: List of subscription config schemas
|
||||||
|
"""
|
||||||
|
return self.entity.subscription_schema
|
||||||
|
|
||||||
|
def validate_credentials(self, user_id: str, credentials: Mapping[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Validate credentials against schema
|
||||||
|
|
||||||
|
:param credentials: Credentials to validate
|
||||||
|
:return: Validation response
|
||||||
|
"""
|
||||||
|
# First validate against schema
|
||||||
|
for config in self.entity.credentials_schema:
|
||||||
|
if config.required and config.name not in credentials:
|
||||||
|
raise TriggerProviderCredentialValidationError(f"Missing required credential field: {config.name}")
|
||||||
|
|
||||||
|
# Then validate with the plugin daemon
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider_id = self.get_provider_id()
|
||||||
|
response = manager.validate_provider_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=str(provider_id),
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
if not response:
|
||||||
|
raise TriggerProviderCredentialValidationError(
|
||||||
|
"Invalid credentials",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||||
|
"""
|
||||||
|
Get supported credential types for this provider.
|
||||||
|
|
||||||
|
:return: List of supported credential types
|
||||||
|
"""
|
||||||
|
types = []
|
||||||
|
if self.entity.oauth_schema:
|
||||||
|
types.append(CredentialType.OAUTH2)
|
||||||
|
if self.entity.credentials_schema:
|
||||||
|
types.append(CredentialType.API_KEY)
|
||||||
|
return types
|
||||||
|
|
||||||
|
def get_credentials_schema(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||||
|
"""
|
||||||
|
Get credentials schema by credential type
|
||||||
|
|
||||||
|
:param credential_type: The type of credential (oauth or api_key)
|
||||||
|
:return: List of provider config schemas
|
||||||
|
"""
|
||||||
|
credential_type = CredentialType.of(credential_type) if isinstance(credential_type, str) else credential_type
|
||||||
|
if credential_type == CredentialType.OAUTH2:
|
||||||
|
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||||
|
if credential_type == CredentialType.API_KEY:
|
||||||
|
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||||
|
if credential_type == CredentialType.UNAUTHORIZED:
|
||||||
|
return []
|
||||||
|
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||||
|
|
||||||
|
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
|
||||||
|
"""
|
||||||
|
Get credential schema config by credential type
|
||||||
|
"""
|
||||||
|
return [x.to_basic_provider_config() for x in self.get_credentials_schema(credential_type)]
|
||||||
|
|
||||||
|
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||||
|
"""
|
||||||
|
Get OAuth client schema for this provider
|
||||||
|
|
||||||
|
:return: List of OAuth client config schemas
|
||||||
|
"""
|
||||||
|
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||||
|
|
||||||
|
def get_properties_schema(self) -> list[BasicProviderConfig]:
|
||||||
|
"""
|
||||||
|
Get properties schema for this provider
|
||||||
|
|
||||||
|
:return: List of properties config schemas
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
[x.to_basic_provider_config() for x in self.entity.subscription_schema.properties_schema.copy()]
|
||||||
|
if self.entity.subscription_schema.properties_schema
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
|
||||||
|
"""
|
||||||
|
Dispatch a trigger through plugin runtime
|
||||||
|
|
||||||
|
:param user_id: User ID
|
||||||
|
:param request: Flask request object
|
||||||
|
:param subscription: Subscription
|
||||||
|
:return: Dispatch response with triggers and raw HTTP response
|
||||||
|
"""
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider_id = self.get_provider_id()
|
||||||
|
|
||||||
|
response = manager.dispatch_event(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=str(provider_id),
|
||||||
|
subscription=subscription.model_dump(),
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def invoke_trigger(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
trigger_name: str,
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
credential_type: CredentialType,
|
||||||
|
request: Request,
|
||||||
|
) -> TriggerInvokeResponse:
|
||||||
|
"""
|
||||||
|
Execute a trigger through plugin runtime
|
||||||
|
|
||||||
|
:param user_id: User ID
|
||||||
|
:param trigger_name: Trigger name
|
||||||
|
:param parameters: Trigger parameters
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:param credential_type: Credential type
|
||||||
|
:param request: Request
|
||||||
|
:return: Trigger execution result
|
||||||
|
"""
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider_id = self.get_provider_id()
|
||||||
|
|
||||||
|
return manager.invoke_trigger(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=str(provider_id),
|
||||||
|
trigger=trigger_name,
|
||||||
|
credentials=credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
request=request,
|
||||||
|
parameters=parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
def subscribe_trigger(
|
||||||
|
self, user_id: str, endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str]
|
||||||
|
) -> Subscription:
|
||||||
|
"""
|
||||||
|
Subscribe to a trigger through plugin runtime
|
||||||
|
|
||||||
|
:param user_id: User ID
|
||||||
|
:param endpoint: Subscription endpoint
|
||||||
|
:param subscription_params: Subscription parameters
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:return: Subscription result
|
||||||
|
"""
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider_id = self.get_provider_id()
|
||||||
|
|
||||||
|
response = manager.subscribe(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=str(provider_id),
|
||||||
|
credentials=credentials,
|
||||||
|
endpoint=endpoint,
|
||||||
|
parameters=parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Subscription.model_validate(response.subscription)
|
||||||
|
|
||||||
|
def unsubscribe_trigger(
|
||||||
|
self, user_id: str, subscription: Subscription, credentials: Mapping[str, str]
|
||||||
|
) -> Unsubscription:
|
||||||
|
"""
|
||||||
|
Unsubscribe from a trigger through plugin runtime
|
||||||
|
|
||||||
|
:param user_id: User ID
|
||||||
|
:param subscription: Subscription metadata
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:return: Unsubscription result
|
||||||
|
"""
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider_id = self.get_provider_id()
|
||||||
|
|
||||||
|
response = manager.unsubscribe(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=str(provider_id),
|
||||||
|
subscription=subscription,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Unsubscription.model_validate(response.subscription)
|
||||||
|
|
||||||
|
def refresh_trigger(self, subscription: Subscription, credentials: Mapping[str, str]) -> Subscription:
|
||||||
|
"""
|
||||||
|
Refresh a trigger subscription through plugin runtime
|
||||||
|
|
||||||
|
:param subscription: Subscription metadata
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:return: Refreshed subscription result
|
||||||
|
"""
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider_id = self.get_provider_id()
|
||||||
|
|
||||||
|
response = manager.refresh(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id="system", # System refresh
|
||||||
|
provider=str(provider_id),
|
||||||
|
subscription=subscription,
|
||||||
|
credentials=credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Subscription.model_validate(response.subscription)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PluginTriggerProviderController"]
|
||||||
254
api/core/trigger/trigger_manager.py
Normal file
254
api/core/trigger/trigger_manager.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
"""
|
||||||
|
Trigger Manager for loading and managing trigger providers and triggers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from flask import Request
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.plugin.entities.request import TriggerInvokeResponse
|
||||||
|
from core.plugin.impl.trigger import PluginTriggerManager
|
||||||
|
from core.trigger.entities.entities import (
|
||||||
|
Subscription,
|
||||||
|
SubscriptionSchema,
|
||||||
|
TriggerEntity,
|
||||||
|
Unsubscription,
|
||||||
|
)
|
||||||
|
from core.trigger.provider import PluginTriggerProviderController
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerManager:
|
||||||
|
"""
|
||||||
|
Manager for trigger providers and triggers
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_plugin_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||||
|
"""
|
||||||
|
List all plugin trigger providers for a tenant
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:return: List of trigger provider controllers
|
||||||
|
"""
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider_entities = manager.fetch_trigger_providers(tenant_id)
|
||||||
|
|
||||||
|
controllers = []
|
||||||
|
for provider in provider_entities:
|
||||||
|
try:
|
||||||
|
controller = PluginTriggerProviderController(
|
||||||
|
entity=provider.declaration,
|
||||||
|
plugin_id=provider.plugin_id,
|
||||||
|
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||||
|
provider_id=TriggerProviderID(provider.provider),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
controllers.append(controller)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load trigger provider %s", provider.plugin_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return controllers
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trigger_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderController:
|
||||||
|
"""
|
||||||
|
Get a specific plugin trigger provider
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:return: Trigger provider controller or None
|
||||||
|
"""
|
||||||
|
# check if context is set
|
||||||
|
try:
|
||||||
|
contexts.plugin_trigger_providers.get()
|
||||||
|
except LookupError:
|
||||||
|
contexts.plugin_trigger_providers.set({})
|
||||||
|
contexts.plugin_trigger_providers_lock.set(Lock())
|
||||||
|
|
||||||
|
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
|
||||||
|
provider_id_str = str(provider_id)
|
||||||
|
if provider_id_str in plugin_trigger_providers:
|
||||||
|
return plugin_trigger_providers[provider_id_str]
|
||||||
|
|
||||||
|
with contexts.plugin_trigger_providers_lock.get():
|
||||||
|
# double check
|
||||||
|
plugin_trigger_providers = contexts.plugin_trigger_providers.get()
|
||||||
|
if provider_id_str in plugin_trigger_providers:
|
||||||
|
return plugin_trigger_providers[provider_id_str]
|
||||||
|
|
||||||
|
manager = PluginTriggerManager()
|
||||||
|
provider = manager.fetch_trigger_provider(tenant_id, provider_id)
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise ValueError(f"Trigger provider {provider_id} not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
controller = PluginTriggerProviderController(
|
||||||
|
entity=provider.declaration,
|
||||||
|
plugin_id=provider.plugin_id,
|
||||||
|
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||||
|
provider_id=provider_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
plugin_trigger_providers[provider_id_str] = controller
|
||||||
|
return controller
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to load trigger provider")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||||
|
"""
|
||||||
|
List all trigger providers (plugin)
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:return: List of all trigger provider controllers
|
||||||
|
"""
|
||||||
|
return cls.list_plugin_trigger_providers(tenant_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_triggers_by_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerEntity]:
|
||||||
|
"""
|
||||||
|
List all triggers for a specific provider
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:return: List of trigger entities
|
||||||
|
"""
|
||||||
|
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
return provider.get_triggers()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trigger(cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str) -> Optional[TriggerEntity]:
|
||||||
|
"""
|
||||||
|
Get a specific trigger
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:param trigger_name: Trigger name
|
||||||
|
:return: Trigger entity or None
|
||||||
|
"""
|
||||||
|
return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invoke_trigger(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
trigger_name: str,
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
credential_type: CredentialType,
|
||||||
|
request: Request,
|
||||||
|
) -> TriggerInvokeResponse:
|
||||||
|
"""
|
||||||
|
Execute a trigger
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param user_id: User ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:param trigger_name: Trigger name
|
||||||
|
:param parameters: Trigger parameters
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:param credential_type: Credential type
|
||||||
|
:param request: Request
|
||||||
|
:return: Trigger execution result
|
||||||
|
"""
|
||||||
|
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
trigger = provider.get_trigger(trigger_name)
|
||||||
|
if not trigger:
|
||||||
|
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
|
||||||
|
return provider.invoke_trigger(user_id, trigger_name, parameters, credentials, credential_type, request)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def subscribe_trigger(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
endpoint: str,
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
) -> Subscription:
|
||||||
|
"""
|
||||||
|
Subscribe to a trigger (e.g., register webhook)
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param user_id: User ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:param endpoint: Subscription endpoint
|
||||||
|
:param parameters: Subscription parameters
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:return: Subscription result
|
||||||
|
"""
|
||||||
|
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
return provider.subscribe_trigger(
|
||||||
|
user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unsubscribe_trigger(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription: Subscription,
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
) -> Unsubscription:
|
||||||
|
"""
|
||||||
|
Unsubscribe from a trigger
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param user_id: User ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:param subscription: Subscription metadata from subscribe operation
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:return: Unsubscription result
|
||||||
|
"""
|
||||||
|
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> SubscriptionSchema:
|
||||||
|
"""
|
||||||
|
Get provider subscription schema
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:return: List of subscription config schemas
|
||||||
|
"""
|
||||||
|
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def refresh_trigger(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription: Subscription,
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
) -> Subscription:
|
||||||
|
"""
|
||||||
|
Refresh a trigger subscription
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider ID
|
||||||
|
:param trigger_name: Trigger name
|
||||||
|
:param subscription: Subscription metadata from subscribe operation
|
||||||
|
:param credentials: Provider credentials
|
||||||
|
:return: Refreshed subscription result
|
||||||
|
"""
|
||||||
|
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(subscription, credentials)
|
||||||
|
|
||||||
|
|
||||||
|
# Export
|
||||||
|
__all__ = ["TriggerManager"]
|
||||||
145
api/core/trigger/utils/encryption.py
Normal file
145
api/core/trigger/utils/encryption.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
|
||||||
|
from core.helper.provider_cache import ProviderCredentialsCache
|
||||||
|
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||||
|
from core.trigger.provider import PluginTriggerProviderController
|
||||||
|
from models.trigger import TriggerSubscription
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderCredentialsCache(ProviderCredentialsCache):
|
||||||
|
"""Cache for trigger provider credentials"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
|
||||||
|
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
provider_id = kwargs["provider_id"]
|
||||||
|
credential_id = kwargs["credential_id"]
|
||||||
|
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
|
||||||
|
"""Cache for trigger provider OAuth client"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider_id: str):
|
||||||
|
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
provider_id = kwargs["provider_id"]
|
||||||
|
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderPropertiesCache(ProviderCredentialsCache):
|
||||||
|
"""Cache for trigger provider properties"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, provider_id: str, subscription_id: str):
|
||||||
|
super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id)
|
||||||
|
|
||||||
|
def _generate_cache_key(self, **kwargs) -> str:
|
||||||
|
tenant_id = kwargs["tenant_id"]
|
||||||
|
provider_id = kwargs["provider_id"]
|
||||||
|
subscription_id = kwargs["subscription_id"]
|
||||||
|
return f"trigger_properties:tenant_id:{tenant_id}:provider_id:{provider_id}:subscription_id:{subscription_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id: str,
|
||||||
|
controller: PluginTriggerProviderController,
|
||||||
|
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
|
||||||
|
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||||
|
cache = TriggerProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=str(controller.get_provider_id()),
|
||||||
|
credential_id=subscription.id,
|
||||||
|
)
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=controller.get_credential_schema_config(subscription.credential_type),
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return encrypter, cache
|
||||||
|
|
||||||
|
|
||||||
|
def delete_cache_for_subscription(tenant_id: str, provider_id: str, subscription_id: str):
|
||||||
|
cache = TriggerProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
credential_id=subscription_id,
|
||||||
|
)
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
|
|
||||||
|
def create_trigger_provider_encrypter_for_properties(
|
||||||
|
tenant_id: str,
|
||||||
|
controller: PluginTriggerProviderController,
|
||||||
|
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
|
||||||
|
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||||
|
cache = TriggerProviderPropertiesCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=str(controller.get_provider_id()),
|
||||||
|
subscription_id=subscription.id,
|
||||||
|
)
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=controller.get_properties_schema(),
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return encrypter, cache
|
||||||
|
|
||||||
|
|
||||||
|
def create_trigger_provider_encrypter(
|
||||||
|
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
|
||||||
|
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||||
|
cache = TriggerProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=str(controller.get_provider_id()),
|
||||||
|
credential_id=credential_id,
|
||||||
|
)
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=controller.get_credential_schema_config(credential_type),
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return encrypter, cache
|
||||||
|
|
||||||
|
|
||||||
|
def create_trigger_provider_oauth_encrypter(
|
||||||
|
tenant_id: str, controller: PluginTriggerProviderController
|
||||||
|
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||||
|
cache = TriggerProviderOAuthClientParamsCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=str(controller.get_provider_id()),
|
||||||
|
)
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return encrypter, cache
|
||||||
|
|
||||||
|
|
||||||
|
def masked_credentials(
|
||||||
|
schemas: list[ProviderConfig],
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
) -> Mapping[str, str]:
|
||||||
|
masked_credentials = {}
|
||||||
|
configs = {x.name: x.to_basic_provider_config() for x in schemas}
|
||||||
|
for key, value in credentials.items():
|
||||||
|
config = configs.get(key)
|
||||||
|
if not config:
|
||||||
|
masked_credentials[key] = value
|
||||||
|
continue
|
||||||
|
if config.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if len(value) <= 4:
|
||||||
|
masked_credentials[key] = "*" * len(value)
|
||||||
|
else:
|
||||||
|
masked_credentials[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||||
|
else:
|
||||||
|
masked_credentials[key] = value
|
||||||
|
return masked_credentials
|
||||||
5
api/core/trigger/utils/endpoint.py
Normal file
5
api/core/trigger/utils/endpoint.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
def parse_endpoint_id(endpoint_id: str) -> str:
|
||||||
|
return f"{dify_config.CONSOLE_API_URL}/triggers/plugin/{endpoint_id}"
|
||||||
@@ -25,6 +25,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
|||||||
TOTAL_PRICE = "total_price"
|
TOTAL_PRICE = "total_price"
|
||||||
CURRENCY = "currency"
|
CURRENCY = "currency"
|
||||||
TOOL_INFO = "tool_info"
|
TOOL_INFO = "tool_info"
|
||||||
|
TRIGGER_INFO = "trigger_info"
|
||||||
AGENT_LOG = "agent_log"
|
AGENT_LOG = "agent_log"
|
||||||
ITERATION_ID = "iteration_id"
|
ITERATION_ID = "iteration_id"
|
||||||
ITERATION_INDEX = "iteration_index"
|
ITERATION_INDEX = "iteration_index"
|
||||||
|
|||||||
@@ -25,6 +25,18 @@ class NodeType(StrEnum):
|
|||||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||||
LIST_OPERATOR = "list-operator"
|
LIST_OPERATOR = "list-operator"
|
||||||
AGENT = "agent"
|
AGENT = "agent"
|
||||||
|
TRIGGER_WEBHOOK = "trigger-webhook"
|
||||||
|
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||||
|
TRIGGER_PLUGIN = "trigger-plugin"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_start_node(self) -> bool:
|
||||||
|
return self in [
|
||||||
|
NodeType.START,
|
||||||
|
NodeType.TRIGGER_WEBHOOK,
|
||||||
|
NodeType.TRIGGER_SCHEDULE,
|
||||||
|
NodeType.TRIGGER_PLUGIN,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ErrorStrategy(StrEnum):
|
class ErrorStrategy(StrEnum):
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
|||||||
from core.workflow.nodes.start import StartNode
|
from core.workflow.nodes.start import StartNode
|
||||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||||
from core.workflow.nodes.tool import ToolNode
|
from core.workflow.nodes.tool import ToolNode
|
||||||
|
from core.workflow.nodes.trigger_plugin import TriggerPluginNode
|
||||||
|
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
|
||||||
|
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
|
||||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||||
@@ -132,4 +135,16 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
|||||||
"2": AgentNode,
|
"2": AgentNode,
|
||||||
"1": AgentNode,
|
"1": AgentNode,
|
||||||
},
|
},
|
||||||
|
NodeType.TRIGGER_WEBHOOK: {
|
||||||
|
LATEST_VERSION: TriggerWebhookNode,
|
||||||
|
"1": TriggerWebhookNode,
|
||||||
|
},
|
||||||
|
NodeType.TRIGGER_PLUGIN: {
|
||||||
|
LATEST_VERSION: TriggerPluginNode,
|
||||||
|
"1": TriggerPluginNode,
|
||||||
|
},
|
||||||
|
NodeType.TRIGGER_SCHEDULE: {
|
||||||
|
LATEST_VERSION: TriggerScheduleNode,
|
||||||
|
"1": TriggerScheduleNode,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
3
api/core/workflow/nodes/trigger_plugin/__init__.py
Normal file
3
api/core/workflow/nodes/trigger_plugin/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .trigger_plugin_node import TriggerPluginNode
|
||||||
|
|
||||||
|
__all__ = ["TriggerPluginNode"]
|
||||||
28
api/core/workflow/nodes/trigger_plugin/entities.py
Normal file
28
api/core/workflow/nodes/trigger_plugin/entities.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class PluginTriggerData(BaseNodeData):
|
||||||
|
"""Plugin trigger node data"""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
desc: Optional[str] = None
|
||||||
|
plugin_id: str = Field(..., description="Plugin ID")
|
||||||
|
provider_id: str = Field(..., description="Provider ID")
|
||||||
|
trigger_name: str = Field(..., description="Trigger name")
|
||||||
|
subscription_id: str = Field(..., description="Subscription ID")
|
||||||
|
plugin_unique_identifier: str = Field(..., description="Plugin unique identifier")
|
||||||
|
parameters: dict[str, Any] = Field(default_factory=dict, description="Trigger parameters")
|
||||||
|
|
||||||
|
# Error handling
|
||||||
|
error_strategy: Optional[ErrorStrategy] = Field(
|
||||||
|
default=ErrorStrategy.FAIL_BRANCH, description="Error handling strategy"
|
||||||
|
)
|
||||||
|
retry_config: RetryConfig = Field(default_factory=lambda: RetryConfig(), description="Retry configuration")
|
||||||
|
default_value_dict: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="Default values for outputs when error occurs"
|
||||||
|
)
|
||||||
151
api/core/workflow/nodes/trigger_plugin/trigger_plugin_node.py
Normal file
151
api/core/workflow/nodes/trigger_plugin/trigger_plugin_node.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||||
|
from core.plugin.utils.http_parser import deserialize_request
|
||||||
|
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||||
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||||
|
|
||||||
|
from .entities import PluginTriggerData
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerPluginNode(BaseNode):
|
||||||
|
_node_type = NodeType.TRIGGER_PLUGIN
|
||||||
|
|
||||||
|
_node_data: PluginTriggerData
|
||||||
|
|
||||||
|
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||||
|
self._node_data = PluginTriggerData.model_validate(data)
|
||||||
|
|
||||||
|
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||||
|
return self._node_data.error_strategy
|
||||||
|
|
||||||
|
def _get_retry_config(self) -> RetryConfig:
|
||||||
|
return self._node_data.retry_config
|
||||||
|
|
||||||
|
def _get_title(self) -> str:
|
||||||
|
return self._node_data.title
|
||||||
|
|
||||||
|
def _get_description(self) -> Optional[str]:
|
||||||
|
return self._node_data.desc
|
||||||
|
|
||||||
|
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||||
|
return self._node_data.default_value_dict
|
||||||
|
|
||||||
|
def get_base_node_data(self) -> BaseNodeData:
|
||||||
|
return self._node_data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "plugin",
|
||||||
|
"config": {
|
||||||
|
"plugin_id": "",
|
||||||
|
"provider_id": "",
|
||||||
|
"trigger_name": "",
|
||||||
|
"subscription_id": "",
|
||||||
|
"parameters": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult:
|
||||||
|
"""
|
||||||
|
Run the plugin trigger node.
|
||||||
|
|
||||||
|
This node invokes the trigger to convert request data into events
|
||||||
|
and makes them available to downstream nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get trigger data passed when workflow was triggered
|
||||||
|
trigger_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||||
|
metadata = {
|
||||||
|
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||||
|
**trigger_inputs,
|
||||||
|
"provider_id": self._node_data.provider_id,
|
||||||
|
"trigger_name": self._node_data.trigger_name,
|
||||||
|
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request_id = trigger_inputs.get("request_id")
|
||||||
|
trigger_name = trigger_inputs.get("trigger_name", "")
|
||||||
|
subscription_id = trigger_inputs.get("subscription_id", "")
|
||||||
|
|
||||||
|
if not request_id or not subscription_id:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=trigger_inputs,
|
||||||
|
outputs={"error": "No request ID or subscription ID available"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
subscription: TriggerProviderSubscriptionApiEntity | None = TriggerProviderService.get_subscription_by_id(
|
||||||
|
tenant_id=self.tenant_id, subscription_id=subscription_id
|
||||||
|
)
|
||||||
|
if not subscription:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=trigger_inputs,
|
||||||
|
outputs={"error": f"Invalid subscription {subscription_id} not found"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=trigger_inputs,
|
||||||
|
outputs={"error": f"Failed to get subscription: {str(e)}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
request = deserialize_request(storage.load_once(f"triggers/{request_id}"))
|
||||||
|
parameters = self._node_data.parameters if hasattr(self, "_node_data") and self._node_data else {}
|
||||||
|
invoke_response = TriggerManager.invoke_trigger(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
provider_id=TriggerProviderID(subscription.provider),
|
||||||
|
trigger_name=trigger_name,
|
||||||
|
parameters=parameters,
|
||||||
|
credentials=subscription.credentials,
|
||||||
|
credential_type=subscription.credential_type,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
outputs = invoke_response.event.variables or {}
|
||||||
|
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=trigger_inputs, outputs=outputs)
|
||||||
|
except PluginInvokeError as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=trigger_inputs,
|
||||||
|
metadata=metadata,
|
||||||
|
error="An error occurred in the plugin, "
|
||||||
|
f"please contact the author of {subscription.provider} for help, "
|
||||||
|
f"error type: {e.get_error_type()}, "
|
||||||
|
f"error details: {e.get_error_message()}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=trigger_inputs,
|
||||||
|
metadata=metadata,
|
||||||
|
error=f"Failed to invoke trigger, error: {e.description}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=trigger_inputs,
|
||||||
|
metadata=metadata,
|
||||||
|
error=f"Failed to invoke trigger: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
3
api/core/workflow/nodes/trigger_schedule/__init__.py
Normal file
3
api/core/workflow/nodes/trigger_schedule/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
|
||||||
|
|
||||||
|
__all__ = ["TriggerScheduleNode"]
|
||||||
51
api/core/workflow/nodes/trigger_schedule/entities.py
Normal file
51
api/core/workflow/nodes/trigger_schedule/entities.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerScheduleNodeData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Trigger Schedule Node Data
|
||||||
|
"""
|
||||||
|
|
||||||
|
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||||
|
frequency: Optional[str] = Field(
|
||||||
|
default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly"
|
||||||
|
)
|
||||||
|
cron_expression: Optional[str] = Field(default=None, description="Cron expression for cron mode")
|
||||||
|
visual_config: Optional[dict] = Field(default=None, description="Visual configuration details")
|
||||||
|
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleConfig(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
cron_expression: str
|
||||||
|
timezone: str = "UTC"
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulePlanUpdate(BaseModel):
|
||||||
|
node_id: Optional[str] = None
|
||||||
|
cron_expression: Optional[str] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VisualConfig(BaseModel):
|
||||||
|
"""Visual configuration for schedule trigger"""
|
||||||
|
|
||||||
|
# For hourly frequency
|
||||||
|
on_minute: Optional[int] = Field(default=0, ge=0, le=59, description="Minute of the hour (0-59)")
|
||||||
|
|
||||||
|
# For daily, weekly, monthly frequencies
|
||||||
|
time: Optional[str] = Field(default="12:00 PM", description="Time in 12-hour format (e.g., '2:30 PM')")
|
||||||
|
|
||||||
|
# For weekly frequency
|
||||||
|
weekdays: Optional[list[Literal["sun", "mon", "tue", "wed", "thu", "fri", "sat"]]] = Field(
|
||||||
|
default=None, description="List of weekdays to run on"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For monthly frequency
|
||||||
|
monthly_days: Optional[list[Union[int, Literal["last"]]]] = Field(
|
||||||
|
default=None, description="Days of month to run on (1-31 or 'last')"
|
||||||
|
)
|
||||||
31
api/core/workflow/nodes/trigger_schedule/exc.py
Normal file
31
api/core/workflow/nodes/trigger_schedule/exc.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from core.workflow.nodes.base.exc import BaseNodeError
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleNodeError(BaseNodeError):
|
||||||
|
"""Base schedule node error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleNotFoundError(ScheduleNodeError):
|
||||||
|
"""Schedule not found error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleConfigError(ScheduleNodeError):
|
||||||
|
"""Schedule configuration error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleExecutionError(ScheduleNodeError):
|
||||||
|
"""Schedule execution error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TenantOwnerNotFoundError(ScheduleExecutionError):
|
||||||
|
"""Tenant owner not found error for schedule execution."""
|
||||||
|
|
||||||
|
pass
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||||
|
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerScheduleNode(BaseNode):
|
||||||
|
_node_type = NodeType.TRIGGER_SCHEDULE
|
||||||
|
|
||||||
|
_node_data: TriggerScheduleNodeData
|
||||||
|
|
||||||
|
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||||
|
self._node_data = TriggerScheduleNodeData(**data)
|
||||||
|
|
||||||
|
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||||
|
return self._node_data.error_strategy
|
||||||
|
|
||||||
|
def _get_retry_config(self) -> RetryConfig:
|
||||||
|
return self._node_data.retry_config
|
||||||
|
|
||||||
|
def _get_title(self) -> str:
|
||||||
|
return self._node_data.title
|
||||||
|
|
||||||
|
def _get_description(self) -> Optional[str]:
|
||||||
|
return self._node_data.desc
|
||||||
|
|
||||||
|
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||||
|
return self._node_data.default_value_dict
|
||||||
|
|
||||||
|
def get_base_node_data(self) -> BaseNodeData:
|
||||||
|
return self._node_data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "trigger-schedule",
|
||||||
|
"config": {
|
||||||
|
"mode": "visual",
|
||||||
|
"frequency": "weekly",
|
||||||
|
"visual_config": {"time": "11:30 AM", "on_minute": 0, "weekdays": ["sun"], "monthly_days": [1]},
|
||||||
|
"timezone": "UTC",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult:
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
node_outputs = {"current_time": current_time.isoformat()}
|
||||||
|
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
outputs=node_outputs,
|
||||||
|
)
|
||||||
3
api/core/workflow/nodes/trigger_webhook/__init__.py
Normal file
3
api/core/workflow/nodes/trigger_webhook/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .node import TriggerWebhookNode
|
||||||
|
|
||||||
|
__all__ = ["TriggerWebhookNode"]
|
||||||
79
api/core/workflow/nodes/trigger_webhook/entities.py
Normal file
79
api/core/workflow/nodes/trigger_webhook/entities.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class Method(StrEnum):
|
||||||
|
GET = "get"
|
||||||
|
POST = "post"
|
||||||
|
HEAD = "head"
|
||||||
|
PATCH = "patch"
|
||||||
|
PUT = "put"
|
||||||
|
DELETE = "delete"
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(StrEnum):
|
||||||
|
JSON = "application/json"
|
||||||
|
FORM_DATA = "multipart/form-data"
|
||||||
|
FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||||
|
TEXT = "text/plain"
|
||||||
|
BINARY = "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookParameter(BaseModel):
|
||||||
|
"""Parameter definition for headers, query params, or body."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
required: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookBodyParameter(BaseModel):
|
||||||
|
"""Body parameter with type information."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: Literal[
|
||||||
|
"string",
|
||||||
|
"number",
|
||||||
|
"boolean",
|
||||||
|
"object",
|
||||||
|
"array[string]",
|
||||||
|
"array[number]",
|
||||||
|
"array[boolean]",
|
||||||
|
"array[object]",
|
||||||
|
"file",
|
||||||
|
] = "string"
|
||||||
|
required: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Webhook Node Data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SyncMode(StrEnum):
|
||||||
|
SYNC = "async" # only support
|
||||||
|
|
||||||
|
method: Method = Method.GET
|
||||||
|
content_type: ContentType = Field(default=ContentType.JSON)
|
||||||
|
headers: Sequence[WebhookParameter] = Field(default_factory=list)
|
||||||
|
params: Sequence[WebhookParameter] = Field(default_factory=list) # query parameters
|
||||||
|
body: Sequence[WebhookBodyParameter] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("method", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def normalize_method(cls, v) -> str:
|
||||||
|
"""Normalize HTTP method to lowercase to support both uppercase and lowercase input."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.lower()
|
||||||
|
return v
|
||||||
|
|
||||||
|
status_code: int = 200 # Expected status code for response
|
||||||
|
response_body: str = "" # Template for response body
|
||||||
|
|
||||||
|
# Webhook specific fields (not from client data, set internally)
|
||||||
|
webhook_id: Optional[str] = None # Set when webhook trigger is created
|
||||||
|
timeout: int = 30 # Timeout in seconds to wait for webhook response
|
||||||
25
api/core/workflow/nodes/trigger_webhook/exc.py
Normal file
25
api/core/workflow/nodes/trigger_webhook/exc.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from core.workflow.nodes.base.exc import BaseNodeError
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookNodeError(BaseNodeError):
|
||||||
|
"""Base webhook node error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookTimeoutError(WebhookNodeError):
|
||||||
|
"""Webhook timeout error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookNotFoundError(WebhookNodeError):
|
||||||
|
"""Webhook not found error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookConfigError(WebhookNodeError):
|
||||||
|
"""Webhook configuration error."""
|
||||||
|
|
||||||
|
pass
|
||||||
126
api/core/workflow/nodes/trigger_webhook/node.py
Normal file
126
api/core/workflow/nodes/trigger_webhook/node.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||||
|
|
||||||
|
from .entities import ContentType, WebhookData
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerWebhookNode(BaseNode):
|
||||||
|
_node_type = NodeType.TRIGGER_WEBHOOK
|
||||||
|
|
||||||
|
_node_data: WebhookData
|
||||||
|
|
||||||
|
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||||
|
self._node_data = WebhookData.model_validate(data)
|
||||||
|
|
||||||
|
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||||
|
return self._node_data.error_strategy
|
||||||
|
|
||||||
|
def _get_retry_config(self) -> RetryConfig:
|
||||||
|
return self._node_data.retry_config
|
||||||
|
|
||||||
|
def _get_title(self) -> str:
|
||||||
|
return self._node_data.title
|
||||||
|
|
||||||
|
def _get_description(self) -> Optional[str]:
|
||||||
|
return self._node_data.desc
|
||||||
|
|
||||||
|
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||||
|
return self._node_data.default_value_dict
|
||||||
|
|
||||||
|
def get_base_node_data(self) -> BaseNodeData:
|
||||||
|
return self._node_data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "webhook",
|
||||||
|
"config": {
|
||||||
|
"method": "get",
|
||||||
|
"content_type": "application/json",
|
||||||
|
"headers": [],
|
||||||
|
"params": [],
|
||||||
|
"body": [],
|
||||||
|
"async_mode": True,
|
||||||
|
"status_code": 200,
|
||||||
|
"response_body": "",
|
||||||
|
"timeout": 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult:
|
||||||
|
"""
|
||||||
|
Run the webhook node.
|
||||||
|
|
||||||
|
Like the start node, this simply takes the webhook data from the variable pool
|
||||||
|
and makes it available to downstream nodes. The actual webhook handling
|
||||||
|
happens in the trigger controller.
|
||||||
|
"""
|
||||||
|
# Get webhook data from variable pool (injected by Celery task)
|
||||||
|
webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||||
|
|
||||||
|
# Extract webhook-specific outputs based on node configuration
|
||||||
|
outputs = self._extract_configured_outputs(webhook_inputs)
|
||||||
|
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs=webhook_inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_configured_outputs(self, webhook_inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Extract outputs based on node configuration from webhook inputs."""
|
||||||
|
outputs = {}
|
||||||
|
|
||||||
|
# Get the raw webhook data (should be injected by Celery task)
|
||||||
|
webhook_data = webhook_inputs.get("webhook_data", {})
|
||||||
|
|
||||||
|
# Extract configured headers (case-insensitive)
|
||||||
|
webhook_headers = webhook_data.get("headers", {})
|
||||||
|
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
|
||||||
|
|
||||||
|
for header in self._node_data.headers:
|
||||||
|
header_name = header.name
|
||||||
|
# Try exact match first, then case-insensitive match
|
||||||
|
value = webhook_headers.get(header_name) or webhook_headers_lower.get(header_name.lower())
|
||||||
|
outputs[header_name] = value
|
||||||
|
|
||||||
|
# Extract configured query parameters
|
||||||
|
for param in self._node_data.params:
|
||||||
|
param_name = param.name
|
||||||
|
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
|
||||||
|
|
||||||
|
# Extract configured body parameters
|
||||||
|
for body_param in self._node_data.body:
|
||||||
|
param_name = body_param.name
|
||||||
|
param_type = body_param.type
|
||||||
|
|
||||||
|
if self._node_data.content_type == ContentType.TEXT:
|
||||||
|
# For text/plain, the entire body is a single string parameter
|
||||||
|
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
|
||||||
|
continue
|
||||||
|
elif self._node_data.content_type == ContentType.BINARY:
|
||||||
|
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param_type == "file":
|
||||||
|
# Get File object (already processed by webhook controller)
|
||||||
|
file_obj = webhook_data.get("files", {}).get(param_name)
|
||||||
|
outputs[param_name] = file_obj
|
||||||
|
else:
|
||||||
|
# Get regular body parameter
|
||||||
|
outputs[param_name] = webhook_data.get("body", {}).get(param_name)
|
||||||
|
|
||||||
|
# Include raw webhook data for debugging/advanced use
|
||||||
|
outputs["_webhook_raw"] = webhook_data
|
||||||
|
|
||||||
|
return outputs
|
||||||
@@ -30,9 +30,41 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||||||
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
|
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
|
# Configure queues based on edition if not explicitly set
|
||||||
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||||
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}
|
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||||
|
# Cloud edition: separate queues for dataset and trigger tasks
|
||||||
|
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor"
|
||||||
|
else
|
||||||
|
# Community edition (SELF_HOSTED): dataset and workflow have separate queues
|
||||||
|
DEFAULT_QUEUES="dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Support for Kubernetes deployment with specific queue workers
|
||||||
|
# Environment variables that can be set:
|
||||||
|
# - CELERY_WORKER_QUEUES: Comma-separated list of queues (overrides CELERY_QUEUES)
|
||||||
|
# - CELERY_WORKER_CONCURRENCY: Number of worker processes (overrides CELERY_WORKER_AMOUNT)
|
||||||
|
# - CELERY_WORKER_POOL: Pool implementation (overrides CELERY_WORKER_CLASS)
|
||||||
|
|
||||||
|
if [[ -n "${CELERY_WORKER_QUEUES}" ]]; then
|
||||||
|
DEFAULT_QUEUES="${CELERY_WORKER_QUEUES}"
|
||||||
|
echo "Using CELERY_WORKER_QUEUES: ${DEFAULT_QUEUES}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "${CELERY_WORKER_CONCURRENCY}" ]]; then
|
||||||
|
CONCURRENCY_OPTION="-c ${CELERY_WORKER_CONCURRENCY}"
|
||||||
|
echo "Using CELERY_WORKER_CONCURRENCY: ${CELERY_WORKER_CONCURRENCY}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
WORKER_POOL="${CELERY_WORKER_POOL:-${CELERY_WORKER_CLASS:-gevent}}"
|
||||||
|
echo "Starting Celery worker with queues: ${DEFAULT_QUEUES}"
|
||||||
|
|
||||||
|
exec celery -A app.celery worker -P ${WORKER_POOL} $CONCURRENCY_OPTION \
|
||||||
|
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
|
||||||
|
-Q ${DEFAULT_QUEUES}
|
||||||
|
|
||||||
elif [[ "${MODE}" == "beat" ]]; then
|
elif [[ "${MODE}" == "beat" ]]; then
|
||||||
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
|
||||||
|
|||||||
@@ -4,8 +4,12 @@ from .create_document_index import handle
|
|||||||
from .create_installed_app_when_app_created import handle
|
from .create_installed_app_when_app_created import handle
|
||||||
from .create_site_record_when_app_created import handle
|
from .create_site_record_when_app_created import handle
|
||||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
||||||
|
from .sync_plugin_trigger_when_app_created import handle
|
||||||
|
from .sync_webhook_when_app_created import handle
|
||||||
|
from .sync_workflow_schedule_when_app_published import handle
|
||||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||||
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
||||||
|
from .update_app_triggers_when_app_published_workflow_updated import handle
|
||||||
|
|
||||||
# Consolidated handler replaces both deduct_quota_when_message_created and
|
# Consolidated handler replaces both deduct_quota_when_message_created and
|
||||||
# update_provider_last_used_at_when_message_created
|
# update_provider_last_used_at_when_message_created
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from events.app_event import app_draft_workflow_was_synced
|
||||||
|
from models.model import App, AppMode
|
||||||
|
from models.workflow import Workflow
|
||||||
|
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@app_draft_workflow_was_synced.connect
|
||||||
|
def handle(sender, synced_draft_workflow: Workflow, **kwargs):
|
||||||
|
"""
|
||||||
|
While creating a workflow or updating a workflow, we may need to sync
|
||||||
|
its plugin trigger relationships in DB.
|
||||||
|
"""
|
||||||
|
app: App = sender
|
||||||
|
if app.mode != AppMode.WORKFLOW.value:
|
||||||
|
# only handle workflow app, chatflow is not supported yet
|
||||||
|
return
|
||||||
|
|
||||||
|
WorkflowPluginTriggerService.sync_plugin_trigger_relationships(app, synced_draft_workflow)
|
||||||
22
api/events/event_handlers/sync_webhook_when_app_created.py
Normal file
22
api/events/event_handlers/sync_webhook_when_app_created.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from events.app_event import app_draft_workflow_was_synced
|
||||||
|
from models.model import App, AppMode
|
||||||
|
from models.workflow import Workflow
|
||||||
|
from services.webhook_service import WebhookService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@app_draft_workflow_was_synced.connect
|
||||||
|
def handle(sender, synced_draft_workflow: Workflow, **kwargs):
|
||||||
|
"""
|
||||||
|
While creating a workflow or updating a workflow, we may need to sync
|
||||||
|
its webhook relationships in DB.
|
||||||
|
"""
|
||||||
|
app: App = sender
|
||||||
|
if app.mode != AppMode.WORKFLOW.value:
|
||||||
|
# only handle workflow app, chatflow is not supported yet
|
||||||
|
return
|
||||||
|
|
||||||
|
WebhookService.sync_webhook_relationships(app, synced_draft_workflow)
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate
|
||||||
|
from events.app_event import app_published_workflow_was_updated
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import AppMode, Workflow, WorkflowSchedulePlan
|
||||||
|
from services.schedule_service import ScheduleService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@app_published_workflow_was_updated.connect
|
||||||
|
def handle(sender, **kwargs):
|
||||||
|
"""
|
||||||
|
Handle app published workflow update event to sync workflow_schedule_plans table.
|
||||||
|
|
||||||
|
When a workflow is published, this handler will:
|
||||||
|
1. Extract schedule trigger nodes from the workflow graph
|
||||||
|
2. Compare with existing workflow_schedule_plans records
|
||||||
|
3. Create/update/delete schedule plans as needed
|
||||||
|
"""
|
||||||
|
app = sender
|
||||||
|
if app.mode != AppMode.WORKFLOW.value:
|
||||||
|
return
|
||||||
|
|
||||||
|
published_workflow = kwargs.get("published_workflow")
|
||||||
|
published_workflow = cast(Workflow, published_workflow)
|
||||||
|
|
||||||
|
sync_schedule_from_workflow(tenant_id=app.tenant_id, app_id=app.id, workflow=published_workflow)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) -> Optional[WorkflowSchedulePlan]:
|
||||||
|
"""
|
||||||
|
Sync schedule plan from workflow graph configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
app_id: App ID
|
||||||
|
workflow: Published workflow instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated or created WorkflowSchedulePlan, or None if no schedule node
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
schedule_config = ScheduleService.extract_schedule_config(workflow)
|
||||||
|
|
||||||
|
existing_plan = session.scalar(
|
||||||
|
select(WorkflowSchedulePlan).where(
|
||||||
|
WorkflowSchedulePlan.tenant_id == tenant_id,
|
||||||
|
WorkflowSchedulePlan.app_id == app_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not schedule_config:
|
||||||
|
if existing_plan:
|
||||||
|
logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id)
|
||||||
|
ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id)
|
||||||
|
session.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
if existing_plan:
|
||||||
|
updates = SchedulePlanUpdate(
|
||||||
|
node_id=schedule_config.node_id,
|
||||||
|
cron_expression=schedule_config.cron_expression,
|
||||||
|
timezone=schedule_config.timezone,
|
||||||
|
)
|
||||||
|
updated_plan = ScheduleService.update_schedule(
|
||||||
|
session=session,
|
||||||
|
schedule_id=existing_plan.id,
|
||||||
|
updates=updates,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return updated_plan
|
||||||
|
else:
|
||||||
|
new_plan = ScheduleService.create_schedule(
|
||||||
|
session=session,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
config=schedule_config,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return new_plan
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.workflow.nodes import NodeType
|
||||||
|
from events.app_event import app_published_workflow_was_updated
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import AppMode, AppTrigger, AppTriggerStatus, Workflow
|
||||||
|
|
||||||
|
|
||||||
|
@app_published_workflow_was_updated.connect
|
||||||
|
def handle(sender, **kwargs):
|
||||||
|
"""
|
||||||
|
Handle app published workflow update event to sync app_triggers table.
|
||||||
|
|
||||||
|
When a workflow is published, this handler will:
|
||||||
|
1. Extract trigger nodes from the workflow graph
|
||||||
|
2. Compare with existing app_triggers records
|
||||||
|
3. Add new triggers and remove obsolete ones
|
||||||
|
"""
|
||||||
|
app = sender
|
||||||
|
if app.mode != AppMode.WORKFLOW.value:
|
||||||
|
return
|
||||||
|
|
||||||
|
published_workflow = kwargs.get("published_workflow")
|
||||||
|
published_workflow = cast(Workflow, published_workflow)
|
||||||
|
# Extract trigger info from workflow
|
||||||
|
trigger_infos = get_trigger_infos_from_workflow(published_workflow)
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# Get existing app triggers
|
||||||
|
existing_triggers = (
|
||||||
|
session.execute(
|
||||||
|
select(AppTrigger).where(AppTrigger.tenant_id == app.tenant_id, AppTrigger.app_id == app.id)
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert existing triggers to dict for easy lookup
|
||||||
|
existing_triggers_map = {trigger.node_id: trigger for trigger in existing_triggers}
|
||||||
|
|
||||||
|
# Get current and new node IDs
|
||||||
|
existing_node_ids = set(existing_triggers_map.keys())
|
||||||
|
new_node_ids = {info["node_id"] for info in trigger_infos}
|
||||||
|
|
||||||
|
# Calculate changes
|
||||||
|
added_node_ids = new_node_ids - existing_node_ids
|
||||||
|
removed_node_ids = existing_node_ids - new_node_ids
|
||||||
|
|
||||||
|
# Remove obsolete triggers
|
||||||
|
for node_id in removed_node_ids:
|
||||||
|
session.delete(existing_triggers_map[node_id])
|
||||||
|
|
||||||
|
for trigger_info in trigger_infos:
|
||||||
|
node_id = trigger_info["node_id"]
|
||||||
|
|
||||||
|
if node_id in added_node_ids:
|
||||||
|
# Create new trigger
|
||||||
|
app_trigger = AppTrigger(
|
||||||
|
tenant_id=app.tenant_id,
|
||||||
|
app_id=app.id,
|
||||||
|
trigger_type=trigger_info["node_type"],
|
||||||
|
title=trigger_info["node_title"],
|
||||||
|
node_id=node_id,
|
||||||
|
provider_name=trigger_info.get("node_provider_name", ""),
|
||||||
|
status=AppTriggerStatus.DISABLED,
|
||||||
|
)
|
||||||
|
session.add(app_trigger)
|
||||||
|
elif node_id in existing_node_ids:
|
||||||
|
# Update existing trigger if needed
|
||||||
|
existing_trigger = existing_triggers_map[node_id]
|
||||||
|
new_title = trigger_info["node_title"]
|
||||||
|
if new_title and existing_trigger.title != new_title:
|
||||||
|
existing_trigger.title = new_title
|
||||||
|
session.add(existing_trigger)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Extract trigger node information from the workflow graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of trigger info dictionaries containing:
|
||||||
|
- node_type: The type of the trigger node ('trigger-webhook', 'trigger-schedule', 'trigger-plugin')
|
||||||
|
- node_id: The node ID in the workflow
|
||||||
|
- node_title: The title of the node
|
||||||
|
- node_provider_name: The name of the node's provider, only for plugin
|
||||||
|
"""
|
||||||
|
graph = published_workflow.graph_dict
|
||||||
|
if not graph:
|
||||||
|
return []
|
||||||
|
|
||||||
|
nodes = graph.get("nodes", [])
|
||||||
|
trigger_types = {NodeType.TRIGGER_WEBHOOK.value, NodeType.TRIGGER_SCHEDULE.value, NodeType.TRIGGER_PLUGIN.value}
|
||||||
|
|
||||||
|
trigger_infos = [
|
||||||
|
{
|
||||||
|
"node_type": node.get("data", {}).get("type"),
|
||||||
|
"node_id": node.get("id"),
|
||||||
|
"node_title": node.get("data", {}).get("title"),
|
||||||
|
"node_provider_name": node.get("data", {}).get("provider_name"),
|
||||||
|
}
|
||||||
|
for node in nodes
|
||||||
|
if node.get("data", {}).get("type") in trigger_types
|
||||||
|
]
|
||||||
|
|
||||||
|
return trigger_infos
|
||||||
@@ -12,6 +12,7 @@ def init_app(app: DifyApp):
|
|||||||
from controllers.inner_api import bp as inner_api_bp
|
from controllers.inner_api import bp as inner_api_bp
|
||||||
from controllers.mcp import bp as mcp_bp
|
from controllers.mcp import bp as mcp_bp
|
||||||
from controllers.service_api import bp as service_api_bp
|
from controllers.service_api import bp as service_api_bp
|
||||||
|
from controllers.trigger import bp as trigger_bp
|
||||||
from controllers.web import bp as web_bp
|
from controllers.web import bp as web_bp
|
||||||
|
|
||||||
CORS(
|
CORS(
|
||||||
@@ -50,3 +51,11 @@ def init_app(app: DifyApp):
|
|||||||
|
|
||||||
app.register_blueprint(inner_api_bp)
|
app.register_blueprint(inner_api_bp)
|
||||||
app.register_blueprint(mcp_bp)
|
app.register_blueprint(mcp_bp)
|
||||||
|
|
||||||
|
# Register trigger blueprint with CORS for webhook calls
|
||||||
|
CORS(
|
||||||
|
trigger_bp,
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH", "HEAD"],
|
||||||
|
)
|
||||||
|
app.register_blueprint(trigger_bp)
|
||||||
|
|||||||
@@ -96,7 +96,9 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
celery_app.set_default()
|
celery_app.set_default()
|
||||||
app.extensions["celery"] = celery_app
|
app.extensions["celery"] = celery_app
|
||||||
|
|
||||||
imports = []
|
imports = [
|
||||||
|
"tasks.async_workflow_tasks", # trigger workers
|
||||||
|
]
|
||||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||||
|
|
||||||
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
|
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
|
||||||
@@ -158,6 +160,12 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
|
"task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise",
|
||||||
"schedule": crontab(minute="0", hour="2"),
|
"schedule": crontab(minute="0", hour="2"),
|
||||||
}
|
}
|
||||||
|
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||||
|
imports.append("schedule.workflow_schedule_task")
|
||||||
|
beat_schedule["workflow_schedule_task"] = {
|
||||||
|
"task": "schedule.workflow_schedule_task.poll_workflow_schedules",
|
||||||
|
"schedule": timedelta(minutes=dify_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL),
|
||||||
|
}
|
||||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||||
|
|
||||||
return celery_app
|
return celery_app
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ def init_app(app: DifyApp):
|
|||||||
reset_encrypt_key_pair,
|
reset_encrypt_key_pair,
|
||||||
reset_password,
|
reset_password,
|
||||||
setup_system_tool_oauth_client,
|
setup_system_tool_oauth_client,
|
||||||
|
setup_system_trigger_oauth_client,
|
||||||
upgrade_db,
|
upgrade_db,
|
||||||
vdb_migrate,
|
vdb_migrate,
|
||||||
)
|
)
|
||||||
@@ -43,6 +44,7 @@ def init_app(app: DifyApp):
|
|||||||
clear_orphaned_file_records,
|
clear_orphaned_file_records,
|
||||||
remove_orphaned_files_on_storage,
|
remove_orphaned_files_on_storage,
|
||||||
setup_system_tool_oauth_client,
|
setup_system_tool_oauth_client,
|
||||||
|
setup_system_trigger_oauth_client,
|
||||||
cleanup_orphaned_draft_variables,
|
cleanup_orphaned_draft_variables,
|
||||||
]
|
]
|
||||||
for cmd in cmds_to_register:
|
for cmd in cmds_to_register:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ workflow_run_for_log_fields = {
|
|||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"version": fields.String,
|
"version": fields.String,
|
||||||
"status": fields.String,
|
"status": fields.String,
|
||||||
|
"triggered_from": fields.String,
|
||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
"elapsed_time": fields.Float,
|
"elapsed_time": fields.Float,
|
||||||
"total_tokens": fields.Integer,
|
"total_tokens": fields.Integer,
|
||||||
|
|||||||
25
api/fields/workflow_trigger_fields.py
Normal file
25
api/fields/workflow_trigger_fields.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from flask_restx import fields
|
||||||
|
|
||||||
|
trigger_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"trigger_type": fields.String,
|
||||||
|
"title": fields.String,
|
||||||
|
"node_id": fields.String,
|
||||||
|
"provider_name": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"status": fields.String,
|
||||||
|
"created_at": fields.DateTime(dt_format="iso8601"),
|
||||||
|
"updated_at": fields.DateTime(dt_format="iso8601"),
|
||||||
|
}
|
||||||
|
|
||||||
|
triggers_list_fields = {"data": fields.List(fields.Nested(trigger_fields))}
|
||||||
|
|
||||||
|
|
||||||
|
webhook_trigger_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"webhook_id": fields.String,
|
||||||
|
"webhook_url": fields.String,
|
||||||
|
"webhook_debug_url": fields.String,
|
||||||
|
"node_id": fields.String,
|
||||||
|
"created_at": fields.DateTime(dt_format="iso8601"),
|
||||||
|
}
|
||||||
97
api/libs/schedule_utils.py
Normal file
97
api/libs/schedule_utils.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
from croniter import croniter
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_next_run_at(
|
||||||
|
cron_expression: str,
|
||||||
|
timezone: str,
|
||||||
|
base_time: Optional[datetime] = None,
|
||||||
|
) -> datetime:
|
||||||
|
"""
|
||||||
|
Calculate the next run time for a cron expression in a specific timezone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cron_expression: Cron expression string (supports croniter extensions like 'L')
|
||||||
|
timezone: Timezone string (e.g., 'UTC', 'America/New_York')
|
||||||
|
base_time: Base time to calculate from (defaults to current UTC time)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next run time in UTC
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Supports croniter's extended syntax including:
|
||||||
|
- 'L' for last day of month
|
||||||
|
- Standard 5-field cron expressions
|
||||||
|
"""
|
||||||
|
|
||||||
|
tz = pytz.timezone(timezone)
|
||||||
|
|
||||||
|
if base_time is None:
|
||||||
|
base_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
base_time_tz = base_time.astimezone(tz)
|
||||||
|
cron = croniter(cron_expression, base_time_tz)
|
||||||
|
next_run_tz = cron.get_next(datetime)
|
||||||
|
next_run_utc = next_run_tz.astimezone(UTC)
|
||||||
|
|
||||||
|
return next_run_utc
|
||||||
|
|
||||||
|
|
||||||
|
def convert_12h_to_24h(time_str: str) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Parse 12-hour time format to 24-hour format for cron compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_str: Time string in format "HH:MM AM/PM" (e.g., "12:30 PM")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (hour, minute) in 24-hour format
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If time string format is invalid or values are out of range
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- "12:00 AM" -> (0, 0) # Midnight
|
||||||
|
- "12:00 PM" -> (12, 0) # Noon
|
||||||
|
- "1:30 PM" -> (13, 30)
|
||||||
|
- "11:59 PM" -> (23, 59)
|
||||||
|
"""
|
||||||
|
if not time_str or not time_str.strip():
|
||||||
|
raise ValueError("Time string cannot be empty")
|
||||||
|
|
||||||
|
parts = time_str.strip().split()
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Invalid time format: '{time_str}'. Expected 'HH:MM AM/PM'")
|
||||||
|
|
||||||
|
time_part, period = parts
|
||||||
|
period = period.upper()
|
||||||
|
|
||||||
|
if period not in ["AM", "PM"]:
|
||||||
|
raise ValueError(f"Invalid period: '{period}'. Must be 'AM' or 'PM'")
|
||||||
|
|
||||||
|
time_parts = time_part.split(":")
|
||||||
|
if len(time_parts) != 2:
|
||||||
|
raise ValueError(f"Invalid time format: '{time_part}'. Expected 'HH:MM'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hour = int(time_parts[0])
|
||||||
|
minute = int(time_parts[1])
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Invalid time values: {e}")
|
||||||
|
|
||||||
|
if hour < 1 or hour > 12:
|
||||||
|
raise ValueError(f"Invalid hour: {hour}. Must be between 1 and 12")
|
||||||
|
|
||||||
|
if minute < 0 or minute > 59:
|
||||||
|
raise ValueError(f"Invalid minute: {minute}. Must be between 0 and 59")
|
||||||
|
|
||||||
|
# Handle 12-hour to 24-hour edge cases
|
||||||
|
if period == "PM" and hour != 12:
|
||||||
|
hour += 12
|
||||||
|
elif period == "AM" and hour == 12:
|
||||||
|
hour = 0
|
||||||
|
|
||||||
|
return hour, minute
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
"""Add workflow trigger logs table
|
||||||
|
|
||||||
|
Revision ID: 4558cfabe44e
|
||||||
|
Revises: 0e154742a5fa
|
||||||
|
Create Date: 2025-08-23 20:38:20.059323
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '4558cfabe44e'
|
||||||
|
down_revision = '8d289573e1da'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('workflow_trigger_logs',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('root_node_id', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('trigger_type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('trigger_data', sa.Text(), nullable=False),
|
||||||
|
sa.Column('inputs', sa.Text(), nullable=False),
|
||||||
|
sa.Column('outputs', sa.Text(), nullable=True),
|
||||||
|
sa.Column('status', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('error', sa.Text(), nullable=True),
|
||||||
|
sa.Column('queue_name', sa.String(length=100), nullable=False),
|
||||||
|
sa.Column('celery_task_id', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column('retry_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('elapsed_time', sa.Float(), nullable=True),
|
||||||
|
sa.Column('total_tokens', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('created_by_role', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_by', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('triggered_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('finished_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
|
||||||
|
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
|
||||||
|
batch_op.create_index('workflow_trigger_log_tenant_app_idx', ['tenant_id', 'app_id'], unique=False)
|
||||||
|
batch_op.create_index('workflow_trigger_log_workflow_id_idx', ['workflow_id'], unique=False)
|
||||||
|
batch_op.create_index('workflow_trigger_log_workflow_run_idx', ['workflow_run_id'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('workflow_trigger_log_workflow_run_idx')
|
||||||
|
batch_op.drop_index('workflow_trigger_log_workflow_id_idx')
|
||||||
|
batch_op.drop_index('workflow_trigger_log_tenant_app_idx')
|
||||||
|
batch_op.drop_index('workflow_trigger_log_status_idx')
|
||||||
|
batch_op.drop_index('workflow_trigger_log_created_at_idx')
|
||||||
|
|
||||||
|
op.drop_table('workflow_trigger_logs')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""Add workflow webhook table
|
||||||
|
|
||||||
|
Revision ID: 5871f634954d
|
||||||
|
Revises: fa8b0fa6f407
|
||||||
|
Create Date: 2025-08-23 20:39:20.704501
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '5871f634954d'
|
||||||
|
down_revision = '4558cfabe44e'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('workflow_webhook_triggers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('node_id', sa.String(length=64), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('webhook_id', sa.String(length=24), nullable=False),
|
||||||
|
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='workflow_webhook_trigger_pkey'),
|
||||||
|
sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'),
|
||||||
|
sa.UniqueConstraint('webhook_id', name='uniq_webhook_id')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('workflow_webhook_trigger_tenant_idx')
|
||||||
|
|
||||||
|
op.drop_table('workflow_webhook_triggers')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""Add app triggers table
|
||||||
|
|
||||||
|
Revision ID: 9ee7d347f4c1
|
||||||
|
Revises: 5871f634954d
|
||||||
|
Create Date: 2025-08-27 17:33:30.082812
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '9ee7d347f4c1'
|
||||||
|
down_revision = '5871f634954d'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('app_triggers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('node_id', sa.String(length=64), nullable=False),
|
||||||
|
sa.Column('trigger_type', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('title', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('provider_name', sa.String(length=255), server_default='', nullable=True),
|
||||||
|
sa.Column('status', sa.String(length=50), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='app_trigger_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('app_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('app_trigger_tenant_app_idx', ['tenant_id', 'app_id'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('app_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('app_trigger_tenant_app_idx')
|
||||||
|
|
||||||
|
op.drop_table('app_triggers')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""Add workflow schedule plan table
|
||||||
|
|
||||||
|
Revision ID: c19938f630b6
|
||||||
|
Revises: 9ee7d347f4c1
|
||||||
|
Create Date: 2025-08-28 20:52:41.300028
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'c19938f630b6'
|
||||||
|
down_revision = '875c659da2f8'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('workflow_schedule_plans',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('node_id', sa.String(length=64), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('cron_expression', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('timezone', sa.String(length=64), nullable=False),
|
||||||
|
sa.Column('next_run_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'),
|
||||||
|
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('workflow_schedule_plan_next_idx')
|
||||||
|
|
||||||
|
op.drop_table('workflow_schedule_plans')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
"""plugin_trigger
|
||||||
|
|
||||||
|
Revision ID: 132392a2635f
|
||||||
|
Revises: 9ee7d347f4c1
|
||||||
|
Create Date: 2025-09-03 15:00:57.326868
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '132392a2635f'
|
||||||
|
down_revision = '9ee7d347f4c1'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('trigger_oauth_system_clients',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='trigger_oauth_system_client_pkey'),
|
||||||
|
sa.UniqueConstraint('plugin_id', 'provider', name='trigger_oauth_system_client_plugin_id_provider_idx')
|
||||||
|
)
|
||||||
|
op.create_table('trigger_oauth_tenant_clients',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||||
|
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client')
|
||||||
|
)
|
||||||
|
op.create_table('trigger_subscriptions',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False, comment='Subscription instance name'),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider_id', sa.String(length=255), nullable=False, comment='Provider identifier (e.g., plugin_id/provider_name)'),
|
||||||
|
sa.Column('endpoint_id', sa.String(length=255), nullable=False, comment='Subscription endpoint'),
|
||||||
|
sa.Column('parameters', sa.JSON(), nullable=False, comment='Subscription parameters JSON'),
|
||||||
|
sa.Column('properties', sa.JSON(), nullable=False, comment='Subscription properties JSON'),
|
||||||
|
sa.Column('credentials', sa.JSON(), nullable=False, comment='Subscription credentials JSON'),
|
||||||
|
sa.Column('credential_type', sa.String(length=50), nullable=False, comment='oauth or api_key'),
|
||||||
|
sa.Column('credential_expires_at', sa.Integer(), nullable=False, comment='OAuth token expiration timestamp, -1 for never'),
|
||||||
|
sa.Column('expires_at', sa.Integer(), nullable=False, comment='Subscription instance expiration timestamp, -1 for never'),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True)
|
||||||
|
batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False)
|
||||||
|
batch_op.create_index('idx_trigger_providers_tenant_provider', ['tenant_id', 'provider_id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('workflow_plugin_triggers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||||
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('node_id', sa.String(length=64), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider_id', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('trigger_id', sa.String(length=510), nullable=False),
|
||||||
|
sa.Column('triggered_by', sa.String(length=16), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'),
|
||||||
|
sa.UniqueConstraint('app_id', 'node_id', 'triggered_by', name='uniq_plugin_node'),
|
||||||
|
sa.UniqueConstraint('trigger_id', 'node_id', name='uniq_trigger_node')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('workflow_plugin_trigger_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
batch_op.create_index('workflow_plugin_trigger_trigger_idx', ['trigger_id'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('workflow_plugin_trigger_trigger_idx')
|
||||||
|
batch_op.drop_index('workflow_plugin_trigger_tenant_idx')
|
||||||
|
|
||||||
|
op.drop_table('workflow_plugin_triggers')
|
||||||
|
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('idx_trigger_providers_tenant_provider')
|
||||||
|
batch_op.drop_index('idx_trigger_providers_tenant_endpoint')
|
||||||
|
batch_op.drop_index('idx_trigger_providers_endpoint')
|
||||||
|
|
||||||
|
op.drop_table('trigger_subscriptions')
|
||||||
|
op.drop_table('trigger_oauth_tenant_clients')
|
||||||
|
op.drop_table('trigger_oauth_system_clients')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
"""plugin_trigger_workflow
|
||||||
|
|
||||||
|
Revision ID: 86f068bf56fb
|
||||||
|
Revises: 132392a2635f
|
||||||
|
Create Date: 2025-09-04 12:12:44.661875
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '86f068bf56fb'
|
||||||
|
down_revision = '132392a2635f'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('subscription_id', sa.String(length=255), nullable=False))
|
||||||
|
batch_op.alter_column('provider_id',
|
||||||
|
existing_type=sa.VARCHAR(length=255),
|
||||||
|
type_=sa.String(length=512),
|
||||||
|
existing_nullable=False)
|
||||||
|
batch_op.alter_column('trigger_id',
|
||||||
|
existing_type=sa.VARCHAR(length=510),
|
||||||
|
type_=sa.String(length=255),
|
||||||
|
existing_nullable=False)
|
||||||
|
batch_op.drop_constraint(batch_op.f('uniq_plugin_node'), type_='unique')
|
||||||
|
batch_op.drop_constraint(batch_op.f('uniq_trigger_node'), type_='unique')
|
||||||
|
batch_op.drop_index(batch_op.f('workflow_plugin_trigger_tenant_idx'))
|
||||||
|
batch_op.drop_index(batch_op.f('workflow_plugin_trigger_trigger_idx'))
|
||||||
|
batch_op.create_unique_constraint('uniq_app_node_subscription', ['app_id', 'node_id'])
|
||||||
|
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id'], unique=False)
|
||||||
|
batch_op.drop_column('triggered_by')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('triggered_by', sa.VARCHAR(length=16), autoincrement=False, nullable=False))
|
||||||
|
batch_op.drop_index('workflow_plugin_trigger_tenant_subscription_idx')
|
||||||
|
batch_op.drop_constraint('uniq_app_node_subscription', type_='unique')
|
||||||
|
batch_op.create_index(batch_op.f('workflow_plugin_trigger_trigger_idx'), ['trigger_id'], unique=False)
|
||||||
|
batch_op.create_index(batch_op.f('workflow_plugin_trigger_tenant_idx'), ['tenant_id'], unique=False)
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('uniq_trigger_node'), ['trigger_id', 'node_id'], postgresql_nulls_not_distinct=False)
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('uniq_plugin_node'), ['app_id', 'node_id', 'triggered_by'], postgresql_nulls_not_distinct=False)
|
||||||
|
batch_op.alter_column('trigger_id',
|
||||||
|
existing_type=sa.String(length=255),
|
||||||
|
type_=sa.VARCHAR(length=510),
|
||||||
|
existing_nullable=False)
|
||||||
|
batch_op.alter_column('provider_id',
|
||||||
|
existing_type=sa.String(length=512),
|
||||||
|
type_=sa.VARCHAR(length=255),
|
||||||
|
existing_nullable=False)
|
||||||
|
batch_op.drop_column('subscription_id')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""plugin_trigger_idx
|
||||||
|
|
||||||
|
Revision ID: 875c659da2f8
|
||||||
|
Revises: 86f068bf56fb
|
||||||
|
Create Date: 2025-09-05 15:51:08.635283
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '875c659da2f8'
|
||||||
|
down_revision = '86f068bf56fb'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('trigger_name', sa.String(length=255), nullable=False))
|
||||||
|
batch_op.drop_index(batch_op.f('workflow_plugin_trigger_tenant_subscription_idx'))
|
||||||
|
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'trigger_name'], unique=False)
|
||||||
|
batch_op.drop_column('trigger_id')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('trigger_id', sa.VARCHAR(length=255), autoincrement=False, nullable=False))
|
||||||
|
batch_op.drop_index('workflow_plugin_trigger_tenant_subscription_idx')
|
||||||
|
batch_op.create_index(batch_op.f('workflow_plugin_trigger_tenant_subscription_idx'), ['tenant_id', 'subscription_id'], unique=False)
|
||||||
|
batch_op.drop_column('trigger_name')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -79,8 +79,12 @@ from .tools import (
|
|||||||
ToolModelInvoke,
|
ToolModelInvoke,
|
||||||
WorkflowToolProvider,
|
WorkflowToolProvider,
|
||||||
)
|
)
|
||||||
|
from .trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerSubscription
|
||||||
from .web import PinnedConversation, SavedMessage
|
from .web import PinnedConversation, SavedMessage
|
||||||
from .workflow import (
|
from .workflow import (
|
||||||
|
AppTrigger,
|
||||||
|
AppTriggerStatus,
|
||||||
|
AppTriggerType,
|
||||||
ConversationVariable,
|
ConversationVariable,
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowAppLog,
|
WorkflowAppLog,
|
||||||
@@ -88,6 +92,7 @@ from .workflow import (
|
|||||||
WorkflowNodeExecutionModel,
|
WorkflowNodeExecutionModel,
|
||||||
WorkflowNodeExecutionTriggeredFrom,
|
WorkflowNodeExecutionTriggeredFrom,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
|
WorkflowSchedulePlan,
|
||||||
WorkflowType,
|
WorkflowType,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -104,9 +109,12 @@ __all__ = [
|
|||||||
"AppAnnotationHitHistory",
|
"AppAnnotationHitHistory",
|
||||||
"AppAnnotationSetting",
|
"AppAnnotationSetting",
|
||||||
"AppDatasetJoin",
|
"AppDatasetJoin",
|
||||||
"AppMCPServer", # Added
|
"AppMCPServer",
|
||||||
"AppMode",
|
"AppMode",
|
||||||
"AppModelConfig",
|
"AppModelConfig",
|
||||||
|
"AppTrigger",
|
||||||
|
"AppTriggerStatus",
|
||||||
|
"AppTriggerType",
|
||||||
"BuiltinToolProvider",
|
"BuiltinToolProvider",
|
||||||
"CeleryTask",
|
"CeleryTask",
|
||||||
"CeleryTaskSet",
|
"CeleryTaskSet",
|
||||||
@@ -165,6 +173,9 @@ __all__ = [
|
|||||||
"ToolLabelBinding",
|
"ToolLabelBinding",
|
||||||
"ToolModelInvoke",
|
"ToolModelInvoke",
|
||||||
"TraceAppConfig",
|
"TraceAppConfig",
|
||||||
|
"TriggerOAuthSystemClient",
|
||||||
|
"TriggerOAuthTenantClient",
|
||||||
|
"TriggerSubscription",
|
||||||
"UploadFile",
|
"UploadFile",
|
||||||
"UserFrom",
|
"UserFrom",
|
||||||
"Whitelist",
|
"Whitelist",
|
||||||
@@ -175,6 +186,7 @@ __all__ = [
|
|||||||
"WorkflowNodeExecutionTriggeredFrom",
|
"WorkflowNodeExecutionTriggeredFrom",
|
||||||
"WorkflowRun",
|
"WorkflowRun",
|
||||||
"WorkflowRunTriggeredFrom",
|
"WorkflowRunTriggeredFrom",
|
||||||
|
"WorkflowSchedulePlan",
|
||||||
"WorkflowToolProvider",
|
"WorkflowToolProvider",
|
||||||
"WorkflowType",
|
"WorkflowType",
|
||||||
"db",
|
"db",
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ class UserFrom(StrEnum):
|
|||||||
|
|
||||||
class WorkflowRunTriggeredFrom(StrEnum):
|
class WorkflowRunTriggeredFrom(StrEnum):
|
||||||
DEBUGGING = "debugging"
|
DEBUGGING = "debugging"
|
||||||
APP_RUN = "app-run"
|
APP_RUN = "app-run" # webapp / service api
|
||||||
|
WEBHOOK = "webhook"
|
||||||
|
SCHEDULE = "schedule"
|
||||||
|
PLUGIN = "plugin"
|
||||||
|
|
||||||
|
|
||||||
class DraftVariableType(StrEnum):
|
class DraftVariableType(StrEnum):
|
||||||
|
|||||||
139
api/models/trigger.py
Normal file
139
api/models/trigger.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import DateTime, Index, Integer, String, UniqueConstraint, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||||
|
from core.trigger.entities.entities import Subscription
|
||||||
|
from core.trigger.utils.endpoint import parse_endpoint_id
|
||||||
|
from models.base import Base
|
||||||
|
from models.types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscription(Base):
|
||||||
|
"""
|
||||||
|
Trigger provider model for managing credentials
|
||||||
|
Supports multiple credential instances per provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "trigger_subscriptions"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="trigger_provider_pkey"),
|
||||||
|
Index("idx_trigger_providers_tenant_provider", "tenant_id", "provider_id"),
|
||||||
|
# Primary index for O(1) lookup by endpoint
|
||||||
|
Index("idx_trigger_providers_endpoint", "endpoint_id", unique=True),
|
||||||
|
# Composite index for tenant-specific queries (optional, kept for compatibility)
|
||||||
|
Index("idx_trigger_providers_tenant_endpoint", "tenant_id", "endpoint_id"),
|
||||||
|
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription instance name")
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
provider_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||||
|
)
|
||||||
|
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||||
|
parameters: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||||
|
properties: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||||
|
|
||||||
|
credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription credentials JSON")
|
||||||
|
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||||
|
credential_expires_at: Mapped[int] = mapped_column(
|
||||||
|
Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never"
|
||||||
|
)
|
||||||
|
expires_at: Mapped[int] = mapped_column(
|
||||||
|
Integer, default=-1, comment="Subscription instance expiration timestamp, -1 for never"
|
||||||
|
)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
server_onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_credential_expired(self) -> bool:
|
||||||
|
"""Check if credential is expired"""
|
||||||
|
if self.credential_expires_at == -1:
|
||||||
|
return False
|
||||||
|
# Check if token expires in next 3 minutes
|
||||||
|
return (self.credential_expires_at - 180) < int(time.time())
|
||||||
|
|
||||||
|
def to_entity(self) -> Subscription:
|
||||||
|
return Subscription(
|
||||||
|
expires_at=self.expires_at,
|
||||||
|
endpoint=parse_endpoint_id(self.endpoint_id),
|
||||||
|
properties=self.properties,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_api_entity(self) -> TriggerProviderSubscriptionApiEntity:
|
||||||
|
return TriggerProviderSubscriptionApiEntity(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
provider=self.provider_id,
|
||||||
|
endpoint=parse_endpoint_id(self.endpoint_id),
|
||||||
|
parameters=self.parameters,
|
||||||
|
properties=self.properties,
|
||||||
|
credential_type=CredentialType(self.credential_type),
|
||||||
|
credentials=self.credentials,
|
||||||
|
workflows_in_use=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# system level trigger oauth client params
|
||||||
|
class TriggerOAuthSystemClient(Base):
|
||||||
|
__tablename__ = "trigger_oauth_system_clients"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="trigger_oauth_system_client_pkey"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
|
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
# oauth params of the trigger provider
|
||||||
|
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
server_onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# tenant level trigger oauth client params (client_id, client_secret, etc.)
|
||||||
|
class TriggerOAuthTenantClient(Base):
|
||||||
|
__tablename__ = "trigger_oauth_tenant_clients"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"),
|
||||||
|
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
|
# tenant id
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||||
|
# oauth params of the trigger provider
|
||||||
|
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
server_onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def oauth_params(self) -> dict:
|
||||||
|
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
@@ -289,6 +289,54 @@ class Workflow(Base):
|
|||||||
def features_dict(self) -> dict[str, Any]:
|
def features_dict(self) -> dict[str, Any]:
|
||||||
return json.loads(self.features) if self.features else {}
|
return json.loads(self.features) if self.features else {}
|
||||||
|
|
||||||
|
def walk_nodes(
|
||||||
|
self, specific_node_type: NodeType | None = None
|
||||||
|
) -> Generator[tuple[str, Mapping[str, Any]], None, None]:
|
||||||
|
"""
|
||||||
|
Walk through the workflow nodes, yield each node configuration.
|
||||||
|
|
||||||
|
Each node configuration is a tuple containing the node's id and the node's properties.
|
||||||
|
|
||||||
|
Node properties example:
|
||||||
|
{
|
||||||
|
"type": "llm",
|
||||||
|
"title": "LLM",
|
||||||
|
"desc": "",
|
||||||
|
"variables": [],
|
||||||
|
"model":
|
||||||
|
{
|
||||||
|
"provider": "langgenius/openai/openai",
|
||||||
|
"name": "gpt-4",
|
||||||
|
"mode": "chat",
|
||||||
|
"completion_params": { "temperature": 0.7 },
|
||||||
|
},
|
||||||
|
"prompt_template": [{ "role": "system", "text": "" }],
|
||||||
|
"context": { "enabled": false, "variable_selector": [] },
|
||||||
|
"vision": { "enabled": false },
|
||||||
|
"memory":
|
||||||
|
{
|
||||||
|
"window": { "enabled": false, "size": 10 },
|
||||||
|
"query_prompt_template": "{{#sys.query#}}\n\n{{#sys.files#}}",
|
||||||
|
"role_prefix": { "user": "", "assistant": "" },
|
||||||
|
},
|
||||||
|
"selected": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
For specific node type, refer to `core.workflow.nodes`
|
||||||
|
"""
|
||||||
|
graph_dict = self.graph_dict
|
||||||
|
if "nodes" not in graph_dict:
|
||||||
|
raise WorkflowDataError("nodes not found in workflow graph")
|
||||||
|
|
||||||
|
if specific_node_type:
|
||||||
|
yield from (
|
||||||
|
(node["id"], node["data"])
|
||||||
|
for node in graph_dict["nodes"]
|
||||||
|
if node["data"]["type"] == specific_node_type.value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield from ((node["id"], node["data"]) for node in graph_dict["nodes"])
|
||||||
|
|
||||||
def user_input_form(self, to_old_structure: bool = False) -> list:
|
def user_input_form(self, to_old_structure: bool = False) -> list:
|
||||||
# get start node from graph
|
# get start node from graph
|
||||||
if not self.graph:
|
if not self.graph:
|
||||||
@@ -1263,3 +1311,320 @@ class WorkflowDraftVariable(Base):
|
|||||||
|
|
||||||
def is_system_variable_editable(name: str) -> bool:
|
def is_system_variable_editable(name: str) -> bool:
|
||||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowTriggerStatus(StrEnum):
|
||||||
|
"""Workflow Trigger Execution Status"""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
QUEUED = "queued"
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCEEDED = "succeeded"
|
||||||
|
FAILED = "failed"
|
||||||
|
RATE_LIMITED = "rate_limited"
|
||||||
|
RETRYING = "retrying"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowTriggerLog(Base):
|
||||||
|
"""
|
||||||
|
Workflow Trigger Log
|
||||||
|
|
||||||
|
Track async trigger workflow runs with re-invocation capability
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- id (uuid) Trigger Log ID (used as workflow_trigger_log_id)
|
||||||
|
- tenant_id (uuid) Workspace ID
|
||||||
|
- app_id (uuid) App ID
|
||||||
|
- workflow_id (uuid) Workflow ID
|
||||||
|
- workflow_run_id (uuid) Optional - Associated workflow run ID when execution starts
|
||||||
|
- root_node_id (string) Optional - Custom starting node ID for workflow execution
|
||||||
|
- trigger_type (string) Type of trigger: webhook, schedule, plugin
|
||||||
|
- trigger_data (text) Full trigger data including inputs (JSON)
|
||||||
|
- inputs (text) Input parameters (JSON)
|
||||||
|
- outputs (text) Optional - Output content (JSON)
|
||||||
|
- status (string) Execution status
|
||||||
|
- error (text) Optional - Error message if failed
|
||||||
|
- queue_name (string) Celery queue used
|
||||||
|
- celery_task_id (string) Optional - Celery task ID for tracking
|
||||||
|
- retry_count (int) Number of retry attempts
|
||||||
|
- elapsed_time (float) Optional - Time consumption in seconds
|
||||||
|
- total_tokens (int) Optional - Total tokens used
|
||||||
|
- created_by_role (string) Creator role: account, end_user
|
||||||
|
- created_by (string) Creator ID
|
||||||
|
- created_at (timestamp) Creation time
|
||||||
|
- triggered_at (timestamp) Optional - When actually triggered
|
||||||
|
- finished_at (timestamp) Optional - Completion time
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "workflow_trigger_logs"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="workflow_trigger_log_pkey"),
|
||||||
|
sa.Index("workflow_trigger_log_tenant_app_idx", "tenant_id", "app_id"),
|
||||||
|
sa.Index("workflow_trigger_log_status_idx", "status"),
|
||||||
|
sa.Index("workflow_trigger_log_created_at_idx", "created_at"),
|
||||||
|
sa.Index("workflow_trigger_log_workflow_run_idx", "workflow_run_id"),
|
||||||
|
sa.Index("workflow_trigger_log_workflow_id_idx", "workflow_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||||
|
root_node_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
|
|
||||||
|
trigger_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
trigger_data: Mapped[str] = mapped_column(sa.Text, nullable=False) # Full TriggerData as JSON
|
||||||
|
inputs: Mapped[str] = mapped_column(sa.Text, nullable=False) # Just inputs for easy viewing
|
||||||
|
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||||
|
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default=WorkflowTriggerStatus.PENDING)
|
||||||
|
error: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
|
||||||
|
|
||||||
|
queue_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
celery_task_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
|
retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
|
|
||||||
|
elapsed_time: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True)
|
||||||
|
total_tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
created_by: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|
||||||
|
triggered_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||||
|
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def created_by_account(self):
|
||||||
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
|
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def created_by_end_user(self):
|
||||||
|
from models.model import EndUser
|
||||||
|
|
||||||
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
|
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary for API responses"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"tenant_id": self.tenant_id,
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"workflow_id": self.workflow_id,
|
||||||
|
"workflow_run_id": self.workflow_run_id,
|
||||||
|
"trigger_type": self.trigger_type,
|
||||||
|
"trigger_data": json.loads(self.trigger_data),
|
||||||
|
"inputs": json.loads(self.inputs),
|
||||||
|
"outputs": json.loads(self.outputs) if self.outputs else None,
|
||||||
|
"status": self.status,
|
||||||
|
"error": self.error,
|
||||||
|
"queue_name": self.queue_name,
|
||||||
|
"celery_task_id": self.celery_task_id,
|
||||||
|
"retry_count": self.retry_count,
|
||||||
|
"elapsed_time": self.elapsed_time,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
"created_by_role": self.created_by_role,
|
||||||
|
"created_by": self.created_by,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"triggered_at": self.triggered_at.isoformat() if self.triggered_at else None,
|
||||||
|
"finished_at": self.finished_at.isoformat() if self.finished_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowWebhookTrigger(Base):
|
||||||
|
"""
|
||||||
|
Workflow Webhook Trigger
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- id (uuid) Primary key
|
||||||
|
- app_id (uuid) App ID to bind to a specific app
|
||||||
|
- node_id (varchar) Node ID which node in the workflow
|
||||||
|
- tenant_id (uuid) Workspace ID
|
||||||
|
- webhook_id (varchar) Webhook ID for URL: https://api.dify.ai/triggers/webhook/:webhook_id
|
||||||
|
- created_by (varchar) User ID of the creator
|
||||||
|
- created_at (timestamp) Creation time
|
||||||
|
- updated_at (timestamp) Last update time
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "workflow_webhook_triggers"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="workflow_webhook_trigger_pkey"),
|
||||||
|
sa.Index("workflow_webhook_trigger_tenant_idx", "tenant_id"),
|
||||||
|
sa.UniqueConstraint("app_id", "node_id", name="uniq_node"),
|
||||||
|
sa.UniqueConstraint("webhook_id", name="uniq_webhook_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||||
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
webhook_id: Mapped[str] = mapped_column(String(24), nullable=False)
|
||||||
|
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
server_onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowPluginTrigger(Base):
|
||||||
|
"""
|
||||||
|
Workflow Plugin Trigger
|
||||||
|
|
||||||
|
Maps plugin triggers to workflow nodes, similar to WorkflowWebhookTrigger
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- id (uuid) Primary key
|
||||||
|
- app_id (uuid) App ID to bind to a specific app
|
||||||
|
- node_id (varchar) Node ID which node in the workflow
|
||||||
|
- tenant_id (uuid) Workspace ID
|
||||||
|
- provider_id (varchar) Plugin provider ID
|
||||||
|
- trigger_name (varchar) trigger name
|
||||||
|
- subscription_id (varchar) Subscription ID
|
||||||
|
- created_at (timestamp) Creation time
|
||||||
|
- updated_at (timestamp) Last update time
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "workflow_plugin_triggers"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="workflow_plugin_trigger_pkey"),
|
||||||
|
sa.Index("workflow_plugin_trigger_tenant_subscription_idx", "tenant_id", "subscription_id", "trigger_name"),
|
||||||
|
sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node_subscription"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||||
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
provider_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
trigger_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
subscription_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.current_timestamp(),
|
||||||
|
server_onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AppTriggerType(StrEnum):
|
||||||
|
"""App Trigger Type Enum"""
|
||||||
|
|
||||||
|
TRIGGER_WEBHOOK = "trigger-webhook"
|
||||||
|
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||||
|
TRIGGER_PLUGIN = "trigger-plugin"
|
||||||
|
|
||||||
|
|
||||||
|
class AppTriggerStatus(StrEnum):
|
||||||
|
"""App Trigger Status Enum"""
|
||||||
|
|
||||||
|
ENABLED = "enabled"
|
||||||
|
DISABLED = "disabled"
|
||||||
|
UNAUTHORIZED = "unauthorized"
|
||||||
|
|
||||||
|
|
||||||
|
class AppTrigger(Base):
|
||||||
|
"""
|
||||||
|
App Trigger
|
||||||
|
|
||||||
|
Manages multiple triggers for an app with enable/disable and authorization states.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- id (uuid) Primary key
|
||||||
|
- tenant_id (uuid) Workspace ID
|
||||||
|
- app_id (uuid) App ID
|
||||||
|
- trigger_type (string) Type: webhook, schedule, plugin
|
||||||
|
- title (string) Trigger title
|
||||||
|
|
||||||
|
- status (string) Status: enabled, disabled, unauthorized, error
|
||||||
|
- node_id (string) Optional workflow node ID
|
||||||
|
- created_at (timestamp) Creation time
|
||||||
|
- updated_at (timestamp) Last update time
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "app_triggers"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="app_trigger_pkey"),
|
||||||
|
sa.Index("app_trigger_tenant_app_idx", "tenant_id", "app_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
node_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=False)
|
||||||
|
trigger_type: Mapped[str] = mapped_column(EnumText(AppTriggerType, length=50), nullable=False)
|
||||||
|
title: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
provider_name: Mapped[str] = mapped_column(String(255), server_default="", nullable=True)
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
EnumText(AppTriggerStatus, length=50), nullable=False, default=AppTriggerStatus.DISABLED
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
default=naive_utc_now(),
|
||||||
|
server_onupdate=func.current_timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowSchedulePlan(Base):
|
||||||
|
"""
|
||||||
|
Workflow Schedule Configuration
|
||||||
|
|
||||||
|
Store schedule configurations for time-based workflow triggers.
|
||||||
|
Uses cron expressions with timezone support for flexible scheduling.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
- id (uuid) Primary key
|
||||||
|
- app_id (uuid) App ID to bind to a specific app
|
||||||
|
- node_id (varchar) Starting node ID for workflow execution
|
||||||
|
- tenant_id (uuid) Workspace ID for multi-tenancy
|
||||||
|
- cron_expression (varchar) Cron expression defining schedule pattern
|
||||||
|
- timezone (varchar) Timezone for cron evaluation (e.g., 'Asia/Shanghai')
|
||||||
|
- next_run_at (timestamp) Next scheduled execution time
|
||||||
|
- created_at (timestamp) Creation timestamp
|
||||||
|
- updated_at (timestamp) Last update timestamp
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "workflow_schedule_plans"
|
||||||
|
__table_args__ = (
|
||||||
|
sa.PrimaryKeyConstraint("id", name="workflow_schedule_plan_pkey"),
|
||||||
|
sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node"),
|
||||||
|
sa.Index("workflow_schedule_plan_next_idx", "next_run_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
|
||||||
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
|
|
||||||
|
# Schedule configuration
|
||||||
|
cron_expression: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
timezone: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
|
||||||
|
# Schedule control
|
||||||
|
next_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary representation"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"tenant_id": self.tenant_id,
|
||||||
|
"cron_expression": self.cron_expression,
|
||||||
|
"timezone": self.timezone,
|
||||||
|
"next_run_at": self.next_run_at.isoformat() if self.next_run_at else None,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
}
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ dependencies = [
|
|||||||
"httpx-sse>=0.4.0",
|
"httpx-sse>=0.4.0",
|
||||||
"sendgrid~=6.12.3",
|
"sendgrid~=6.12.3",
|
||||||
"flask-restx>=1.3.0",
|
"flask-restx>=1.3.0",
|
||||||
|
"croniter>=6.0.0",
|
||||||
]
|
]
|
||||||
# Before adding new dependency, consider place it in
|
# Before adding new dependency, consider place it in
|
||||||
# alphabet order (a-z) and suitable group.
|
# alphabet order (a-z) and suitable group.
|
||||||
|
|||||||
198
api/repositories/sqlalchemy_workflow_trigger_log_repository.py
Normal file
198
api/repositories/sqlalchemy_workflow_trigger_log_repository.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
SQLAlchemy implementation of WorkflowTriggerLogRepository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import and_, delete, func, select, update
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus
|
||||||
|
from repositories.workflow_trigger_log_repository import TriggerLogOrderBy, WorkflowTriggerLogRepository
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||||
|
"""
|
||||||
|
SQLAlchemy implementation of WorkflowTriggerLogRepository.
|
||||||
|
|
||||||
|
Optimized for large table operations with proper indexing and batch processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session: Session):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||||
|
"""Create a new trigger log entry."""
|
||||||
|
self.session.add(trigger_log)
|
||||||
|
self.session.flush()
|
||||||
|
return trigger_log
|
||||||
|
|
||||||
|
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||||
|
"""Update an existing trigger log entry."""
|
||||||
|
self.session.merge(trigger_log)
|
||||||
|
self.session.flush()
|
||||||
|
return trigger_log
|
||||||
|
|
||||||
|
def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]:
|
||||||
|
"""Get a trigger log by its ID."""
|
||||||
|
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.id == trigger_log_id)
|
||||||
|
|
||||||
|
if tenant_id:
|
||||||
|
query = query.where(WorkflowTriggerLog.tenant_id == tenant_id)
|
||||||
|
|
||||||
|
return self.session.scalar(query)
|
||||||
|
|
||||||
|
def get_by_status(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
status: WorkflowTriggerStatus,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT,
|
||||||
|
order_desc: bool = True,
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""Get trigger logs by status with pagination."""
|
||||||
|
query = select(WorkflowTriggerLog).where(
|
||||||
|
and_(
|
||||||
|
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||||
|
WorkflowTriggerLog.app_id == app_id,
|
||||||
|
WorkflowTriggerLog.status == status,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply ordering
|
||||||
|
order_column = getattr(WorkflowTriggerLog, order_by.value)
|
||||||
|
if order_desc:
|
||||||
|
query = query.order_by(order_column.desc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(order_column.asc())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
query = query.limit(limit).offset(offset)
|
||||||
|
|
||||||
|
return list(self.session.scalars(query).all())
|
||||||
|
|
||||||
|
def get_failed_for_retry(
|
||||||
|
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""Get failed trigger logs eligible for retry."""
|
||||||
|
query = (
|
||||||
|
select(WorkflowTriggerLog)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||||
|
WorkflowTriggerLog.status.in_([WorkflowTriggerStatus.FAILED, WorkflowTriggerStatus.RATE_LIMITED]),
|
||||||
|
WorkflowTriggerLog.retry_count < max_retry_count,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(WorkflowTriggerLog.created_at.asc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(self.session.scalars(query).all())
|
||||||
|
|
||||||
|
def get_recent_logs(
|
||||||
|
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""Get recent trigger logs within specified hours."""
|
||||||
|
since = datetime.utcnow() - timedelta(hours=hours)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(WorkflowTriggerLog)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||||
|
WorkflowTriggerLog.app_id == app_id,
|
||||||
|
WorkflowTriggerLog.created_at >= since,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(WorkflowTriggerLog.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(self.session.scalars(query).all())
|
||||||
|
|
||||||
|
def count_by_status(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
status: Optional[WorkflowTriggerStatus] = None,
|
||||||
|
since: Optional[datetime] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count trigger logs by status."""
|
||||||
|
query = select(func.count(WorkflowTriggerLog.id)).where(
|
||||||
|
and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.app_id == app_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
query = query.where(WorkflowTriggerLog.status == status)
|
||||||
|
|
||||||
|
if since:
|
||||||
|
query = query.where(WorkflowTriggerLog.created_at >= since)
|
||||||
|
|
||||||
|
return self.session.scalar(query) or 0
|
||||||
|
|
||||||
|
def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int:
|
||||||
|
"""Delete expired trigger logs in batches."""
|
||||||
|
total_deleted = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Get batch of IDs to delete
|
||||||
|
subquery = (
|
||||||
|
select(WorkflowTriggerLog.id)
|
||||||
|
.where(and_(WorkflowTriggerLog.tenant_id == tenant_id, WorkflowTriggerLog.created_at < before_date))
|
||||||
|
.limit(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the batch
|
||||||
|
result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(subquery)))
|
||||||
|
|
||||||
|
deleted = result.rowcount
|
||||||
|
total_deleted += deleted
|
||||||
|
|
||||||
|
if deleted < batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
return total_deleted
|
||||||
|
|
||||||
|
def archive_completed_logs(
|
||||||
|
self, tenant_id: str, before_date: datetime, batch_size: int = 1000
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""Get completed logs for archival."""
|
||||||
|
query = (
|
||||||
|
select(WorkflowTriggerLog)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
WorkflowTriggerLog.tenant_id == tenant_id,
|
||||||
|
WorkflowTriggerLog.status == WorkflowTriggerStatus.SUCCEEDED,
|
||||||
|
WorkflowTriggerLog.finished_at < before_date,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.limit(batch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(self.session.scalars(query).all())
|
||||||
|
|
||||||
|
def update_status_batch(
|
||||||
|
self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
"""Update status for multiple trigger logs."""
|
||||||
|
update_data: dict[str, Any] = {"status": new_status}
|
||||||
|
|
||||||
|
if error_message is not None:
|
||||||
|
update_data["error"] = error_message
|
||||||
|
|
||||||
|
if new_status in [WorkflowTriggerStatus.SUCCEEDED, WorkflowTriggerStatus.FAILED]:
|
||||||
|
update_data["finished_at"] = datetime.utcnow()
|
||||||
|
|
||||||
|
result = self.session.execute(
|
||||||
|
update(WorkflowTriggerLog).where(WorkflowTriggerLog.id.in_(trigger_log_ids)).values(**update_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.rowcount
|
||||||
206
api/repositories/workflow_trigger_log_repository.py
Normal file
206
api/repositories/workflow_trigger_log_repository.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
Repository protocol for WorkflowTriggerLog operations.
|
||||||
|
|
||||||
|
This module provides a protocol interface for operations on WorkflowTriggerLog,
|
||||||
|
designed to efficiently handle a potentially large volume of trigger logs with
|
||||||
|
proper indexing and batch operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
|
from models.workflow import WorkflowTriggerLog, WorkflowTriggerStatus
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerLogOrderBy(StrEnum):
|
||||||
|
"""Fields available for ordering trigger logs"""
|
||||||
|
|
||||||
|
CREATED_AT = "created_at"
|
||||||
|
TRIGGERED_AT = "triggered_at"
|
||||||
|
FINISHED_AT = "finished_at"
|
||||||
|
STATUS = "status"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowTriggerLogRepository(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for operations on WorkflowTriggerLog.
|
||||||
|
|
||||||
|
This repository provides efficient access patterns for the trigger log table,
|
||||||
|
which is expected to grow large over time. It includes:
|
||||||
|
- Batch operations for cleanup
|
||||||
|
- Efficient queries with proper indexing
|
||||||
|
- Pagination support
|
||||||
|
- Status-based filtering
|
||||||
|
|
||||||
|
Implementation notes:
|
||||||
|
- Leverage database indexes on (tenant_id, app_id), status, and created_at
|
||||||
|
- Use batch operations for deletions to avoid locking
|
||||||
|
- Support pagination for large result sets
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||||
|
"""
|
||||||
|
Create a new trigger log entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trigger_log: The WorkflowTriggerLog instance to create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created WorkflowTriggerLog with generated ID
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def update(self, trigger_log: WorkflowTriggerLog) -> WorkflowTriggerLog:
|
||||||
|
"""
|
||||||
|
Update an existing trigger log entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trigger_log: The WorkflowTriggerLog instance to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated WorkflowTriggerLog
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_id(self, trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[WorkflowTriggerLog]:
|
||||||
|
"""
|
||||||
|
Get a trigger log by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trigger_log_id: The trigger log identifier
|
||||||
|
tenant_id: Optional tenant identifier for additional security
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WorkflowTriggerLog if found, None otherwise
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_status(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
status: WorkflowTriggerStatus,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
order_by: TriggerLogOrderBy = TriggerLogOrderBy.CREATED_AT,
|
||||||
|
order_desc: bool = True,
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""
|
||||||
|
Get trigger logs by status with pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant identifier
|
||||||
|
app_id: The application identifier
|
||||||
|
status: The workflow trigger status to filter by
|
||||||
|
limit: Maximum number of results
|
||||||
|
offset: Number of results to skip
|
||||||
|
order_by: Field to order results by
|
||||||
|
order_desc: Whether to order descending (True) or ascending (False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of WorkflowTriggerLog instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_failed_for_retry(
|
||||||
|
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""
|
||||||
|
Get failed trigger logs that are eligible for retry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant identifier
|
||||||
|
max_retry_count: Maximum retry count to consider
|
||||||
|
limit: Maximum number of results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of WorkflowTriggerLog instances eligible for retry
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_recent_logs(
|
||||||
|
self, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""
|
||||||
|
Get recent trigger logs within specified hours.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant identifier
|
||||||
|
app_id: The application identifier
|
||||||
|
hours: Number of hours to look back
|
||||||
|
limit: Maximum number of results
|
||||||
|
offset: Number of results to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of recent WorkflowTriggerLog instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def count_by_status(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
status: Optional[WorkflowTriggerStatus] = None,
|
||||||
|
since: Optional[datetime] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count trigger logs by status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant identifier
|
||||||
|
app_id: The application identifier
|
||||||
|
status: Optional status filter
|
||||||
|
since: Optional datetime to count from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Count of matching trigger logs
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete_expired_logs(self, tenant_id: str, before_date: datetime, batch_size: int = 1000) -> int:
|
||||||
|
"""
|
||||||
|
Delete expired trigger logs in batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant identifier
|
||||||
|
before_date: Delete logs created before this date
|
||||||
|
batch_size: Number of logs to delete per batch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of logs deleted
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def archive_completed_logs(
|
||||||
|
self, tenant_id: str, before_date: datetime, batch_size: int = 1000
|
||||||
|
) -> Sequence[WorkflowTriggerLog]:
|
||||||
|
"""
|
||||||
|
Get completed logs for archival before deletion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant identifier
|
||||||
|
before_date: Get logs completed before this date
|
||||||
|
batch_size: Number of logs to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of WorkflowTriggerLog instances for archival
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def update_status_batch(
|
||||||
|
self, trigger_log_ids: Sequence[str], new_status: WorkflowTriggerStatus, error_message: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Update status for multiple trigger logs at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trigger_log_ids: List of trigger log IDs to update
|
||||||
|
new_status: The new status to set
|
||||||
|
error_message: Optional error message to set
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of logs updated
|
||||||
|
"""
|
||||||
|
...
|
||||||
127
api/schedule/workflow_schedule_task.py
Normal file
127
api/schedule/workflow_schedule_task.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from celery import group, shared_task
|
||||||
|
from sqlalchemy import and_, select
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
from libs.schedule_utils import calculate_next_run_at
|
||||||
|
from models.workflow import AppTrigger, AppTriggerStatus, AppTriggerType, WorkflowSchedulePlan
|
||||||
|
from services.workflow.queue_dispatcher import QueueDispatcherManager
|
||||||
|
from tasks.workflow_schedule_tasks import run_schedule_trigger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(queue="schedule_poller")
|
||||||
|
def poll_workflow_schedules() -> None:
|
||||||
|
"""
|
||||||
|
Poll and process due workflow schedules.
|
||||||
|
|
||||||
|
Streaming flow:
|
||||||
|
1. Fetch due schedules in batches
|
||||||
|
2. Process each batch until all due schedules are handled
|
||||||
|
3. Optional: Limit total dispatches per tick as a circuit breaker
|
||||||
|
"""
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
with session_factory() as session:
|
||||||
|
total_dispatched = 0
|
||||||
|
total_rate_limited = 0
|
||||||
|
|
||||||
|
# Process in batches until we've handled all due schedules or hit the limit
|
||||||
|
while True:
|
||||||
|
due_schedules = _fetch_due_schedules(session)
|
||||||
|
|
||||||
|
if not due_schedules:
|
||||||
|
break
|
||||||
|
|
||||||
|
dispatched_count, rate_limited_count = _process_schedules(session, due_schedules)
|
||||||
|
total_dispatched += dispatched_count
|
||||||
|
total_rate_limited += rate_limited_count
|
||||||
|
|
||||||
|
logger.debug("Batch processed: %d dispatched, %d rate limited", dispatched_count, rate_limited_count)
|
||||||
|
|
||||||
|
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||||
|
if (
|
||||||
|
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
|
||||||
|
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||||
|
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if total_dispatched > 0 or total_rate_limited > 0:
|
||||||
|
logger.info("Total processed: %d dispatched, %d rate limited", total_dispatched, total_rate_limited)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||||
|
"""
|
||||||
|
Fetch a batch of due schedules, sorted by most overdue first.
|
||||||
|
|
||||||
|
Returns up to WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE schedules per call.
|
||||||
|
Used in a loop to progressively process all due schedules.
|
||||||
|
"""
|
||||||
|
now = naive_utc_now()
|
||||||
|
|
||||||
|
due_schedules = session.scalars(
|
||||||
|
(
|
||||||
|
select(WorkflowSchedulePlan)
|
||||||
|
.join(
|
||||||
|
AppTrigger,
|
||||||
|
and_(
|
||||||
|
AppTrigger.app_id == WorkflowSchedulePlan.app_id,
|
||||||
|
AppTrigger.node_id == WorkflowSchedulePlan.node_id,
|
||||||
|
AppTrigger.trigger_type == AppTriggerType.TRIGGER_SCHEDULE,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
WorkflowSchedulePlan.next_run_at <= now,
|
||||||
|
WorkflowSchedulePlan.next_run_at.isnot(None),
|
||||||
|
AppTrigger.status == AppTriggerStatus.ENABLED,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(WorkflowSchedulePlan.next_run_at.asc())
|
||||||
|
.with_for_update(skip_locked=True)
|
||||||
|
.limit(dify_config.WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE)
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(due_schedules)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> tuple[int, int]:
|
||||||
|
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
||||||
|
if not schedules:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
dispatcher_manager = QueueDispatcherManager()
|
||||||
|
tasks_to_dispatch = []
|
||||||
|
rate_limited_count = 0
|
||||||
|
|
||||||
|
for schedule in schedules:
|
||||||
|
next_run_at = calculate_next_run_at(
|
||||||
|
schedule.cron_expression,
|
||||||
|
schedule.timezone,
|
||||||
|
)
|
||||||
|
schedule.next_run_at = next_run_at
|
||||||
|
|
||||||
|
dispatcher = dispatcher_manager.get_dispatcher(schedule.tenant_id)
|
||||||
|
if not dispatcher.check_daily_quota(schedule.tenant_id):
|
||||||
|
logger.info("Tenant %s rate limited, skipping schedule_plan %s", schedule.tenant_id, schedule.id)
|
||||||
|
rate_limited_count += 1
|
||||||
|
else:
|
||||||
|
tasks_to_dispatch.append(schedule.id)
|
||||||
|
|
||||||
|
if tasks_to_dispatch:
|
||||||
|
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
||||||
|
job.apply_async()
|
||||||
|
|
||||||
|
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return len(tasks_to_dispatch), rate_limited_count
|
||||||
@@ -26,6 +26,7 @@ from core.workflow.nodes.llm.entities import LLMNodeData
|
|||||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
|
from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode
|
||||||
from events.app_event import app_model_config_was_updated, app_was_created
|
from events.app_event import app_model_config_was_updated, app_was_created
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from factories import variable_factory
|
from factories import variable_factory
|
||||||
@@ -595,6 +596,9 @@ class AppDslService:
|
|||||||
if not include_secret and data_type == NodeType.AGENT.value:
|
if not include_secret and data_type == NodeType.AGENT.value:
|
||||||
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||||
tool.pop("credential_id", None)
|
tool.pop("credential_id", None)
|
||||||
|
if data_type == NodeType.TRIGGER_SCHEDULE.value:
|
||||||
|
# override the config with the default config
|
||||||
|
node_data["config"] = TriggerScheduleNode.get_default_config()["config"]
|
||||||
|
|
||||||
export_data["workflow"] = workflow_dict
|
export_data["workflow"] = workflow_dict
|
||||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||||
|
|||||||
320
api/services/async_workflow_service.py
Normal file
320
api/services/async_workflow_service.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""
|
||||||
|
Universal async workflow execution service.
|
||||||
|
|
||||||
|
This service provides a centralized entry point for triggering workflows asynchronously
|
||||||
|
with support for different subscription tiers, rate limiting, and execution tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.account import Account
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import App, EndUser
|
||||||
|
from models.workflow import Workflow, WorkflowTriggerLog, WorkflowTriggerStatus
|
||||||
|
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||||
|
from services.errors.app import InvokeDailyRateLimitError, WorkflowNotFoundError
|
||||||
|
from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData
|
||||||
|
from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority
|
||||||
|
from services.workflow.rate_limiter import TenantDailyRateLimiter
|
||||||
|
from services.workflow_service import WorkflowService
|
||||||
|
from tasks.async_workflow_tasks import (
|
||||||
|
execute_workflow_professional,
|
||||||
|
execute_workflow_sandbox,
|
||||||
|
execute_workflow_team,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncWorkflowService:
|
||||||
|
"""
|
||||||
|
Universal entry point for async workflow execution - ALL METHODS ARE NON-BLOCKING
|
||||||
|
|
||||||
|
This service handles:
|
||||||
|
- Trigger data validation and processing
|
||||||
|
- Queue routing based on subscription tier
|
||||||
|
- Daily rate limiting with timezone support
|
||||||
|
- Execution tracking and logging
|
||||||
|
- Retry mechanisms for failed executions
|
||||||
|
|
||||||
|
Important: All trigger methods return immediately after queuing tasks.
|
||||||
|
Actual workflow execution happens asynchronously in background Celery workers.
|
||||||
|
Use trigger log IDs to monitor execution status and results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def trigger_workflow_async(
|
||||||
|
cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData
|
||||||
|
) -> AsyncTriggerResponse:
|
||||||
|
"""
|
||||||
|
Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK
|
||||||
|
|
||||||
|
Creates a trigger log and dispatches to appropriate queue based on subscription tier.
|
||||||
|
The workflow execution happens asynchronously in the background via Celery workers.
|
||||||
|
This method returns immediately after queuing the task, not after execution completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session to use for operations
|
||||||
|
user: User (Account or EndUser) who initiated the workflow trigger
|
||||||
|
trigger_data: Validated Pydantic model containing trigger information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncTriggerResponse with workflow_trigger_log_id, task_id, status="queued", and queue
|
||||||
|
Note: The actual workflow execution status must be checked separately via workflow_trigger_log_id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WorkflowNotFoundError: If app or workflow not found
|
||||||
|
InvokeDailyRateLimitError: If daily rate limit exceeded
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
- Non-blocking: Returns immediately after queuing
|
||||||
|
- Asynchronous: Actual execution happens in background Celery workers
|
||||||
|
- Status tracking: Use workflow_trigger_log_id to monitor progress
|
||||||
|
- Queue-based: Routes to different queues based on subscription tier
|
||||||
|
"""
|
||||||
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
|
dispatcher_manager = QueueDispatcherManager()
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
rate_limiter = TenantDailyRateLimiter(redis_client)
|
||||||
|
|
||||||
|
# 1. Validate app exists
|
||||||
|
app_model = session.scalar(select(App).where(App.id == trigger_data.app_id))
|
||||||
|
if not app_model:
|
||||||
|
raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}")
|
||||||
|
|
||||||
|
# 2. Get workflow
|
||||||
|
workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id)
|
||||||
|
|
||||||
|
# 3. Get dispatcher based on tenant subscription
|
||||||
|
dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id)
|
||||||
|
|
||||||
|
# 4. Rate limiting check will be done without timezone first
|
||||||
|
|
||||||
|
# 5. Determine user role and ID
|
||||||
|
if isinstance(user, Account):
|
||||||
|
created_by_role = CreatorUserRole.ACCOUNT
|
||||||
|
created_by = user.id
|
||||||
|
else: # EndUser
|
||||||
|
created_by_role = CreatorUserRole.END_USER
|
||||||
|
created_by = user.id
|
||||||
|
|
||||||
|
# 6. Create trigger log entry first (for tracking)
|
||||||
|
trigger_log = WorkflowTriggerLog(
|
||||||
|
tenant_id=trigger_data.tenant_id,
|
||||||
|
app_id=trigger_data.app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
root_node_id=trigger_data.root_node_id,
|
||||||
|
trigger_type=trigger_data.trigger_type,
|
||||||
|
trigger_data=trigger_data.model_dump_json(),
|
||||||
|
inputs=json.dumps(dict(trigger_data.inputs)),
|
||||||
|
status=WorkflowTriggerStatus.PENDING,
|
||||||
|
queue_name=dispatcher.get_queue_name(),
|
||||||
|
retry_count=0,
|
||||||
|
created_by_role=created_by_role,
|
||||||
|
created_by=created_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger_log = trigger_log_repo.create(trigger_log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# 7. Check and consume daily quota
|
||||||
|
if not dispatcher.consume_quota(trigger_data.tenant_id):
|
||||||
|
# Update trigger log status
|
||||||
|
trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED
|
||||||
|
trigger_log.error = f"Daily limit reached for {dispatcher.get_queue_name()}"
|
||||||
|
trigger_log_repo.update(trigger_log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
tenant_owner_tz = rate_limiter._get_tenant_owner_timezone(trigger_data.tenant_id)
|
||||||
|
|
||||||
|
remaining = rate_limiter.get_remaining_quota(trigger_data.tenant_id, dispatcher.get_daily_limit())
|
||||||
|
|
||||||
|
reset_time = rate_limiter.get_quota_reset_time(trigger_data.tenant_id, tenant_owner_tz)
|
||||||
|
|
||||||
|
raise InvokeDailyRateLimitError(
|
||||||
|
f"Daily workflow execution limit reached. "
|
||||||
|
f"Limit resets at {reset_time.strftime('%Y-%m-%d %H:%M:%S %Z')}. "
|
||||||
|
f"Remaining quota: {remaining}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Create task data
|
||||||
|
queue_name = dispatcher.get_queue_name()
|
||||||
|
|
||||||
|
task_data = WorkflowTaskData(workflow_trigger_log_id=trigger_log.id)
|
||||||
|
|
||||||
|
# 9. Dispatch to appropriate queue
|
||||||
|
task_data_dict = task_data.model_dump(mode="json")
|
||||||
|
|
||||||
|
task: AsyncResult | None = None
|
||||||
|
if queue_name == QueuePriority.PROFESSIONAL:
|
||||||
|
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
|
||||||
|
elif queue_name == QueuePriority.TEAM:
|
||||||
|
task = execute_workflow_team.delay(task_data_dict) # type: ignore
|
||||||
|
else: # SANDBOX
|
||||||
|
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise ValueError(f"Failed to queue task for queue: {queue_name}")
|
||||||
|
|
||||||
|
# 10. Update trigger log with task info
|
||||||
|
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||||
|
trigger_log.celery_task_id = task.id
|
||||||
|
trigger_log.triggered_at = datetime.now(UTC)
|
||||||
|
trigger_log_repo.update(trigger_log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return AsyncTriggerResponse(
|
||||||
|
workflow_trigger_log_id=trigger_log.id,
|
||||||
|
task_id=task.id, # type: ignore
|
||||||
|
status="queued",
|
||||||
|
queue=queue_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reinvoke_trigger(
|
||||||
|
cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str
|
||||||
|
) -> AsyncTriggerResponse:
|
||||||
|
"""
|
||||||
|
Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK
|
||||||
|
|
||||||
|
Updates the existing trigger log to retry status and creates a new async execution.
|
||||||
|
Returns immediately after queuing the retry, not after execution completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session to use for operations
|
||||||
|
user: User (Account or EndUser) who initiated the retry
|
||||||
|
workflow_trigger_log_id: ID of the trigger log to re-invoke
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncTriggerResponse with new execution information (status="queued")
|
||||||
|
Note: This creates a new trigger log entry for the retry attempt
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If trigger log not found
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
- Non-blocking: Returns immediately after queuing retry
|
||||||
|
- Creates new trigger log: Original log marked as retrying, new log for execution
|
||||||
|
- Preserves original trigger data: Uses same inputs and configuration
|
||||||
|
"""
|
||||||
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
|
|
||||||
|
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id)
|
||||||
|
|
||||||
|
if not trigger_log:
|
||||||
|
raise ValueError(f"Trigger log not found: {workflow_trigger_log_id}")
|
||||||
|
|
||||||
|
# Reconstruct trigger data from log
|
||||||
|
trigger_data = TriggerData.model_validate_json(trigger_log.trigger_data)
|
||||||
|
|
||||||
|
# Reset log for retry
|
||||||
|
trigger_log.status = WorkflowTriggerStatus.RETRYING
|
||||||
|
trigger_log.retry_count += 1
|
||||||
|
trigger_log.error = None
|
||||||
|
trigger_log.triggered_at = datetime.now(UTC)
|
||||||
|
trigger_log_repo.update(trigger_log)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Re-trigger workflow (this will create a new trigger log)
|
||||||
|
return cls.trigger_workflow_async(session, user, trigger_data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: Optional[str] = None) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get trigger log by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_trigger_log_id: ID of the trigger log
|
||||||
|
tenant_id: Optional tenant ID for security check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trigger log as dictionary or None if not found
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
|
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
|
||||||
|
|
||||||
|
if not trigger_log:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return trigger_log.to_dict()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_recent_logs(
|
||||||
|
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get recent trigger logs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
app_id: Application ID
|
||||||
|
hours: Number of hours to look back
|
||||||
|
limit: Maximum number of results
|
||||||
|
offset: Number of results to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of trigger logs as dictionaries
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
|
logs = trigger_log_repo.get_recent_logs(
|
||||||
|
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
|
||||||
|
)
|
||||||
|
|
||||||
|
return [log.to_dict() for log in logs]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_failed_logs_for_retry(cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get failed logs eligible for retry
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
max_retry_count: Maximum retry count
|
||||||
|
limit: Maximum number of results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of failed trigger logs as dictionaries
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
|
logs = trigger_log_repo.get_failed_for_retry(
|
||||||
|
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
return [log.to_dict() for log in logs]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: Optional[str] = None) -> Workflow:
|
||||||
|
"""
|
||||||
|
Get workflow for the app
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_model: App model instance
|
||||||
|
workflow_id: Optional specific workflow ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Workflow instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WorkflowNotFoundError: If workflow not found
|
||||||
|
"""
|
||||||
|
if workflow_id:
|
||||||
|
# Get specific published workflow
|
||||||
|
workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id)
|
||||||
|
if not workflow:
|
||||||
|
raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}")
|
||||||
|
else:
|
||||||
|
# Get default published workflow
|
||||||
|
workflow = workflow_service.get_published_workflow(app_model)
|
||||||
|
if not workflow:
|
||||||
|
raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}")
|
||||||
|
|
||||||
|
return workflow
|
||||||
@@ -16,3 +16,9 @@ class WorkflowNotFoundError(Exception):
|
|||||||
|
|
||||||
class WorkflowIdFormatError(Exception):
|
class WorkflowIdFormatError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeDailyRateLimitError(Exception):
|
||||||
|
"""Raised when daily rate limit is exceeded for workflow invocations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class OAuthProxyService(BasePluginClient):
|
|||||||
__KEY_PREFIX__ = "oauth_proxy_context:"
|
__KEY_PREFIX__ = "oauth_proxy_context:"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
|
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str, extra_data: dict = {}):
|
||||||
"""
|
"""
|
||||||
Create a proxy context for an OAuth 2.0 authorization request.
|
Create a proxy context for an OAuth 2.0 authorization request.
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ class OAuthProxyService(BasePluginClient):
|
|||||||
"""
|
"""
|
||||||
context_id = str(uuid.uuid4())
|
context_id = str(uuid.uuid4())
|
||||||
data = {
|
data = {
|
||||||
|
**extra_data,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"plugin_id": plugin_id,
|
"plugin_id": plugin_id,
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
|
|||||||
@@ -4,11 +4,18 @@ from typing import Any, Literal
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.plugin.entities.parameters import PluginParameterOption
|
from core.plugin.entities.parameters import PluginParameterOption
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||||
|
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||||
|
from core.trigger.entities.entities import SubscriptionBuilder
|
||||||
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import BuiltinToolProvider
|
from models.tools import BuiltinToolProvider
|
||||||
|
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||||
|
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||||
|
|
||||||
|
|
||||||
class PluginParameterService:
|
class PluginParameterService:
|
||||||
@@ -20,7 +27,8 @@ class PluginParameterService:
|
|||||||
provider: str,
|
provider: str,
|
||||||
action: str,
|
action: str,
|
||||||
parameter: str,
|
parameter: str,
|
||||||
provider_type: Literal["tool"],
|
credential_id: str | None,
|
||||||
|
provider_type: Literal["tool", "trigger"],
|
||||||
) -> Sequence[PluginParameterOption]:
|
) -> Sequence[PluginParameterOption]:
|
||||||
"""
|
"""
|
||||||
Get dynamic select options for a plugin parameter.
|
Get dynamic select options for a plugin parameter.
|
||||||
@@ -33,7 +41,7 @@ class PluginParameterService:
|
|||||||
parameter: The parameter name.
|
parameter: The parameter name.
|
||||||
"""
|
"""
|
||||||
credentials: Mapping[str, Any] = {}
|
credentials: Mapping[str, Any] = {}
|
||||||
|
credential_type: str = CredentialType.UNAUTHORIZED.value
|
||||||
match provider_type:
|
match provider_type:
|
||||||
case "tool":
|
case "tool":
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
@@ -49,24 +57,56 @@ class PluginParameterService:
|
|||||||
else:
|
else:
|
||||||
# fetch credentials from db
|
# fetch credentials from db
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
db_record = (
|
if credential_id:
|
||||||
session.query(BuiltinToolProvider)
|
db_record = (
|
||||||
.where(
|
session.query(BuiltinToolProvider)
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
.where(
|
||||||
BuiltinToolProvider.provider == provider,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.provider == provider,
|
||||||
|
BuiltinToolProvider.id == credential_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
db_record = (
|
||||||
|
session.query(BuiltinToolProvider)
|
||||||
|
.where(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.provider == provider,
|
||||||
|
)
|
||||||
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if db_record is None:
|
if db_record is None:
|
||||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||||
|
|
||||||
credentials = encrypter.decrypt(db_record.credentials)
|
credentials = encrypter.decrypt(db_record.credentials)
|
||||||
|
credential_type = db_record.credential_type
|
||||||
|
case "trigger":
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, TriggerProviderID(provider))
|
||||||
|
if credential_id:
|
||||||
|
subscription: TriggerProviderSubscriptionApiEntity | SubscriptionBuilder | None = (
|
||||||
|
TriggerSubscriptionBuilderService.get_subscription_builder(credential_id)
|
||||||
|
or TriggerProviderService.get_subscription_by_id(tenant_id, credential_id)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
subscription: TriggerProviderSubscriptionApiEntity | SubscriptionBuilder | None = (
|
||||||
|
TriggerProviderService.get_subscription_by_id(tenant_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscription is None:
|
||||||
|
raise ValueError(f"Subscription {credential_id} not found")
|
||||||
|
|
||||||
|
credentials = subscription.credentials
|
||||||
|
credential_type = subscription.credential_type or CredentialType.UNAUTHORIZED
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||||
|
|
||||||
return (
|
return (
|
||||||
DynamicSelectClient()
|
DynamicSelectClient()
|
||||||
.fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter)
|
.fetch_dynamic_select_options(
|
||||||
|
tenant_id, user_id, plugin_id, provider, action, credentials, credential_type, parameter
|
||||||
|
)
|
||||||
.options
|
.options
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from mimetypes import guess_type
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper import marketplace
|
from core.helper import marketplace
|
||||||
@@ -176,6 +177,13 @@ class PluginService:
|
|||||||
manager = PluginInstaller()
|
manager = PluginInstaller()
|
||||||
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
|
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||||
|
url_prefix = (
|
||||||
|
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
|
||||||
|
)
|
||||||
|
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
|
def get_asset(tenant_id: str, asset_file: str) -> tuple[bytes, str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
274
api/services/schedule_service.py
Normal file
274
api/services/schedule_service.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||||
|
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||||
|
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||||
|
from models.account import Account, TenantAccountJoin
|
||||||
|
from models.workflow import Workflow, WorkflowSchedulePlan
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleService:
|
||||||
|
@staticmethod
|
||||||
|
def create_schedule(
|
||||||
|
session: Session,
|
||||||
|
tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
config: ScheduleConfig,
|
||||||
|
) -> WorkflowSchedulePlan:
|
||||||
|
"""
|
||||||
|
Create a new schedule with validated configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
app_id: Application ID
|
||||||
|
config: Validated schedule configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created WorkflowSchedulePlan instance
|
||||||
|
"""
|
||||||
|
next_run_at = calculate_next_run_at(
|
||||||
|
config.cron_expression,
|
||||||
|
config.timezone,
|
||||||
|
)
|
||||||
|
|
||||||
|
schedule = WorkflowSchedulePlan(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
node_id=config.node_id,
|
||||||
|
cron_expression=config.cron_expression,
|
||||||
|
timezone=config.timezone,
|
||||||
|
next_run_at=next_run_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(schedule)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_schedule(
|
||||||
|
session: Session,
|
||||||
|
schedule_id: str,
|
||||||
|
updates: SchedulePlanUpdate,
|
||||||
|
) -> WorkflowSchedulePlan:
|
||||||
|
"""
|
||||||
|
Update an existing schedule with validated configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
schedule_id: Schedule ID to update
|
||||||
|
updates: Validated update configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ScheduleNotFoundError: If schedule not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated WorkflowSchedulePlan instance
|
||||||
|
"""
|
||||||
|
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||||
|
if not schedule:
|
||||||
|
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||||
|
|
||||||
|
# If time-related fields are updated, synchronously update the next_run_at.
|
||||||
|
time_fields_updated = False
|
||||||
|
|
||||||
|
if updates.node_id is not None:
|
||||||
|
schedule.node_id = updates.node_id
|
||||||
|
|
||||||
|
if updates.cron_expression is not None:
|
||||||
|
schedule.cron_expression = updates.cron_expression
|
||||||
|
time_fields_updated = True
|
||||||
|
|
||||||
|
if updates.timezone is not None:
|
||||||
|
schedule.timezone = updates.timezone
|
||||||
|
time_fields_updated = True
|
||||||
|
|
||||||
|
if time_fields_updated:
|
||||||
|
schedule.next_run_at = calculate_next_run_at(
|
||||||
|
schedule.cron_expression,
|
||||||
|
schedule.timezone,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.flush()
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_schedule(
|
||||||
|
session: Session,
|
||||||
|
schedule_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Delete a schedule plan.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
schedule_id: Schedule ID to delete
|
||||||
|
"""
|
||||||
|
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||||
|
if not schedule:
|
||||||
|
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||||
|
|
||||||
|
session.delete(schedule)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenant_owner(session: Session, tenant_id: str) -> Optional[Account]:
|
||||||
|
"""
|
||||||
|
Returns an account to execute scheduled workflows on behalf of the tenant.
|
||||||
|
Prioritizes owner over admin to ensure proper authorization hierarchy.
|
||||||
|
"""
|
||||||
|
result = session.execute(
|
||||||
|
select(TenantAccountJoin)
|
||||||
|
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "owner")
|
||||||
|
.limit(1)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
# Owner may not exist in some tenant configurations, fallback to admin
|
||||||
|
result = session.execute(
|
||||||
|
select(TenantAccountJoin)
|
||||||
|
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.role == "admin")
|
||||||
|
.limit(1)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return session.get(Account, result.account_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_next_run_at(
|
||||||
|
session: Session,
|
||||||
|
schedule_id: str,
|
||||||
|
) -> datetime:
|
||||||
|
"""
|
||||||
|
Advances the schedule to its next execution time after a successful trigger.
|
||||||
|
Uses current time as base to prevent missing executions during delays.
|
||||||
|
"""
|
||||||
|
schedule = session.get(WorkflowSchedulePlan, schedule_id)
|
||||||
|
if not schedule:
|
||||||
|
raise ScheduleNotFoundError(f"Schedule not found: {schedule_id}")
|
||||||
|
|
||||||
|
# Base on current time to handle execution delays gracefully
|
||||||
|
next_run_at = calculate_next_run_at(
|
||||||
|
schedule.cron_expression,
|
||||||
|
schedule.timezone,
|
||||||
|
)
|
||||||
|
|
||||||
|
schedule.next_run_at = next_run_at
|
||||||
|
session.flush()
|
||||||
|
return next_run_at
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_schedule_config(workflow: Workflow) -> Optional[ScheduleConfig]:
|
||||||
|
"""
|
||||||
|
Extracts schedule configuration from workflow graph.
|
||||||
|
|
||||||
|
Searches for the first schedule trigger node in the workflow and converts
|
||||||
|
its configuration (either visual or cron mode) into a unified ScheduleConfig.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow: The workflow containing the graph definition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScheduleConfig if a valid schedule node is found, None if no schedule node exists
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ScheduleConfigError: If graph parsing fails or schedule configuration is invalid
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Currently only returns the first schedule node found.
|
||||||
|
Multiple schedule nodes in the same workflow are not supported.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
graph_data = workflow.graph_dict
|
||||||
|
except (json.JSONDecodeError, TypeError, AttributeError) as e:
|
||||||
|
raise ScheduleConfigError(f"Failed to parse workflow graph: {e}")
|
||||||
|
|
||||||
|
if not graph_data:
|
||||||
|
raise ScheduleConfigError("Workflow graph is empty")
|
||||||
|
|
||||||
|
nodes = graph_data.get("nodes", [])
|
||||||
|
for node in nodes:
|
||||||
|
node_data = node.get("data", {})
|
||||||
|
|
||||||
|
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mode = node_data.get("mode", "visual")
|
||||||
|
timezone = node_data.get("timezone", "UTC")
|
||||||
|
node_id = node.get("id", "start")
|
||||||
|
|
||||||
|
cron_expression = None
|
||||||
|
if mode == "cron":
|
||||||
|
cron_expression = node_data.get("cron_expression")
|
||||||
|
if not cron_expression:
|
||||||
|
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||||
|
elif mode == "visual":
|
||||||
|
frequency = node_data.get("frequency")
|
||||||
|
visual_config_dict = node_data.get("visual_config", {})
|
||||||
|
visual_config = VisualConfig(**visual_config_dict)
|
||||||
|
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
|
||||||
|
else:
|
||||||
|
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
|
||||||
|
|
||||||
|
return ScheduleConfig(node_id=node_id, cron_expression=cron_expression, timezone=timezone)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def visual_to_cron(frequency: str, visual_config: VisualConfig) -> str:
|
||||||
|
"""
|
||||||
|
Converts user-friendly visual schedule settings to cron expression.
|
||||||
|
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
|
||||||
|
"""
|
||||||
|
if frequency == "hourly":
|
||||||
|
if visual_config.on_minute is None:
|
||||||
|
raise ScheduleConfigError("on_minute is required for hourly schedules")
|
||||||
|
return f"{visual_config.on_minute} * * * *"
|
||||||
|
|
||||||
|
elif frequency == "daily":
|
||||||
|
if not visual_config.time:
|
||||||
|
raise ScheduleConfigError("time is required for daily schedules")
|
||||||
|
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||||
|
return f"{minute} {hour} * * *"
|
||||||
|
|
||||||
|
elif frequency == "weekly":
|
||||||
|
if not visual_config.time:
|
||||||
|
raise ScheduleConfigError("time is required for weekly schedules")
|
||||||
|
if not visual_config.weekdays:
|
||||||
|
raise ScheduleConfigError("Weekdays are required for weekly schedules")
|
||||||
|
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||||
|
weekday_map = {"sun": "0", "mon": "1", "tue": "2", "wed": "3", "thu": "4", "fri": "5", "sat": "6"}
|
||||||
|
cron_weekdays = [weekday_map[day] for day in visual_config.weekdays]
|
||||||
|
return f"{minute} {hour} * * {','.join(sorted(cron_weekdays))}"
|
||||||
|
|
||||||
|
elif frequency == "monthly":
|
||||||
|
if not visual_config.time:
|
||||||
|
raise ScheduleConfigError("time is required for monthly schedules")
|
||||||
|
if not visual_config.monthly_days:
|
||||||
|
raise ScheduleConfigError("Monthly days are required for monthly schedules")
|
||||||
|
hour, minute = convert_12h_to_24h(visual_config.time)
|
||||||
|
|
||||||
|
numeric_days = []
|
||||||
|
has_last = False
|
||||||
|
for day in visual_config.monthly_days:
|
||||||
|
if day == "last":
|
||||||
|
has_last = True
|
||||||
|
else:
|
||||||
|
numeric_days.append(day)
|
||||||
|
|
||||||
|
result_days = [str(d) for d in sorted(set(numeric_days))]
|
||||||
|
if has_last:
|
||||||
|
result_days.append("L")
|
||||||
|
|
||||||
|
return f"{minute} {hour} {','.join(result_days)} * *"
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ScheduleConfigError(f"Unsupported frequency: {frequency}")
|
||||||
@@ -13,6 +13,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
|||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||||
from core.plugin.entities.plugin import ToolProviderID
|
from core.plugin.entities.plugin import ToolProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
from core.tools.entities.api_entities import (
|
from core.tools.entities.api_entities import (
|
||||||
@@ -21,7 +22,6 @@ from core.tools.entities.api_entities import (
|
|||||||
ToolProviderCredentialApiEntity,
|
ToolProviderCredentialApiEntity,
|
||||||
ToolProviderCredentialInfoApiEntity,
|
ToolProviderCredentialInfoApiEntity,
|
||||||
)
|
)
|
||||||
from core.tools.entities.tool_entities import CredentialType
|
|
||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
@@ -39,7 +39,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class BuiltinToolManageService:
|
class BuiltinToolManageService:
|
||||||
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||||
__DEFAULT_EXPIRES_AT__ = 2147483647
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||||
@@ -278,9 +277,7 @@ class BuiltinToolManageService:
|
|||||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||||
credential_type=api_type.value,
|
credential_type=api_type.value,
|
||||||
name=name,
|
name=name,
|
||||||
expires_at=expires_at
|
expires_at=expires_at if expires_at is not None else -1,
|
||||||
if expires_at is not None
|
|
||||||
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(db_provider)
|
session.add(db_provider)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from yarl import URL
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.mcp.types import Tool as MCPTool
|
from core.mcp.types import Tool as MCPTool
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
@@ -16,7 +17,6 @@ from core.tools.entities.common_entities import I18nObject
|
|||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ApiProviderAuthType,
|
ApiProviderAuthType,
|
||||||
CredentialType,
|
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
@@ -25,18 +25,12 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p
|
|||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ToolTransformService:
|
class ToolTransformService:
|
||||||
@classmethod
|
|
||||||
def get_plugin_icon_url(cls, tenant_id: str, filename: str) -> str:
|
|
||||||
url_prefix = (
|
|
||||||
URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "plugin" / "icon"
|
|
||||||
)
|
|
||||||
return str(url_prefix % {"tenant_id": tenant_id, "filename": filename})
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
|
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
|
||||||
"""
|
"""
|
||||||
@@ -74,11 +68,9 @@ class ToolTransformService:
|
|||||||
elif isinstance(provider, ToolProviderApiEntity):
|
elif isinstance(provider, ToolProviderApiEntity):
|
||||||
if provider.plugin_id:
|
if provider.plugin_id:
|
||||||
if isinstance(provider.icon, str):
|
if isinstance(provider.icon, str):
|
||||||
provider.icon = ToolTransformService.get_plugin_icon_url(
|
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
|
||||||
tenant_id=tenant_id, filename=provider.icon
|
|
||||||
)
|
|
||||||
if isinstance(provider.icon_dark, str) and provider.icon_dark:
|
if isinstance(provider.icon_dark, str) and provider.icon_dark:
|
||||||
provider.icon_dark = ToolTransformService.get_plugin_icon_url(
|
provider.icon_dark = PluginService.get_plugin_icon_url(
|
||||||
tenant_id=tenant_id, filename=provider.icon_dark
|
tenant_id=tenant_id, filename=provider.icon_dark
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
526
api/services/trigger/trigger_provider_service.py
Normal file
526
api/services/trigger/trigger_provider_service.py
Normal file
@@ -0,0 +1,526 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import desc, func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||||
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||||
|
from core.helper.provider_encryption import create_provider_encrypter
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||||
|
from core.trigger.entities.api_entities import (
|
||||||
|
TriggerProviderApiEntity,
|
||||||
|
TriggerProviderSubscriptionApiEntity,
|
||||||
|
)
|
||||||
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
|
from core.trigger.utils.encryption import (
|
||||||
|
create_trigger_provider_encrypter_for_properties,
|
||||||
|
create_trigger_provider_encrypter_for_subscription,
|
||||||
|
create_trigger_provider_oauth_encrypter,
|
||||||
|
delete_cache_for_subscription,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerSubscription
|
||||||
|
from models.workflow import WorkflowPluginTrigger
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerProviderService:
|
||||||
|
"""Service for managing trigger providers and credentials"""
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Trigger provider
|
||||||
|
##########################
|
||||||
|
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_trigger_provider(cls, tenant_id: str, provider: TriggerProviderID) -> TriggerProviderApiEntity:
|
||||||
|
"""Get info for a trigger provider"""
|
||||||
|
return TriggerManager.get_trigger_provider(tenant_id, provider).to_api_entity()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
|
||||||
|
"""List all trigger providers for the current tenant"""
|
||||||
|
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_trigger_provider_subscriptions(
|
||||||
|
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||||
|
) -> list[TriggerProviderSubscriptionApiEntity]:
|
||||||
|
"""List all trigger subscriptions for the current tenant"""
|
||||||
|
subscriptions: list[TriggerProviderSubscriptionApiEntity] = []
|
||||||
|
workflows_in_use_map: dict[str, int] = {}
|
||||||
|
with Session(db.engine, autoflush=False) as session:
|
||||||
|
# Get all subscriptions
|
||||||
|
subscriptions_db = (
|
||||||
|
session.query(TriggerSubscription)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||||
|
.order_by(desc(TriggerSubscription.created_at))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db]
|
||||||
|
if not subscriptions:
|
||||||
|
return []
|
||||||
|
usage_counts = (
|
||||||
|
session.query(
|
||||||
|
WorkflowPluginTrigger.subscription_id,
|
||||||
|
func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||||
|
WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]),
|
||||||
|
)
|
||||||
|
.group_by(WorkflowPluginTrigger.subscription_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts}
|
||||||
|
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
for subscription in subscriptions:
|
||||||
|
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
subscription.credentials = encrypter.mask_credentials(subscription.credentials)
|
||||||
|
count = workflows_in_use_map.get(subscription.id)
|
||||||
|
subscription.workflows_in_use = count if count is not None else 0
|
||||||
|
|
||||||
|
return subscriptions
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_trigger_subscription(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
endpoint_id: str,
|
||||||
|
credential_type: CredentialType,
|
||||||
|
parameters: Mapping[str, Any],
|
||||||
|
properties: Mapping[str, Any],
|
||||||
|
credentials: Mapping[str, str],
|
||||||
|
subscription_id: Optional[str] = None,
|
||||||
|
credential_expires_at: int = -1,
|
||||||
|
expires_at: int = -1,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Add a new trigger provider with credentials.
|
||||||
|
Supports multiple credential instances per provider.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider identifier (e.g., "plugin_id/provider_name")
|
||||||
|
:param credential_type: Type of credential (oauth or api_key)
|
||||||
|
:param credentials: Credential data to encrypt and store
|
||||||
|
:param name: Optional name for this credential instance
|
||||||
|
:param expires_at: OAuth token expiration timestamp
|
||||||
|
:return: Success response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
with Session(db.engine, autoflush=False) as session:
|
||||||
|
# Use distributed lock to prevent race conditions
|
||||||
|
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||||
|
with redis_client.lock(lock_key, timeout=20):
|
||||||
|
# Check provider count limit
|
||||||
|
provider_count = (
|
||||||
|
session.query(TriggerSubscription)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
||||||
|
raise ValueError(
|
||||||
|
f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) "
|
||||||
|
f"reached for {provider_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if name already exists
|
||||||
|
existing = (
|
||||||
|
session.query(TriggerSubscription)
|
||||||
|
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||||
|
|
||||||
|
credential_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_credential_schema_config(credential_type),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
properties_encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=provider_controller.get_properties_schema(),
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create provider record
|
||||||
|
db_provider = TriggerSubscription(
|
||||||
|
id=subscription_id or str(uuid.uuid4()),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
name=name,
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
|
provider_id=str(provider_id),
|
||||||
|
parameters=parameters,
|
||||||
|
properties=properties_encrypter.encrypt(dict(properties)),
|
||||||
|
credentials=credential_encrypter.encrypt(dict(credentials)),
|
||||||
|
credential_type=credential_type.value,
|
||||||
|
credential_expires_at=credential_expires_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(db_provider)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return {"result": "success", "id": str(db_provider.id)}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to add trigger provider")
|
||||||
|
raise ValueError(str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_subscription_by_id(
|
||||||
|
cls, tenant_id: str, subscription_id: str | None = None
|
||||||
|
) -> TriggerProviderSubscriptionApiEntity | None:
|
||||||
|
"""
|
||||||
|
Get a trigger subscription by the ID.
|
||||||
|
"""
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
subscription: TriggerSubscription | None = None
|
||||||
|
if subscription_id:
|
||||||
|
subscription = (
|
||||||
|
session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first()
|
||||||
|
if subscription:
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(
|
||||||
|
tenant_id, TriggerProviderID(subscription.provider_id)
|
||||||
|
)
|
||||||
|
encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
subscription.credentials = encrypter.decrypt(subscription.credentials)
|
||||||
|
return subscription.to_api_entity()
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_trigger_provider(cls, session: Session, tenant_id: str, subscription_id: str):
|
||||||
|
"""
|
||||||
|
Delete a trigger provider subscription within an existing session.
|
||||||
|
|
||||||
|
:param session: Database session
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param subscription_id: Subscription instance ID
|
||||||
|
:return: Success response
|
||||||
|
"""
|
||||||
|
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||||
|
if not db_provider:
|
||||||
|
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
session.delete(db_provider)
|
||||||
|
delete_cache_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=db_provider.provider_id,
|
||||||
|
subscription_id=db_provider.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def refresh_oauth_token(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
subscription_id: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Refresh OAuth token for a trigger provider.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param subscription_id: Subscription instance ID
|
||||||
|
:return: New token info
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||||
|
|
||||||
|
if not db_provider:
|
||||||
|
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||||
|
|
||||||
|
if db_provider.credential_type != CredentialType.OAUTH2.value:
|
||||||
|
raise ValueError("Only OAuth credentials can be refreshed")
|
||||||
|
|
||||||
|
provider_id = TriggerProviderID(db_provider.provider_id)
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
# Create encrypter
|
||||||
|
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=db_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decrypt current credentials
|
||||||
|
current_credentials = encrypter.decrypt(db_provider.credentials)
|
||||||
|
|
||||||
|
# Get OAuth client configuration
|
||||||
|
redirect_uri = (
|
||||||
|
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback"
|
||||||
|
)
|
||||||
|
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
|
||||||
|
|
||||||
|
# Refresh token
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=db_provider.user_id,
|
||||||
|
plugin_id=provider_id.plugin_id,
|
||||||
|
provider=provider_id.provider_name,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
system_credentials=system_credentials or {},
|
||||||
|
credentials=current_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update credentials
|
||||||
|
db_provider.credentials = encrypter.encrypt(dict(refreshed_credentials.credentials))
|
||||||
|
db_provider.expires_at = refreshed_credentials.expires_at
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": "success",
|
||||||
|
"expires_at": refreshed_credentials.expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get OAuth client configuration for a provider.
|
||||||
|
First tries tenant-level OAuth, then falls back to system OAuth.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:return: OAuth client configuration or None
|
||||||
|
"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
with Session(db.engine, autoflush=False) as session:
|
||||||
|
tenant_client: TriggerOAuthTenantClient | None = (
|
||||||
|
session.query(TriggerOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider_id.provider_name,
|
||||||
|
plugin_id=provider_id.plugin_id,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth_params: Mapping[str, Any] | None = None
|
||||||
|
if tenant_client:
|
||||||
|
encrypter, _ = create_trigger_provider_oauth_encrypter(tenant_id, provider_controller)
|
||||||
|
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
|
||||||
|
return oauth_params
|
||||||
|
|
||||||
|
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||||
|
if not is_verified:
|
||||||
|
return oauth_params
|
||||||
|
|
||||||
|
# Check for system-level OAuth client
|
||||||
|
system_client: TriggerOAuthSystemClient | None = (
|
||||||
|
session.query(TriggerOAuthSystemClient)
|
||||||
|
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if system_client:
|
||||||
|
try:
|
||||||
|
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||||
|
|
||||||
|
return oauth_params
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_custom_oauth_client_params(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
client_params: Optional[dict] = None,
|
||||||
|
enabled: Optional[bool] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Save or update custom OAuth client parameters for a trigger provider.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:param client_params: OAuth client parameters (client_id, client_secret, etc.)
|
||||||
|
:param enabled: Enable/disable the custom OAuth client
|
||||||
|
:return: Success response
|
||||||
|
"""
|
||||||
|
if client_params is None and enabled is None:
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
# Get provider controller to access schema
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
# Find existing custom client params
|
||||||
|
custom_client = (
|
||||||
|
session.query(TriggerOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=provider_id.plugin_id,
|
||||||
|
provider=provider_id.provider_name,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new record if doesn't exist
|
||||||
|
if custom_client is None:
|
||||||
|
custom_client = TriggerOAuthTenantClient(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=provider_id.plugin_id,
|
||||||
|
provider=provider_id.provider_name,
|
||||||
|
)
|
||||||
|
session.add(custom_client)
|
||||||
|
|
||||||
|
# Update client params if provided
|
||||||
|
if client_params is not None:
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle hidden values
|
||||||
|
original_params = encrypter.decrypt(custom_client.oauth_params)
|
||||||
|
new_params: dict = {
|
||||||
|
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||||
|
for key, value in client_params.items()
|
||||||
|
}
|
||||||
|
custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
|
||||||
|
|
||||||
|
# Update enabled status if provided
|
||||||
|
if enabled is not None:
|
||||||
|
custom_client.enabled = enabled
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
||||||
|
"""
|
||||||
|
Get custom OAuth client parameters for a trigger provider.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:return: Masked OAuth client parameters
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
custom_client = (
|
||||||
|
session.query(TriggerOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=provider_id.plugin_id,
|
||||||
|
provider=provider_id.provider_name,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if custom_client is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Get provider controller to access schema
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
|
||||||
|
# Create encrypter to decrypt and mask values
|
||||||
|
encrypter, _ = create_provider_encrypter(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||||
|
cache=NoOpProviderCredentialCache(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
||||||
|
"""
|
||||||
|
Delete custom OAuth client parameters for a trigger provider.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:return: Success response
|
||||||
|
"""
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
session.query(TriggerOAuthTenantClient).filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider=provider_id.provider_name,
|
||||||
|
plugin_id=provider_id.plugin_id,
|
||||||
|
).delete()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
|
||||||
|
"""
|
||||||
|
Check if custom OAuth client is enabled for a trigger provider.
|
||||||
|
|
||||||
|
:param tenant_id: Tenant ID
|
||||||
|
:param provider_id: Provider identifier
|
||||||
|
:return: True if enabled, False otherwise
|
||||||
|
"""
|
||||||
|
with Session(db.engine, autoflush=False) as session:
|
||||||
|
custom_client = (
|
||||||
|
session.query(TriggerOAuthTenantClient)
|
||||||
|
.filter_by(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
plugin_id=provider_id.plugin_id,
|
||||||
|
provider=provider_id.provider_name,
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return custom_client is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
|
||||||
|
"""
|
||||||
|
Get a trigger subscription by the endpoint ID.
|
||||||
|
"""
|
||||||
|
with Session(db.engine, autoflush=False) as session:
|
||||||
|
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
||||||
|
if not subscription:
|
||||||
|
return None
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(
|
||||||
|
subscription.tenant_id, TriggerProviderID(subscription.provider_id)
|
||||||
|
)
|
||||||
|
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||||
|
tenant_id=subscription.tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
subscription.credentials = credential_encrypter.decrypt(subscription.credentials)
|
||||||
|
|
||||||
|
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||||
|
tenant_id=subscription.tenant_id,
|
||||||
|
controller=provider_controller,
|
||||||
|
subscription=subscription,
|
||||||
|
)
|
||||||
|
subscription.properties = properties_encrypter.decrypt(subscription.properties)
|
||||||
|
return subscription
|
||||||
317
api/services/trigger/trigger_subscription_builder_service.py
Normal file
317
api/services/trigger/trigger_subscription_builder_service.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import Request, Response
|
||||||
|
|
||||||
|
from core.plugin.entities.plugin import TriggerProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
|
||||||
|
from core.trigger.entities.entities import (
|
||||||
|
RequestLog,
|
||||||
|
SubscriptionBuilder,
|
||||||
|
SubscriptionBuilderUpdater,
|
||||||
|
)
|
||||||
|
from core.trigger.provider import PluginTriggerProviderController
|
||||||
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
|
from core.trigger.utils.encryption import masked_credentials
|
||||||
|
from core.trigger.utils.endpoint import parse_endpoint_id
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerSubscriptionBuilderService:
|
||||||
|
"""Service for managing trigger providers and credentials"""
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Trigger provider
|
||||||
|
##########################
|
||||||
|
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# Validation endpoint
|
||||||
|
##########################
|
||||||
|
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||||
|
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 30 * 60 * 1000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encode_cache_key(cls, subscription_id: str) -> str:
|
||||||
|
return f"trigger:subscription:validation:{subscription_id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def verify_trigger_subscription_builder(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription_builder_id: str,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Verify a trigger subscription builder"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||||
|
if not subscription_builder:
|
||||||
|
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||||
|
|
||||||
|
if subscription_builder.credential_type == CredentialType.OAUTH2:
|
||||||
|
return {"verified": bool(subscription_builder.credentials)}
|
||||||
|
|
||||||
|
if subscription_builder.credential_type == CredentialType.API_KEY:
|
||||||
|
credentials_to_validate = subscription_builder.credentials
|
||||||
|
try:
|
||||||
|
provider_controller.validate_credentials(user_id, credentials_to_validate)
|
||||||
|
except ToolProviderCredentialValidationError as e:
|
||||||
|
raise ValueError(f"Invalid credentials: {e}")
|
||||||
|
return {"verified": True}
|
||||||
|
|
||||||
|
return {"verified": True}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_trigger_subscription_builder(
|
||||||
|
cls, tenant_id: str, user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Build a trigger subscription builder"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||||
|
if not subscription_builder:
|
||||||
|
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||||
|
|
||||||
|
if not subscription_builder.name:
|
||||||
|
raise ValueError("Subscription builder name is required")
|
||||||
|
|
||||||
|
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||||
|
if credential_type == CredentialType.UNAUTHORIZED:
|
||||||
|
# manually create
|
||||||
|
TriggerProviderService.add_trigger_subscription(
|
||||||
|
subscription_id=subscription_builder.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
name=subscription_builder.name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
endpoint_id=subscription_builder.endpoint_id,
|
||||||
|
parameters=subscription_builder.parameters,
|
||||||
|
properties=subscription_builder.properties,
|
||||||
|
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||||
|
expires_at=subscription_builder.expires_at,
|
||||||
|
credentials=subscription_builder.credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# automatically create
|
||||||
|
subscription = TriggerManager.subscribe_trigger(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
endpoint=parse_endpoint_id(subscription_builder.endpoint_id),
|
||||||
|
parameters=subscription_builder.parameters,
|
||||||
|
credentials=subscription_builder.credentials,
|
||||||
|
)
|
||||||
|
|
||||||
|
TriggerProviderService.add_trigger_subscription(
|
||||||
|
subscription_id=subscription_builder.id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
name=subscription_builder.name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
endpoint_id=subscription_builder.endpoint_id,
|
||||||
|
parameters=subscription_builder.parameters,
|
||||||
|
properties=subscription.properties,
|
||||||
|
credentials=subscription_builder.credentials,
|
||||||
|
credential_type=credential_type,
|
||||||
|
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||||
|
expires_at=subscription_builder.expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.delete_trigger_subscription_builder(subscription_builder_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_trigger_subscription_builder(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
credential_type: CredentialType,
|
||||||
|
) -> SubscriptionBuilderApiEntity:
|
||||||
|
"""
|
||||||
|
Add a new trigger subscription validation.
|
||||||
|
"""
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
subscription_schema = provider_controller.get_subscription_schema()
|
||||||
|
subscription_id = str(uuid.uuid4())
|
||||||
|
subscription_builder = SubscriptionBuilder(
|
||||||
|
id=subscription_id,
|
||||||
|
name=None,
|
||||||
|
endpoint_id=subscription_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider_id=str(provider_id),
|
||||||
|
parameters=subscription_schema.get_default_parameters(),
|
||||||
|
properties=subscription_schema.get_default_properties(),
|
||||||
|
credentials={},
|
||||||
|
credential_type=credential_type,
|
||||||
|
credential_expires_at=-1,
|
||||||
|
expires_at=-1,
|
||||||
|
)
|
||||||
|
cache_key = cls.encode_cache_key(subscription_id)
|
||||||
|
redis_client.setex(
|
||||||
|
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
||||||
|
)
|
||||||
|
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_trigger_subscription_builder(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
provider_id: TriggerProviderID,
|
||||||
|
subscription_builder_id: str,
|
||||||
|
subscription_builder_updater: SubscriptionBuilderUpdater,
|
||||||
|
) -> SubscriptionBuilderApiEntity:
|
||||||
|
"""
|
||||||
|
Update a trigger subscription validation.
|
||||||
|
"""
|
||||||
|
subscription_id = subscription_builder_id
|
||||||
|
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||||
|
if not provider_controller:
|
||||||
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
cache_key = cls.encode_cache_key(subscription_id)
|
||||||
|
subscription_builder_cache = cls.get_subscription_builder(subscription_builder_id)
|
||||||
|
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||||
|
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||||
|
|
||||||
|
subscription_builder_updater.update(subscription_builder_cache)
|
||||||
|
|
||||||
|
redis_client.setex(
|
||||||
|
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder_cache.model_dump_json()
|
||||||
|
)
|
||||||
|
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder_cache)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def builder_to_api_entity(
|
||||||
|
cls, controller: PluginTriggerProviderController, entity: SubscriptionBuilder
|
||||||
|
) -> SubscriptionBuilderApiEntity:
|
||||||
|
credential_type = CredentialType.of(entity.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||||
|
return SubscriptionBuilderApiEntity(
|
||||||
|
id=entity.id,
|
||||||
|
name=entity.name or "",
|
||||||
|
provider=entity.provider_id,
|
||||||
|
endpoint=parse_endpoint_id(entity.endpoint_id),
|
||||||
|
parameters=entity.parameters,
|
||||||
|
properties=entity.properties,
|
||||||
|
credential_type=credential_type,
|
||||||
|
credentials=masked_credentials(
|
||||||
|
schemas=controller.get_credentials_schema(credential_type),
|
||||||
|
credentials=entity.credentials,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_trigger_subscription_builder(cls, subscription_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete a trigger subscription validation.
|
||||||
|
"""
|
||||||
|
cache_key = cls.encode_cache_key(subscription_id)
|
||||||
|
redis_client.delete(cache_key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
|
||||||
|
"""
|
||||||
|
Get a trigger subscription by the endpoint ID.
|
||||||
|
"""
|
||||||
|
cache_key = cls.encode_cache_key(endpoint_id)
|
||||||
|
subscription_cache = redis_client.get(cache_key)
|
||||||
|
if subscription_cache:
|
||||||
|
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def append_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||||
|
"""Append validation request log to Redis."""
|
||||||
|
log = RequestLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
endpoint=endpoint_id,
|
||||||
|
request={
|
||||||
|
"method": request.method,
|
||||||
|
"url": request.url,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"data": request.get_data(as_text=True),
|
||||||
|
},
|
||||||
|
response={
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"headers": dict(response.headers),
|
||||||
|
"data": response.get_data(as_text=True),
|
||||||
|
},
|
||||||
|
created_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
key = f"trigger:subscription:validation:logs:{endpoint_id}"
|
||||||
|
logs = json.loads(redis_client.get(key) or "[]")
|
||||||
|
logs.append(log.model_dump(mode="json"))
|
||||||
|
|
||||||
|
# Keep last N logs
|
||||||
|
logs = logs[-cls.__VALIDATION_REQUEST_CACHE_COUNT__ :]
|
||||||
|
redis_client.setex(key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, json.dumps(logs, default=str))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||||
|
"""List request logs for validation endpoint."""
|
||||||
|
key = f"trigger:subscription:validation:logs:{endpoint_id}"
|
||||||
|
logs_json = redis_client.get(key)
|
||||||
|
if not logs_json:
|
||||||
|
return []
|
||||||
|
return [RequestLog.model_validate(log) for log in json.loads(logs_json)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||||
|
"""
|
||||||
|
Process a temporary endpoint request.
|
||||||
|
|
||||||
|
:param endpoint_id: The endpoint identifier
|
||||||
|
:param request: The Flask request object
|
||||||
|
:return: The Flask response object
|
||||||
|
"""
|
||||||
|
# check if validation endpoint exists
|
||||||
|
subscription_builder = cls.get_subscription_builder(endpoint_id)
|
||||||
|
if not subscription_builder:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# response to validation endpoint
|
||||||
|
controller = TriggerManager.get_trigger_provider(
|
||||||
|
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
|
||||||
|
)
|
||||||
|
response = controller.dispatch(
|
||||||
|
user_id=subscription_builder.user_id,
|
||||||
|
request=request,
|
||||||
|
subscription=subscription_builder.to_subscription(),
|
||||||
|
)
|
||||||
|
# append the request log
|
||||||
|
cls.append_log(endpoint_id, request, response.response)
|
||||||
|
return response.response
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_subscription_builder_by_id(cls, subscription_builder_id: str) -> SubscriptionBuilderApiEntity:
|
||||||
|
"""Get a trigger subscription builder API entity."""
|
||||||
|
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||||
|
if not subscription_builder:
|
||||||
|
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||||
|
return cls.builder_to_api_entity(
|
||||||
|
controller=TriggerManager.get_trigger_provider(
|
||||||
|
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
|
||||||
|
),
|
||||||
|
entity=subscription_builder,
|
||||||
|
)
|
||||||
126
api/services/trigger_debug_service.py
Normal file
126
api/services/trigger_debug_service.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
"""Trigger debug service for webhook debugging in draft workflows."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from redis import RedisError
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TRIGGER_DEBUG_EVENT_TTL = 300
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerDebugEvent(BaseModel):
|
||||||
|
subscription_id: str
|
||||||
|
request_id: str
|
||||||
|
timestamp: int
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerDebugService:
|
||||||
|
"""
|
||||||
|
Redis-based trigger debug service with polling support.
|
||||||
|
Uses {tenant_id} hash tags for Redis Cluster compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# LUA_SELECT: Atomic poll or register for event
|
||||||
|
# KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id}
|
||||||
|
# KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:{subscription_id}:{trigger}
|
||||||
|
# ARGV[1] = address_id
|
||||||
|
# compressed lua code, you can use LLM to uncompress it
|
||||||
|
LUA_SELECT = (
|
||||||
|
"local v=redis.call('GET',KEYS[1]);"
|
||||||
|
"if v then redis.call('DEL',KEYS[1]);return v end;"
|
||||||
|
"redis.call('SADD',KEYS[2],ARGV[1]);"
|
||||||
|
f"redis.call('EXPIRE',KEYS[2],{TRIGGER_DEBUG_EVENT_TTL});"
|
||||||
|
"return false"
|
||||||
|
)
|
||||||
|
|
||||||
|
# LUA_DISPATCH: Dispatch event to all waiting addresses
|
||||||
|
# KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:{subscription_id}:{trigger}
|
||||||
|
# ARGV[1] = tenant_id
|
||||||
|
# ARGV[2] = event_json
|
||||||
|
# compressed lua code, you can use LLM to uncompress it
|
||||||
|
LUA_DISPATCH = (
|
||||||
|
"local a=redis.call('SMEMBERS',KEYS[1]);"
|
||||||
|
"if #a==0 then return 0 end;"
|
||||||
|
"redis.call('DEL',KEYS[1]);"
|
||||||
|
"for i=1,#a do "
|
||||||
|
f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
|
||||||
|
"end;"
|
||||||
|
"return #a"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def address(cls, tenant_id: str, user_id: str, app_id: str, node_id: str) -> str:
|
||||||
|
address_id = hashlib.sha1(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
|
||||||
|
return f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def waiting_pool(cls, tenant_id: str, subscription_id: str, trigger_name: str) -> str:
|
||||||
|
return f"trigger_debug_waiting_pool:{{{tenant_id}}}:{subscription_id}:{trigger_name}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dispatch_debug_event(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
subscription_id: str,
|
||||||
|
triggers: list[str],
|
||||||
|
request_id: str,
|
||||||
|
timestamp: int,
|
||||||
|
) -> int:
|
||||||
|
event_json = TriggerDebugEvent(
|
||||||
|
subscription_id=subscription_id,
|
||||||
|
request_id=request_id,
|
||||||
|
timestamp=timestamp,
|
||||||
|
).model_dump_json()
|
||||||
|
|
||||||
|
dispatched = 0
|
||||||
|
if len(triggers) > 10:
|
||||||
|
logger.warning(
|
||||||
|
"Too many triggers to dispatch at once: %d triggers tenant: %s subscription: %s",
|
||||||
|
len(triggers),
|
||||||
|
tenant_id,
|
||||||
|
subscription_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
for trigger_name in triggers:
|
||||||
|
try:
|
||||||
|
dispatched += redis_client.eval(
|
||||||
|
cls.LUA_DISPATCH,
|
||||||
|
1,
|
||||||
|
cls.waiting_pool(tenant_id, subscription_id, trigger_name),
|
||||||
|
tenant_id,
|
||||||
|
event_json,
|
||||||
|
)
|
||||||
|
except RedisError:
|
||||||
|
logger.exception("Failed to dispatch for trigger: %s", trigger_name)
|
||||||
|
return dispatched
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def poll_event(
|
||||||
|
cls,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
app_id: str,
|
||||||
|
subscription_id: str,
|
||||||
|
node_id: str,
|
||||||
|
trigger_name: str,
|
||||||
|
) -> Optional[TriggerDebugEvent]:
|
||||||
|
address_id = hashlib.sha1(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
|
||||||
|
|
||||||
|
try:
|
||||||
|
event = redis_client.eval(
|
||||||
|
cls.LUA_SELECT,
|
||||||
|
2,
|
||||||
|
cls.address(tenant_id, user_id, app_id, node_id),
|
||||||
|
cls.waiting_pool(tenant_id, subscription_id, trigger_name),
|
||||||
|
address_id,
|
||||||
|
)
|
||||||
|
return TriggerDebugEvent.model_validate_json(event) if event else None
|
||||||
|
except RedisError:
|
||||||
|
logger.exception("Failed to poll debug event")
|
||||||
|
return None
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user