mirror of
https://github.com/langgenius/dify.git
synced 2026-02-03 22:44:12 +00:00
Compare commits
1 Commits
fix/api-to
...
inject-con
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5eaf535a0d |
@@ -112,7 +112,6 @@ ignore_imports =
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> configs
|
||||
core.workflow.nodes.document_extractor.node -> core.file.file_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.entities -> configs
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.api_token_cache import ApiTokenCache
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
@@ -131,10 +130,6 @@ class BaseApiKeyResource(Resource):
|
||||
|
||||
if key is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
||||
return # Type checker hint: abort() raises exception
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
@@ -51,7 +51,6 @@ from fields.dataset_fields import (
|
||||
weighted_score_fields,
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.api_token_cache import ApiTokenCache
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@@ -820,10 +819,6 @@ class DatasetApiDeleteApi(Resource):
|
||||
|
||||
if key is None:
|
||||
console_ns.abort(404, message="API key not found")
|
||||
return # Type checker hint: abort() raises exception
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
@@ -17,7 +17,6 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.api_token_cache import ApiTokenCache
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
@@ -297,14 +296,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token with Redis caching.
|
||||
|
||||
This function uses a two-tier approach:
|
||||
1. First checks Redis cache for the token
|
||||
2. If not cached, queries database and caches the result
|
||||
|
||||
The last_used_at field is updated asynchronously via Celery task
|
||||
to avoid blocking the request.
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None or " " not in auth_header:
|
||||
@@ -316,20 +308,8 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
# Try to get token from cache first
|
||||
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token validation served from cache for scope: %s", scope)
|
||||
# Asynchronously update last_used_at (non-blocking)
|
||||
_async_update_token_last_used_at(auth_token, scope)
|
||||
return cached_token
|
||||
|
||||
# Cache miss - query database
|
||||
logger.debug("Token cache miss, querying database for scope: %s", scope)
|
||||
current_time = naive_utc_now()
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
@@ -349,35 +329,10 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
# Cache the valid token
|
||||
ApiTokenCache.set(auth_token, scope, api_token)
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
def _async_update_token_last_used_at(auth_token: str, scope: str | None):
|
||||
"""
|
||||
Asynchronously update the last_used_at timestamp for a token.
|
||||
|
||||
This schedules a Celery task to update the database without blocking
|
||||
the current request. The start time is passed to ensure only older
|
||||
records are updated, providing natural concurrency control.
|
||||
"""
|
||||
try:
|
||||
from tasks.update_api_token_last_used_task import update_api_token_last_used_task
|
||||
|
||||
# Record the request start time for concurrency control
|
||||
start_time = naive_utc_now()
|
||||
start_time_iso = start_time.isoformat()
|
||||
|
||||
# Fire and forget - don't wait for result
|
||||
update_api_token_last_used_task.delay(auth_token, scope, start_time_iso)
|
||||
logger.debug("Scheduled async update for last_used_at (scope: %s, start_time: %s)", scope, start_time_iso)
|
||||
except Exception as e:
|
||||
# Don't fail the request if task scheduling fails
|
||||
logger.warning("Failed to schedule last_used_at update task: %s", e)
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
@@ -50,6 +51,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
@@ -71,6 +73,13 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._document_extractor_unstructured_api_config = (
|
||||
document_extractor_unstructured_api_config
|
||||
or UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@@ -135,6 +144,16 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DOCUMENT_EXTRACTOR:
|
||||
document_extractor_class = cast(type[DocumentExtractorNode], node_class)
|
||||
return document_extractor_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
|
||||
from .node import DocumentExtractorNode
|
||||
|
||||
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]
|
||||
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"]
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
variable_selector: Sequence[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnstructuredApiConfig:
|
||||
api_url: str | None = None
|
||||
api_key: str = ""
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import charset_normalizer
|
||||
import docx
|
||||
@@ -20,7 +20,6 @@ from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
@@ -29,11 +28,15 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
"""
|
||||
@@ -47,6 +50,23 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
unstructured_api_config: UnstructuredApiConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
@@ -64,7 +84,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
extracted_text_list = [
|
||||
_extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config)
|
||||
for file in value
|
||||
]
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@@ -72,7 +95,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(value)
|
||||
extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@@ -103,7 +126,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
def _extract_text_by_mime_type(
|
||||
*,
|
||||
file_content: bytes,
|
||||
mime_type: str,
|
||||
unstructured_api_config: UnstructuredApiConfig,
|
||||
) -> str:
|
||||
"""Extract text from a file based on its MIME type."""
|
||||
match mime_type:
|
||||
case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml":
|
||||
@@ -111,7 +139,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
case "application/pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case "application/msword":
|
||||
return _extract_text_from_doc(file_content)
|
||||
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case "text/csv":
|
||||
@@ -119,11 +147,11 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
|
||||
return _extract_text_from_excel(file_content)
|
||||
case "application/vnd.ms-powerpoint":
|
||||
return _extract_text_from_ppt(file_content)
|
||||
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
return _extract_text_from_pptx(file_content)
|
||||
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case "application/epub+zip":
|
||||
return _extract_text_from_epub(file_content)
|
||||
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case "message/rfc822":
|
||||
return _extract_text_from_eml(file_content)
|
||||
case "application/vnd.ms-outlook":
|
||||
@@ -140,7 +168,12 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
|
||||
|
||||
|
||||
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
|
||||
def _extract_text_by_file_extension(
|
||||
*,
|
||||
file_content: bytes,
|
||||
file_extension: str,
|
||||
unstructured_api_config: UnstructuredApiConfig,
|
||||
) -> str:
|
||||
"""Extract text from a file based on its file extension."""
|
||||
match file_extension:
|
||||
case (
|
||||
@@ -203,7 +236,7 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
case ".pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case ".doc":
|
||||
return _extract_text_from_doc(file_content)
|
||||
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case ".docx":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case ".csv":
|
||||
@@ -211,11 +244,11 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
case ".xls" | ".xlsx":
|
||||
return _extract_text_from_excel(file_content)
|
||||
case ".ppt":
|
||||
return _extract_text_from_ppt(file_content)
|
||||
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case ".pptx":
|
||||
return _extract_text_from_pptx(file_content)
|
||||
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case ".epub":
|
||||
return _extract_text_from_epub(file_content)
|
||||
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
|
||||
case ".eml":
|
||||
return _extract_text_from_eml(file_content)
|
||||
case ".msg":
|
||||
@@ -312,13 +345,13 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
"""
|
||||
Extract text from a DOC file.
|
||||
"""
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if not dify_config.UNSTRUCTURED_API_URL:
|
||||
if not unstructured_api_config.api_url:
|
||||
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
|
||||
|
||||
try:
|
||||
@@ -329,8 +362,8 @@ def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=unstructured_api_config.api_key,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
@@ -420,12 +453,20 @@ def _download_file_content(file: File) -> bytes:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_file(file: File):
|
||||
def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
file_content = _download_file_content(file)
|
||||
if file.extension:
|
||||
extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension)
|
||||
extracted_text = _extract_text_by_file_extension(
|
||||
file_content=file_content,
|
||||
file_extension=file.extension,
|
||||
unstructured_api_config=unstructured_api_config,
|
||||
)
|
||||
elif file.mime_type:
|
||||
extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type)
|
||||
extracted_text = _extract_text_by_mime_type(
|
||||
file_content=file_content,
|
||||
mime_type=file.mime_type,
|
||||
unstructured_api_config=unstructured_api_config,
|
||||
)
|
||||
else:
|
||||
raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing")
|
||||
return extracted_text
|
||||
@@ -517,12 +558,12 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
if unstructured_api_config.api_url:
|
||||
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
@@ -530,8 +571,8 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=unstructured_api_config.api_key,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
@@ -543,12 +584,12 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
if unstructured_api_config.api_url:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
@@ -556,8 +597,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=unstructured_api_config.api_key,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
@@ -568,12 +609,12 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
try:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
if unstructured_api_config.api_url:
|
||||
with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
@@ -581,8 +622,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=unstructured_api_config.api_key,
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
|
||||
@@ -104,7 +104,6 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.generate_summary_index_task", # summary index generation
|
||||
"tasks.regenerate_summary_index_task", # summary index regeneration
|
||||
"tasks.update_api_token_last_used_task", # async API token last_used_at update
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
|
||||
@@ -1,380 +0,0 @@
|
||||
"""
|
||||
API Token Cache Module
|
||||
|
||||
Provides Redis-based caching for API token validation to reduce database load.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CachedApiToken:
|
||||
"""
|
||||
Simple data class to represent a cached API token.
|
||||
|
||||
This is NOT a SQLAlchemy model instance, but a plain Python object
|
||||
that mimics the ApiToken model interface for read-only access.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
app_id: str | None,
|
||||
tenant_id: str | None,
|
||||
type: str,
|
||||
token: str,
|
||||
last_used_at: datetime | None,
|
||||
created_at: datetime | None,
|
||||
):
|
||||
self.id = id
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
self.type = type
|
||||
self.token = token
|
||||
self.last_used_at = last_used_at
|
||||
self.created_at = created_at
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CachedApiToken id={self.id} type={self.type}>"
|
||||
|
||||
|
||||
# Cache configuration
|
||||
CACHE_KEY_PREFIX = "api_token"
|
||||
CACHE_TTL_SECONDS = 600 # 10 minutes
|
||||
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens (防穿透)
|
||||
|
||||
|
||||
class ApiTokenCache:
|
||||
"""
|
||||
Redis cache wrapper for API tokens.
|
||||
Handles serialization, deserialization, and cache invalidation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_cache_key(token: str, scope: str | None = None) -> str:
|
||||
"""
|
||||
Generate cache key for the given token and scope.
|
||||
|
||||
Args:
|
||||
token: The API token string
|
||||
scope: The token type/scope (e.g., 'app', 'dataset')
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
scope_str = scope or "any"
|
||||
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
|
||||
|
||||
@staticmethod
|
||||
def _serialize_token(api_token: Any) -> str:
|
||||
"""
|
||||
Serialize ApiToken object to JSON string.
|
||||
|
||||
Args:
|
||||
api_token: ApiToken model instance
|
||||
|
||||
Returns:
|
||||
JSON string representation
|
||||
"""
|
||||
data = {
|
||||
"id": str(api_token.id),
|
||||
"app_id": str(api_token.app_id) if api_token.app_id else None,
|
||||
"tenant_id": str(api_token.tenant_id) if api_token.tenant_id else None,
|
||||
"type": api_token.type,
|
||||
"token": api_token.token,
|
||||
"last_used_at": api_token.last_used_at.isoformat() if api_token.last_used_at else None,
|
||||
"created_at": api_token.created_at.isoformat() if api_token.created_at else None,
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_token(cached_data: str) -> Any:
|
||||
"""
|
||||
Deserialize JSON string back to a CachedApiToken object.
|
||||
|
||||
Args:
|
||||
cached_data: JSON string from cache
|
||||
|
||||
Returns:
|
||||
CachedApiToken instance or None
|
||||
"""
|
||||
if cached_data == "null":
|
||||
# Cached null value (token doesn't exist)
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(cached_data)
|
||||
|
||||
# Create a simple data object (NOT a SQLAlchemy model instance)
|
||||
# This is safe because it's just a plain Python object with attributes
|
||||
token_obj = CachedApiToken(
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
tenant_id=data["tenant_id"],
|
||||
type=data["type"],
|
||||
token=data["token"],
|
||||
last_used_at=datetime.fromisoformat(data["last_used_at"]) if data["last_used_at"] else None,
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data["created_at"] else None,
|
||||
)
|
||||
|
||||
return token_obj
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
logger.warning("Failed to deserialize token from cache: %s", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def get(token: str, scope: str | None) -> Any | None:
|
||||
"""
|
||||
Get API token from cache.
|
||||
|
||||
Args:
|
||||
token: The API token string
|
||||
scope: The token type/scope
|
||||
|
||||
Returns:
|
||||
CachedApiToken instance if found in cache, None if not cached or cache miss
|
||||
"""
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
cached_data = redis_client.get(cache_key)
|
||||
|
||||
if cached_data is None:
|
||||
logger.debug("Cache miss for token key: %s", cache_key)
|
||||
return None
|
||||
|
||||
# Decode bytes to string
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode("utf-8")
|
||||
|
||||
logger.debug("Cache hit for token key: %s", cache_key)
|
||||
return ApiTokenCache._deserialize_token(cached_data)
|
||||
|
||||
@staticmethod
|
||||
def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
||||
"""
|
||||
Add cache key to tenant index for efficient invalidation.
|
||||
|
||||
Maintains a Redis SET: tenant_tokens:{tenant_id} containing all cache keys
|
||||
for that tenant. This allows O(1) tenant-wide invalidation.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
cache_key: The cache key to add to the index
|
||||
"""
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
try:
|
||||
index_key = f"tenant_tokens:{tenant_id}"
|
||||
redis_client.sadd(index_key, cache_key)
|
||||
# Set TTL on the index itself (slightly longer than cache TTL)
|
||||
redis_client.expire(index_key, CACHE_TTL_SECONDS + 60)
|
||||
except Exception as e:
|
||||
# Don't fail if index update fails
|
||||
logger.warning("Failed to update tenant index: %s", e)
|
||||
|
||||
@staticmethod
|
||||
def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None:
|
||||
"""
|
||||
Remove cache key from tenant index.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
cache_key: The cache key to remove from the index
|
||||
"""
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
try:
|
||||
index_key = f"tenant_tokens:{tenant_id}"
|
||||
redis_client.srem(index_key, cache_key)
|
||||
except Exception as e:
|
||||
# Don't fail if index update fails
|
||||
logger.warning("Failed to remove from tenant index: %s", e)
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
|
||||
"""
|
||||
Set API token in cache.
|
||||
|
||||
Args:
|
||||
token: The API token string
|
||||
scope: The token type/scope
|
||||
api_token: ApiToken instance to cache (None for non-existent tokens)
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
|
||||
if api_token is None:
|
||||
# Cache null value to prevent cache penetration
|
||||
cached_value = "null"
|
||||
ttl = CACHE_NULL_TTL_SECONDS
|
||||
else:
|
||||
cached_value = ApiTokenCache._serialize_token(api_token)
|
||||
|
||||
try:
|
||||
redis_client.setex(cache_key, ttl, cached_value)
|
||||
|
||||
# Add to tenant index for efficient tenant-wide invalidation
|
||||
if api_token is not None and hasattr(api_token, "tenant_id"):
|
||||
ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key)
|
||||
|
||||
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache token: %s", e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def delete(token: str, scope: str | None = None) -> bool:
|
||||
"""
|
||||
Delete API token from cache.
|
||||
|
||||
Args:
|
||||
token: The API token string
|
||||
scope: The token type/scope (None to delete all scopes)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if scope is None:
|
||||
# Delete all possible scopes for this token
|
||||
# This is a safer approach when scope is unknown
|
||||
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
|
||||
try:
|
||||
keys_to_delete = list(redis_client.scan_iter(match=pattern))
|
||||
if keys_to_delete:
|
||||
redis_client.delete(*keys_to_delete)
|
||||
logger.info("Deleted %d cache entries for token", len(keys_to_delete))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete token cache with pattern: %s", e)
|
||||
return False
|
||||
else:
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
try:
|
||||
# Try to get tenant_id before deleting (for index cleanup)
|
||||
tenant_id = None
|
||||
try:
|
||||
cached_data = redis_client.get(cache_key)
|
||||
if cached_data and cached_data != b"null":
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode("utf-8")
|
||||
data = json.loads(cached_data)
|
||||
tenant_id = data.get("tenant_id")
|
||||
except Exception as e:
|
||||
# If we can't get tenant_id, just delete the key without index cleanup
|
||||
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
|
||||
|
||||
# Delete the cache key
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
# Remove from tenant index
|
||||
if tenant_id:
|
||||
ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key)
|
||||
|
||||
logger.info("Deleted cache for key: %s", cache_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete token cache: %s", e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=False)
|
||||
def invalidate_by_tenant(tenant_id: str) -> bool:
|
||||
"""
|
||||
Invalidate all API token caches for a specific tenant.
|
||||
Use this when tenant status changes or tokens are batch updated.
|
||||
|
||||
Uses a two-tier approach:
|
||||
1. Try to use tenant index (fast, O(n) where n = tenant's tokens)
|
||||
2. Fallback to full scan if index doesn't exist (slow, O(N) where N = all tokens)
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try using tenant index first (efficient approach)
|
||||
index_key = f"tenant_tokens:{tenant_id}"
|
||||
cache_keys = redis_client.smembers(index_key)
|
||||
|
||||
if cache_keys:
|
||||
# Index exists - use it (fast path)
|
||||
deleted_count = 0
|
||||
for cache_key in cache_keys:
|
||||
if isinstance(cache_key, bytes):
|
||||
cache_key = cache_key.decode("utf-8")
|
||||
redis_client.delete(cache_key)
|
||||
deleted_count += 1
|
||||
|
||||
# Delete the index itself
|
||||
redis_client.delete(index_key)
|
||||
|
||||
logger.info(
|
||||
"Invalidated %d token cache entries for tenant: %s (via index)",
|
||||
deleted_count,
|
||||
tenant_id,
|
||||
)
|
||||
return True
|
||||
|
||||
# Index doesn't exist - fallback to scanning (slow path)
|
||||
logger.info("Tenant index not found, falling back to full scan for tenant: %s", tenant_id)
|
||||
|
||||
pattern = f"{CACHE_KEY_PREFIX}:*"
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
checked_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = redis_client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
for key in keys:
|
||||
checked_count += 1
|
||||
try:
|
||||
# Fetch and check if this token belongs to the tenant
|
||||
cached_data = redis_client.get(key)
|
||||
if cached_data:
|
||||
# Decode if bytes
|
||||
if isinstance(cached_data, bytes):
|
||||
cached_data = cached_data.decode("utf-8")
|
||||
|
||||
# Skip null values
|
||||
if cached_data == "null":
|
||||
continue
|
||||
|
||||
# Deserialize and check tenant_id
|
||||
data = json.loads(cached_data)
|
||||
if data.get("tenant_id") == tenant_id:
|
||||
redis_client.delete(key)
|
||||
deleted_count += 1
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.warning("Failed to check cache key %s: %s", key, e)
|
||||
continue
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"Invalidated %d token cache entries for tenant: %s (checked %d keys via scan)",
|
||||
deleted_count,
|
||||
tenant_id,
|
||||
checked_count,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Failed to invalidate tenant token cache: %s", e)
|
||||
return False
|
||||
@@ -14,7 +14,6 @@ from sqlalchemy.orm import sessionmaker
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.api_token_cache import ApiTokenCache
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from models import (
|
||||
ApiToken,
|
||||
@@ -135,12 +134,6 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(session, api_token_id: str):
|
||||
# Fetch token details for cache invalidation
|
||||
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
|
||||
if token_obj:
|
||||
# Invalidate cache before deletion
|
||||
ApiTokenCache.delete(token_obj.token, token_obj.type)
|
||||
|
||||
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"""
|
||||
Celery task for updating API token last_used_at timestamp asynchronously.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import ApiToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset", bind=True)
|
||||
def update_api_token_last_used_task(self, token: str, scope: str | None, start_time_iso: str):
|
||||
"""
|
||||
Asynchronously update the last_used_at timestamp for an API token.
|
||||
|
||||
Uses timestamp comparison to ensure only updates when last_used_at is older
|
||||
than the request start time, providing natural concurrency control.
|
||||
|
||||
Args:
|
||||
token: The API token string
|
||||
scope: The token type/scope (e.g., 'app', 'dataset')
|
||||
start_time_iso: ISO format timestamp of when the request started
|
||||
"""
|
||||
try:
|
||||
# Parse start_time from ISO format
|
||||
start_time = datetime.fromisoformat(start_time_iso)
|
||||
# Update database
|
||||
current_time = naive_utc_now()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == token,
|
||||
ApiToken.type == scope,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < start_time)),
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
)
|
||||
result = session.execute(update_stmt)
|
||||
|
||||
# Check if any rows were updated
|
||||
rowcount = getattr(result, "rowcount", 0)
|
||||
if rowcount > 0:
|
||||
session.commit()
|
||||
logger.info("Updated last_used_at for token (async): %s... (scope: %s)", token[:10], scope)
|
||||
return {"status": "updated", "rowcount": rowcount, "start_time": start_time_iso}
|
||||
else:
|
||||
logger.debug("No update needed for token: %s... (already up-to-date)", token[:10])
|
||||
return {"status": "no_update_needed", "reason": "last_used_at >= start_time"}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update last_used_at for token (async): %s", e)
|
||||
# Don't retry on failure to avoid blocking the queue
|
||||
return {"status": "failed", "error": str(e)}
|
||||
@@ -1,364 +0,0 @@
|
||||
"""Endpoint tests for controllers.console.workspace.tool_providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
from contextlib import contextmanager
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_CONTROLLER_MODULE: ModuleType | None = None
|
||||
_WRAPS_MODULE: ModuleType | None = None
|
||||
_CONTROLLER_PATCHERS: list[patch] = []
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_db():
|
||||
mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True))
|
||||
with patch("extensions.ext_database.db.session", mock_session):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
module_name = "controllers.console.workspace.tool_providers"
|
||||
global _CONTROLLER_MODULE
|
||||
if _CONTROLLER_MODULE is None:
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
patch_targets = [
|
||||
("libs.login.login_required", _noop),
|
||||
("controllers.console.wraps.setup_required", _noop),
|
||||
("controllers.console.wraps.account_initialization_required", _noop),
|
||||
("controllers.console.wraps.is_admin_or_owner_required", _noop),
|
||||
("controllers.console.wraps.enterprise_license_required", _noop),
|
||||
]
|
||||
for target, value in patch_targets:
|
||||
patcher = patch(target, value)
|
||||
patcher.start()
|
||||
_CONTROLLER_PATCHERS.append(patcher)
|
||||
monkeypatch.setenv("DIFY_SETUP_READY", "true")
|
||||
with _mock_db():
|
||||
_CONTROLLER_MODULE = importlib.import_module(module_name)
|
||||
|
||||
module = _CONTROLLER_MODULE
|
||||
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
# Ensure decorators that consult deployment edition do not reach the database.
|
||||
global _WRAPS_MODULE
|
||||
wraps_module = importlib.import_module("controllers.console.wraps")
|
||||
_WRAPS_MODULE = wraps_module
|
||||
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_account(user_id: str = "user-123") -> SimpleNamespace:
|
||||
return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None)
|
||||
|
||||
|
||||
def _set_current_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
controller_module: ModuleType,
|
||||
user: SimpleNamespace,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
def _getter():
|
||||
return user, tenant_id
|
||||
|
||||
user.current_tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter)
|
||||
if _WRAPS_MODULE is not None:
|
||||
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "_get_user", lambda: user)
|
||||
|
||||
|
||||
def test_tool_provider_list_calls_service_with_query(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
|
||||
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
|
||||
|
||||
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
|
||||
response = controller_module.ToolProviderListApi().get()
|
||||
|
||||
assert response == [{"provider": "builtin"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
|
||||
|
||||
|
||||
def test_builtin_provider_add_passes_payload(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value={"status": "ok"})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
|
||||
|
||||
payload = {
|
||||
"credentials": {"api_key": "sk-test"},
|
||||
"name": "MyTool",
|
||||
"type": controller_module.CredentialType.API_KEY,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/openai/add",
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
|
||||
|
||||
assert response == {"status": "ok"}
|
||||
service_mock.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
provider="openai",
|
||||
credentials={"api_key": "sk-test"},
|
||||
name="MyTool",
|
||||
api_type=controller_module.CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
|
||||
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-789")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
|
||||
monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/tool-provider/builtin/my-provider/tools",
|
||||
method="GET",
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
|
||||
|
||||
assert response == [{"name": "tool-a"}]
|
||||
service_mock.assert_called_once_with("tenant-789", "my-provider")
|
||||
|
||||
|
||||
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-9")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
|
||||
service_mock = MagicMock(return_value={"info": True})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
|
||||
|
||||
with app.test_request_context("/info", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
|
||||
|
||||
assert resp == {"info": True}
|
||||
service_mock.assert_called_once_with("tenant-9", "demo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-cred")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
|
||||
service_mock = MagicMock(return_value=[{"cred": 1}])
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"get_builtin_tool_provider_credentials",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/creds", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
|
||||
|
||||
assert resp == [{"cred": 1}]
|
||||
service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo")
|
||||
|
||||
|
||||
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
|
||||
service_mock = MagicMock(return_value={"schema": "ok"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
|
||||
|
||||
with app.test_request_context("/remote?url=https://example.com/"):
|
||||
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
|
||||
|
||||
assert resp == {"schema": "ok"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
|
||||
|
||||
|
||||
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
|
||||
service_mock = MagicMock(return_value=[{"tool": "t"}])
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
|
||||
|
||||
with app.test_request_context("/tools?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderListToolsApi().get()
|
||||
|
||||
assert resp == [{"tool": "t"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
|
||||
|
||||
|
||||
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
|
||||
service_mock = MagicMock(return_value={"provider": "foo"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
|
||||
|
||||
with app.test_request_context("/get?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderGetApi().get()
|
||||
|
||||
assert resp == {"provider": "foo"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
|
||||
|
||||
|
||||
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-13")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
|
||||
service_mock = MagicMock(return_value={"schema": True})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_provider_credentials_schema",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
with app.test_request_context("/schema", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
|
||||
provider="demo", credential_type="api-key"
|
||||
)
|
||||
|
||||
assert resp == {"schema": True}
|
||||
service_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
|
||||
tool_service = MagicMock(return_value={"wf": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_tool_id",
|
||||
tool_service,
|
||||
)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000001"
|
||||
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"wf": 1}
|
||||
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
|
||||
|
||||
|
||||
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
|
||||
service_mock = MagicMock(return_value={"app": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"get_workflow_tool_by_app_id",
|
||||
service_mock,
|
||||
)
|
||||
|
||||
app_id = "00000000-0000-0000-0000-000000000002"
|
||||
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
|
||||
assert resp == {"app": 1}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
|
||||
|
||||
|
||||
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
|
||||
service_mock = MagicMock(return_value=[{"id": 1}])
|
||||
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000003"
|
||||
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderListToolApi().get()
|
||||
|
||||
assert resp == [{"id": 1}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
|
||||
|
||||
|
||||
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
"list_builtin_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/builtin"):
|
||||
resp = controller_module.ToolBuiltinListApi().get()
|
||||
|
||||
assert resp == [{"name": "builtin"}]
|
||||
|
||||
|
||||
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-api")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.ApiToolManageService,
|
||||
"list_api_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/api"):
|
||||
resp = controller_module.ToolApiListApi().get()
|
||||
|
||||
assert resp == [{"name": "api"}]
|
||||
|
||||
|
||||
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
"list_tenant_workflow_tools",
|
||||
MagicMock(return_value=[provider]),
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/workflow"):
|
||||
resp = controller_module.ToolWorkflowListApi().get()
|
||||
|
||||
assert resp == [{"name": "wf"}]
|
||||
|
||||
|
||||
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-label")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
|
||||
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
|
||||
|
||||
with app.test_request_context("/tool-labels"):
|
||||
resp = controller_module.ToolLabelsApi().get()
|
||||
|
||||
assert resp == ["a", "b"]
|
||||
@@ -1,256 +0,0 @@
|
||||
"""
|
||||
Unit tests for API Token Cache module.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.api_token_cache import (
|
||||
CACHE_KEY_PREFIX,
|
||||
CACHE_NULL_TTL_SECONDS,
|
||||
CACHE_TTL_SECONDS,
|
||||
ApiTokenCache,
|
||||
CachedApiToken,
|
||||
)
|
||||
|
||||
|
||||
class TestApiTokenCache:
|
||||
"""Test cases for ApiTokenCache class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures."""
|
||||
self.mock_token = MagicMock()
|
||||
self.mock_token.id = "test-token-id-123"
|
||||
self.mock_token.app_id = "test-app-id-456"
|
||||
self.mock_token.tenant_id = "test-tenant-id-789"
|
||||
self.mock_token.type = "app"
|
||||
self.mock_token.token = "test-token-value-abc"
|
||||
self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
|
||||
self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
def test_make_cache_key(self):
|
||||
"""Test cache key generation."""
|
||||
# Test with scope
|
||||
key = ApiTokenCache._make_cache_key("my-token", "app")
|
||||
assert key == f"{CACHE_KEY_PREFIX}:app:my-token"
|
||||
|
||||
# Test without scope
|
||||
key = ApiTokenCache._make_cache_key("my-token", None)
|
||||
assert key == f"{CACHE_KEY_PREFIX}:any:my-token"
|
||||
|
||||
def test_serialize_token(self):
|
||||
"""Test token serialization."""
|
||||
serialized = ApiTokenCache._serialize_token(self.mock_token)
|
||||
data = json.loads(serialized)
|
||||
|
||||
assert data["id"] == "test-token-id-123"
|
||||
assert data["app_id"] == "test-app-id-456"
|
||||
assert data["tenant_id"] == "test-tenant-id-789"
|
||||
assert data["type"] == "app"
|
||||
assert data["token"] == "test-token-value-abc"
|
||||
assert data["last_used_at"] == "2026-02-03T10:00:00"
|
||||
assert data["created_at"] == "2026-01-01T00:00:00"
|
||||
|
||||
def test_serialize_token_with_nulls(self):
|
||||
"""Test token serialization with None values."""
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "test-id"
|
||||
mock_token.app_id = None
|
||||
mock_token.tenant_id = None
|
||||
mock_token.type = "dataset"
|
||||
mock_token.token = "test-token"
|
||||
mock_token.last_used_at = None
|
||||
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
serialized = ApiTokenCache._serialize_token(mock_token)
|
||||
data = json.loads(serialized)
|
||||
|
||||
assert data["app_id"] is None
|
||||
assert data["tenant_id"] is None
|
||||
assert data["last_used_at"] is None
|
||||
|
||||
def test_deserialize_token(self):
|
||||
"""Test token deserialization."""
|
||||
cached_data = json.dumps(
|
||||
{
|
||||
"id": "test-id",
|
||||
"app_id": "test-app",
|
||||
"tenant_id": "test-tenant",
|
||||
"type": "app",
|
||||
"token": "test-token",
|
||||
"last_used_at": "2026-02-03T10:00:00",
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
)
|
||||
|
||||
result = ApiTokenCache._deserialize_token(cached_data)
|
||||
|
||||
assert isinstance(result, CachedApiToken)
|
||||
assert result.id == "test-id"
|
||||
assert result.app_id == "test-app"
|
||||
assert result.tenant_id == "test-tenant"
|
||||
assert result.type == "app"
|
||||
assert result.token == "test-token"
|
||||
assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0)
|
||||
assert result.created_at == datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
def test_deserialize_null_token(self):
|
||||
"""Test deserialization of null token (cached miss)."""
|
||||
result = ApiTokenCache._deserialize_token("null")
|
||||
assert result is None
|
||||
|
||||
def test_deserialize_invalid_json(self):
|
||||
"""Test deserialization with invalid JSON."""
|
||||
result = ApiTokenCache._deserialize_token("invalid-json{")
|
||||
assert result is None
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_get_cache_hit(self, mock_redis):
|
||||
"""Test cache hit scenario."""
|
||||
cached_data = json.dumps(
|
||||
{
|
||||
"id": "test-id",
|
||||
"app_id": "test-app",
|
||||
"tenant_id": "test-tenant",
|
||||
"type": "app",
|
||||
"token": "test-token",
|
||||
"last_used_at": "2026-02-03T10:00:00",
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
}
|
||||
)
|
||||
mock_redis.get.return_value = cached_data.encode("utf-8")
|
||||
|
||||
result = ApiTokenCache.get("test-token", "app")
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, CachedApiToken)
|
||||
assert result.app_id == "test-app"
|
||||
mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_get_cache_miss(self, mock_redis):
|
||||
"""Test cache miss scenario."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = ApiTokenCache.get("test-token", "app")
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_set_valid_token(self, mock_redis):
|
||||
"""Test setting a valid token in cache."""
|
||||
result = ApiTokenCache.set("test-token", "app", self.mock_token)
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called_once()
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token"
|
||||
assert args[1] == CACHE_TTL_SECONDS
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_set_null_token(self, mock_redis):
|
||||
"""Test setting a null token (cache penetration prevention)."""
|
||||
result = ApiTokenCache.set("invalid-token", "app", None)
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called_once()
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token"
|
||||
assert args[1] == CACHE_NULL_TTL_SECONDS
|
||||
assert args[2] == "null"
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_delete_with_scope(self, mock_redis):
|
||||
"""Test deleting token cache with specific scope."""
|
||||
result = ApiTokenCache.delete("test-token", "app")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_delete_without_scope(self, mock_redis):
|
||||
"""Test deleting token cache without scope (delete all)."""
|
||||
# Mock scan_iter to return an iterator of keys
|
||||
mock_redis.scan_iter.return_value = iter(
|
||||
[
|
||||
b"api_token:app:test-token",
|
||||
b"api_token:dataset:test-token",
|
||||
]
|
||||
)
|
||||
|
||||
result = ApiTokenCache.delete("test-token", None)
|
||||
|
||||
assert result is True
|
||||
# Verify scan_iter was called with the correct pattern
|
||||
mock_redis.scan_iter.assert_called_once()
|
||||
call_args = mock_redis.scan_iter.call_args
|
||||
assert call_args[1]["match"] == f"{CACHE_KEY_PREFIX}:*:test-token"
|
||||
|
||||
# Verify delete was called with all matched keys
|
||||
mock_redis.delete.assert_called_once_with(
|
||||
b"api_token:app:test-token",
|
||||
b"api_token:dataset:test-token",
|
||||
)
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_redis_fallback_on_exception(self, mock_redis):
|
||||
"""Test Redis fallback when Redis is unavailable."""
|
||||
from redis import RedisError
|
||||
|
||||
mock_redis.get.side_effect = RedisError("Connection failed")
|
||||
|
||||
result = ApiTokenCache.get("test-token", "app")
|
||||
|
||||
# Should return None (fallback) instead of raising exception
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestApiTokenCacheIntegration:
|
||||
"""Integration test scenarios."""
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_full_cache_lifecycle(self, mock_redis):
|
||||
"""Test complete cache lifecycle: set -> get -> delete."""
|
||||
# Setup mock token
|
||||
mock_token = MagicMock()
|
||||
mock_token.id = "id-123"
|
||||
mock_token.app_id = "app-456"
|
||||
mock_token.tenant_id = "tenant-789"
|
||||
mock_token.type = "app"
|
||||
mock_token.token = "token-abc"
|
||||
mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
|
||||
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
|
||||
|
||||
# 1. Set token in cache
|
||||
ApiTokenCache.set("token-abc", "app", mock_token)
|
||||
assert mock_redis.setex.called
|
||||
|
||||
# 2. Simulate cache hit
|
||||
cached_data = ApiTokenCache._serialize_token(mock_token)
|
||||
mock_redis.get.return_value = cached_data.encode("utf-8")
|
||||
|
||||
retrieved = ApiTokenCache.get("token-abc", "app")
|
||||
assert retrieved is not None
|
||||
assert isinstance(retrieved, CachedApiToken)
|
||||
|
||||
# 3. Delete from cache
|
||||
ApiTokenCache.delete("token-abc", "app")
|
||||
assert mock_redis.delete.called
|
||||
|
||||
@patch("libs.api_token_cache.redis_client")
|
||||
def test_cache_penetration_prevention(self, mock_redis):
|
||||
"""Test that non-existent tokens are cached as null."""
|
||||
# Set null token (cache miss)
|
||||
ApiTokenCache.set("non-existent-token", "app", None)
|
||||
|
||||
args = mock_redis.setex.call_args[0]
|
||||
assert args[2] == "null"
|
||||
assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user