mirror of
https://github.com/langgenius/dify.git
synced 2026-04-14 12:32:44 +00:00
Compare commits
3 Commits
feat/creat
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4783e8c14 | ||
|
|
736880e046 | ||
|
|
bd7a9b5fcf |
@@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
# Leave empty to preserve current unprefixed behavior.
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# redis Sentinel configuration.
|
||||
REDIS_USE_SENTINEL=false
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_KEY_PREFIX: str = Field(
|
||||
description="Optional global prefix for Redis keys, topics, and transport artifacts",
|
||||
default="",
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="Enable SSL/TLS for the Redis connection",
|
||||
default=False,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
|
||||
@@ -254,7 +254,7 @@ def resolve_dify_schema_refs(
|
||||
return resolver.resolve(schema)
|
||||
|
||||
|
||||
def _remove_metadata_fields(schema: dict) -> dict:
|
||||
def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Remove metadata fields from schema that shouldn't be included in resolved output
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.redis_names import normalize_redis_key_prefix
|
||||
|
||||
|
||||
class _CelerySentinelKwargsDict(TypedDict):
|
||||
@@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict):
|
||||
password: str | None
|
||||
|
||||
|
||||
class CelerySentinelTransportDict(TypedDict):
|
||||
class CelerySentinelTransportDict(TypedDict, total=False):
|
||||
master_name: str | None
|
||||
sentinel_kwargs: _CelerySentinelKwargsDict
|
||||
global_keyprefix: str
|
||||
|
||||
|
||||
class CelerySSLOptionsDict(TypedDict):
|
||||
@@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
|
||||
|
||||
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
|
||||
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
|
||||
transport_options: CelerySentinelTransportDict | dict[str, Any]
|
||||
if dify_config.CELERY_USE_SENTINEL:
|
||||
return CelerySentinelTransportDict(
|
||||
transport_options = CelerySentinelTransportDict(
|
||||
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
sentinel_kwargs=_CelerySentinelKwargsDict(
|
||||
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
password=dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
),
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
transport_options = {}
|
||||
|
||||
global_keyprefix = get_celery_redis_global_keyprefix()
|
||||
if global_keyprefix:
|
||||
transport_options["global_keyprefix"] = global_keyprefix
|
||||
|
||||
return transport_options
|
||||
|
||||
|
||||
def get_celery_redis_global_keyprefix() -> str | None:
|
||||
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
|
||||
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
|
||||
if not normalized_prefix:
|
||||
return None
|
||||
return f"{normalized_prefix}:"
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> Celery:
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import ssl
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
import redis
|
||||
from redis import RedisError
|
||||
@@ -18,17 +18,26 @@ from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.redis_names import (
|
||||
normalize_redis_key_prefix,
|
||||
serialize_redis_name,
|
||||
serialize_redis_name_arg,
|
||||
serialize_redis_name_args,
|
||||
)
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_normalize_redis_key_prefix = normalize_redis_key_prefix
|
||||
_serialize_redis_name = serialize_redis_name
|
||||
_serialize_redis_name_arg = serialize_redis_name_arg
|
||||
_serialize_redis_name_args = serialize_redis_name_args
|
||||
|
||||
|
||||
class RedisClientWrapper:
|
||||
"""
|
||||
A wrapper class for the Redis client that addresses the issue where the global
|
||||
@@ -59,68 +68,148 @@ class RedisClientWrapper:
|
||||
if self._client is None:
|
||||
self._client = client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Type hints for IDE support and static analysis
|
||||
# These are not executed at runtime but provide type information
|
||||
def get(self, name: str | bytes) -> Any: ...
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any: ...
|
||||
def delete(self, *names: str | bytes) -> Any: ...
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Lock: ...
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any: ...
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
def pubsub(self) -> PubSub: ...
|
||||
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
def _require_client(self) -> redis.Redis | RedisCluster:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Redis client is not initialized. Call init_app first.")
|
||||
return getattr(self._client, item)
|
||||
return self._client
|
||||
|
||||
def _get_prefix(self) -> str:
|
||||
return dify_config.REDIS_KEY_PREFIX
|
||||
|
||||
def get(self, name: str | bytes) -> Any:
|
||||
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: str | bytes,
|
||||
value: Any,
|
||||
ex: int | None = None,
|
||||
px: int | None = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: int | None = None,
|
||||
pxat: int | None = None,
|
||||
) -> Any:
|
||||
return self._require_client().set(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
value,
|
||||
ex=ex,
|
||||
px=px,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
keepttl=keepttl,
|
||||
get=get,
|
||||
exat=exat,
|
||||
pxat=pxat,
|
||||
)
|
||||
|
||||
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
|
||||
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
|
||||
|
||||
def setnx(self, name: str | bytes, value: Any) -> Any:
|
||||
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
|
||||
|
||||
def delete(self, *names: str | bytes) -> Any:
|
||||
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
|
||||
|
||||
def incr(self, name: str | bytes, amount: int = 1) -> Any:
|
||||
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
|
||||
|
||||
def expire(
|
||||
self,
|
||||
name: str | bytes,
|
||||
time: int | timedelta,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any:
|
||||
return self._require_client().expire(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
time,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
gt=gt,
|
||||
lt=lt,
|
||||
)
|
||||
|
||||
def exists(self, *names: str | bytes) -> Any:
|
||||
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
|
||||
|
||||
def ttl(self, name: str | bytes) -> Any:
|
||||
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def getdel(self, name: str | bytes) -> Any:
|
||||
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: float | None = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
thread_local: bool = True,
|
||||
) -> Any:
|
||||
return self._require_client().lock(
|
||||
_serialize_redis_name(name, self._get_prefix()),
|
||||
timeout=timeout,
|
||||
sleep=sleep,
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
thread_local=thread_local,
|
||||
)
|
||||
|
||||
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
|
||||
|
||||
def hgetall(self, name: str | bytes) -> Any:
|
||||
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
|
||||
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
|
||||
|
||||
def hlen(self, name: str | bytes) -> Any:
|
||||
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def zadd(
|
||||
self,
|
||||
name: str | bytes,
|
||||
mapping: dict[str | bytes | int | float, float | int | str | bytes],
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
ch: bool = False,
|
||||
incr: bool = False,
|
||||
gt: bool = False,
|
||||
lt: bool = False,
|
||||
) -> Any:
|
||||
return self._require_client().zadd(
|
||||
_serialize_redis_name_arg(name, self._get_prefix()),
|
||||
cast(Any, mapping),
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
ch=ch,
|
||||
incr=incr,
|
||||
gt=gt,
|
||||
lt=lt,
|
||||
)
|
||||
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
|
||||
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
|
||||
|
||||
def zcard(self, name: str | bytes) -> Any:
|
||||
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
|
||||
|
||||
def pubsub(self) -> PubSub:
|
||||
return self._require_client().pubsub()
|
||||
|
||||
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
|
||||
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
return getattr(self._require_client(), item)
|
||||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
32
api/extensions/redis_names.py
Normal file
32
api/extensions/redis_names.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def normalize_redis_key_prefix(prefix: str | None) -> str:
|
||||
"""Normalize the configured Redis key prefix for consistent runtime use."""
|
||||
if prefix is None:
|
||||
return ""
|
||||
return prefix.strip()
|
||||
|
||||
|
||||
def get_redis_key_prefix() -> str:
|
||||
"""Read and normalize the current Redis key prefix from config."""
|
||||
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
|
||||
|
||||
|
||||
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
|
||||
"""Convert a logical Redis name into the physical name used in Redis."""
|
||||
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
|
||||
if not normalized_prefix:
|
||||
return name
|
||||
return f"{normalized_prefix}:{name}"
|
||||
|
||||
|
||||
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
|
||||
"""Prefix string Redis names while preserving bytes inputs unchanged."""
|
||||
if isinstance(name, bytes):
|
||||
return name
|
||||
return serialize_redis_name(name, prefix)
|
||||
|
||||
|
||||
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
|
||||
return tuple(serialize_redis_name_arg(name, prefix) for name in names)
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@@ -32,12 +33,13 @@ class Topic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.publish(self._topic, payload)
|
||||
self._client.publish(self._redis_topic, payload)
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
@@ -46,7 +48,7 @@ class Topic:
|
||||
return _RedisSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
topic=self._redis_topic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
@@ -30,12 +31,13 @@ class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
|
||||
self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr]
|
||||
|
||||
def as_subscriber(self) -> Subscriber:
|
||||
return self
|
||||
@@ -44,7 +46,7 @@ class ShardedTopic:
|
||||
return _RedisShardedSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
topic=self._redis_topic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Self
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis, RedisCluster
|
||||
@@ -35,7 +36,7 @@ class StreamsTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = f"stream:{topic}"
|
||||
self._key = serialize_redis_name(f"stream:{topic}")
|
||||
self._retention_seconds = retention_seconds
|
||||
self.max_length = 5000
|
||||
|
||||
|
||||
@@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock:
|
||||
timeout=self._ttl_seconds,
|
||||
thread_local=False,
|
||||
)
|
||||
acquired = bool(self._lock.acquire(*args, **kwargs))
|
||||
lock = self._lock
|
||||
if lock is None:
|
||||
raise RuntimeError("Redis lock initialization failed.")
|
||||
acquired = bool(lock.acquire(*args, **kwargs))
|
||||
self._acquired = acquired
|
||||
if acquired:
|
||||
self._start_heartbeat()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
|
||||
@@ -168,7 +169,9 @@ class ModelProviderService:
|
||||
model_name=model,
|
||||
)
|
||||
|
||||
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
|
||||
def get_provider_credential(
|
||||
self, tenant_id: str, provider: str, credential_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
get provider credentials.
|
||||
|
||||
@@ -180,7 +183,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id)
|
||||
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate provider credentials before saving.
|
||||
|
||||
@@ -192,7 +195,7 @@ class ModelProviderService:
|
||||
provider_configuration.validate_provider_credentials(credentials)
|
||||
|
||||
def create_provider_credential(
|
||||
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
|
||||
self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create and save new provider credentials.
|
||||
@@ -210,7 +213,7 @@ class ModelProviderService:
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
@@ -254,7 +257,7 @@ class ModelProviderService:
|
||||
|
||||
def get_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
|
||||
) -> dict | None:
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve model-specific credentials.
|
||||
|
||||
@@ -270,7 +273,9 @@ class ModelProviderService:
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
|
||||
def validate_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any]
|
||||
):
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
@@ -287,7 +292,13 @@ class ModelProviderService:
|
||||
)
|
||||
|
||||
def create_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
create and save model credentials.
|
||||
@@ -314,7 +325,7 @@ class ModelProviderService:
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
) -> None:
|
||||
|
||||
@@ -33,6 +33,7 @@ REDIS_USERNAME=
|
||||
REDIS_PASSWORD=difyai123456
|
||||
REDIS_USE_SSL=false
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=
|
||||
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
|
||||
@@ -236,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.
|
||||
_ = DifyConfig().normalized_pubsub_redis_url
|
||||
|
||||
|
||||
def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
assert config.REDIS_KEY_PREFIX == ""
|
||||
|
||||
|
||||
def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a")
|
||||
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
assert config.REDIS_KEY_PREFIX == "enterprise-a"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
|
||||
[
|
||||
|
||||
@@ -7,6 +7,47 @@ from unittest.mock import MagicMock, patch
|
||||
class TestCelerySSLConfiguration:
|
||||
"""Test suite for Celery SSL configuration."""
|
||||
|
||||
def test_get_celery_broker_transport_options_includes_global_keyprefix_for_redis(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.CELERY_USE_SENTINEL = False
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import get_celery_broker_transport_options
|
||||
|
||||
result = get_celery_broker_transport_options()
|
||||
|
||||
assert result["global_keyprefix"] == "enterprise-a:"
|
||||
|
||||
def test_get_celery_broker_transport_options_omits_global_keyprefix_when_prefix_empty(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.CELERY_USE_SENTINEL = False
|
||||
mock_config.REDIS_KEY_PREFIX = " "
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import get_celery_broker_transport_options
|
||||
|
||||
result = get_celery_broker_transport_options()
|
||||
|
||||
assert "global_keyprefix" not in result
|
||||
|
||||
def test_get_celery_broker_transport_options_keeps_sentinel_and_adds_global_keyprefix(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.CELERY_USE_SENTINEL = True
|
||||
mock_config.CELERY_SENTINEL_MASTER_NAME = "mymaster"
|
||||
mock_config.CELERY_SENTINEL_SOCKET_TIMEOUT = 0.1
|
||||
mock_config.CELERY_SENTINEL_PASSWORD = "secret"
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from extensions.ext_celery import get_celery_broker_transport_options
|
||||
|
||||
result = get_celery_broker_transport_options()
|
||||
|
||||
assert result["master_name"] == "mymaster"
|
||||
assert result["sentinel_kwargs"]["password"] == "secret"
|
||||
assert result["global_keyprefix"] == "enterprise-a:"
|
||||
|
||||
def test_get_celery_ssl_options_when_ssl_disabled(self):
|
||||
"""Test SSL options when BROKER_USE_SSL is False."""
|
||||
from configs import DifyConfig
|
||||
@@ -151,3 +192,49 @@ class TestCelerySSLConfiguration:
|
||||
# Check that SSL is also applied to Redis backend
|
||||
assert "redis_backend_use_ssl" in celery_app.conf
|
||||
assert celery_app.conf["redis_backend_use_ssl"] is not None
|
||||
|
||||
def test_celery_init_applies_global_keyprefix_to_broker_and_backend_transport(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.BROKER_USE_SSL = False
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.CELERY_BACKEND = "redis"
|
||||
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
|
||||
mock_config.CELERY_USE_SENTINEL = False
|
||||
mock_config.LOG_FORMAT = "%(message)s"
|
||||
mock_config.LOG_TZ = "UTC"
|
||||
mock_config.LOG_FILE = None
|
||||
mock_config.CELERY_TASK_ANNOTATIONS = {}
|
||||
|
||||
mock_config.CELERY_BEAT_SCHEDULER_TIME = 1
|
||||
mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False
|
||||
mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False
|
||||
mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False
|
||||
mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False
|
||||
mock_config.ENABLE_CLEAN_MESSAGES = False
|
||||
mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False
|
||||
mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False
|
||||
mock_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK = False
|
||||
mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False
|
||||
mock_config.MARKETPLACE_ENABLED = False
|
||||
mock_config.WORKFLOW_LOG_CLEANUP_ENABLED = False
|
||||
mock_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK = False
|
||||
mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False
|
||||
mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1
|
||||
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
|
||||
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
|
||||
mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False
|
||||
mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
mock_config.ENTERPRISE_TELEMETRY_ENABLED = False
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_celery import init_app
|
||||
|
||||
app = DifyApp(__name__)
|
||||
celery_app = init_app(app)
|
||||
|
||||
assert celery_app.conf["broker_transport_options"]["global_keyprefix"] == "enterprise-a:"
|
||||
assert celery_app.conf["result_backend_transport_options"]["global_keyprefix"] == "enterprise-a:"
|
||||
|
||||
@@ -6,6 +6,7 @@ from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastCh
|
||||
|
||||
def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
|
||||
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
|
||||
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
|
||||
|
||||
channel = ext_redis.get_pubsub_broadcast_channel()
|
||||
|
||||
@@ -14,6 +15,7 @@ def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
|
||||
|
||||
def test_get_pubsub_broadcast_channel_sharded(monkeypatch):
|
||||
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded")
|
||||
monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object())
|
||||
|
||||
channel = ext_redis.get_pubsub_broadcast_channel()
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from redis import RedisError
|
||||
from redis.retry import Retry
|
||||
|
||||
from extensions.ext_redis import (
|
||||
RedisClientWrapper,
|
||||
_get_base_redis_params,
|
||||
_get_cluster_connection_health_params,
|
||||
_get_connection_health_params,
|
||||
_normalize_redis_key_prefix,
|
||||
_serialize_redis_name,
|
||||
redis_fallback,
|
||||
)
|
||||
|
||||
@@ -123,3 +126,99 @@ class TestRedisFallback:
|
||||
|
||||
assert test_func.__name__ == "test_func"
|
||||
assert test_func.__doc__ == "Test function docstring"
|
||||
|
||||
|
||||
class TestRedisKeyPrefixHelpers:
|
||||
def test_normalize_redis_key_prefix_trims_whitespace(self):
|
||||
assert _normalize_redis_key_prefix(" enterprise-a ") == "enterprise-a"
|
||||
|
||||
def test_normalize_redis_key_prefix_treats_whitespace_only_as_empty(self):
|
||||
assert _normalize_redis_key_prefix(" ") == ""
|
||||
|
||||
def test_serialize_redis_name_returns_original_when_prefix_empty(self):
|
||||
assert _serialize_redis_name("model_lb_index:test", "") == "model_lb_index:test"
|
||||
|
||||
def test_serialize_redis_name_adds_single_colon_separator(self):
|
||||
assert _serialize_redis_name("model_lb_index:test", "enterprise-a") == "enterprise-a:model_lb_index:test"
|
||||
|
||||
|
||||
class TestRedisClientWrapperKeyPrefix:
|
||||
def test_wrapper_get_prefixes_string_keys(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.get("oauth_state:abc")
|
||||
|
||||
mock_client.get.assert_called_once_with("enterprise-a:oauth_state:abc")
|
||||
|
||||
def test_wrapper_delete_prefixes_multiple_keys(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.delete("key:a", "key:b")
|
||||
|
||||
mock_client.delete.assert_called_once_with("enterprise-a:key:a", "enterprise-a:key:b")
|
||||
|
||||
def test_wrapper_lock_prefixes_lock_name(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.lock("resource-lock", timeout=10)
|
||||
|
||||
mock_client.lock.assert_called_once()
|
||||
args, kwargs = mock_client.lock.call_args
|
||||
assert args == ("enterprise-a:resource-lock",)
|
||||
assert kwargs["timeout"] == 10
|
||||
|
||||
def test_wrapper_hash_operations_prefix_key_name(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.hset("hash:key", "field", "value")
|
||||
wrapper.hgetall("hash:key")
|
||||
|
||||
mock_client.hset.assert_called_once_with("enterprise-a:hash:key", "field", "value")
|
||||
mock_client.hgetall.assert_called_once_with("enterprise-a:hash:key")
|
||||
|
||||
def test_wrapper_zadd_prefixes_sorted_set_name(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
wrapper.zadd("zset:key", {"member": 1})
|
||||
|
||||
mock_client.zadd.assert_called_once()
|
||||
args, kwargs = mock_client.zadd.call_args
|
||||
assert args == ("enterprise-a:zset:key", {"member": 1})
|
||||
assert kwargs["nx"] is False
|
||||
|
||||
def test_wrapper_preserves_keys_when_prefix_is_empty(self):
|
||||
mock_client = MagicMock()
|
||||
wrapper = RedisClientWrapper()
|
||||
wrapper.initialize(mock_client)
|
||||
|
||||
with patch("extensions.ext_redis.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = " "
|
||||
|
||||
wrapper.get("plain:key")
|
||||
|
||||
mock_client.get.assert_called_once_with("plain:key")
|
||||
|
||||
@@ -139,6 +139,28 @@ class TestTopic:
|
||||
|
||||
mock_redis_client.publish.assert_called_once_with("test-topic", payload)
|
||||
|
||||
def test_publish_prefixes_regular_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
topic = Topic(mock_redis_client, "test-topic")
|
||||
|
||||
topic.publish(b"test message")
|
||||
|
||||
mock_redis_client.publish.assert_called_once_with("enterprise-a:test-topic", b"test message")
|
||||
|
||||
def test_subscribe_prefixes_regular_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
topic = Topic(mock_redis_client, "test-topic")
|
||||
|
||||
subscription = topic.subscribe()
|
||||
try:
|
||||
subscription._start_if_needed()
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
mock_redis_client.pubsub.return_value.subscribe.assert_called_once_with("enterprise-a:test-topic")
|
||||
|
||||
|
||||
class TestShardedTopic:
|
||||
"""Test cases for the ShardedTopic class."""
|
||||
@@ -176,6 +198,15 @@ class TestShardedTopic:
|
||||
|
||||
mock_redis_client.spublish.assert_called_once_with("test-sharded-topic", payload)
|
||||
|
||||
def test_publish_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
|
||||
|
||||
sharded_topic.publish(b"test sharded message")
|
||||
|
||||
mock_redis_client.spublish.assert_called_once_with("enterprise-a:test-sharded-topic", b"test sharded message")
|
||||
|
||||
def test_subscribe_returns_sharded_subscription(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock):
|
||||
"""Test that subscribe() returns a _RedisShardedSubscription instance."""
|
||||
subscription = sharded_topic.subscribe()
|
||||
@@ -185,6 +216,19 @@ class TestShardedTopic:
|
||||
assert subscription._pubsub is mock_redis_client.pubsub.return_value
|
||||
assert subscription._topic == "test-sharded-topic"
|
||||
|
||||
def test_subscribe_prefixes_sharded_topic(self, mock_redis_client: MagicMock):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic")
|
||||
|
||||
subscription = sharded_topic.subscribe()
|
||||
try:
|
||||
subscription._start_if_needed()
|
||||
finally:
|
||||
subscription.close()
|
||||
|
||||
mock_redis_client.pubsub.return_value.ssubscribe.assert_called_once_with("enterprise-a:test-sharded-topic")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SubscriptionTestCase:
|
||||
|
||||
@@ -2,6 +2,7 @@ import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -150,6 +151,25 @@ class TestStreamsBroadcastChannel:
|
||||
# Expire called after publish
|
||||
assert fake_redis._expire_calls.get("stream:beta", 0) >= 1
|
||||
|
||||
def test_topic_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
|
||||
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("alpha")
|
||||
|
||||
assert topic._topic == "alpha"
|
||||
assert topic._key == "enterprise-a:stream:alpha"
|
||||
|
||||
def test_publish_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis):
|
||||
with patch("extensions.redis_names.dify_config") as mock_config:
|
||||
mock_config.REDIS_KEY_PREFIX = "enterprise-a"
|
||||
topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("beta")
|
||||
|
||||
topic.publish(b"hello")
|
||||
|
||||
assert fake_redis._store["enterprise-a:stream:beta"][0][1] == {b"data": b"hello"}
|
||||
assert fake_redis._expire_calls.get("enterprise-a:stream:beta", 0) >= 1
|
||||
|
||||
def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel):
|
||||
topic = streams_channel.topic("producer-subscriber")
|
||||
|
||||
|
||||
@@ -351,6 +351,9 @@ REDIS_SSL_CERTFILE=
|
||||
REDIS_SSL_KEYFILE=
|
||||
# Path to client private key file for SSL authentication
|
||||
REDIS_DB=0
|
||||
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
# Leave empty to preserve current unprefixed behavior.
|
||||
REDIS_KEY_PREFIX=
|
||||
# Optional: limit total Redis connections used by API/Worker (unset for default)
|
||||
# Align with API's REDIS_MAX_CONNECTIONS in configs
|
||||
REDIS_MAX_CONNECTIONS=
|
||||
|
||||
@@ -88,6 +88,7 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
|
||||
1. **Redis Configuration**:
|
||||
|
||||
- `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings.
|
||||
- `REDIS_KEY_PREFIX`: Optional global namespace prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
|
||||
|
||||
1. **Celery Configuration**:
|
||||
|
||||
|
||||
@@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env
|
||||
REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-}
|
||||
REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-}
|
||||
REDIS_DB: ${REDIS_DB:-0}
|
||||
REDIS_KEY_PREFIX: ${REDIS_KEY_PREFIX:-}
|
||||
REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-}
|
||||
REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false}
|
||||
REDIS_SENTINELS: ${REDIS_SENTINELS:-}
|
||||
|
||||
@@ -78,6 +78,7 @@ describe('tool/tool-form/item', () => {
|
||||
mockUseLanguage.mockReturnValue('en_US')
|
||||
})
|
||||
|
||||
// Text input fields render their descriptions inline above the input.
|
||||
it('should render text input labels and forward props to form input item', () => {
|
||||
const handleChange = vi.fn()
|
||||
const handleManageInputField = vi.fn()
|
||||
@@ -121,6 +122,31 @@ describe('tool/tool-form/item', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// URL fragments inside descriptions should be rendered as external links.
|
||||
it('should render URLs in descriptions as external links', () => {
|
||||
render(
|
||||
<ToolFormItem
|
||||
readOnly={false}
|
||||
nodeId="tool-node"
|
||||
schema={createSchema({
|
||||
tooltip: {
|
||||
en_US: 'Visit https://docs.dify.ai/tools for docs',
|
||||
zh_Hans: 'Visit https://docs.dify.ai/tools for docs',
|
||||
},
|
||||
})}
|
||||
value={{}}
|
||||
onChange={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
const link = screen.getByRole('link', { name: 'https://docs.dify.ai/tools' })
|
||||
expect(link).toHaveAttribute('href', 'https://docs.dify.ai/tools')
|
||||
expect(link).toHaveAttribute('target', '_blank')
|
||||
expect(link).toHaveAttribute('rel', 'noopener noreferrer')
|
||||
expect(link.parentElement).toHaveTextContent('Visit https://docs.dify.ai/tools for docs')
|
||||
})
|
||||
|
||||
// Non-text fields keep their descriptions inside the tooltip and support JSON schema preview.
|
||||
it('should show tooltip for non-description fields and open the schema modal', () => {
|
||||
const objectSchema = createSchema({
|
||||
name: 'tool_config',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import type { ToolVarInputs } from '../../types'
|
||||
import type { CredentialFormSchema } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Tool } from '@/app/components/tools/types'
|
||||
@@ -15,6 +15,45 @@ import { useLanguage } from '@/app/components/header/account-setting/model-provi
|
||||
import { SchemaModal } from '@/app/components/plugins/plugin-detail-panel/tool-selector/components'
|
||||
import FormInputItem from '@/app/components/workflow/nodes/_base/components/form-input-item'
|
||||
|
||||
const URL_REGEX = /(https?:\/\/\S+)/g
|
||||
|
||||
const renderDescriptionWithLinks = (description: string): ReactNode => {
|
||||
const matches = [...description.matchAll(URL_REGEX)]
|
||||
|
||||
if (!matches.length)
|
||||
return description
|
||||
|
||||
const parts: ReactNode[] = []
|
||||
let currentIndex = 0
|
||||
|
||||
matches.forEach((match, index) => {
|
||||
const [url] = match
|
||||
const start = match.index ?? 0
|
||||
|
||||
if (start > currentIndex)
|
||||
parts.push(description.slice(currentIndex, start))
|
||||
|
||||
parts.push(
|
||||
<a
|
||||
key={`${url}-${index}`}
|
||||
href={url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-text-accent hover:underline"
|
||||
>
|
||||
{url}
|
||||
</a>,
|
||||
)
|
||||
|
||||
currentIndex = start + url.length
|
||||
})
|
||||
|
||||
if (currentIndex < description.length)
|
||||
parts.push(description.slice(currentIndex))
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
type Props = {
|
||||
readOnly: boolean
|
||||
nodeId: string
|
||||
@@ -87,7 +126,9 @@ const ToolFormItem: FC<Props> = ({
|
||||
)}
|
||||
</div>
|
||||
{showDescription && tooltip && (
|
||||
<div className="body-xs-regular pb-0.5 text-text-tertiary">{tooltip[language] || tooltip.en_US}</div>
|
||||
<div className="body-xs-regular break-words pb-0.5 text-text-tertiary">
|
||||
{renderDescriptionWithLinks(tooltip[language] || tooltip.en_US)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<FormInputItem
|
||||
|
||||
Reference in New Issue
Block a user