Compare commits

...

3 Commits

Author SHA1 Message Date
Joel
d4783e8c14 chore: url in tool description support clicking jump directly (#35163)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/amd64, ubuntu-latest, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/amd64, ubuntu-latest, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Skip Duplicate Checks (push) Waiting to run
Main CI Pipeline / Check Changed Files (push) Blocked by required conditions
Main CI Pipeline / Run API Tests (push) Blocked by required conditions
Main CI Pipeline / Skip API Tests (push) Blocked by required conditions
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Run Web Tests (push) Blocked by required conditions
Main CI Pipeline / Skip Web Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Run Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Skip Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Blocked by required conditions
Main CI Pipeline / Run VDB Tests (push) Blocked by required conditions
Main CI Pipeline / Skip VDB Tests (push) Blocked by required conditions
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / Run DB Migration Test (push) Blocked by required conditions
Main CI Pipeline / Skip DB Migration Test (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Trigger i18n Sync on Push / trigger (push) Waiting to run
2026-04-14 09:55:55 +00:00
Blackoutta
736880e046 feat: support configurable redis key prefix (#35139) 2026-04-14 09:31:41 +00:00
wdeveloper16
bd7a9b5fcf refactor: replace bare dict with dict[str, Any] in model provider service and core modules (#35122)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-14 09:18:30 +00:00
24 changed files with 613 additions and 87 deletions

View File

@@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false

View File

@@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
default=0,
)
REDIS_KEY_PREFIX: str = Field(
description="Optional global prefix for Redis keys, topics, and transport artifacts",
default="",
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
default=False,

View File

@@ -1,7 +1,7 @@
import json
import re
from collections.abc import Generator
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
@@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
action_name = None

View File

@@ -254,7 +254,7 @@ def resolve_dify_schema_refs(
return resolver.resolve(schema)
def _remove_metadata_fields(schema: dict) -> dict:
def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]:
"""
Remove metadata fields from schema that shouldn't be included in resolved output

View File

@@ -9,6 +9,7 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import normalize_redis_key_prefix
class _CelerySentinelKwargsDict(TypedDict):
@@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict):
password: str | None
class CelerySentinelTransportDict(TypedDict):
class CelerySentinelTransportDict(TypedDict, total=False):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
global_keyprefix: str
class CelerySSLOptionsDict(TypedDict):
@@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
transport_options: CelerySentinelTransportDict | dict[str, Any]
if dify_config.CELERY_USE_SENTINEL:
return CelerySentinelTransportDict(
transport_options = CelerySentinelTransportDict(
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
sentinel_kwargs=_CelerySentinelKwargsDict(
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
password=dify_config.CELERY_SENTINEL_PASSWORD,
),
)
return {}
else:
transport_options = {}
global_keyprefix = get_celery_redis_global_keyprefix()
if global_keyprefix:
transport_options["global_keyprefix"] = global_keyprefix
return transport_options
def get_celery_redis_global_keyprefix() -> str | None:
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
if not normalized_prefix:
return None
return f"{normalized_prefix}:"
def init_app(app: DifyApp) -> Celery:

View File

@@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
from typing import Any, Union, cast
import redis
from redis import RedisError
@@ -18,17 +18,26 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import (
normalize_redis_key_prefix,
serialize_redis_name,
serialize_redis_name_arg,
serialize_redis_name_args,
)
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
logger = logging.getLogger(__name__)
_normalize_redis_key_prefix = normalize_redis_key_prefix
_serialize_redis_name = serialize_redis_name
_serialize_redis_name_arg = serialize_redis_name_arg
_serialize_redis_name_args = serialize_redis_name_args
class RedisClientWrapper:
"""
A wrapper class for the Redis client that addresses the issue where the global
@@ -59,68 +68,148 @@ class RedisClientWrapper:
if self._client is None:
self._client = client
if TYPE_CHECKING:
# Type hints for IDE support and static analysis
# These are not executed at runtime but provide type information
def get(self, name: str | bytes) -> Any: ...
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any: ...
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
def setnx(self, name: str | bytes, value: Any) -> Any: ...
def delete(self, *names: str | bytes) -> Any: ...
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Lock: ...
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def pubsub(self) -> PubSub: ...
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
def __getattr__(self, item: str) -> Any:
def _require_client(self) -> redis.Redis | RedisCluster:
if self._client is None:
raise RuntimeError("Redis client is not initialized. Call init_app first.")
return getattr(self._client, item)
return self._client
def _get_prefix(self) -> str:
return dify_config.REDIS_KEY_PREFIX
def get(self, name: str | bytes) -> Any:
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any:
return self._require_client().set(
_serialize_redis_name_arg(name, self._get_prefix()),
value,
ex=ex,
px=px,
nx=nx,
xx=xx,
keepttl=keepttl,
get=get,
exat=exat,
pxat=pxat,
)
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
def setnx(self, name: str | bytes, value: Any) -> Any:
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
def delete(self, *names: str | bytes) -> Any:
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
def incr(self, name: str | bytes, amount: int = 1) -> Any:
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().expire(
_serialize_redis_name_arg(name, self._get_prefix()),
time,
nx=nx,
xx=xx,
gt=gt,
lt=lt,
)
def exists(self, *names: str | bytes) -> Any:
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
def ttl(self, name: str | bytes) -> Any:
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
def getdel(self, name: str | bytes) -> Any:
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Any:
return self._require_client().lock(
_serialize_redis_name(name, self._get_prefix()),
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
def hgetall(self, name: str | bytes) -> Any:
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
def hlen(self, name: str | bytes) -> Any:
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().zadd(
_serialize_redis_name_arg(name, self._get_prefix()),
cast(Any, mapping),
nx=nx,
xx=xx,
ch=ch,
incr=incr,
gt=gt,
lt=lt,
)
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
def zcard(self, name: str | bytes) -> Any:
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
def pubsub(self) -> PubSub:
return self._require_client().pubsub()
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
def __getattr__(self, item: str) -> Any:
return getattr(self._require_client(), item)
redis_client: RedisClientWrapper = RedisClientWrapper()

View File

@@ -0,0 +1,32 @@
from configs import dify_config
def normalize_redis_key_prefix(prefix: str | None) -> str:
"""Normalize the configured Redis key prefix for consistent runtime use."""
if prefix is None:
return ""
return prefix.strip()
def get_redis_key_prefix() -> str:
"""Read and normalize the current Redis key prefix from config."""
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
"""Convert a logical Redis name into the physical name used in Redis."""
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
if not normalized_prefix:
return name
return f"{normalized_prefix}:{name}"
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
"""Prefix string Redis names while preserving bytes inputs unchanged."""
if isinstance(name, bytes):
return name
return serialize_redis_name(name, prefix)
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
return tuple(serialize_redis_name_arg(name, prefix) for name in names)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@@ -32,12 +33,13 @@ class Topic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
self._client.publish(self._redis_topic, payload)
def as_subscriber(self) -> Subscriber:
return self
@@ -46,7 +48,7 @@ class Topic:
return _RedisSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@@ -30,12 +31,13 @@ class ShardedTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr]
def as_subscriber(self) -> Subscriber:
return self
@@ -44,7 +46,7 @@ class ShardedTopic:
return _RedisShardedSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@@ -6,6 +6,7 @@ import threading
from collections.abc import Iterator
from typing import Self
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
@@ -35,7 +36,7 @@ class StreamsTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
self._client = redis_client
self._topic = topic
self._key = f"stream:{topic}"
self._key = serialize_redis_name(f"stream:{topic}")
self._retention_seconds = retention_seconds
self.max_length = 5000

View File

@@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock:
timeout=self._ttl_seconds,
thread_local=False,
)
acquired = bool(self._lock.acquire(*args, **kwargs))
lock = self._lock
if lock is None:
raise RuntimeError("Redis lock initialization failed.")
acquired = bool(lock.acquire(*args, **kwargs))
self._acquired = acquired
if acquired:
self._start_heartbeat()

View File

@@ -1,4 +1,5 @@
import logging
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule
@@ -168,7 +169,9 @@ class ModelProviderService:
model_name=model,
)
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
def get_provider_credential(
self, tenant_id: str, provider: str, credential_id: str | None = None
) -> dict[str, Any] | None:
"""
get provider credentials.
@@ -180,7 +183,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id)
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]):
"""
validate provider credentials before saving.
@@ -192,7 +195,7 @@ class ModelProviderService:
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
@@ -210,7 +213,7 @@ class ModelProviderService:
self,
tenant_id: str,
provider: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
) -> None:
@@ -254,7 +257,7 @@ class ModelProviderService:
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> dict | None:
) -> dict[str, Any] | None:
"""
Retrieve model-specific credentials.
@@ -270,7 +273,9 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
def validate_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any]
):
"""
validate model credentials.
@@ -287,7 +292,13 @@ class ModelProviderService:
)
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
self,
tenant_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
) -> None:
"""
create and save model credentials.
@@ -314,7 +325,7 @@ class ModelProviderService:
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
) -> None:

View File

@@ -33,6 +33,7 @@ REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
REDIS_KEY_PREFIX=
# PostgreSQL database configuration
DB_USERNAME=postgres

View File

@@ -236,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.
_ = DifyConfig().normalized_pubsub_redis_url
def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == ""
def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == "enterprise-a"
@pytest.mark.parametrize(
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
[

View File

@@ -7,6 +7,47 @@ from unittest.mock import MagicMock, patch
class TestCelerySSLConfiguration:
"""Test suite for Celery SSL configuration."""
def test_get_celery_broker_transport_options_includes_global_keyprefix_for_redis(self):
mock_config = MagicMock()
mock_config.CELERY_USE_SENTINEL = False
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
with patch("extensions.ext_celery.dify_config", mock_config):
from extensions.ext_celery import get_celery_broker_transport_options
result = get_celery_broker_transport_options()
assert result["global_keyprefix"] == "enterprise-a:"
def test_get_celery_broker_transport_options_omits_global_keyprefix_when_prefix_empty(self):
mock_config = MagicMock()
mock_config.CELERY_USE_SENTINEL = False
mock_config.REDIS_KEY_PREFIX = " "
with patch("extensions.ext_celery.dify_config", mock_config):
from extensions.ext_celery import get_celery_broker_transport_options
result = get_celery_broker_transport_options()
assert "global_keyprefix" not in result
def test_get_celery_broker_transport_options_keeps_sentinel_and_adds_global_keyprefix(self):
mock_config = MagicMock()
mock_config.CELERY_USE_SENTINEL = True
mock_config.CELERY_SENTINEL_MASTER_NAME = "mymaster"
mock_config.CELERY_SENTINEL_SOCKET_TIMEOUT = 0.1
mock_config.CELERY_SENTINEL_PASSWORD = "secret"
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
with patch("extensions.ext_celery.dify_config", mock_config):
from extensions.ext_celery import get_celery_broker_transport_options
result = get_celery_broker_transport_options()
assert result["master_name"] == "mymaster"
assert result["sentinel_kwargs"]["password"] == "secret"
assert result["global_keyprefix"] == "enterprise-a:"
def test_get_celery_ssl_options_when_ssl_disabled(self):
"""Test SSL options when BROKER_USE_SSL is False."""
from configs import DifyConfig
@@ -151,3 +192,49 @@ class TestCelerySSLConfiguration:
# Check that SSL is also applied to Redis backend
assert "redis_backend_use_ssl" in celery_app.conf
assert celery_app.conf["redis_backend_use_ssl"] is not None
def test_celery_init_applies_global_keyprefix_to_broker_and_backend_transport(self):
mock_config = MagicMock()
mock_config.BROKER_USE_SSL = False
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
mock_config.CELERY_BACKEND = "redis"
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
mock_config.CELERY_USE_SENTINEL = False
mock_config.LOG_FORMAT = "%(message)s"
mock_config.LOG_TZ = "UTC"
mock_config.LOG_FILE = None
mock_config.CELERY_TASK_ANNOTATIONS = {}
mock_config.CELERY_BEAT_SCHEDULER_TIME = 1
mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False
mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False
mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False
mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False
mock_config.ENABLE_CLEAN_MESSAGES = False
mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False
mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False
mock_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK = False
mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False
mock_config.MARKETPLACE_ENABLED = False
mock_config.WORKFLOW_LOG_CLEANUP_ENABLED = False
mock_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK = False
mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False
mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
mock_config.ENTERPRISE_ENABLED = False
mock_config.ENTERPRISE_TELEMETRY_ENABLED = False
with patch("extensions.ext_celery.dify_config", mock_config):
from dify_app import DifyApp
from extensions.ext_celery import init_app
app = DifyApp(__name__)
celery_app = init_app(app)
assert celery_app.conf["broker_transport_options"]["global_keyprefix"] == "enterprise-a:"
assert celery_app.conf["result_backend_transport_options"]["global_keyprefix"] == "enterprise-a:"

View File

@@ -6,6 +6,7 @@ from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastCh
def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
channel = ext_redis.get_pubsub_broadcast_channel()
@@ -14,6 +15,7 @@ def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
def test_get_pubsub_broadcast_channel_sharded(monkeypatch):
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded")
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
channel = ext_redis.get_pubsub_broadcast_channel()

View File

@@ -1,12 +1,15 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from redis import RedisError
from redis.retry import Retry
from extensions.ext_redis import (
RedisClientWrapper,
_get_base_redis_params,
_get_cluster_connection_health_params,
_get_connection_health_params,
_normalize_redis_key_prefix,
_serialize_redis_name,
redis_fallback,
)
@@ -123,3 +126,99 @@ class TestRedisFallback:
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"
class TestRedisKeyPrefixHelpers:
def test_normalize_redis_key_prefix_trims_whitespace(self):
assert _normalize_redis_key_prefix(" enterprise-a ") == "enterprise-a"
def test_normalize_redis_key_prefix_treats_whitespace_only_as_empty(self):
assert _normalize_redis_key_prefix(" ") == ""
def test_serialize_redis_name_returns_original_when_prefix_empty(self):
assert _serialize_redis_name("model_lb_index:test", "") == "model_lb_index:test"
def test_serialize_redis_name_adds_single_colon_separator(self):
assert _serialize_redis_name("model_lb_index:test", "enterprise-a") == "enterprise-a:model_lb_index:test"
class TestRedisClientWrapperKeyPrefix:
def test_wrapper_get_prefixes_string_keys(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.get("oauth_state:abc")
mock_client.get.assert_called_once_with("enterprise-a:oauth_state:abc")
def test_wrapper_delete_prefixes_multiple_keys(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.delete("key:a", "key:b")
mock_client.delete.assert_called_once_with("enterprise-a:key:a", "enterprise-a:key:b")
def test_wrapper_lock_prefixes_lock_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.lock("resource-lock", timeout=10)
mock_client.lock.assert_called_once()
args, kwargs = mock_client.lock.call_args
assert args == ("enterprise-a:resource-lock",)
assert kwargs["timeout"] == 10
def test_wrapper_hash_operations_prefix_key_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.hset("hash:key", "field", "value")
wrapper.hgetall("hash:key")
mock_client.hset.assert_called_once_with("enterprise-a:hash:key", "field", "value")
mock_client.hgetall.assert_called_once_with("enterprise-a:hash:key")
def test_wrapper_zadd_prefixes_sorted_set_name(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
wrapper.zadd("zset:key", {"member": 1})
mock_client.zadd.assert_called_once()
args, kwargs = mock_client.zadd.call_args
assert args == ("enterprise-a:zset:key", {"member": 1})
assert kwargs["nx"] is False
def test_wrapper_preserves_keys_when_prefix_is_empty(self):
mock_client = MagicMock()
wrapper = RedisClientWrapper()
wrapper.initialize(mock_client)
with patch("extensions.ext_redis.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = " "
wrapper.get("plain:key")
mock_client.get.assert_called_once_with("plain:key")

View File

@@ -139,6 +139,28 @@ class TestTopic:
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
def test_publish_prefixes_regular_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = Topic(mock_redis_client, "test-topic")
topic.publish(b"test message")
mock_redis_client.publish.assert_called_once_with("enterprise-a:test-topic", b"test message")
def test_subscribe_prefixes_regular_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = Topic(mock_redis_client, "test-topic")
subscription = topic.subscribe()
try:
subscription._start_if_needed()
finally:
subscription.close()
mock_redis_client.pubsub.return_value.subscribe.assert_called_once_with("enterprise-a:test-topic")
class TestShardedTopic:
"""Test cases for the ShardedTopic class."""
@@ -176,6 +198,15 @@ class TestShardedTopic:
mock_redis_client.spublish.assert_called_once_with("test-sharded-topic", payload)
def test_publish_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
sharded_topic.publish(b"test sharded message")
mock_redis_client.spublish.assert_called_once_with("enterprise-a:test-sharded-topic", b"test sharded message")
def test_subscribe_returns_sharded_subscription(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock):
"""Test that subscribe() returns a _RedisShardedSubscription instance."""
subscription = sharded_topic.subscribe()
@@ -185,6 +216,19 @@ class TestShardedTopic:
assert subscription._pubsub is mock_redis_client.pubsub.return_value
assert subscription._topic == "test-sharded-topic"
def test_subscribe_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
subscription = sharded_topic.subscribe()
try:
subscription._start_if_needed()
finally:
subscription.close()
mock_redis_client.pubsub.return_value.ssubscribe.assert_called_once_with("enterprise-a:test-sharded-topic")
@dataclasses.dataclass(frozen=True)
class SubscriptionTestCase:

View File

@@ -2,6 +2,7 @@ import threading
import time
from dataclasses import dataclass
from typing import cast
from unittest.mock import patch
import pytest
@@ -150,6 +151,25 @@ class TestStreamsBroadcastChannel:
# Expire called after publish
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
def test_topic_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("alpha")
assert topic._topic == "alpha"
assert topic._key == "enterprise-a:stream:alpha"
def test_publish_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
with patch("extensions.redis_names.dify_config") as mock_config:
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("beta")
topic.publish(b"hello")
assert fake_redis._store["enterprise-a:stream:beta"][0][1] == {b"data": b"hello"}
assert fake_redis._expire_calls.get("enterprise-a:stream:beta", 0) >= 1
def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel):
topic = streams_channel.topic("producer-subscriber")

View File

@@ -351,6 +351,9 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# Optional: limit total Redis connections used by API/Worker (unset for default)
# Align with API's REDIS_MAX_CONNECTIONS in configs
REDIS_MAX_CONNECTIONS=

View File

@@ -88,6 +88,7 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
1. **Redis Configuration**:
- `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings.
- `REDIS_KEY_PREFIX`: Optional global namespace prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
1. **Celery Configuration**:

View File

@@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env
REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-}
REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-}
REDIS_DB: ${REDIS_DB:-0}
REDIS_KEY_PREFIX: ${REDIS_KEY_PREFIX:-}
REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-}
REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false}
REDIS_SENTINELS: ${REDIS_SENTINELS:-}

View File

@@ -78,6 +78,7 @@ describe('tool/tool-form/item', () => {
mockUseLanguage.mockReturnValue('en_US')
})
// Text input fields render their descriptions inline above the input.
it('should render text input labels and forward props to form input item', () => {
const handleChange = vi.fn()
const handleManageInputField = vi.fn()
@@ -121,6 +122,31 @@ describe('tool/tool-form/item', () => {
})
})
// URL fragments inside descriptions should be rendered as external links.
it('should render URLs in descriptions as external links', () => {
render(
<ToolFormItem
readOnly={false}
nodeId="tool-node"
schema={createSchema({
tooltip: {
en_US: 'Visit https://docs.dify.ai/tools for docs',
zh_Hans: 'Visit https://docs.dify.ai/tools for docs',
},
})}
value={{}}
onChange={vi.fn()}
/>,
)
const link = screen.getByRole('link', { name: 'https://docs.dify.ai/tools' })
expect(link).toHaveAttribute('href', 'https://docs.dify.ai/tools')
expect(link).toHaveAttribute('target', '_blank')
expect(link).toHaveAttribute('rel', 'noopener noreferrer')
expect(link.parentElement).toHaveTextContent('Visit https://docs.dify.ai/tools for docs')
})
// Non-text fields keep their descriptions inside the tooltip and support JSON schema preview.
it('should show tooltip for non-description fields and open the schema modal', () => {
const objectSchema = createSchema({
name: 'tool_config',

View File

@@ -1,5 +1,5 @@
'use client'
import type { FC } from 'react'
import type { FC, ReactNode } from 'react'
import type { ToolVarInputs } from '../../types'
import type { CredentialFormSchema } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { Tool } from '@/app/components/tools/types'
@@ -15,6 +15,45 @@ import { useLanguage } from '@/app/components/header/account-setting/model-provi
import { SchemaModal } from '@/app/components/plugins/plugin-detail-panel/tool-selector/components'
import FormInputItem from '@/app/components/workflow/nodes/_base/components/form-input-item'
const URL_REGEX = /(https?:\/\/\S+)/g
const renderDescriptionWithLinks = (description: string): ReactNode => {
const matches = [...description.matchAll(URL_REGEX)]
if (!matches.length)
return description
const parts: ReactNode[] = []
let currentIndex = 0
matches.forEach((match, index) => {
const [url] = match
const start = match.index ?? 0
if (start > currentIndex)
parts.push(description.slice(currentIndex, start))
parts.push(
<a
key={`${url}-${index}`}
href={url}
target="_blank"
rel="noopener noreferrer"
className="text-text-accent hover:underline"
>
{url}
</a>,
)
currentIndex = start + url.length
})
if (currentIndex < description.length)
parts.push(description.slice(currentIndex))
return parts
}
type Props = {
readOnly: boolean
nodeId: string
@@ -87,7 +126,9 @@ const ToolFormItem: FC<Props> = ({
)}
</div>
{showDescription && tooltip && (
<div className="body-xs-regular pb-0.5 text-text-tertiary">{tooltip[language] || tooltip.en_US}</div>
<div className="body-xs-regular break-words pb-0.5 text-text-tertiary">
{renderDescriptionWithLinks(tooltip[language] || tooltip.en_US)}
</div>
)}
</div>
<FormInputItem