Compare commits

...

56 Commits

Author SHA1 Message Date
autofix-ci[bot]
e7556436e3 [autofix.ci] apply automated fixes 2026-02-06 07:49:23 +00:00
Yansong Zhang
9e939a414f Merge branch 'fix/api-token-lock' of github.com:langgenius/dify into fix/api-token-lock 2026-02-06 15:46:39 +08:00
Yansong Zhang
8ad46fa110 fix: Avoid the risk of long transactions and move db operations to the appropriate service layer 2026-02-06 15:45:50 +08:00
Yansong Zhang
f7c7feb433 fix: Avoid the risk of long transactions and move db operations to the appropriate service layer 2026-02-06 15:45:46 +08:00
hj24
92ff9648b9 Merge branch 'main' into fix/api-token-lock 2026-02-06 14:15:37 +08:00
QuantumGhost
e988266f53 chore: update HITL auto deploy workflow (#32040) 2026-02-06 14:15:32 +08:00
Yansong Zhang
af60f03225 fix: Avoid the risk of long transactions and move db operations to the appropriate service layer 2026-02-06 13:50:17 +08:00
Yansong Zhang
996b09371f fix: Avoid the risk of long transactions and move db operations to the appropriate service layer 2026-02-06 13:50:12 +08:00
zyssyz123
91bacc57d7 Update api/schedule/update_api_token_last_used_task.py
Co-authored-by: hj24 <mambahj24@gmail.com>
2026-02-06 13:39:37 +08:00
autofix-ci[bot]
11d0f907f0 [autofix.ci] apply automated fixes 2026-02-06 04:46:26 +00:00
Yansong Zhang
cbf2181b6c fix: Standardized adjustment 2026-02-06 12:44:26 +08:00
Yansong Zhang
1e476ad173 fix: Standardized adjustment 2026-02-06 12:44:22 +08:00
longbingljw
d9530f7bb7 fix: make flask upgrade-db fail on error (#32024) 2026-02-06 12:01:31 +08:00
Yansong Zhang
1ce330eeed fix test 2026-02-06 11:42:07 +08:00
Yansong Zhang
ee8ff23482 fix style check 2026-02-06 11:26:41 +08:00
wangxiaolei
b24e6edada fix: fix agent node tool type is not right (#32008)
Infer real tool type via querying relevant database tables.

The root cause for incorrect `type` field is still not clear.
2026-02-06 11:24:39 +08:00
Yansong Zhang
fa6f6730b5 add queue api_token_update 2026-02-06 11:08:12 +08:00
Yansong Zhang
b0876e0ec8 add queue api_token_update 2026-02-06 11:08:09 +08:00
Yansong Zhang
4763375c61 fix style check 2026-02-06 10:53:44 +08:00
autofix-ci[bot]
63248ed088 [autofix.ci] apply automated fixes 2026-02-06 02:42:42 +00:00
Yansong Zhang
18d66b3262 Modify to synchronize redis data to db regularly. 2026-02-06 10:40:44 +08:00
Yansong Zhang
ce3fdb604d Modify to synchronize redis data to db regularly. 2026-02-06 10:40:39 +08:00
autofix-ci[bot]
57f76c4072 [autofix.ci] apply automated fixes 2026-02-06 02:11:04 +00:00
Yansong Zhang
cb2b3e07ba fix: Standardized adjustment 2026-02-06 10:08:46 +08:00
Yansong Zhang
8cbd1af0d1 Merge remote-tracking branch 'origin/main' into fix/api-token-lock 2026-02-04 16:09:10 +08:00
Yansong Zhang
402b0e2cd6 Merge remote-tracking branch 'origin/main' into fix/api-token-lock 2026-02-04 15:55:38 +08:00
autofix-ci[bot]
07f2a40802 [autofix.ci] apply automated fixes 2026-02-04 07:50:51 +00:00
Yansong Zhang
93b535be95 make it great agin 2026-02-04 15:46:30 +08:00
Yansong Zhang
13be706202 make it great agin 2026-02-04 15:46:24 +08:00
autofix-ci[bot]
0685e294c4 [autofix.ci] apply automated fixes 2026-02-04 05:34:58 +00:00
Yansong Zhang
f8cc056604 fix start_time -> update_time 2026-02-04 13:30:45 +08:00
autofix-ci[bot]
ef1e233c2d [autofix.ci] apply automated fixes 2026-02-04 05:27:16 +00:00
Yansong Zhang
6d87424ab8 fix start_time -> update_time 2026-02-04 13:18:47 +08:00
Yansong Zhang
aaa98c9550 Merge branch 'fix/api-token-lock' of github.com:langgenius/dify into fix/api-token-lock 2026-02-04 12:20:51 +08:00
Yansong Zhang
4719f2569c fix start_time -> update_time 2026-02-04 12:20:42 +08:00
autofix-ci[bot]
282ec583db [autofix.ci] apply automated fixes 2026-02-04 04:06:50 +00:00
Yansong Zhang
c7337d5b67 make it great agin 2026-02-04 12:02:43 +08:00
Yansong Zhang
e1efea16a4 make it great agin 2026-02-04 12:00:43 +08:00
autofix-ci[bot]
dcba86b707 [autofix.ci] apply automated fixes 2026-02-04 03:22:17 +00:00
Yansong Zhang
d02ed82854 Merge branch 'fix/api-token-lock' of github.com:langgenius/dify into fix/api-token-lock 2026-02-04 11:18:03 +08:00
Yansong Zhang
60e3a7b419 make it great agin 2026-02-04 11:17:37 +08:00
Yansong Zhang
240684e723 make it great agin 2026-02-04 11:17:30 +08:00
autofix-ci[bot]
edfd34bc90 [autofix.ci] apply automated fixes 2026-02-03 08:32:34 +00:00
Yansong Zhang
292a9ff487 fix linter 2026-02-03 16:28:28 +08:00
Yansong Zhang
132684898b fix linter 2026-02-03 16:26:52 +08:00
autofix-ci[bot]
138117526a [autofix.ci] apply automated fixes 2026-02-03 07:53:38 +00:00
Yansong Zhang
396834c808 fix linter 2026-02-03 15:48:55 +08:00
autofix-ci[bot]
657b3f5990 [autofix.ci] apply automated fixes 2026-02-03 07:29:15 +00:00
Yansong Zhang
ea5089aba7 fix linter 2026-02-03 15:25:15 +08:00
Yansong Zhang
d69e4de47b fix linter 2026-02-03 15:22:01 +08:00
zyssyz123
79ead90487 Update api/tasks/update_api_token_last_used_task.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-03 15:19:59 +08:00
zyssyz123
c7de79dcbf Update api/libs/api_token_cache.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-02-03 15:16:29 +08:00
Yansong Zhang
da9abcf885 Merge branch 'fix/api-token-lock' of github.com:langgenius/dify into fix/api-token-lock 2026-02-03 15:15:01 +08:00
Yansong Zhang
d54b08701e fix linter 2026-02-03 15:14:17 +08:00
autofix-ci[bot]
c2fdfdc504 [autofix.ci] apply automated fixes 2026-02-03 07:08:47 +00:00
Yansong Zhang
d58d3f5bde add redis for api token 2026-02-03 15:03:11 +08:00
19 changed files with 1377 additions and 40 deletions

View File

@@ -4,8 +4,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "feat/hitl-frontend"
- "feat/hitl-backend"
- "feat/hitl"
types:
- completed
@@ -14,10 +13,7 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
github.event.workflow_run.head_branch == 'feat/hitl'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1

View File

@@ -102,6 +102,8 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
core.workflow.nodes.agent.agent_node -> core.db.session_factory
core.workflow.nodes.agent.agent_node -> models.tools
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability

View File

@@ -122,7 +122,7 @@ These commands assume you start from the repository root.
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).

View File

@@ -739,8 +739,10 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception:
except Exception as e:
logger.exception("Failed to execute database migration")
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally:
lock.release()
else:

View File

@@ -1155,6 +1155,16 @@ class CeleryScheduleTasksConfig(BaseSettings):
default=0,
)
# API token last_used_at batch update
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
description="Enable periodic batch update of API token last_used_at timestamps",
default=True,
)
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
description="Interval in minutes for batch updating API token last_used_at (default 30)",
default=30,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",

View File

@@ -10,6 +10,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
@@ -131,6 +132,11 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -55,6 +55,7 @@ from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
# Register models for flask_restx to avoid dict type issues in Swagger
@@ -820,6 +821,11 @@ class DatasetApiDeleteApi(Resource):
if key is None:
console_ns.abort(404, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -1,27 +1,24 @@
import logging
import time
from collections.abc import Callable
from datetime import timedelta
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from typing import Concatenate, ParamSpec, TypeVar, cast
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
@@ -296,7 +293,14 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
Validate and get API token with Redis caching.
This function uses a two-tier approach:
1. First checks Redis cache for the token
2. If not cached, queries database and caches the result
The last_used_at field is updated asynchronously via Celery task
to avoid blocking the request.
"""
auth_header = request.headers.get("Authorization")
if auth_header is None or " " not in auth_header:
@@ -308,29 +312,18 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time)
)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
result = session.execute(update_stmt)
api_token = session.scalar(stmt)
# Try to get token from cache first
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token validation served from cache for scope: %s", scope)
# Record usage in Redis for later batch update (no Celery task per request)
record_token_usage(auth_token, scope)
return cast(ApiToken, cached_token)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
if not api_token:
raise Unauthorized("Access token is invalid")
return api_token
# Cache miss - use Redis lock for single-flight mode
# This ensures only one request queries DB for the same token concurrently
return fetch_token_with_single_flight(auth_token, scope)
class DatasetApiResource(Resource):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Union, cast
from packaging.version import Version
from pydantic import ValidationError
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.db.session_factory import session_factory
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@@ -49,6 +50,12 @@ from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
MCPToolProvider,
WorkflowToolProvider,
)
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
@@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
llm_usage=llm_usage,
)
)
@staticmethod
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
provider_type_str = tool_config.get("type")
if provider_type_str:
return ToolProviderType(provider_type_str)
provider_id = tool_config.get("provider_name")
if not provider_id:
return ToolProviderType.BUILT_IN
with session_factory.create_session() as session:
provider_map: dict[
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
ToolProviderType,
] = {
WorkflowToolProvider: ToolProviderType.WORKFLOW,
MCPToolProvider: ToolProviderType.MCP,
ApiToolProvider: ToolProviderType.API,
BuiltinToolProvider: ToolProviderType.BUILT_IN,
}
for provider_model, provider_type in provider_map.items():
stmt = select(provider_model).where(
provider_model.id == provider_id,
provider_model.tenant_id == tenant_id,
)
if session.scalar(stmt):
return provider_type
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")

View File

@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"

View File

@@ -184,6 +184,15 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
}
if dify_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK:
imports.append("schedule.update_api_token_last_used_task")
beat_schedule["batch_update_api_token_last_used"] = {
"task": "schedule.update_api_token_last_used_task.batch_update_api_token_last_used",
# "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
"schedule": timedelta(minutes=2),
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@@ -0,0 +1,114 @@
"""
Scheduled task to batch-update API token last_used_at timestamps.
Instead of updating the database on every request, token usage is recorded
in Redis as lightweight SET keys (api_token_active:{scope}:{token}).
This task runs periodically (default every 30 minutes) to flush those
records into the database in a single batch operation.
"""
import logging
import time
from datetime import datetime
import click
from sqlalchemy import update
from sqlalchemy.orm import Session
import app
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import ApiToken
from services.api_token_service import ACTIVE_TOKEN_KEY_PREFIX
logger = logging.getLogger(__name__)
@app.celery.task(queue="api_token")
def batch_update_api_token_last_used():
"""
Batch update last_used_at for all recently active API tokens.
Scans Redis for api_token_active:* keys, parses the token and scope
from each key, and performs a batch database update.
"""
click.echo(click.style("batch_update_api_token_last_used: start.", fg="green"))
start_at = time.perf_counter()
updated_count = 0
scanned_count = 0
try:
# Collect all active token keys and their values (the actual usage timestamps)
token_entries: list[tuple[str, str | None, datetime]] = [] # (token, scope, usage_time)
keys_to_delete: list[str | bytes] = []
for key in redis_client.scan_iter(match=f"{ACTIVE_TOKEN_KEY_PREFIX}*", count=200):
if isinstance(key, bytes):
key = key.decode("utf-8")
scanned_count += 1
# Read the value (ISO timestamp recorded at actual request time)
value = redis_client.get(key)
if not value:
keys_to_delete.append(key)
continue
if isinstance(value, bytes):
value = value.decode("utf-8")
try:
usage_time = datetime.fromisoformat(value)
except (ValueError, TypeError):
logger.warning("Invalid timestamp in key %s: %s", key, value)
keys_to_delete.append(key)
continue
# Parse token info from key: api_token_active:{scope}:{token}
suffix = key[len(ACTIVE_TOKEN_KEY_PREFIX) :]
parts = suffix.split(":", 1)
if len(parts) == 2:
scope_str, token = parts
scope = None if scope_str == "None" else scope_str
token_entries.append((token, scope, usage_time))
keys_to_delete.append(key)
if not token_entries:
click.echo(click.style("batch_update_api_token_last_used: no active tokens found.", fg="yellow"))
# Still clean up any invalid keys
if keys_to_delete:
redis_client.delete(*keys_to_delete)
return
# Update each token in its own short transaction to avoid long transactions
for token, scope, usage_time in token_entries:
with Session(db.engine, expire_on_commit=False) as session, session.begin():
stmt = (
update(ApiToken)
.where(
ApiToken.token == token,
ApiToken.type == scope,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < usage_time)),
)
.values(last_used_at=usage_time)
)
result = session.execute(stmt)
rowcount = getattr(result, "rowcount", 0)
if rowcount > 0:
updated_count += 1
# Delete processed keys from Redis
if keys_to_delete:
redis_client.delete(*keys_to_delete)
except Exception:
logger.exception("batch_update_api_token_last_used failed")
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"batch_update_api_token_last_used: done. "
f"scanned={scanned_count}, updated={updated_count}, elapsed={elapsed:.2f}s",
fg="green",
)
)

View File

@@ -0,0 +1,330 @@
"""
API Token Service
Handles all API token caching, validation, and usage recording.
Includes Redis cache operations, database queries, and single-flight concurrency control.
"""
import logging
from datetime import datetime
from typing import Any
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs.datetime_utils import naive_utc_now
from models.model import ApiToken
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------
# Pydantic DTO
# ---------------------------------------------------------------------
class CachedApiToken(BaseModel):
"""
Pydantic model for cached API token data.
This is NOT a SQLAlchemy model instance, but a plain Pydantic model
that mimics the ApiToken model interface for read-only access.
"""
id: str
app_id: str | None
tenant_id: str | None
type: str
token: str
last_used_at: datetime | None
created_at: datetime | None
def __repr__(self) -> str:
return f"<CachedApiToken id={self.id} type={self.type}>"
# ---------------------------------------------------------------------
# Cache configuration
# ---------------------------------------------------------------------
CACHE_KEY_PREFIX = "api_token"
CACHE_TTL_SECONDS = 600 # 10 minutes
CACHE_NULL_TTL_SECONDS = 60 # 1 minute for non-existent tokens
ACTIVE_TOKEN_KEY_PREFIX = "api_token_active:"
# ---------------------------------------------------------------------
# Cache class
# ---------------------------------------------------------------------
class ApiTokenCache:
"""
Redis cache wrapper for API tokens.
Handles serialization, deserialization, and cache invalidation.
"""
@staticmethod
def make_active_key(token: str, scope: str | None = None) -> str:
"""Generate Redis key for recording token usage."""
return f"{ACTIVE_TOKEN_KEY_PREFIX}{scope}:{token}"
@staticmethod
def _make_tenant_index_key(tenant_id: str) -> str:
"""Generate Redis key for tenant token index."""
return f"tenant_tokens:{tenant_id}"
@staticmethod
def _make_cache_key(token: str, scope: str | None = None) -> str:
"""Generate cache key for the given token and scope."""
scope_str = scope or "any"
return f"{CACHE_KEY_PREFIX}:{scope_str}:{token}"
@staticmethod
def _serialize_token(api_token: Any) -> bytes:
"""Serialize ApiToken object to JSON bytes."""
if isinstance(api_token, CachedApiToken):
return api_token.model_dump_json().encode("utf-8")
cached = CachedApiToken(
id=str(api_token.id),
app_id=str(api_token.app_id) if api_token.app_id else None,
tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
type=api_token.type,
token=api_token.token,
last_used_at=api_token.last_used_at,
created_at=api_token.created_at,
)
return cached.model_dump_json().encode("utf-8")
@staticmethod
def _deserialize_token(cached_data: bytes | str) -> Any:
"""Deserialize JSON bytes/string back to a CachedApiToken Pydantic model."""
if cached_data in {b"null", "null"}:
return None
try:
if isinstance(cached_data, bytes):
cached_data = cached_data.decode("utf-8")
return CachedApiToken.model_validate_json(cached_data)
except (ValueError, Exception) as e:
logger.warning("Failed to deserialize token from cache: %s", e)
return None
@staticmethod
@redis_fallback(default_return=None)
def get(token: str, scope: str | None) -> Any | None:
"""Get API token from cache."""
cache_key = ApiTokenCache._make_cache_key(token, scope)
cached_data = redis_client.get(cache_key)
if cached_data is None:
logger.debug("Cache miss for token key: %s", cache_key)
return None
logger.debug("Cache hit for token key: %s", cache_key)
return ApiTokenCache._deserialize_token(cached_data)
@staticmethod
def _add_to_tenant_index(tenant_id: str | None, cache_key: str) -> None:
"""Add cache key to tenant index for efficient invalidation."""
if not tenant_id:
return
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
redis_client.sadd(index_key, cache_key)
redis_client.expire(index_key, CACHE_TTL_SECONDS + 60)
except Exception as e:
logger.warning("Failed to update tenant index: %s", e)
@staticmethod
def _remove_from_tenant_index(tenant_id: str | None, cache_key: str) -> None:
"""Remove cache key from tenant index."""
if not tenant_id:
return
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
redis_client.srem(index_key, cache_key)
except Exception as e:
logger.warning("Failed to remove from tenant index: %s", e)
@staticmethod
@redis_fallback(default_return=False)
def set(token: str, scope: str | None, api_token: Any | None, ttl: int = CACHE_TTL_SECONDS) -> bool:
"""Set API token in cache."""
cache_key = ApiTokenCache._make_cache_key(token, scope)
if api_token is None:
cached_value = b"null"
ttl = CACHE_NULL_TTL_SECONDS
else:
cached_value = ApiTokenCache._serialize_token(api_token)
try:
redis_client.setex(cache_key, ttl, cached_value)
if api_token is not None and hasattr(api_token, "tenant_id"):
ApiTokenCache._add_to_tenant_index(api_token.tenant_id, cache_key)
logger.debug("Cached token with key: %s, ttl: %ss", cache_key, ttl)
return True
except Exception as e:
logger.warning("Failed to cache token: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def delete(token: str, scope: str | None = None) -> bool:
"""Delete API token from cache."""
if scope is None:
pattern = f"{CACHE_KEY_PREFIX}:*:{token}"
try:
keys_to_delete = list(redis_client.scan_iter(match=pattern))
if keys_to_delete:
redis_client.delete(*keys_to_delete)
logger.info("Deleted %d cache entries for token", len(keys_to_delete))
return True
except Exception as e:
logger.warning("Failed to delete token cache with pattern: %s", e)
return False
else:
cache_key = ApiTokenCache._make_cache_key(token, scope)
try:
tenant_id = None
try:
cached_data = redis_client.get(cache_key)
if cached_data and cached_data != b"null":
cached_token = ApiTokenCache._deserialize_token(cached_data)
if cached_token:
tenant_id = cached_token.tenant_id
except Exception as e:
logger.debug("Failed to get tenant_id for cache cleanup: %s", e)
redis_client.delete(cache_key)
if tenant_id:
ApiTokenCache._remove_from_tenant_index(tenant_id, cache_key)
logger.info("Deleted cache for key: %s", cache_key)
return True
except Exception as e:
logger.warning("Failed to delete token cache: %s", e)
return False
@staticmethod
@redis_fallback(default_return=False)
def invalidate_by_tenant(tenant_id: str) -> bool:
"""Invalidate all API token caches for a specific tenant via tenant index."""
try:
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
cache_keys = redis_client.smembers(index_key)
if cache_keys:
deleted_count = 0
for cache_key in cache_keys:
if isinstance(cache_key, bytes):
cache_key = cache_key.decode("utf-8")
redis_client.delete(cache_key)
deleted_count += 1
redis_client.delete(index_key)
logger.info(
"Invalidated %d token cache entries for tenant: %s",
deleted_count,
tenant_id,
)
else:
logger.info(
"No tenant index found for %s, relying on TTL expiration",
tenant_id,
)
return True
except Exception as e:
logger.warning("Failed to invalidate tenant token cache: %s", e)
return False
# ---------------------------------------------------------------------
# Token usage recording (for batch update)
# ---------------------------------------------------------------------
def record_token_usage(auth_token: str, scope: str | None) -> None:
"""
Record token usage in Redis for later batch update by a scheduled job.
Instead of dispatching a Celery task per request, we simply SET a key in Redis.
A Celery Beat scheduled task will periodically scan these keys and batch-update
last_used_at in the database.
"""
try:
key = ApiTokenCache.make_active_key(auth_token, scope)
redis_client.set(key, naive_utc_now().isoformat(), ex=3600)
except Exception as e:
logger.warning("Failed to record token usage: %s", e)
# ---------------------------------------------------------------------
# Database query + single-flight
# ---------------------------------------------------------------------
def query_token_from_db(auth_token: str, scope: str | None) -> ApiToken:
"""
Query API token from database and cache the result.
Raises Unauthorized if token is invalid.
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
ApiTokenCache.set(auth_token, scope, None)
raise Unauthorized("Access token is invalid")
ApiTokenCache.set(auth_token, scope, api_token)
record_token_usage(auth_token, scope)
return api_token
def fetch_token_with_single_flight(auth_token: str, scope: str | None) -> ApiToken | Any:
"""
Fetch token from DB with single-flight pattern using Redis lock.
Ensures only one concurrent request queries the database for the same token.
Falls back to direct query if lock acquisition fails.
"""
logger.debug("Token cache miss, attempting to acquire query lock for scope: %s", scope)
lock_key = f"api_token_query_lock:{scope}:{auth_token}"
lock = redis_client.lock(lock_key, timeout=10, blocking_timeout=5)
try:
if lock.acquire(blocking=True):
try:
cached_token = ApiTokenCache.get(auth_token, scope)
if cached_token is not None:
logger.debug("Token cached by concurrent request, using cached version")
return cached_token
return query_token_from_db(auth_token, scope)
finally:
lock.release()
else:
logger.warning("Lock timeout for token: %s, proceeding with direct query", auth_token[:10])
return query_token_from_db(auth_token, scope)
except Unauthorized:
raise
except Exception as e:
logger.warning("Redis lock failed for token query: %s, proceeding anyway", e)
return query_token_from_db(auth_token, scope)

View File

@@ -48,6 +48,7 @@ from models.workflow import (
WorkflowArchiveLog,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.api_token_service import ApiTokenCache
logger = logging.getLogger(__name__)
@@ -134,6 +135,12 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(session, api_token_id: str):
# Fetch token details for cache invalidation
token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first()
if token_obj:
# Invalidate cache before deletion
ApiTokenCache.delete(token_obj.token, token_obj.type)
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(

View File

@@ -0,0 +1,375 @@
"""
Integration tests for API Token Cache with Redis.
These tests require:
- Redis server running
- Test database configured
"""
import time
from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
from extensions.ext_redis import redis_client
from models.model import ApiToken
from services.api_token_service import ApiTokenCache, CachedApiToken
class TestApiTokenCacheRedisIntegration:
"""Integration tests with real Redis."""
def setup_method(self):
"""Setup test fixtures and clean Redis."""
self.test_token = "test-integration-token-123"
self.test_scope = "app"
self.cache_key = f"api_token:{self.test_scope}:{self.test_token}"
# Clean up any existing test data
self._cleanup()
def teardown_method(self):
"""Cleanup test data from Redis."""
self._cleanup()
def _cleanup(self):
"""Remove test data from Redis."""
try:
redis_client.delete(self.cache_key)
redis_client.delete(ApiTokenCache._make_tenant_index_key("test-tenant-id"))
redis_client.delete(ApiTokenCache.make_active_key(self.test_token, self.test_scope))
except Exception:
pass # Ignore cleanup errors
def test_cache_set_and_get_with_real_redis(self):
"""Test cache set and get operations with real Redis."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id-123"
mock_token.app_id = "test-app-456"
mock_token.tenant_id = "test-tenant-789"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = datetime.now()
mock_token.created_at = datetime.now() - timedelta(days=30)
# Set in cache
result = ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
assert result is True
# Verify in Redis
cached_data = redis_client.get(self.cache_key)
assert cached_data is not None
# Get from cache
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
assert cached_token is not None
assert isinstance(cached_token, CachedApiToken)
assert cached_token.id == "test-id-123"
assert cached_token.app_id == "test-app-456"
assert cached_token.tenant_id == "test-tenant-789"
assert cached_token.type == "app"
assert cached_token.token == self.test_token
def test_cache_ttl_with_real_redis(self):
"""Test cache TTL is set correctly."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
ttl = redis_client.ttl(self.cache_key)
assert 595 <= ttl <= 600 # Should be around 600 seconds (10 minutes)
def test_cache_null_value_for_invalid_token(self):
"""Test caching null value for invalid tokens."""
result = ApiTokenCache.set(self.test_token, self.test_scope, None)
assert result is True
cached_data = redis_client.get(self.cache_key)
assert cached_data == b"null"
cached_token = ApiTokenCache.get(self.test_token, self.test_scope)
assert cached_token is None
ttl = redis_client.ttl(self.cache_key)
assert 55 <= ttl <= 60
def test_cache_delete_with_real_redis(self):
"""Test cache deletion with real Redis."""
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
assert redis_client.exists(self.cache_key) == 1
result = ApiTokenCache.delete(self.test_token, self.test_scope)
assert result is True
assert redis_client.exists(self.cache_key) == 0
def test_tenant_index_creation(self):
"""Test tenant index is created when caching token."""
from unittest.mock import MagicMock
tenant_id = "test-tenant-id"
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = tenant_id
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
assert redis_client.exists(index_key) == 1
members = redis_client.smembers(index_key)
cache_keys = [m.decode("utf-8") if isinstance(m, bytes) else m for m in members]
assert self.cache_key in cache_keys
def test_invalidate_by_tenant_via_index(self):
"""Test tenant-wide cache invalidation using index (fast path)."""
from unittest.mock import MagicMock
tenant_id = "test-tenant-id"
for i in range(3):
token_value = f"test-token-{i}"
mock_token = MagicMock()
mock_token.id = f"test-id-{i}"
mock_token.app_id = "test-app"
mock_token.tenant_id = tenant_id
mock_token.type = "app"
mock_token.token = token_value
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(token_value, "app", mock_token)
for i in range(3):
key = f"api_token:app:test-token-{i}"
assert redis_client.exists(key) == 1
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
assert result is True
for i in range(3):
key = f"api_token:app:test-token-{i}"
assert redis_client.exists(key) == 0
assert redis_client.exists(ApiTokenCache._make_tenant_index_key(tenant_id)) == 0
def test_concurrent_cache_access(self):
"""Test concurrent cache access doesn't cause issues."""
import concurrent.futures
from unittest.mock import MagicMock
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = "app"
mock_token.token = self.test_token
mock_token.last_used_at = None
mock_token.created_at = datetime.now()
ApiTokenCache.set(self.test_token, self.test_scope, mock_token)
def get_from_cache():
return ApiTokenCache.get(self.test_token, self.test_scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(get_from_cache) for _ in range(50)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
assert len(results) == 50
assert all(r is not None for r in results)
assert all(isinstance(r, CachedApiToken) for r in results)
class TestTokenUsageRecording:
"""Tests for recording token usage in Redis (batch update approach)."""
def setup_method(self):
self.test_token = "test-usage-token"
self.test_scope = "app"
self.active_key = ApiTokenCache.make_active_key(self.test_token, self.test_scope)
def teardown_method(self):
try:
redis_client.delete(self.active_key)
except Exception:
pass
def test_record_token_usage_sets_redis_key(self):
"""Test that record_token_usage writes an active key to Redis."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
# Key should exist
assert redis_client.exists(self.active_key) == 1
# Value should be an ISO timestamp
value = redis_client.get(self.active_key)
if isinstance(value, bytes):
value = value.decode("utf-8")
datetime.fromisoformat(value) # Should not raise
def test_record_token_usage_has_ttl(self):
"""Test that active keys have a TTL as safety net."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
ttl = redis_client.ttl(self.active_key)
assert 3595 <= ttl <= 3600 # ~1 hour
def test_record_token_usage_overwrites(self):
"""Test that repeated calls overwrite the same key (no accumulation)."""
from services.api_token_service import record_token_usage
record_token_usage(self.test_token, self.test_scope)
first_value = redis_client.get(self.active_key)
time.sleep(0.01) # Tiny delay so timestamp differs
record_token_usage(self.test_token, self.test_scope)
second_value = redis_client.get(self.active_key)
# Key count should still be 1 (overwritten, not accumulated)
assert redis_client.exists(self.active_key) == 1
class TestEndToEndCacheFlow:
"""End-to-end integration test for complete cache flow."""
@pytest.mark.usefixtures("db_session")
def test_complete_flow_cache_miss_then_hit(self, db_session):
"""
Test complete flow:
1. First request (cache miss) -> query DB -> cache result
2. Second request (cache hit) -> return from cache
3. Verify Redis state
"""
test_token_value = "test-e2e-token"
test_scope = "app"
test_token = ApiToken()
test_token.id = "test-e2e-id"
test_token.token = test_token_value
test_token.type = test_scope
test_token.app_id = "test-app"
test_token.tenant_id = "test-tenant"
test_token.last_used_at = None
test_token.created_at = datetime.now()
db_session.add(test_token)
db_session.commit()
try:
# Step 1: Cache miss - set token in cache
ApiTokenCache.set(test_token_value, test_scope, test_token)
cache_key = f"api_token:{test_scope}:{test_token_value}"
assert redis_client.exists(cache_key) == 1
# Step 2: Cache hit - get from cache
cached_token = ApiTokenCache.get(test_token_value, test_scope)
assert cached_token is not None
assert cached_token.id == test_token.id
assert cached_token.token == test_token_value
# Step 3: Verify tenant index
index_key = ApiTokenCache._make_tenant_index_key(test_token.tenant_id)
assert redis_client.exists(index_key) == 1
assert cache_key.encode() in redis_client.smembers(index_key)
# Step 4: Delete and verify cleanup
ApiTokenCache.delete(test_token_value, test_scope)
assert redis_client.exists(cache_key) == 0
assert cache_key.encode() not in redis_client.smembers(index_key)
finally:
db_session.delete(test_token)
db_session.commit()
redis_client.delete(f"api_token:{test_scope}:{test_token_value}")
redis_client.delete(ApiTokenCache._make_tenant_index_key(test_token.tenant_id))
def test_high_concurrency_simulation(self):
"""Simulate high concurrency access to cache."""
import concurrent.futures
from unittest.mock import MagicMock
test_token_value = "test-concurrent-token"
test_scope = "app"
mock_token = MagicMock()
mock_token.id = "concurrent-id"
mock_token.app_id = "test-app"
mock_token.tenant_id = "test-tenant"
mock_token.type = test_scope
mock_token.token = test_token_value
mock_token.last_used_at = datetime.now()
mock_token.created_at = datetime.now()
ApiTokenCache.set(test_token_value, test_scope, mock_token)
try:
def read_cache():
return ApiTokenCache.get(test_token_value, test_scope)
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(read_cache) for _ in range(100)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
elapsed = time.time() - start_time
assert len(results) == 100
assert all(r is not None for r in results)
assert elapsed < 1.0, f"Too slow: {elapsed}s for 100 cache reads"
finally:
ApiTokenCache.delete(test_token_value, test_scope)
redis_client.delete(ApiTokenCache._make_tenant_index_key(mock_token.tenant_id))
class TestRedisFailover:
"""Test behavior when Redis is unavailable."""
@patch("services.api_token_service.redis_client")
def test_graceful_degradation_when_redis_fails(self, mock_redis):
"""Test system degrades gracefully when Redis is unavailable."""
from redis import RedisError
mock_redis.get.side_effect = RedisError("Connection failed")
mock_redis.setex.side_effect = RedisError("Connection failed")
result_get = ApiTokenCache.get("test-token", "app")
assert result_get is None
result_set = ApiTokenCache.set("test-token", "app", None)
assert result_set is False

View File

@@ -0,0 +1,197 @@
from unittest.mock import MagicMock, patch
import pytest
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.agent.agent_node import AgentNode
class TestInferToolProviderType:
"""Test cases for AgentNode._infer_tool_provider_type method."""
def test_infer_type_from_config_workflow(self):
"""Test inferring workflow provider type from config."""
tool_config = {
"type": "workflow",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
def test_infer_type_from_config_builtin(self):
"""Test inferring builtin provider type from config."""
tool_config = {
"type": "builtin",
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_from_config_api(self):
"""Test inferring API provider type from config."""
tool_config = {
"type": "api",
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
def test_infer_type_from_config_mcp(self):
"""Test inferring MCP provider type from config."""
tool_config = {
"type": "mcp",
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
def test_infer_type_invalid_config_value_raises_error(self):
"""Test that invalid type value in config raises ValueError."""
tool_config = {
"type": "invalid-type",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with pytest.raises(ValueError):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_workflow_type_from_database(self):
"""Test inferring workflow provider type from database."""
tool_config = {
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns a result
mock_session.scalar.return_value = True
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
# Should only query once (after finding WorkflowToolProvider)
assert mock_session.scalar.call_count == 1
def test_infer_mcp_type_from_database(self):
"""Test inferring MCP provider type from database."""
tool_config = {
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns a result
mock_session.scalar.side_effect = [None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
assert mock_session.scalar.call_count == 2
def test_infer_api_type_from_database(self):
"""Test inferring API provider type from database."""
tool_config = {
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns None
# Third query (ApiToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
assert mock_session.scalar.call_count == 3
def test_infer_builtin_type_from_database(self):
"""Test inferring builtin provider type from database."""
tool_config = {
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First three queries return None
# Fourth query (BuiltinToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
assert mock_session.scalar.call_count == 4
def test_infer_type_default_when_not_found(self):
"""Test raising AgentNodeError when provider is not found in database."""
tool_config = {
"provider_name": "unknown-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# All queries return None
mock_session.scalar.return_value = None
# Current implementation raises AgentNodeError when provider not found
from core.workflow.nodes.agent.exc import AgentNodeError
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_type_default_when_no_provider_name(self):
"""Test defaulting to BUILT_IN when provider_name is missing."""
tool_config = {}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_database_exception_propagates(self):
"""Test that database exception propagates (current implementation doesn't catch it)."""
tool_config = {
"provider_name": "provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# Database query raises exception
mock_session.scalar.side_effect = Exception("Database error")
# Current implementation doesn't catch exceptions, so it propagates
with pytest.raises(Exception, match="Database error"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)

View File

@@ -132,6 +132,8 @@ class TestCelerySSLConfiguration:
mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
with patch("extensions.ext_celery.dify_config", mock_config):
from dify_app import DifyApp

View File

@@ -0,0 +1,250 @@
"""
Unit tests for API Token Cache module.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock, patch
from services.api_token_service import (
CACHE_KEY_PREFIX,
CACHE_NULL_TTL_SECONDS,
CACHE_TTL_SECONDS,
ApiTokenCache,
CachedApiToken,
)
class TestApiTokenCache:
"""Test cases for ApiTokenCache class."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_token = MagicMock()
self.mock_token.id = "test-token-id-123"
self.mock_token.app_id = "test-app-id-456"
self.mock_token.tenant_id = "test-tenant-id-789"
self.mock_token.type = "app"
self.mock_token.token = "test-token-value-abc"
self.mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
self.mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
def test_make_cache_key(self):
"""Test cache key generation."""
# Test with scope
key = ApiTokenCache._make_cache_key("my-token", "app")
assert key == f"{CACHE_KEY_PREFIX}:app:my-token"
# Test without scope
key = ApiTokenCache._make_cache_key("my-token", None)
assert key == f"{CACHE_KEY_PREFIX}:any:my-token"
def test_serialize_token(self):
"""Test token serialization."""
serialized = ApiTokenCache._serialize_token(self.mock_token)
data = json.loads(serialized)
assert data["id"] == "test-token-id-123"
assert data["app_id"] == "test-app-id-456"
assert data["tenant_id"] == "test-tenant-id-789"
assert data["type"] == "app"
assert data["token"] == "test-token-value-abc"
assert data["last_used_at"] == "2026-02-03T10:00:00"
assert data["created_at"] == "2026-01-01T00:00:00"
def test_serialize_token_with_nulls(self):
"""Test token serialization with None values."""
mock_token = MagicMock()
mock_token.id = "test-id"
mock_token.app_id = None
mock_token.tenant_id = None
mock_token.type = "dataset"
mock_token.token = "test-token"
mock_token.last_used_at = None
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
serialized = ApiTokenCache._serialize_token(mock_token)
data = json.loads(serialized)
assert data["app_id"] is None
assert data["tenant_id"] is None
assert data["last_used_at"] is None
def test_deserialize_token(self):
"""Test token deserialization."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
)
result = ApiTokenCache._deserialize_token(cached_data)
assert isinstance(result, CachedApiToken)
assert result.id == "test-id"
assert result.app_id == "test-app"
assert result.tenant_id == "test-tenant"
assert result.type == "app"
assert result.token == "test-token"
assert result.last_used_at == datetime(2026, 2, 3, 10, 0, 0)
assert result.created_at == datetime(2026, 1, 1, 0, 0, 0)
def test_deserialize_null_token(self):
"""Test deserialization of null token (cached miss)."""
result = ApiTokenCache._deserialize_token("null")
assert result is None
def test_deserialize_invalid_json(self):
"""Test deserialization with invalid JSON."""
result = ApiTokenCache._deserialize_token("invalid-json{")
assert result is None
@patch("services.api_token_service.redis_client")
def test_get_cache_hit(self, mock_redis):
"""Test cache hit scenario."""
cached_data = json.dumps(
{
"id": "test-id",
"app_id": "test-app",
"tenant_id": "test-tenant",
"type": "app",
"token": "test-token",
"last_used_at": "2026-02-03T10:00:00",
"created_at": "2026-01-01T00:00:00",
}
).encode("utf-8")
mock_redis.get.return_value = cached_data
result = ApiTokenCache.get("test-token", "app")
assert result is not None
assert isinstance(result, CachedApiToken)
assert result.app_id == "test-app"
mock_redis.get.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("services.api_token_service.redis_client")
def test_get_cache_miss(self, mock_redis):
"""Test cache miss scenario."""
mock_redis.get.return_value = None
result = ApiTokenCache.get("test-token", "app")
assert result is None
mock_redis.get.assert_called_once()
@patch("services.api_token_service.redis_client")
def test_set_valid_token(self, mock_redis):
"""Test setting a valid token in cache."""
result = ApiTokenCache.set("test-token", "app", self.mock_token)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:test-token"
assert args[1] == CACHE_TTL_SECONDS
@patch("services.api_token_service.redis_client")
def test_set_null_token(self, mock_redis):
"""Test setting a null token (cache penetration prevention)."""
result = ApiTokenCache.set("invalid-token", "app", None)
assert result is True
mock_redis.setex.assert_called_once()
args = mock_redis.setex.call_args[0]
assert args[0] == f"{CACHE_KEY_PREFIX}:app:invalid-token"
assert args[1] == CACHE_NULL_TTL_SECONDS
assert args[2] == b"null"
@patch("services.api_token_service.redis_client")
def test_delete_with_scope(self, mock_redis):
"""Test deleting token cache with specific scope."""
result = ApiTokenCache.delete("test-token", "app")
assert result is True
mock_redis.delete.assert_called_once_with(f"{CACHE_KEY_PREFIX}:app:test-token")
@patch("services.api_token_service.redis_client")
def test_delete_without_scope(self, mock_redis):
"""Test deleting token cache without scope (delete all)."""
# Mock scan_iter to return an iterator of keys
mock_redis.scan_iter.return_value = iter(
[
b"api_token:app:test-token",
b"api_token:dataset:test-token",
]
)
result = ApiTokenCache.delete("test-token", None)
assert result is True
# Verify scan_iter was called with the correct pattern
mock_redis.scan_iter.assert_called_once()
call_args = mock_redis.scan_iter.call_args
assert call_args[1]["match"] == f"{CACHE_KEY_PREFIX}:*:test-token"
# Verify delete was called with all matched keys
mock_redis.delete.assert_called_once_with(
b"api_token:app:test-token",
b"api_token:dataset:test-token",
)
@patch("services.api_token_service.redis_client")
def test_redis_fallback_on_exception(self, mock_redis):
"""Test Redis fallback when Redis is unavailable."""
from redis import RedisError
mock_redis.get.side_effect = RedisError("Connection failed")
result = ApiTokenCache.get("test-token", "app")
# Should return None (fallback) instead of raising exception
assert result is None
class TestApiTokenCacheIntegration:
"""Integration test scenarios."""
@patch("services.api_token_service.redis_client")
def test_full_cache_lifecycle(self, mock_redis):
"""Test complete cache lifecycle: set -> get -> delete."""
# Setup mock token
mock_token = MagicMock()
mock_token.id = "id-123"
mock_token.app_id = "app-456"
mock_token.tenant_id = "tenant-789"
mock_token.type = "app"
mock_token.token = "token-abc"
mock_token.last_used_at = datetime(2026, 2, 3, 10, 0, 0)
mock_token.created_at = datetime(2026, 1, 1, 0, 0, 0)
# 1. Set token in cache
ApiTokenCache.set("token-abc", "app", mock_token)
assert mock_redis.setex.called
# 2. Simulate cache hit
cached_data = ApiTokenCache._serialize_token(mock_token)
mock_redis.get.return_value = cached_data # bytes from model_dump_json().encode()
retrieved = ApiTokenCache.get("token-abc", "app")
assert retrieved is not None
assert isinstance(retrieved, CachedApiToken)
# 3. Delete from cache
ApiTokenCache.delete("token-abc", "app")
assert mock_redis.delete.called
@patch("services.api_token_service.redis_client")
def test_cache_penetration_prevention(self, mock_redis):
"""Test that non-existent tokens are cached as null."""
# Set null token (cache miss)
ApiTokenCache.set("non-existent-token", "app", None)
args = mock_redis.setex.call_args[0]
assert args[2] == b"null"
assert args[1] == CACHE_NULL_TTL_SECONDS # Shorter TTL for null values