mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 15:24:14 +00:00
Compare commits
12 Commits
fix/templa
...
feat/crend
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
820157558c | ||
|
|
6453fc4973 | ||
|
|
f62f926537 | ||
|
|
b3dafd913b | ||
|
|
b2d8a7eaf1 | ||
|
|
3e54414191 | ||
|
|
a173546c8d | ||
|
|
aa69d90489 | ||
|
|
4ba1292455 | ||
|
|
bb01c31f30 | ||
|
|
cd90b2ca9e | ||
|
|
9a65350cf7 |
@@ -1,15 +1,16 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
2
Makefile
2
Makefile
@@ -62,7 +62,7 @@ check:
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format and check with fixes..."
|
||||
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
|
||||
@@ -30,6 +30,7 @@ select = [
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"T201", # print-found
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
@@ -91,11 +92,18 @@ ignore = [
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
||||
@@ -7,7 +7,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
print(message, flush=True)
|
||||
_logger.debug(message)
|
||||
|
||||
|
||||
# grpc gevent
|
||||
|
||||
@@ -739,18 +739,18 @@ where sites.id is null limit 1000"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
print(f"App {app_id} not found")
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
tenant = app.tenant
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
print(f"Fix failed for app {app.id}")
|
||||
logger.info("Fix failed for app %s", app.id)
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
print(f"Fixing missing site for app {app.id}")
|
||||
logger.info("Fixing missing site for app %s", app.id)
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
@@ -1544,7 +1544,7 @@ def transform_datasource_credentials():
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
print(jina_plugin_unique_identifier)
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
|
||||
@@ -62,6 +62,9 @@ class ChatMessageListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
|
||||
@@ -118,12 +118,14 @@ class RagPipelineExportApi(Resource):
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
parser.add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=args["include_secret"] == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
@@ -417,7 +417,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
print("Weave login successful")
|
||||
logger.info("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
|
||||
@@ -229,7 +229,7 @@ class OceanBaseVector(BaseVector):
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Invalid JSON metadata: {metadata_str}")
|
||||
logger.warning("Invalid JSON metadata: %s", metadata_str)
|
||||
metadata = {}
|
||||
metadata["score"] = score
|
||||
docs.append(Document(page_content=_text, metadata=metadata))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import array
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
@@ -19,6 +20,8 @@ from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
oracledb.defaults.fetch_lobs = False
|
||||
|
||||
|
||||
@@ -180,8 +183,8 @@ class OracleVector(BaseVector):
|
||||
value,
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
except Exception:
|
||||
logger.exception("Failed to insert record %s into %s", value[0], self.table_name)
|
||||
conn.close()
|
||||
return pks
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@@ -23,6 +24,8 @@ from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
@@ -187,8 +190,8 @@ class RelytVector(BaseVector):
|
||||
delete_condition = chunks_table.c.id.in_(ids)
|
||||
conn.execute(chunks_table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
except Exception:
|
||||
logger.exception("Delete operation failed for collection %s", self._collection_name)
|
||||
return False
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
@@ -164,8 +164,8 @@ class TiDBVector(BaseVector):
|
||||
delete_condition = table.c.id.in_(ids)
|
||||
conn.execute(table.delete().where(delete_condition))
|
||||
return True
|
||||
except Exception as e:
|
||||
print("Delete operation failed:", str(e))
|
||||
except Exception:
|
||||
logger.exception("Delete operation failed for collection %s", self._collection_name)
|
||||
return False
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
@@ -417,12 +417,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
if db_model is not None:
|
||||
offload_data = db_model.offload_data
|
||||
|
||||
else:
|
||||
db_model = self._to_db_model(domain_model)
|
||||
offload_data = []
|
||||
offload_data = db_model.offload_data
|
||||
|
||||
offload_data = db_model.offload_data
|
||||
if domain_model.inputs is not None:
|
||||
result = self._truncate_and_upload(
|
||||
domain_model.inputs,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
@@ -8,6 +9,8 @@ from typing import Any, ClassVar, Optional
|
||||
class SchemaRegistry:
|
||||
"""Schema registry manages JSON schemas with version support"""
|
||||
|
||||
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
|
||||
|
||||
_default_instance: ClassVar[Optional["SchemaRegistry"]] = None
|
||||
_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
@@ -83,7 +86,7 @@ class SchemaRegistry:
|
||||
self.metadata[uri] = metadata
|
||||
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||
self.logger.warning("Failed to load schema %s/%s: %s", version, schema_name, e)
|
||||
|
||||
def get_schema(self, uri: str) -> Any | None:
|
||||
"""Retrieves a schema by URI with version support"""
|
||||
|
||||
@@ -147,4 +147,4 @@ class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command: %s")
|
||||
self.logger.exception("Failed to send abort command")
|
||||
|
||||
@@ -46,7 +46,11 @@ limit 1000"""
|
||||
record_id = str(i.id)
|
||||
provider_name = str(i.provider_name)
|
||||
retrieval_model = i.retrieval_model
|
||||
print(type(retrieval_model))
|
||||
logger.debug(
|
||||
"Processing dataset %s with retrieval model of type %s",
|
||||
record_id,
|
||||
type(retrieval_model),
|
||||
)
|
||||
|
||||
if record_id in failed_ids:
|
||||
continue
|
||||
|
||||
@@ -1327,14 +1327,14 @@ class RagPipelineService:
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_excution_log = (
|
||||
document_pipeline_execution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
)
|
||||
if not document_pipeline_excution_log:
|
||||
if not document_pipeline_execution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
@@ -1346,10 +1346,10 @@ class RagPipelineService:
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={
|
||||
"inputs": document_pipeline_excution_log.input_data,
|
||||
"start_node_id": document_pipeline_excution_log.datasource_node_id,
|
||||
"datasource_type": document_pipeline_excution_log.datasource_type,
|
||||
"datasource_info_list": [json.loads(document_pipeline_excution_log.datasource_info)],
|
||||
"inputs": document_pipeline_execution_log.input_data,
|
||||
"start_node_id": document_pipeline_execution_log.datasource_node_id,
|
||||
"datasource_type": document_pipeline_execution_log.datasource_type,
|
||||
"datasource_info_list": [json.loads(document_pipeline_execution_log.datasource_info)],
|
||||
"original_document_id": document.id,
|
||||
},
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
|
||||
@@ -685,12 +685,24 @@ class RagPipelineDslService:
|
||||
|
||||
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
||||
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node_data = node.get("data", {})
|
||||
if not node_data:
|
||||
continue
|
||||
data_type = node_data.get("type", "")
|
||||
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node_data.get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
# filter credential id from tool node
|
||||
if not include_secret and data_type == NodeType.TOOL.value:
|
||||
node_data.pop("credential_id", None)
|
||||
# filter credential id from agent node
|
||||
if not include_secret and data_type == NodeType.AGENT.value:
|
||||
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
|
||||
tool.pop("credential_id", None)
|
||||
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = self._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
@@ -17,6 +18,8 @@ from services.entities.knowledge_entities.rag_pipeline_entities import Knowledge
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RagPipelineTransformService:
|
||||
def transform_dataset(self, dataset_id: str):
|
||||
@@ -35,11 +38,11 @@ class RagPipelineTransformService:
|
||||
indexing_technique = dataset.indexing_technique
|
||||
|
||||
if not datasource_type and not indexing_technique:
|
||||
return self._transfrom_to_empty_pipeline(dataset)
|
||||
return self._transform_to_empty_pipeline(dataset)
|
||||
|
||||
doc_form = dataset.doc_form
|
||||
if not doc_form:
|
||||
return self._transfrom_to_empty_pipeline(dataset)
|
||||
return self._transform_to_empty_pipeline(dataset)
|
||||
retrieval_model = dataset.retrieval_model
|
||||
pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
|
||||
# deal dependencies
|
||||
@@ -257,10 +260,10 @@ class RagPipelineTransformService:
|
||||
if plugin_unique_identifier:
|
||||
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
|
||||
if need_install_plugin_unique_identifiers:
|
||||
print(need_install_plugin_unique_identifiers)
|
||||
logger.debug("Installing missing pipeline plugins %s", need_install_plugin_unique_identifiers)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, need_install_plugin_unique_identifiers)
|
||||
|
||||
def _transfrom_to_empty_pipeline(self, dataset: Dataset):
|
||||
def _transform_to_empty_pipeline(self, dataset: Dataset):
|
||||
pipeline = Pipeline(
|
||||
tenant_id=dataset.tenant_id,
|
||||
name=dataset.name,
|
||||
|
||||
@@ -450,7 +450,8 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
raise ValueError("No default credential found")
|
||||
# plugin does not require credentials, skip
|
||||
return
|
||||
|
||||
# Check credential policy compliance using the default credential ID
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Integration tests for ChatMessageApi permission verification."""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import completion as completion_api
|
||||
from controllers.console.app import message as message_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
@@ -99,3 +101,106 @@ class TestChatMessageApiPermissions:
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_get_requires_edit_permission(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Ensure GET chat-messages endpoint enforces edit permissions."""
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
conversation_id = uuid.uuid4()
|
||||
created_at = naive_utc_now()
|
||||
|
||||
mock_conversation = SimpleNamespace(id=str(conversation_id), app_id=str(mock_app_model.id))
|
||||
mock_message = SimpleNamespace(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=str(conversation_id),
|
||||
inputs=[],
|
||||
query="hello",
|
||||
message=[{"text": "hello"}],
|
||||
message_tokens=0,
|
||||
re_sign_file_url_answer="",
|
||||
answer_tokens=0,
|
||||
provider_response_latency=0.0,
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=mock_account.id,
|
||||
feedbacks=[],
|
||||
workflow_run_id=None,
|
||||
annotation=None,
|
||||
annotation_hit_history=None,
|
||||
created_at=created_at,
|
||||
agent_thoughts=[],
|
||||
message_files=[],
|
||||
message_metadata_dict={},
|
||||
status="success",
|
||||
error="",
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
class MockQuery:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
if getattr(self.model, "__name__", "") == "Conversation":
|
||||
return mock_conversation
|
||||
return None
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, *_):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
if getattr(self.model, "__name__", "") == "Message":
|
||||
return [mock_message]
|
||||
return []
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.query.side_effect = MockQuery
|
||||
mock_session.scalar.return_value = False
|
||||
|
||||
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
class DummyPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
self.data = data
|
||||
self.limit = limit
|
||||
self.has_more = has_more
|
||||
|
||||
monkeypatch.setattr(message_api, "InfiniteScrollPagination", DummyPagination)
|
||||
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
query_string={"conversation_id": str(conversation_id)},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -9,8 +9,7 @@ import Button from '@/app/components/base/button'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import type { PipelineTemplate } from '@/models/pipeline'
|
||||
import { PipelineTemplateListQueryKeyPrefix, useUpdateTemplateInfo } from '@/service/use-pipeline'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
import { useInvalidCustomizedTemplateList, useUpdateTemplateInfo } from '@/service/use-pipeline'
|
||||
|
||||
type EditPipelineInfoProps = {
|
||||
onClose: () => void
|
||||
@@ -63,7 +62,7 @@ const EditPipelineInfo = ({
|
||||
}, [])
|
||||
|
||||
const { mutateAsync: updatePipeline } = useUpdateTemplateInfo()
|
||||
const invalidCustomizedTemplateList = useInvalid([...PipelineTemplateListQueryKeyPrefix, 'customized'])
|
||||
const invalidCustomizedTemplateList = useInvalidCustomizedTemplateList()
|
||||
|
||||
const handleSave = useCallback(async () => {
|
||||
if (!name) {
|
||||
|
||||
@@ -5,9 +5,9 @@ import EditPipelineInfo from './edit-pipeline-info'
|
||||
import type { PipelineTemplate } from '@/models/pipeline'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import {
|
||||
PipelineTemplateListQueryKeyPrefix,
|
||||
useDeleteTemplate,
|
||||
useExportTemplateDSL,
|
||||
useInvalidCustomizedTemplateList,
|
||||
usePipelineTemplateById,
|
||||
} from '@/service/use-pipeline'
|
||||
import { downloadFile } from '@/utils/format'
|
||||
@@ -18,7 +18,6 @@ import Details from './details'
|
||||
import Content from './content'
|
||||
import Actions from './actions'
|
||||
import { useCreatePipelineDatasetFromCustomized } from '@/service/knowledge/use-create-dataset'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
import { useInvalidDatasetList } from '@/service/knowledge/use-dataset'
|
||||
|
||||
type TemplateCardProps = {
|
||||
@@ -128,7 +127,7 @@ const TemplateCard = ({
|
||||
}, [])
|
||||
|
||||
const { mutateAsync: deletePipeline } = useDeleteTemplate()
|
||||
const invalidCustomizedTemplateList = useInvalid([...PipelineTemplateListQueryKeyPrefix, 'customized'])
|
||||
const invalidCustomizedTemplateList = useInvalidCustomizedTemplateList()
|
||||
|
||||
const onConfirmDelete = useCallback(async () => {
|
||||
await deletePipeline(pipeline.id, {
|
||||
|
||||
@@ -321,7 +321,7 @@ const GotoAnything: FC<Props> = ({
|
||||
autoFocus
|
||||
/>
|
||||
{searchMode !== 'general' && (
|
||||
<div className='flex items-center gap-1 rounded bg-blue-50 px-2 py-[2px] text-xs font-medium text-blue-600 dark:bg-blue-900/40 dark:text-blue-300'>
|
||||
<div className='flex items-center gap-1 rounded bg-gray-100 px-2 py-[2px] text-xs font-medium text-gray-700 dark:bg-gray-800 dark:text-gray-300'>
|
||||
<span>{(() => {
|
||||
if (searchMode === 'scopes')
|
||||
return 'SCOPES'
|
||||
|
||||
@@ -6,6 +6,7 @@ import { useInvalidateAllBuiltInTools, useInvalidateAllToolProviders } from '@/s
|
||||
import { useInvalidateStrategyProviders } from '@/service/use-strategy'
|
||||
import type { Plugin, PluginDeclaration, PluginManifestInMarket } from '../../types'
|
||||
import { PluginType } from '../../types'
|
||||
import { useInvalidDataSourceList } from '@/service/use-pipeline'
|
||||
|
||||
const useRefreshPluginList = () => {
|
||||
const invalidateInstalledPluginList = useInvalidateInstalledPluginList()
|
||||
@@ -16,6 +17,7 @@ const useRefreshPluginList = () => {
|
||||
|
||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||
const invalidateAllBuiltInTools = useInvalidateAllBuiltInTools()
|
||||
const invalidateAllDataSources = useInvalidDataSourceList()
|
||||
|
||||
const invalidateStrategyProviders = useInvalidateStrategyProviders()
|
||||
return {
|
||||
@@ -30,6 +32,9 @@ const useRefreshPluginList = () => {
|
||||
// TODO: update suggested tools. It's a function in hook useMarketplacePlugins,handleUpdatePlugins
|
||||
}
|
||||
|
||||
if ((manifest && PluginType.datasource.includes(manifest.category)) || refreshAllType)
|
||||
invalidateAllDataSources()
|
||||
|
||||
// model select
|
||||
if ((manifest && PluginType.model.includes(manifest.category)) || refreshAllType) {
|
||||
refreshModelProviders()
|
||||
|
||||
@@ -33,6 +33,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { useInvalid } from '@/service/use-base'
|
||||
import {
|
||||
publishedPipelineInfoQueryKeyPrefix,
|
||||
useInvalidCustomizedTemplateList,
|
||||
usePublishAsCustomizedPipeline,
|
||||
} from '@/service/use-pipeline'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
@@ -158,6 +159,8 @@ const Popup = () => {
|
||||
push(`/datasets/${datasetId}/documents/create-from-pipeline`)
|
||||
}, [datasetId, push])
|
||||
|
||||
const invalidCustomizedTemplateList = useInvalidCustomizedTemplateList()
|
||||
|
||||
const handlePublishAsKnowledgePipeline = useCallback(async (
|
||||
name: string,
|
||||
icon: IconInfo,
|
||||
@@ -189,6 +192,7 @@ const Popup = () => {
|
||||
</div>
|
||||
),
|
||||
})
|
||||
invalidCustomizedTemplateList()
|
||||
}
|
||||
catch {
|
||||
notify({ type: 'error', message: t('datasetPipeline.publishTemplate.error.message') })
|
||||
|
||||
@@ -48,6 +48,10 @@ export const usePipelineTemplateList = (params: PipelineTemplateListParams) => {
|
||||
})
|
||||
}
|
||||
|
||||
export const useInvalidCustomizedTemplateList = () => {
|
||||
return useInvalid([...PipelineTemplateListQueryKeyPrefix, 'customized'])
|
||||
}
|
||||
|
||||
export const usePipelineTemplateById = (params: PipelineTemplateByIdRequest, enabled: boolean) => {
|
||||
const { template_id, type } = params
|
||||
return useQuery<PipelineTemplateByIdResponse>({
|
||||
|
||||
Reference in New Issue
Block a user