This commit is contained in:
jyong
2025-06-06 14:22:00 +08:00
parent 70432952fd
commit d2750f1a02
3 changed files with 79 additions and 25 deletions

View File

@@ -1,3 +1,4 @@
from core.datasource.__base import datasource_provider
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@@ -43,6 +44,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),

View File

@@ -120,6 +120,47 @@ class DatasourceProviderService:
return copy_credentials_list
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
get datasource credentials.
:param tenant_id: workspace id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}")
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
copy_credentials_list.append(
{
"credentials": copy_credentials,
"type": datasource_provider.auth_type,
}
)
return copy_credentials_list
def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None:
"""
update datasource credentials.

View File

@@ -43,6 +43,7 @@ from models.account import Account
from models.dataset import Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.oauth import DatasourceProvider
from models.workflow import (
Workflow,
WorkflowNodeExecutionTriggeredFrom,
@@ -50,6 +51,7 @@ from models.workflow import (
WorkflowType,
)
from services.dataset_service import DatasetService
from services.datasource_provider_service import DatasourceProviderService
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
PipelineTemplateInfoEntity,
@@ -442,6 +444,7 @@ class RagPipelineService:
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
@@ -450,32 +453,40 @@ class RagPipelineService:
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get('provider_name'),
plugin_id=datasource_node_data.get('plugin_id'),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [page.model_dump() for page in online_document_result.result],
"provider_type": datasource_node_data.get("provider_type"),
}
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: GetOnlineDocumentPagesResponse = datasource_runtime._get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [page.model_dump() for page in online_document_result.result],
"provider_type": datasource_node_data.get("provider_type"),
}
elif datasource_runtime.datasource_provider_type == DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [result.model_dump() for result in website_crawl_result.result],
"provider_type": datasource_node_data.get("provider_type"),
}
else:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
return {
"result": [result.model_dump() for result in website_crawl_result.result],
"provider_type": datasource_node_data.get("provider_type"),
}
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]