This commit is contained in:
jyong
2025-05-15 15:14:52 +08:00
parent 3f1363503b
commit 818eb46a8b
21 changed files with 2117 additions and 149 deletions

View File

@@ -0,0 +1,170 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
parser.add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
try:
import_info = DatasetService.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return import_info, 201
class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
args = parser.parse_args()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=args,
)
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@@ -0,0 +1,147 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return result.model_dump(mode="json"), 200
class RagPipelineExportApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=bool, default=False, location="args")
args = parser.parse_args()
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

View File

@@ -4,6 +4,7 @@ from typing import cast
from flask import abort, request
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -23,12 +24,18 @@ from controllers.console.wraps import (
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from fields.workflow_run_fields import (
workflow_run_detail_fields,
workflow_run_node_execution_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs import helper
from libs.helper import TimestampField
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
@@ -36,6 +43,7 @@ from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
logger = logging.getLogger(__name__)
@@ -461,45 +469,6 @@ class DefaultRagPipelineBlockConfigApi(Resource):
rag_pipeline_service = RagPipelineService()
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
class ConvertToRagPipelineApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline):
"""
Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
if request.data:
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
else:
args = {}
# convert to workflow mode
rag_pipeline_service = RagPipelineService()
new_app_model = rag_pipeline_service.convert_to_workflow(pipeline=pipeline, account=current_user, args=args)
# return app id
return {
"new_app_id": new_app_model.id,
}
class RagPipelineConfigApi(Resource):
"""Resource for rag pipeline configuration."""
@@ -674,6 +643,85 @@ class RagPipelineSecondStepApi(Resource):
)
class RagPipelineWorkflowRunListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_pagination_fields)
def get(self, pipeline: Pipeline):
"""
Get workflow run list
"""
parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args")
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
return result
class RagPipelineWorkflowRunDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_detail_fields)
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
"""
run_id = str(run_id)
rag_pipeline_service = RagPipelineService()
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id)
return workflow_run
class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields)
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run node execution list
"""
run_id = str(run_id)
rag_pipeline_service = RagPipelineService()
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
pipeline=pipeline,
run_id=run_id,
)
return {"data": node_executions}
class DatasourceListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
provider.to_dict()
for provider in BuiltinToolManageService.list_rag_pipeline_datasources(
tenant_id,
)
]
)
api.add_resource(
DraftRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
@@ -694,10 +742,10 @@ api.add_resource(
RagPipelineDraftNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
RagPipelinePublishedNodeRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
)
# api.add_resource(
# RagPipelinePublishedNodeRunApi,
# "/rag/pipelines/<uuid:pipeline_id>/workflows/published/nodes/<string:node_id>/run",
# )
api.add_resource(
RagPipelineDraftRunIterationNodeApi,
@@ -724,11 +772,24 @@ api.add_resource(
DefaultRagPipelineBlockConfigApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/convert-to-workflow",
)
api.add_resource(
RagPipelineByIdApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
)
api.add_resource(
RagPipelineWorkflowRunListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
)
api.add_resource(
RagPipelineWorkflowRunDetailApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
)
api.add_resource(
RagPipelineWorkflowRunNodeExecutionListApi,
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
)
api.add_resource(
DatasourceListApi,
"/rag/pipelines/datasources",
)

View File

@@ -2,13 +2,13 @@ from collections.abc import Generator
from typing import Any, Optional
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceInvokeMessage,
DatasourceParameter,
DatasourceProviderType,
)
from core.plugin.manager.datasource import PluginDatasourceManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@@ -44,7 +44,7 @@ class DatasourcePlugin:
datasource_parameters: dict[str, Any],
rag_pipeline_id: Optional[str] = None,
) -> Generator[DatasourceInvokeMessage, None, None]:
manager = PluginDatasourceManager()
manager = DatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)
@@ -64,7 +64,7 @@ class DatasourcePlugin:
datasource_parameters: dict[str, Any],
rag_pipeline_id: Optional[str] = None,
) -> Generator[DatasourceInvokeMessage, None, None]:
manager = PluginDatasourceManager()
manager = DatasourceManager()
datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters)

View File

@@ -4,12 +4,11 @@ from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.manager.tool import PluginToolManager
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.plugin.impl.tool import PluginToolManager
from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(BuiltinToolProviderController):
class DatasourcePluginProviderController:
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
plugin_id: str
@@ -32,12 +31,21 @@ class DatasourcePluginProviderController(BuiltinToolProviderController):
"""
return DatasourceProviderType.RAG_PIPELINE
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
manager = PluginToolManager()
if not manager.validate_provider_credentials(
if not manager.validate_datasource_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=self.entity.identity.name,
@@ -69,7 +77,7 @@ class DatasourcePluginProviderController(BuiltinToolProviderController):
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_datasources(self) -> list[DatasourceTool]: # type: ignore
def get_datasources(self) -> list[DatasourcePlugin]: # type: ignore
"""
get all datasources
"""

View File

@@ -0,0 +1,73 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
class DatasourceApiEntity(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]] = None
labels: list[str] = Field(default_factory=list)
output_schema: Optional[dict] = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
class DatasourceProviderApiEntity(BaseModel):
id: str
author: str
name: str # identifier
description: I18nObject
icon: str | dict
label: I18nObject # label
type: ToolProviderType
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@field_validator("datasources", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
return {
"id": self.id,
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"datasources": datasources,
"labels": self.labels,
}

View File

@@ -5,6 +5,7 @@ from typing import Generic, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, Field
from core.agent.plugin_entities import AgentProviderEntityWithPlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
@@ -46,6 +47,13 @@ class PluginToolProviderEntity(BaseModel):
declaration: ToolProviderEntityWithPlugin
class PluginDatasourceProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
declaration: DatasourceProviderEntityWithPlugin
class PluginAgentProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str

View File

@@ -4,7 +4,11 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginDatasourceProviderEntity,
PluginToolProviderEntity,
)
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@@ -41,6 +45,37 @@ class PluginToolManager(BasePluginClient):
return response
def fetch_datasources(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
Fetch datasources for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/datasources",
list[PluginToolProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.tools:
tool.identity.provider = provider.declaration.identity.name
return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
@@ -197,6 +232,36 @@ class PluginToolManager(BasePluginClient):
return False
def validate_datasource_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the datasource
"""
tool_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": tool_provider_id.provider_name,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
return False
def get_runtime_parameters(
self,
tenant_id: str,

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL
import contexts
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@@ -495,6 +496,31 @@ class ToolManager:
# get plugin providers
yield from cls.list_plugin_providers(tenant_id)
@classmethod
def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]:
"""
list all the datasource providers
"""
manager = PluginToolManager()
provider_entities = manager.fetch_datasources(tenant_id)
return [
DatasourcePluginProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
for provider in provider_entities
]
@classmethod
def list_builtin_datasources(cls, tenant_id: str) -> Generator[DatasourcePluginProviderController, None, None]:
"""
list all the builtin datasources
"""
# get builtin datasources
yield from cls.list_datasource_providers(tenant_id)
@classmethod
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
"""

View File

@@ -10,8 +10,8 @@ class RerankingModelConfig(BaseModel):
Reranking Model Config.
"""
provider: str
model: str
reranking_provider_name: str
reranking_model_name: str
class VectorSetting(BaseModel):

View File

@@ -56,6 +56,8 @@ external_knowledge_info_fields = {
doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
icon_info_fields = {"icon_type": fields.String, "icon": fields.String, "icon_background": fields.String}
dataset_detail_fields = {
"id": fields.String,
"name": fields.String,
@@ -81,6 +83,10 @@ dataset_detail_fields = {
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
"built_in_field_enabled": fields.Boolean,
"pipeline_id": fields.String,
"runtime_mode": fields.String,
"chunk_structure": fields.String,
"icon_info": fields.Nested(icon_info_fields),
}
dataset_query_detail_fields = {

View File

@@ -0,0 +1,163 @@
from flask_restful import fields # type: ignore
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
pipeline_detail_kernel_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
}
related_app_list = {
"data": fields.List(fields.Nested(pipeline_detail_kernel_fields)),
"total": fields.Integer,
}
app_detail_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"tracing": fields.Raw,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
}
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
app_partial_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String(attribute="desc_or_prompt"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_fields)),
}
app_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
}
template_fields = {
"name": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"description": fields.String,
"mode": fields.String,
}
template_list_fields = {
"data": fields.List(fields.Nested(template_fields)),
}
site_fields = {
"access_token": fields.String(attribute="code"),
"code": fields.String,
"title": fields.String,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"default_language": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"customize_domain": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"customize_token_strategy": fields.String,
"prompt_public": fields.Boolean,
"app_base_url": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
}
deleted_tool_fields = {
"type": fields.String,
"tool_name": fields.String,
"provider_id": fields.String,
}
app_detail_fields_with_site = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
}
app_site_fields = {
"app_id": fields.String,
"access_token": fields.String(attribute="code"),
"code": fields.String,
"title": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"description": fields.String,
"default_language": fields.String,
"customize_domain": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"customize_token_strategy": fields.String,
"prompt_public": fields.Boolean,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String}
pipeline_import_fields = {
"id": fields.String,
"status": fields.String,
"pipeline_id": fields.String,
"current_dsl_version": fields.String,
"imported_dsl_version": fields.String,
"error": fields.String,
}
pipeline_import_check_dependencies_fields = {
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
}

View File

@@ -63,6 +63,10 @@ class Dataset(db.Model): # type: ignore[name-defined]
collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
icon_info = db.Column(JSONB, nullable=True)
runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
pipeline_id = db.Column(StringUUID, nullable=True)
chunk_structure = db.Column(db.String(255), nullable=True)
@property
def dataset_keyword_table(self):

View File

@@ -51,6 +51,40 @@ class BuiltinToolProvider(Base):
return cast(dict, json.loads(self.encrypted_credentials))
class BuiltinDatasourceProvider(Base):
"""
This table stores the datasource provider information for built-in datasources for each tenant.
"""
__tablename__ = "tool_builtin_datasource_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_datasource_provider_pkey"),
# one tenant can only have one tool provider with the same name
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_datasource_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# name of the tool provider
provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
# credential of the tool provider
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
@property
def credentials(self) -> dict:
return cast(dict, json.loads(self.encrypted_credentials))
class ApiToolProvider(Base):
"""
The table stores the api providers.

View File

@@ -52,6 +52,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
RetrievalModel,
SegmentUpdateArgs,
)
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.errors.account import InvalidActionError, NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
@@ -59,6 +60,7 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.rag_pipeline.rag_pipeline_dsl_service import ImportMode, RagPipelineDslService, RagPipelineImportInfo
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.batch_clean_document_task import batch_clean_document_task
@@ -235,6 +237,63 @@ class DatasetService:
db.session.commit()
return dataset
@staticmethod
def create_empty_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
db.session.add(dataset)
db.session.commit()
return dataset
@staticmethod
def create_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
# check if dataset name already exists
if Dataset.query.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
dataset = Dataset(
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
if rag_pipeline_dataset_create_entity.yaml_content:
rag_pipeline_import_info: RagPipelineImportInfo = RagPipelineDslService.import_rag_pipeline(
current_user, ImportMode.YAML_CONTENT, rag_pipeline_dataset_create_entity.yaml_content, dataset
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": dataset.id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first()

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel
@@ -14,3 +14,100 @@ class PipelineTemplateInfoEntity(BaseModel):
name: str
description: str
icon_info: IconInfo
class RagPipelineDatasetCreateEntity(BaseModel):
name: str
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[str]
yaml_content: str
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str
reranking_model_name: str
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting
keyword_setting: KeywordSetting
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class KnowledgeConfiguration(BaseModel):
"""
Knowledge Configuration.
"""
chunk_structure: str
index_method: IndexMethod
retrieval_setting: RetrievalSetting

View File

@@ -1,28 +1,45 @@
import json
import threading
import time
from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime
from typing import Any, Literal, Optional
from uuid import uuid4
from flask_login import current_user
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repository.repository_factory import RepositoryFactory
from core.repository.workflow_node_execution_repository import OrderConfig
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account
from models.dataset import Pipeline, PipelineBuiltInTemplate, PipelineCustomizedTemplate # type: ignore
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowType
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowType,
)
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.errors.app import WorkflowHashNotEqualError
from services.errors.workflow_service import DraftWorkflowDeletionError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
@@ -180,7 +197,6 @@ class RagPipelineService:
*,
pipeline: Pipeline,
graph: dict,
features: dict,
unique_hash: Optional[str],
account: Account,
environment_variables: Sequence[Variable],
@@ -197,9 +213,6 @@ class RagPipelineService:
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# validate features structure
self.validate_features_structure(pipeline=pipeline, features=features)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
@@ -208,7 +221,6 @@ class RagPipelineService:
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
features=json.dumps(features),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
@@ -218,7 +230,6 @@ class RagPipelineService:
# update draft workflow if found
else:
workflow.graph = json.dumps(graph)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
@@ -227,8 +238,8 @@ class RagPipelineService:
# commit db session changes
db.session.commit()
# trigger app workflow events
app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
# trigger workflow events TODO
# app_draft_workflow_was_synced.send(pipeline, synced_draft_workflow=workflow)
# return draft workflow
return workflow
@@ -269,8 +280,8 @@ class RagPipelineService:
# commit db session changes
session.add(workflow)
# trigger app workflow events
app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
# trigger app workflow events TODO
# app_published_workflow_was_updated.send(pipeline, published_workflow=workflow)
# return new workflow
return workflow
@@ -508,46 +519,6 @@ class RagPipelineService:
return workflow_node_execution
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
"""
Basic mode of chatbot app(expert mode) to workflow
Completion App to Workflow App
:param app_model: App instance
:param account: Account instance
:param args: dict
:return:
"""
# chatbot convert to workflow mode
workflow_converter = WorkflowConverter()
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
new_app: App = workflow_converter.convert_to_workflow(
app_model=app_model,
account=account,
name=args.get("name", "Default Name"),
icon_type=args.get("icon_type", "emoji"),
icon=args.get("icon", "🤖"),
icon_background=args.get("icon_background", "#FFEAD5"),
)
return new_app
def validate_features_structure(self, app_model: App, features: dict) -> dict:
if app_model.mode == AppMode.ADVANCED_CHAT.value:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW.value:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
else:
raise ValueError(f"Invalid app mode: {app_model.mode}")
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
@@ -578,38 +549,6 @@ class RagPipelineService:
return workflow
def delete_workflow(self, *, session: Session, workflow_id: str, tenant_id: str) -> bool:
"""
Delete a workflow
:param session: SQLAlchemy database session
:param workflow_id: Workflow ID
:param tenant_id: Tenant ID
:return: True if successful
:raises: ValueError if workflow not found
:raises: WorkflowInUseError if workflow is in use
:raises: DraftWorkflowDeletionError if workflow is a draft version
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError(f"Workflow with ID {workflow_id} not found")
# Check if workflow is a draft version
if workflow.version == "draft":
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt)
if app:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
session.delete(workflow)
return True
def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict:
"""
Get second step parameters of rag pipeline
@@ -627,3 +566,101 @@ class RagPipelineService:
datasource_provider_variables = pipeline_variables.get(datasource_provider, [])
shared_variables = pipeline_variables.get("shared", [])
return datasource_provider_variables + shared_variables
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
:param app_model: app model
:param args: request args
"""
limit = int(args.get("limit", 20))
base_query = db.session.query(WorkflowRun).filter(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
)
if args.get("last_id"):
last_workflow_run = base_query.filter(
WorkflowRun.id == args.get("last_id"),
).first()
if not last_workflow_run:
raise ValueError("Last workflow run not exists")
workflow_runs = (
base_query.filter(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
.limit(limit)
.all()
)
else:
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
has_more = False
if len(workflow_runs) == limit:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
if rest_count > 0:
has_more = True
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
:param app_model: app model
:param run_id: workflow run id
"""
workflow_run = (
db.session.query(WorkflowRun)
.filter(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
WorkflowRun.id == run_id,
)
.first()
)
return workflow_run
def get_rag_pipeline_workflow_run_node_executions(
self,
pipeline: Pipeline,
run_id: str,
) -> list[WorkflowNodeExecution]:
"""
Get workflow run node execution list
"""
workflow_run = self.get_rag_pipeline_workflow_run(pipeline, run_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if not workflow_run:
return []
# Use the repository to get the node executions
repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": pipeline.tenant_id,
"app_id": pipeline.id,
"session_factory": db.session.get_bind(),
}
)
# Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
return list(node_executions)

View File

@@ -0,0 +1,841 @@
import base64
import hashlib
import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, Pipeline
from models.workflow import Workflow
from services.dataset_service import DatasetCollectionBindingService
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
from services.plugin.dependencies_analysis import DependenciesAnalysisService
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.1.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class RagPipelineImportInfo(BaseModel):
id: str
status: ImportStatus
pipeline_id: Optional[str] = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
dataset_id: Optional[str] = None
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
current_ver = version.parse(CURRENT_DSL_VERSION)
imported_ver = version.parse(imported_version)
except version.InvalidVersion:
return ImportStatus.FAILED
# If imported version is newer than current, always return PENDING
if imported_ver > current_ver:
return ImportStatus.PENDING
# If imported version is older than current's major, return PENDING
if imported_ver.major < current_ver.major:
return ImportStatus.PENDING
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
if imported_ver.minor < current_ver.minor:
return ImportStatus.COMPLETED_WITH_WARNINGS
# If imported version equals or is older than current's micro, return COMPLETED
return ImportStatus.COMPLETED
class RagPipelinePendingData(BaseModel):
import_mode: str
yaml_content: str
name: str | None
description: str | None
icon_type: str | None
icon: str | None
icon_background: str | None
pipeline_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
pipeline_id: str | None
class RagPipelineDslService:
def __init__(self, session: Session):
self._session = session
def import_rag_pipeline(
self,
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
pipeline_id: Optional[str] = None,
dataset: Optional[Dataset] = None,
) -> RagPipelineImportInfo:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
# Validate import mode
try:
mode = ImportMode(import_mode)
except ValueError:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_url is required when import_mode is yaml-url",
)
try:
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="File size exceeds the limit of 10MB",
)
if not content:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Empty content from url",
)
except Exception as e:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Error fetching YAML from URL: {str(e)}",
)
elif mode == ImportMode.YAML_CONTENT:
if not yaml_content:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_content is required when import_mode is yaml-content",
)
content = yaml_content
# Process YAML content
try:
# Parse YAML to validate format
data = yaml.safe_load(content)
if not isinstance(data, dict):
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid YAML format: content must be a mapping",
)
# Validate and fix DSL version
if not data.get("version"):
data["version"] = "0.1.0"
if not data.get("kind") or data.get("kind") != "rag-pipeline":
data["kind"] = "rag-pipeline"
imported_version = data.get("version", "0.1.0")
# check if imported_version is a float-like string
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version)
# Extract app data
pipeline_data = data.get("pipeline")
if not pipeline_data:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing pipeline data in YAML content",
)
# If app_id is provided, check if it exists
pipeline = None
if pipeline_id:
stmt = select(Pipeline).where(
Pipeline.id == pipeline_id,
Pipeline.tenant_id == account.current_tenant_id,
)
pipeline = self._session.scalar(stmt)
if not pipeline:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Pipeline not found",
)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
pending_data = RagPipelinePendingData(
import_mode=import_mode,
yaml_content=content,
pipeline_id=pipeline_id,
)
redis_client.setex(
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
IMPORT_INFO_REDIS_EXPIRY,
pending_data.model_dump_json(),
)
return RagPipelineImportInfo(
id=import_id,
status=status,
pipeline_id=pipeline_id,
imported_dsl_version=imported_version,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
# Create or update app
pipeline = self._create_or_update_pipeline(
pipeline=pipeline,
data=data,
account=account,
dependencies=check_dependencies_pending_data,
)
# create dataset
name = pipeline.name
description = pipeline.description
icon_type = data.get("rag_pipeline", {}).get("icon_type")
icon = data.get("rag_pipeline", {}).get("icon")
icon_background = data.get("rag_pipeline", {}).get("icon_background")
icon_url = data.get("rag_pipeline", {}).get("icon_url")
workflow = data.get("workflow", {})
graph = workflow.get("graph", {})
nodes = graph.get("nodes", [])
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if not dataset:
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=name,
description=description,
icon_info={
"type": icon_type,
"icon": icon,
"background": icon_background,
"url": icon_url,
},
indexing_technique=knowledge_configuration.index_method.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
)
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.index_method.embedding_setting.embedding_model_name
)
dataset.embedding_model_provider = (
knowledge_configuration.index_method.embedding_setting.embedding_provider_name
)
elif knowledge_configuration.index_method.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
dataset_id = dataset.id
if not dataset_id:
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
return RagPipelineImportInfo(
id=import_id,
status=status,
pipeline_id=pipeline.id,
dataset_id=dataset_id,
imported_dsl_version=imported_version,
)
except yaml.YAMLError as e:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid YAML format: {str(e)}",
)
except Exception as e:
logger.exception("Failed to import app")
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def confirm_import(self, *, import_id: str, account: Account) -> RagPipelineImportInfo:
"""
Confirm an import that requires confirmation
"""
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
pending_data = redis_client.get(redis_key)
if not pending_data:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Import information expired or does not exist",
)
try:
if not isinstance(pending_data, str | bytes):
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid import information",
)
pending_data = RagPipelinePendingData.model_validate_json(pending_data)
data = yaml.safe_load(pending_data.yaml_content)
pipeline = None
if pending_data.pipeline_id:
stmt = select(Pipeline).where(
Pipeline.id == pending_data.pipeline_id,
Pipeline.tenant_id == account.current_tenant_id,
)
pipeline = self._session.scalar(stmt)
# Create or update app
pipeline = self._create_or_update_pipeline(
pipeline=pipeline,
data=data,
account=account,
)
# create dataset
name = pipeline.name
description = pipeline.description
icon_type = data.get("rag_pipeline", {}).get("icon_type")
icon = data.get("rag_pipeline", {}).get("icon")
icon_background = data.get("rag_pipeline", {}).get("icon_background")
icon_url = data.get("rag_pipeline", {}).get("icon_url")
workflow = data.get("workflow", {})
graph = workflow.get("graph", {})
nodes = graph.get("nodes", [])
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge_index":
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
if not dataset:
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=name,
description=description,
icon_info={
"type": icon_type,
"icon": icon,
"background": icon_background,
"url": icon_url,
},
indexing_technique=knowledge_configuration.index_method.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.index_method.indexing_technique == "high_quality":
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
)
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = (
knowledge_configuration.index_method.embedding_setting.embedding_model_name
)
dataset.embedding_model_provider = (
knowledge_configuration.index_method.embedding_setting.embedding_provider_name
)
elif knowledge_configuration.index_method.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
dataset_id = dataset.id
if not dataset_id:
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
# Delete import info from Redis
redis_client.delete(redis_key)
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.COMPLETED,
pipeline_id=pipeline.id,
dataset_id=dataset_id,
current_dsl_version=CURRENT_DSL_VERSION,
imported_dsl_version=data.get("version", "0.1.0"),
)
except Exception as e:
logger.exception("Error confirming import")
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def check_dependencies(
self,
*,
pipeline: Pipeline,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()
# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)
def _create_or_update_pipeline(
self,
*,
pipeline: Optional[Pipeline],
data: dict,
account: Account,
dependencies: Optional[list[PluginDependency]] = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
pipeline_data = data.get("pipeline", {})
pipeline_mode = pipeline_data.get("mode")
if not pipeline_mode:
raise ValueError("loss pipeline mode")
# Set icon type
icon_type_value = icon_type or pipeline_data.get("icon_type")
if icon_type_value in ["emoji", "link"]:
icon_type = icon_type_value
else:
icon_type = "emoji"
icon = icon or str(pipeline_data.get("icon", ""))
if pipeline:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.icon_type = icon_type
pipeline.icon = icon
pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id
pipeline.mode = pipeline_mode.value
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.icon_type = icon_type
pipeline.icon = icon
pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF")
pipeline.enable_site = True
pipeline.enable_api = True
pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False)
pipeline.created_by = account.id
pipeline.updated_by = account.id
self._session.add(pipeline)
self._session.commit()
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for rag pipeline")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
rag_pipeline_variables = [
variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list
]
rag_pipeline_service = RagPipelineService()
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (
decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id,
tenant_id=pipeline.tenant_id,
)
)
]
rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
return pipeline
@classmethod
def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str:
"""
Export pipeline
:param pipeline: Pipeline instance
:param include_secret: Whether include secret variable
:return:
"""
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "rag_pipeline",
"pipeline": {
"name": pipeline.name,
"mode": pipeline.mode,
"icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon,
"icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background,
"description": pipeline.description,
"use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon,
},
}
cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
@classmethod
def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
"""
Append workflow export data
:param export_data: export data
:param pipeline: Pipeline instance
"""
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
for node in workflow_dict.get("graph", {}).get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
for dataset_id in dataset_ids
]
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies
)
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None:
"""
Append model config export data
:param export_data: export data
:param pipeline: Pipeline instance
"""
app_model_config = pipeline.app_model_config
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data["model_config"] = app_model_config.to_dict()
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies
)
]
@classmethod
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
return dependencies
@classmethod
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
for node in graph.get("nodes", []):
try:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "reranking_model"
):
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
),
)
elif (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "weighted_score"
):
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
vector_setting = (
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
)
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
vector_setting.embedding_provider_name
),
)
elif knowledge_retrieval_entity.retrieval_mode == "single":
model_config = knowledge_retrieval_entity.single_retrieval_config
if model_config:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
model_config.model.provider
),
)
case _:
# TODO: Handle default case or unknown node types
pass
except Exception as e:
logger.exception("Error extracting node dependency", exc_info=e)
return dependencies
@classmethod
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
"""
Extract dependencies from model config
:param model_config: model config dict
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
try:
# completion model
model_dict = model_config.get("model", {})
if model_dict:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
)
# reranking model
dataset_configs = model_config.get("dataset_configs", {})
if dataset_configs:
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
if dataset_config.get("reranking_model"):
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
dataset_config.get("reranking_model", {})
.get("reranking_provider_name", {})
.get("provider")
)
)
# tools
agent_configs = model_config.get("agent_mode", {})
if agent_configs:
for agent_config in agent_configs.get("tools", []):
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
)
except Exception as e:
logger.exception("Error extracting model config dependency", exc_info=e)
return dependencies
@classmethod
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
if not dependencies:
return []
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
@staticmethod
def _generate_aes_key(tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()
@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()
@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
except Exception:
return None

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from sqlalchemy.orm import Session
from configs import dify_config
from core.datasource.entities.api_entities import DatasourceProviderApiEntity
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
@@ -16,7 +17,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from models.tools import BuiltinDatasourceProvider, BuiltinToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@@ -286,6 +287,67 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def list_rag_pipeline_datasources(tenant_id: str) -> list[DatasourceProviderApiEntity]:
"""
list rag pipeline datasources
"""
# get all builtin providers
datasource_provider_controllers = ToolManager.list_datasource_providers(tenant_id)
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinDatasourceProvider] = (
db.session.query(BuiltinDatasourceProvider)
.filter(BuiltinDatasourceProvider.tenant_id == tenant_id)
.all()
or []
)
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[DatasourceProviderApiEntity] = []
for provider_controller in datasource_provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_datasource_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
datasources = provider_controller.get_datasources()
for datasource in datasources or []:
user_builtin_provider.datasources.append(
ToolTransformService.convert_datasource_entity_to_api_entity(
tenant_id=tenant_id,
datasource=datasource,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:
raise e
return result
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:

View File

@@ -5,6 +5,11 @@ from typing import Optional, Union, cast
from yarl import URL
from configs import dify_config
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@@ -21,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from models.tools import ApiToolProvider, BuiltinDatasourceProvider, BuiltinToolProvider, WorkflowToolProvider
logger = logging.getLogger(__name__)
@@ -140,6 +145,64 @@ class ToolTransformService:
return result
@classmethod
def builtin_datasource_provider_to_user_provider(
cls,
provider_controller: DatasourcePluginProviderController,
db_provider: Optional[BuiltinDatasourceProvider],
decrypt_credentials: bool = True,
) -> DatasourceProviderApiEntity:
"""
convert provider controller to user provider
"""
result = DatasourceProviderApiEntity(
id=provider_controller.entity.identity.name,
author=provider_controller.entity.identity.author,
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
label=provider_controller.entity.identity.label,
type=DatasourceProviderType.RAG_PIPELINE,
masked_credentials={},
is_team_authorization=False,
plugin_id=provider_controller.plugin_id,
plugin_unique_identifier=provider_controller.plugin_unique_identifier,
datasources=[],
)
# get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
for name, value in schema.items():
if result.masked_credentials:
result.masked_credentials[name] = ""
# check if the provider need credentials
if not provider_controller.need_credentials:
result.is_team_authorization = True
result.allow_delete = False
elif db_provider:
result.is_team_authorization = True
if decrypt_credentials:
credentials = db_provider.credentials
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
return result
@staticmethod
def api_provider_to_controller(
db_provider: ApiToolProvider,
@@ -304,3 +367,48 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
@staticmethod
def convert_datasource_entity_to_api_entity(
datasource: DatasourcePlugin,
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> DatasourceApiEntity:
"""
convert tool to user tool
"""
# fork tool runtime
datasource = datasource.fork_datasource_runtime(
runtime=DatasourceRuntime(
credentials=credentials or {},
tenant_id=tenant_id,
)
)
# get datasource parameters
parameters = datasource.entity.parameters or []
# get datasource runtime parameters
runtime_parameters = datasource.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return DatasourceApiEntity(
author=datasource.entity.identity.author,
name=datasource.entity.identity.name,
label=datasource.entity.identity.label,
description=datasource.entity.description.human if datasource.entity.description else I18nObject(en_US=""),
output_schema=datasource.entity.output_schema,
parameters=current_parameters,
labels=labels or [],
)

View File

@@ -203,7 +203,6 @@ class WorkflowService:
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,