Compare commits

...

9 Commits

Author SHA1 Message Date
-LAN-
e3c1edc394 chore: lint
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-01-09 16:23:43 +08:00
-LAN-
49c7501cc8 refactor(api): add UserFrom creator role conversion and reuse in knowledge retrieval
Tests not run (not requested).
2026-01-09 16:23:43 +08:00
-LAN-
b64725b733 refactor(api): use InvokeFrom.to_creator_user_role in dataset query callback
CreatorUserRole.END_USER maps to `end_user` (underscore), and the callback now uses the shared `to_creator_user_role` conversion.

Tests not run (not requested).
2026-01-09 16:23:42 +08:00
-LAN-
df8f762159 Rename user_from to creator_user_role for better readability. (vibe-kanban 224d800e)
in api/core/rag/retrieval/dataset\_retrieval.py#L763 we pass a user\_from to created\_by\_role. In this way, we should use the UserFrom enum, but in fact, we should use CreatorUserRole here. So this is a naming mistake for the parameter, and we should change all the upstream parameter names to the correct values along its propagation path.
2026-01-09 16:23:42 +08:00
-LAN-
eb5522ff29 revert: api/core/rag/retrieval/dataset_retrieval.py, separate it into another PR
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-01-09 16:23:42 +08:00
-LAN-
7e33faecfe fix(api): switch dataset query created_by_role to CreatorUserRole enums
Note: `CreatorUserRole.END_USER` is `"end_user"` (underscore), matching the prior value.

Tests not run (not requested).
2026-01-09 16:23:42 +08:00
-LAN-
a015cad8b8 chore: run make lint (2 files reformatted, all checks passed)
chore: run make type-check (0 errors, 0 warnings, 0 notes)
2026-01-09 16:23:42 +08:00
-LAN-
c7a3a4fc0e Updated dataset query attribution to use the UserFrom enum in the callback and hit-testing paths, matching the retrieval flow and removing literal role strings. Adjusted api/core/callback_handler/index_tool_callback_handler.py and api/services/hit_testing_service.py where DatasetQuery records are created.
fix(api): use UserFrom for dataset query created_by_role

Tests not run (not requested).
2026-01-09 16:23:42 +08:00
-LAN-
27932cb669 Updated api/core/rag/retrieval/dataset_retrieval.py to set user_from via UserFrom, so dataset query attribution aligns with the enum used elsewhere in the codepath.
fix(api): use UserFrom enum for dataset retrieval user_from

Tests not run (not requested).

1) `make lint`
2) `make type-check`
3) `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`
2026-01-09 16:23:41 +08:00
6 changed files with 28 additions and 16 deletions

View File

@@ -9,6 +9,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from models.enums import CreatorUserRole
if TYPE_CHECKING: if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
@@ -80,6 +81,11 @@ class InvokeFrom(StrEnum):
return "dev" return "dev"
def to_creator_user_role(self) -> CreatorUserRole:
if self in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
return CreatorUserRole.ACCOUNT
return CreatorUserRole.END_USER
class ModelConfigWithCredentialsEntity(BaseModel): class ModelConfigWithCredentialsEntity(BaseModel):
""" """

View File

@@ -37,9 +37,7 @@ class DatasetIndexToolCallbackHandler:
content=query, content=query,
source="app", source="app",
source_app_id=self._app_id, source_app_id=self._app_id,
created_by_role=( created_by_role=self._invoke_from.to_creator_user_role(),
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
),
created_by=self._user_id, created_by=self._user_id,
) )

View File

@@ -63,6 +63,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model: dict[str, Any] = { default_retrieval_model: dict[str, Any] = {
@@ -176,13 +177,13 @@ class DatasetRetrieval:
) )
all_documents = [] all_documents = []
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" creator_user_role = invoke_from.to_creator_user_role()
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve( all_documents = self.single_retrieve(
app_id, app_id,
tenant_id, tenant_id,
user_id, user_id,
user_from, creator_user_role,
query, query,
available_datasets, available_datasets,
model_instance, model_instance,
@@ -197,7 +198,7 @@ class DatasetRetrieval:
app_id, app_id,
tenant_id, tenant_id,
user_id, user_id,
user_from, creator_user_role,
available_datasets, available_datasets,
query, query,
retrieve_config.top_k or 0, retrieve_config.top_k or 0,
@@ -334,7 +335,7 @@ class DatasetRetrieval:
app_id: str, app_id: str,
tenant_id: str, tenant_id: str,
user_id: str, user_id: str,
user_from: str, creator_user_role: CreatorUserRole,
query: str, query: str,
available_datasets: list, available_datasets: list,
model_instance: ModelInstance, model_instance: ModelInstance,
@@ -444,7 +445,7 @@ class DatasetRetrieval:
weights=retrieval_model_config.get("weights", None), weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter, document_ids_filter=document_ids_filter,
) )
self._on_query(query, None, [dataset_id], app_id, user_from, user_id) self._on_query(query, None, [dataset_id], app_id, creator_user_role, user_id)
if results: if results:
thread = threading.Thread( thread = threading.Thread(
@@ -466,7 +467,7 @@ class DatasetRetrieval:
app_id: str, app_id: str,
tenant_id: str, tenant_id: str,
user_id: str, user_id: str,
user_from: str, creator_user_role: CreatorUserRole,
available_datasets: list, available_datasets: list,
query: str | None, query: str | None,
top_k: int, top_k: int,
@@ -584,7 +585,7 @@ class DatasetRetrieval:
if thread_exceptions: if thread_exceptions:
raise thread_exceptions[0] raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) self._on_query(query, attachment_ids, dataset_ids, app_id, creator_user_role, user_id)
if all_documents: if all_documents:
# add thread to call _on_retrieval_end # add thread to call _on_retrieval_end
@@ -733,7 +734,7 @@ class DatasetRetrieval:
attachment_ids: list[str] | None, attachment_ids: list[str] | None,
dataset_ids: list[str], dataset_ids: list[str],
app_id: str, app_id: str,
user_from: str, creator_user_role: CreatorUserRole,
user_id: str, user_id: str,
): ):
""" """
@@ -755,7 +756,7 @@ class DatasetRetrieval:
content=json.dumps(contents), content=json.dumps(contents),
source="app", source="app",
source_app_id=app_id, source_app_id=app_id,
created_by_role=user_from, created_by_role=creator_user_role,
created_by=user_id, created_by=user_id,
) )
dataset_queries.append(dataset_query) dataset_queries.append(dataset_query)

View File

@@ -268,6 +268,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
usage = self._merge_usage(usage, metadata_usage) usage = self._merge_usage(usage, metadata_usage)
all_documents = [] all_documents = []
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
creator_user_role = self.user_from.to_creator_user_role()
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config # fetch model config
if node_data.single_retrieval_config is None: if node_data.single_retrieval_config is None:
@@ -292,7 +293,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
user_id=self.user_id, user_id=self.user_id,
app_id=self.app_id, app_id=self.app_id,
user_from=self.user_from.value, creator_user_role=creator_user_role,
query=query, query=query,
model_config=model_config, model_config=model_config,
model_instance=model_instance, model_instance=model_instance,
@@ -334,7 +335,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
app_id=self.app_id, app_id=self.app_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
user_id=self.user_id, user_id=self.user_id,
user_from=self.user_from.value, creator_user_role=creator_user_role,
available_datasets=available_datasets, available_datasets=available_datasets,
query=query, query=query,
top_k=node_data.multiple_retrieval_config.top_k, top_k=node_data.multiple_retrieval_config.top_k,

View File

@@ -12,6 +12,11 @@ class UserFrom(StrEnum):
ACCOUNT = "account" ACCOUNT = "account"
END_USER = "end-user" END_USER = "end-user"
def to_creator_user_role(self) -> "CreatorUserRole":
if self == UserFrom.ACCOUNT:
return CreatorUserRole.ACCOUNT
return CreatorUserRole.END_USER
class WorkflowRunTriggeredFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging" DEBUGGING = "debugging"

View File

@@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from models import Account from models import Account
from models.dataset import Dataset, DatasetQuery from models.dataset import Dataset, DatasetQuery
from models.enums import CreatorUserRole
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -98,7 +99,7 @@ class HitTestingService:
content=json.dumps(dataset_queries), content=json.dumps(dataset_queries),
source="hit_testing", source="hit_testing",
source_app_id=None, source_app_id=None,
created_by_role="account", created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id, created_by=account.id,
) )
db.session.add(dataset_query) db.session.add(dataset_query)
@@ -138,7 +139,7 @@ class HitTestingService:
content=query, content=query,
source="hit_testing", source="hit_testing",
source_app_id=None, source_app_id=None,
created_by_role="account", created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id, created_by=account.id,
) )