mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 21:15:10 +00:00
Compare commits
9 Commits
deploy/age
...
8c92-updat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e3c1edc394 | ||
|
|
49c7501cc8 | ||
|
|
b64725b733 | ||
|
|
df8f762159 | ||
|
|
eb5522ff29 | ||
|
|
7e33faecfe | ||
|
|
a015cad8b8 | ||
|
|
c7a3a4fc0e | ||
|
|
27932cb669 |
@@ -9,6 +9,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp
|
|||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
from core.file import File, FileUploadConfig
|
from core.file import File, FileUploadConfig
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
@@ -80,6 +81,11 @@ class InvokeFrom(StrEnum):
|
|||||||
|
|
||||||
return "dev"
|
return "dev"
|
||||||
|
|
||||||
|
def to_creator_user_role(self) -> CreatorUserRole:
|
||||||
|
if self in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
|
||||||
|
return CreatorUserRole.ACCOUNT
|
||||||
|
return CreatorUserRole.END_USER
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -37,9 +37,7 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
content=query,
|
content=query,
|
||||||
source="app",
|
source="app",
|
||||||
source_app_id=self._app_id,
|
source_app_id=self._app_id,
|
||||||
created_by_role=(
|
created_by_role=self._invoke_from.to_creator_user_role(),
|
||||||
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
|
||||||
),
|
|
||||||
created_by=self._user_id,
|
created_by=self._user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
|
|||||||
from models import UploadFile
|
from models import UploadFile
|
||||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
|
|
||||||
default_retrieval_model: dict[str, Any] = {
|
default_retrieval_model: dict[str, Any] = {
|
||||||
@@ -176,13 +177,13 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
|
|
||||||
all_documents = []
|
all_documents = []
|
||||||
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
|
creator_user_role = invoke_from.to_creator_user_role()
|
||||||
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||||
all_documents = self.single_retrieve(
|
all_documents = self.single_retrieve(
|
||||||
app_id,
|
app_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
user_id,
|
user_id,
|
||||||
user_from,
|
creator_user_role,
|
||||||
query,
|
query,
|
||||||
available_datasets,
|
available_datasets,
|
||||||
model_instance,
|
model_instance,
|
||||||
@@ -197,7 +198,7 @@ class DatasetRetrieval:
|
|||||||
app_id,
|
app_id,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
user_id,
|
user_id,
|
||||||
user_from,
|
creator_user_role,
|
||||||
available_datasets,
|
available_datasets,
|
||||||
query,
|
query,
|
||||||
retrieve_config.top_k or 0,
|
retrieve_config.top_k or 0,
|
||||||
@@ -334,7 +335,7 @@ class DatasetRetrieval:
|
|||||||
app_id: str,
|
app_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
user_from: str,
|
creator_user_role: CreatorUserRole,
|
||||||
query: str,
|
query: str,
|
||||||
available_datasets: list,
|
available_datasets: list,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
@@ -444,7 +445,7 @@ class DatasetRetrieval:
|
|||||||
weights=retrieval_model_config.get("weights", None),
|
weights=retrieval_model_config.get("weights", None),
|
||||||
document_ids_filter=document_ids_filter,
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
self._on_query(query, None, [dataset_id], app_id, user_from, user_id)
|
self._on_query(query, None, [dataset_id], app_id, creator_user_role, user_id)
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
@@ -466,7 +467,7 @@ class DatasetRetrieval:
|
|||||||
app_id: str,
|
app_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
user_from: str,
|
creator_user_role: CreatorUserRole,
|
||||||
available_datasets: list,
|
available_datasets: list,
|
||||||
query: str | None,
|
query: str | None,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
@@ -584,7 +585,7 @@ class DatasetRetrieval:
|
|||||||
|
|
||||||
if thread_exceptions:
|
if thread_exceptions:
|
||||||
raise thread_exceptions[0]
|
raise thread_exceptions[0]
|
||||||
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
|
self._on_query(query, attachment_ids, dataset_ids, app_id, creator_user_role, user_id)
|
||||||
|
|
||||||
if all_documents:
|
if all_documents:
|
||||||
# add thread to call _on_retrieval_end
|
# add thread to call _on_retrieval_end
|
||||||
@@ -733,7 +734,7 @@ class DatasetRetrieval:
|
|||||||
attachment_ids: list[str] | None,
|
attachment_ids: list[str] | None,
|
||||||
dataset_ids: list[str],
|
dataset_ids: list[str],
|
||||||
app_id: str,
|
app_id: str,
|
||||||
user_from: str,
|
creator_user_role: CreatorUserRole,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -755,7 +756,7 @@ class DatasetRetrieval:
|
|||||||
content=json.dumps(contents),
|
content=json.dumps(contents),
|
||||||
source="app",
|
source="app",
|
||||||
source_app_id=app_id,
|
source_app_id=app_id,
|
||||||
created_by_role=user_from,
|
created_by_role=creator_user_role,
|
||||||
created_by=user_id,
|
created_by=user_id,
|
||||||
)
|
)
|
||||||
dataset_queries.append(dataset_query)
|
dataset_queries.append(dataset_query)
|
||||||
|
|||||||
@@ -268,6 +268,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
usage = self._merge_usage(usage, metadata_usage)
|
usage = self._merge_usage(usage, metadata_usage)
|
||||||
all_documents = []
|
all_documents = []
|
||||||
dataset_retrieval = DatasetRetrieval()
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
creator_user_role = self.user_from.to_creator_user_role()
|
||||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||||
# fetch model config
|
# fetch model config
|
||||||
if node_data.single_retrieval_config is None:
|
if node_data.single_retrieval_config is None:
|
||||||
@@ -292,7 +293,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
app_id=self.app_id,
|
app_id=self.app_id,
|
||||||
user_from=self.user_from.value,
|
creator_user_role=creator_user_role,
|
||||||
query=query,
|
query=query,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
@@ -334,7 +335,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||||||
app_id=self.app_id,
|
app_id=self.app_id,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
user_from=self.user_from.value,
|
creator_user_role=creator_user_role,
|
||||||
available_datasets=available_datasets,
|
available_datasets=available_datasets,
|
||||||
query=query,
|
query=query,
|
||||||
top_k=node_data.multiple_retrieval_config.top_k,
|
top_k=node_data.multiple_retrieval_config.top_k,
|
||||||
|
|||||||
@@ -12,6 +12,11 @@ class UserFrom(StrEnum):
|
|||||||
ACCOUNT = "account"
|
ACCOUNT = "account"
|
||||||
END_USER = "end-user"
|
END_USER = "end-user"
|
||||||
|
|
||||||
|
def to_creator_user_role(self) -> "CreatorUserRole":
|
||||||
|
if self == UserFrom.ACCOUNT:
|
||||||
|
return CreatorUserRole.ACCOUNT
|
||||||
|
return CreatorUserRole.END_USER
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunTriggeredFrom(StrEnum):
|
class WorkflowRunTriggeredFrom(StrEnum):
|
||||||
DEBUGGING = "debugging"
|
DEBUGGING = "debugging"
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.dataset import Dataset, DatasetQuery
|
from models.dataset import Dataset, DatasetQuery
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -98,7 +99,7 @@ class HitTestingService:
|
|||||||
content=json.dumps(dataset_queries),
|
content=json.dumps(dataset_queries),
|
||||||
source="hit_testing",
|
source="hit_testing",
|
||||||
source_app_id=None,
|
source_app_id=None,
|
||||||
created_by_role="account",
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
)
|
)
|
||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
@@ -138,7 +139,7 @@ class HitTestingService:
|
|||||||
content=query,
|
content=query,
|
||||||
source="hit_testing",
|
source="hit_testing",
|
||||||
source_app_id=None,
|
source_app_id=None,
|
||||||
created_by_role="account",
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user