mirror of
https://github.com/langgenius/dify.git
synced 2026-01-03 21:17:09 +00:00
fix(rag-pipeline-dsl): dsl import session error
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
@@ -10,6 +11,7 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@@ -64,10 +66,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
yaml_content=args["yaml_content"],
|
||||
)
|
||||
try:
|
||||
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
|
||||
@@ -110,9 +110,11 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# Add null check for dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session)
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
@@ -360,9 +362,10 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
pipeline=pipeline, workflow=workflow, start_node_id=args.get("start_node_id", "shared")
|
||||
)
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session)
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
@@ -446,9 +449,10 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session)
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Any, Optional, cast
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
@@ -1286,9 +1286,8 @@ class Pipeline(Base): # type: ignore[name-defined]
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(Base):
|
||||
@@ -1308,6 +1307,7 @@ class DocumentPipelineExecutionLog(Base):
|
||||
created_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class PipelineRecommendedPlugin(Base):
|
||||
__tablename__ = "pipeline_recommended_plugins"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||
@@ -1318,4 +1318,4 @@ class PipelineRecommendedPlugin(Base):
|
||||
position = db.Column(db.Integer, nullable=False, default=0)
|
||||
active = db.Column(db.Boolean, nullable=False, default=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@@ -352,9 +352,10 @@ class RagPipelineService:
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
|
||||
# update dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
@@ -1110,9 +1111,10 @@ class RagPipelineService:
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
with Session(db.engine) as session:
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
|
||||
# check template name is exist
|
||||
template_name = args.get("name")
|
||||
@@ -1136,7 +1138,9 @@ class RagPipelineService:
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
dsl = RagPipelineDslService.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||
|
||||
pipeline_customized_template = PipelineCustomizedTemplate(
|
||||
name=args.get("name"),
|
||||
|
||||
@@ -30,7 +30,6 @@ from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from models import Account
|
||||
@@ -235,10 +234,7 @@ class RagPipelineDslService:
|
||||
status=ImportStatus.FAILED,
|
||||
error="Pipeline not found",
|
||||
)
|
||||
dataset = pipeline.dataset
|
||||
if dataset:
|
||||
self._session.merge(dataset)
|
||||
dataset_name = dataset.name
|
||||
dataset = pipeline.retrieve_dataset(session=self._session)
|
||||
|
||||
# If major version mismatch, store import info in Redis
|
||||
if status == ImportStatus.PENDING:
|
||||
@@ -300,7 +296,7 @@ class RagPipelineDslService:
|
||||
):
|
||||
raise ValueError("Chunk structure is not compatible with the published pipeline")
|
||||
if not dataset:
|
||||
datasets = db.session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
generate_name = generate_incremental_name(names, name)
|
||||
dataset = Dataset(
|
||||
@@ -321,7 +317,7 @@ class RagPipelineDslService:
|
||||
)
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@@ -339,8 +335,8 @@ class RagPipelineDslService:
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
@@ -454,7 +450,7 @@ class RagPipelineDslService:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
@@ -472,8 +468,8 @@ class RagPipelineDslService:
|
||||
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
|
||||
type="dataset",
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
self._session.add(dataset_collection_binding)
|
||||
self._session.commit()
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = knowledge_configuration.embedding_model
|
||||
@@ -538,18 +534,10 @@ class RagPipelineDslService:
|
||||
account: Account,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
) -> Pipeline:
|
||||
"""Create a new app or update an existing one."""
|
||||
if not account.current_tenant_id:
|
||||
raise ValueError("Tenant id is required")
|
||||
"""Create a new app or update an existing one."""
|
||||
pipeline_data = data.get("rag_pipeline", {})
|
||||
# Set icon type
|
||||
icon_type_value = pipeline_data.get("icon_type")
|
||||
if icon_type_value in ["emoji", "link"]:
|
||||
icon_type = icon_type_value
|
||||
else:
|
||||
icon_type = "emoji"
|
||||
icon = str(pipeline_data.get("icon", ""))
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
@@ -609,7 +597,7 @@ class RagPipelineDslService:
|
||||
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
self._session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
@@ -632,8 +620,8 @@ class RagPipelineDslService:
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=rag_pipeline_variables_list,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
self._session.add(workflow)
|
||||
self._session.flush()
|
||||
pipeline.workflow_id = workflow.id
|
||||
else:
|
||||
workflow.graph = json.dumps(graph)
|
||||
@@ -643,19 +631,18 @@ class RagPipelineDslService:
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables_list
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
self._session.commit()
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str:
|
||||
def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str:
|
||||
"""
|
||||
Export pipeline
|
||||
:param pipeline: Pipeline instance
|
||||
:param include_secret: Whether include secret variable
|
||||
:return:
|
||||
"""
|
||||
dataset = pipeline.dataset
|
||||
dataset = pipeline.retrieve_dataset(session=self._session)
|
||||
if not dataset:
|
||||
raise ValueError("Missing dataset for rag pipeline")
|
||||
icon_info = dataset.icon_info
|
||||
@@ -672,12 +659,11 @@ class RagPipelineDslService:
|
||||
},
|
||||
}
|
||||
|
||||
cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
|
||||
self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
|
||||
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
@@ -685,7 +671,7 @@ class RagPipelineDslService:
|
||||
"""
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
self._session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
@@ -701,11 +687,11 @@ class RagPipelineDslService:
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
dependencies = self._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
@@ -713,19 +699,17 @@ class RagPipelineDslService:
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow
|
||||
:param workflow: Workflow instance
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
|
||||
dependencies = self._extract_dependencies_from_workflow_graph(graph)
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
|
||||
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow graph
|
||||
:param graph: Workflow graph
|
||||
@@ -882,25 +866,22 @@ class RagPipelineDslService:
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
|
||||
@staticmethod
|
||||
def _generate_aes_key(tenant_id: str) -> bytes:
|
||||
def _generate_aes_key(self, tenant_id: str) -> bytes:
|
||||
"""Generate AES key based on tenant_id"""
|
||||
return hashlib.sha256(tenant_id.encode()).digest()
|
||||
|
||||
@classmethod
|
||||
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
|
||||
def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str:
|
||||
"""Encrypt dataset_id using AES-CBC mode"""
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
key = self._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
|
||||
return base64.b64encode(ct_bytes).decode()
|
||||
|
||||
@classmethod
|
||||
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
|
||||
def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None:
|
||||
"""AES decryption"""
|
||||
try:
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
key = self._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
|
||||
@@ -908,39 +889,37 @@ class RagPipelineDslService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def create_rag_pipeline_dataset(
|
||||
self,
|
||||
tenant_id: str,
|
||||
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
|
||||
):
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
self._session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
):
|
||||
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
|
||||
else:
|
||||
# generate a random name as Untitled 1 2 3 ...
|
||||
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
|
||||
names,
|
||||
"Untitled",
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = rag_pipeline_dsl_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=None,
|
||||
dataset_name=rag_pipeline_dataset_create_entity.name,
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
account = cast(Account, current_user)
|
||||
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
|
||||
dataset=None,
|
||||
dataset_name=rag_pipeline_dataset_create_entity.name,
|
||||
icon_info=rag_pipeline_dataset_create_entity.icon_info,
|
||||
)
|
||||
return {
|
||||
"id": rag_pipeline_import_info.id,
|
||||
"dataset_id": rag_pipeline_import_info.dataset_id,
|
||||
|
||||
Reference in New Issue
Block a user