mirror of
https://github.com/langgenius/dify.git
synced 2025-12-23 15:57:29 +00:00
Compare commits
5 Commits
mysql-adap
...
release/e-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f93adf9193 | ||
|
|
ce2d000c5d | ||
|
|
80b017c94c | ||
|
|
bd0b7d399f | ||
|
|
1f988b0ada |
@@ -42,6 +42,15 @@ REDIS_PORT=6379
|
|||||||
REDIS_USERNAME=
|
REDIS_USERNAME=
|
||||||
REDIS_PASSWORD=difyai123456
|
REDIS_PASSWORD=difyai123456
|
||||||
REDIS_USE_SSL=false
|
REDIS_USE_SSL=false
|
||||||
|
# SSL configuration for Redis (when REDIS_USE_SSL=true)
|
||||||
|
REDIS_SSL_CERT_REQS=CERT_NONE
|
||||||
|
# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
|
||||||
|
REDIS_SSL_CA_CERTS=
|
||||||
|
# Path to CA certificate file for SSL verification
|
||||||
|
REDIS_SSL_CERTFILE=
|
||||||
|
# Path to client certificate file for SSL authentication
|
||||||
|
REDIS_SSL_KEYFILE=
|
||||||
|
# Path to client private key file for SSL authentication
|
||||||
REDIS_DB=0
|
REDIS_DB=0
|
||||||
|
|
||||||
# redis Sentinel configuration.
|
# redis Sentinel configuration.
|
||||||
|
|||||||
20
api/configs/middleware/cache/redis_config.py
vendored
20
api/configs/middleware/cache/redis_config.py
vendored
@@ -39,6 +39,26 @@ class RedisConfig(BaseSettings):
|
|||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REDIS_SSL_CERT_REQS: str = Field(
|
||||||
|
description="SSL certificate requirements (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)",
|
||||||
|
default="CERT_NONE",
|
||||||
|
)
|
||||||
|
|
||||||
|
REDIS_SSL_CA_CERTS: Optional[str] = Field(
|
||||||
|
description="Path to the CA certificate file for SSL verification",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
REDIS_SSL_CERTFILE: Optional[str] = Field(
|
||||||
|
description="Path to the client certificate file for SSL authentication",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
REDIS_SSL_KEYFILE: Optional[str] = Field(
|
||||||
|
description="Path to the client private key file for SSL authentication",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
||||||
description="Enable Redis Sentinel mode for high availability",
|
description="Enable Redis Sentinel mode for high availability",
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
|
import ssl
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
from celery import Celery, Task # type: ignore
|
from celery import Celery, Task # type: ignore
|
||||||
@@ -8,6 +10,40 @@ from configs import dify_config
|
|||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
|
||||||
|
|
||||||
|
def _get_celery_ssl_options() -> Optional[dict[str, Any]]:
|
||||||
|
"""Get SSL configuration for Celery broker/backend connections."""
|
||||||
|
# Use REDIS_USE_SSL for consistency with the main Redis client
|
||||||
|
# Only apply SSL if we're using Redis as broker/backend
|
||||||
|
if not dify_config.REDIS_USE_SSL:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if Celery is actually using Redis
|
||||||
|
broker_is_redis = dify_config.CELERY_BROKER_URL and (
|
||||||
|
dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not broker_is_redis:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Map certificate requirement strings to SSL constants
|
||||||
|
cert_reqs_map = {
|
||||||
|
"CERT_NONE": ssl.CERT_NONE,
|
||||||
|
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
|
||||||
|
"CERT_REQUIRED": ssl.CERT_REQUIRED,
|
||||||
|
}
|
||||||
|
|
||||||
|
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||||
|
|
||||||
|
ssl_options = {
|
||||||
|
"ssl_cert_reqs": ssl_cert_reqs,
|
||||||
|
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||||
|
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||||
|
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssl_options
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp) -> Celery:
|
def init_app(app: DifyApp) -> Celery:
|
||||||
class FlaskTask(Task):
|
class FlaskTask(Task):
|
||||||
def __call__(self, *args: object, **kwargs: object) -> object:
|
def __call__(self, *args: object, **kwargs: object) -> object:
|
||||||
@@ -33,14 +69,6 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
task_ignore_result=True,
|
task_ignore_result=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add SSL options to the Celery configuration
|
|
||||||
ssl_options = {
|
|
||||||
"ssl_cert_reqs": None,
|
|
||||||
"ssl_ca_certs": None,
|
|
||||||
"ssl_certfile": None,
|
|
||||||
"ssl_keyfile": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
celery_app.conf.update(
|
celery_app.conf.update(
|
||||||
result_backend=dify_config.CELERY_RESULT_BACKEND,
|
result_backend=dify_config.CELERY_RESULT_BACKEND,
|
||||||
broker_transport_options=broker_transport_options,
|
broker_transport_options=broker_transport_options,
|
||||||
@@ -51,9 +79,13 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
|
timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if dify_config.BROKER_USE_SSL:
|
# Apply SSL configuration if enabled
|
||||||
|
ssl_options = _get_celery_ssl_options()
|
||||||
|
if ssl_options:
|
||||||
celery_app.conf.update(
|
celery_app.conf.update(
|
||||||
broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration
|
broker_use_ssl=ssl_options,
|
||||||
|
# Also apply SSL to the backend if it's Redis
|
||||||
|
redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dify_config.LOG_FILE:
|
if dify_config.LOG_FILE:
|
||||||
@@ -113,7 +145,7 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
|
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:
|
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||||
imports.append("schedule.check_upgradable_plugin_task")
|
imports.append("schedule.check_upgradable_plugin_task")
|
||||||
beat_schedule["check_upgradable_plugin_task"] = {
|
beat_schedule["check_upgradable_plugin_task"] = {
|
||||||
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
|
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import ssl
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
@@ -53,73 +54,132 @@ class RedisClientWrapper:
|
|||||||
redis_client = RedisClientWrapper()
|
redis_client = RedisClientWrapper()
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||||
global redis_client
|
"""Get SSL configuration for Redis connection."""
|
||||||
connection_class: type[Union[Connection, SSLConnection]] = Connection
|
if not dify_config.REDIS_USE_SSL:
|
||||||
if dify_config.REDIS_USE_SSL:
|
return Connection, {}
|
||||||
connection_class = SSLConnection
|
|
||||||
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
|
|
||||||
if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
|
|
||||||
if resp_protocol >= 3:
|
|
||||||
clientside_cache_config = CacheConfig()
|
|
||||||
else:
|
|
||||||
raise ValueError("Client side cache is only supported in RESP3")
|
|
||||||
else:
|
|
||||||
clientside_cache_config = None
|
|
||||||
|
|
||||||
redis_params: dict[str, Any] = {
|
cert_reqs_map = {
|
||||||
|
"CERT_NONE": ssl.CERT_NONE,
|
||||||
|
"CERT_OPTIONAL": ssl.CERT_OPTIONAL,
|
||||||
|
"CERT_REQUIRED": ssl.CERT_REQUIRED,
|
||||||
|
}
|
||||||
|
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||||
|
|
||||||
|
ssl_kwargs = {
|
||||||
|
"ssl_cert_reqs": ssl_cert_reqs,
|
||||||
|
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||||
|
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||||
|
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||||
|
}
|
||||||
|
|
||||||
|
return SSLConnection, ssl_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_configuration() -> CacheConfig | None:
|
||||||
|
"""Get client-side cache configuration if enabled."""
|
||||||
|
if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
|
||||||
|
if resp_protocol < 3:
|
||||||
|
raise ValueError("Client side cache is only supported in RESP3")
|
||||||
|
|
||||||
|
return CacheConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_redis_params() -> dict[str, Any]:
|
||||||
|
"""Get base Redis connection parameters."""
|
||||||
|
return {
|
||||||
"username": dify_config.REDIS_USERNAME,
|
"username": dify_config.REDIS_USERNAME,
|
||||||
"password": dify_config.REDIS_PASSWORD or None, # Temporary fix for empty password
|
"password": dify_config.REDIS_PASSWORD or None,
|
||||||
"db": dify_config.REDIS_DB,
|
"db": dify_config.REDIS_DB,
|
||||||
"encoding": "utf-8",
|
"encoding": "utf-8",
|
||||||
"encoding_errors": "strict",
|
"encoding_errors": "strict",
|
||||||
"decode_responses": False,
|
"decode_responses": False,
|
||||||
"protocol": resp_protocol,
|
"protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||||
"cache_config": clientside_cache_config,
|
"cache_config": _get_cache_configuration(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if dify_config.REDIS_USE_SENTINEL:
|
|
||||||
assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True"
|
|
||||||
sentinel_hosts = [
|
|
||||||
(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")
|
|
||||||
]
|
|
||||||
sentinel = Sentinel(
|
|
||||||
sentinel_hosts,
|
|
||||||
sentinel_kwargs={
|
|
||||||
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
|
||||||
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
|
||||||
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
|
||||||
redis_client.initialize(master)
|
|
||||||
elif dify_config.REDIS_USE_CLUSTERS:
|
|
||||||
assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True"
|
|
||||||
nodes = [
|
|
||||||
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
|
|
||||||
for node in dify_config.REDIS_CLUSTERS.split(",")
|
|
||||||
]
|
|
||||||
redis_client.initialize(
|
|
||||||
RedisCluster(
|
|
||||||
startup_nodes=nodes,
|
|
||||||
password=dify_config.REDIS_CLUSTERS_PASSWORD,
|
|
||||||
protocol=resp_protocol,
|
|
||||||
cache_config=clientside_cache_config,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
redis_params.update(
|
|
||||||
{
|
|
||||||
"host": dify_config.REDIS_HOST,
|
|
||||||
"port": dify_config.REDIS_PORT,
|
|
||||||
"connection_class": connection_class,
|
|
||||||
"protocol": resp_protocol,
|
|
||||||
"cache_config": clientside_cache_config,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
pool = redis.ConnectionPool(**redis_params)
|
|
||||||
redis_client.initialize(redis.Redis(connection_pool=pool))
|
|
||||||
|
|
||||||
|
def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||||
|
"""Create Redis client using Sentinel configuration."""
|
||||||
|
if not dify_config.REDIS_SENTINELS:
|
||||||
|
raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True")
|
||||||
|
|
||||||
|
if not dify_config.REDIS_SENTINEL_SERVICE_NAME:
|
||||||
|
raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True")
|
||||||
|
|
||||||
|
sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")]
|
||||||
|
|
||||||
|
sentinel = Sentinel(
|
||||||
|
sentinel_hosts,
|
||||||
|
sentinel_kwargs={
|
||||||
|
"socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT,
|
||||||
|
"username": dify_config.REDIS_SENTINEL_USERNAME,
|
||||||
|
"password": dify_config.REDIS_SENTINEL_PASSWORD,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params)
|
||||||
|
return master
|
||||||
|
|
||||||
|
|
||||||
|
def _create_cluster_client() -> Union[redis.Redis, RedisCluster]:
|
||||||
|
"""Create Redis cluster client."""
|
||||||
|
if not dify_config.REDIS_CLUSTERS:
|
||||||
|
raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True")
|
||||||
|
|
||||||
|
nodes = [
|
||||||
|
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
|
||||||
|
for node in dify_config.REDIS_CLUSTERS.split(",")
|
||||||
|
]
|
||||||
|
|
||||||
|
cluster: RedisCluster = RedisCluster(
|
||||||
|
startup_nodes=nodes,
|
||||||
|
password=dify_config.REDIS_CLUSTERS_PASSWORD,
|
||||||
|
protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL,
|
||||||
|
cache_config=_get_cache_configuration(),
|
||||||
|
)
|
||||||
|
return cluster
|
||||||
|
|
||||||
|
|
||||||
|
def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]:
|
||||||
|
"""Create standalone Redis client."""
|
||||||
|
connection_class, ssl_kwargs = _get_ssl_configuration()
|
||||||
|
|
||||||
|
redis_params.update(
|
||||||
|
{
|
||||||
|
"host": dify_config.REDIS_HOST,
|
||||||
|
"port": dify_config.REDIS_PORT,
|
||||||
|
"connection_class": connection_class,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if ssl_kwargs:
|
||||||
|
redis_params.update(ssl_kwargs)
|
||||||
|
|
||||||
|
pool = redis.ConnectionPool(**redis_params)
|
||||||
|
client: redis.Redis = redis.Redis(connection_pool=pool)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def init_app(app: DifyApp):
|
||||||
|
"""Initialize Redis client and attach it to the app."""
|
||||||
|
global redis_client
|
||||||
|
|
||||||
|
# Determine Redis mode and create appropriate client
|
||||||
|
if dify_config.REDIS_USE_SENTINEL:
|
||||||
|
redis_params = _get_base_redis_params()
|
||||||
|
client = _create_sentinel_client(redis_params)
|
||||||
|
elif dify_config.REDIS_USE_CLUSTERS:
|
||||||
|
client = _create_cluster_client()
|
||||||
|
else:
|
||||||
|
redis_params = _get_base_redis_params()
|
||||||
|
client = _create_standalone_client(redis_params)
|
||||||
|
|
||||||
|
# Initialize the wrapper and attach to app
|
||||||
|
redis_client.initialize(client)
|
||||||
app.extensions["redis"] = redis_client
|
app.extensions["redis"] = redis_client
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ class OAuthProxyService(BasePluginClient):
|
|||||||
if not context_id:
|
if not context_id:
|
||||||
raise ValueError("context_id is required")
|
raise ValueError("context_id is required")
|
||||||
# get data from redis
|
# get data from redis
|
||||||
data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}")
|
key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}"
|
||||||
|
data = redis_client.get(key)
|
||||||
if not data:
|
if not data:
|
||||||
raise ValueError("context_id is invalid")
|
raise ValueError("context_id is invalid")
|
||||||
|
redis_client.delete(key)
|
||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
|
|||||||
149
api/tests/unit_tests/extensions/test_celery_ssl.py
Normal file
149
api/tests/unit_tests/extensions/test_celery_ssl.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Tests for Celery SSL configuration."""
|
||||||
|
|
||||||
|
import ssl
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestCelerySSLConfiguration:
|
||||||
|
"""Test suite for Celery SSL configuration."""
|
||||||
|
|
||||||
|
def test_get_celery_ssl_options_when_ssl_disabled(self):
|
||||||
|
"""Test SSL options when REDIS_USE_SSL is False."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.REDIS_USE_SSL = False
|
||||||
|
|
||||||
|
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||||
|
from extensions.ext_celery import _get_celery_ssl_options
|
||||||
|
|
||||||
|
result = _get_celery_ssl_options()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_celery_ssl_options_when_broker_not_redis(self):
|
||||||
|
"""Test SSL options when broker is not Redis."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.REDIS_USE_SSL = True
|
||||||
|
mock_config.CELERY_BROKER_URL = "amqp://localhost:5672"
|
||||||
|
|
||||||
|
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||||
|
from extensions.ext_celery import _get_celery_ssl_options
|
||||||
|
|
||||||
|
result = _get_celery_ssl_options()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_celery_ssl_options_with_cert_none(self):
|
||||||
|
"""Test SSL options with CERT_NONE requirement."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.REDIS_USE_SSL = True
|
||||||
|
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||||
|
mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE"
|
||||||
|
mock_config.REDIS_SSL_CA_CERTS = None
|
||||||
|
mock_config.REDIS_SSL_CERTFILE = None
|
||||||
|
mock_config.REDIS_SSL_KEYFILE = None
|
||||||
|
|
||||||
|
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||||
|
from extensions.ext_celery import _get_celery_ssl_options
|
||||||
|
|
||||||
|
result = _get_celery_ssl_options()
|
||||||
|
assert result is not None
|
||||||
|
assert result["ssl_cert_reqs"] == ssl.CERT_NONE
|
||||||
|
assert result["ssl_ca_certs"] is None
|
||||||
|
assert result["ssl_certfile"] is None
|
||||||
|
assert result["ssl_keyfile"] is None
|
||||||
|
|
||||||
|
def test_get_celery_ssl_options_with_cert_required(self):
|
||||||
|
"""Test SSL options with CERT_REQUIRED and certificates."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.REDIS_USE_SSL = True
|
||||||
|
mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0"
|
||||||
|
mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED"
|
||||||
|
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
|
||||||
|
mock_config.REDIS_SSL_CERTFILE = "/path/to/client.crt"
|
||||||
|
mock_config.REDIS_SSL_KEYFILE = "/path/to/client.key"
|
||||||
|
|
||||||
|
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||||
|
from extensions.ext_celery import _get_celery_ssl_options
|
||||||
|
|
||||||
|
result = _get_celery_ssl_options()
|
||||||
|
assert result is not None
|
||||||
|
assert result["ssl_cert_reqs"] == ssl.CERT_REQUIRED
|
||||||
|
assert result["ssl_ca_certs"] == "/path/to/ca.crt"
|
||||||
|
assert result["ssl_certfile"] == "/path/to/client.crt"
|
||||||
|
assert result["ssl_keyfile"] == "/path/to/client.key"
|
||||||
|
|
||||||
|
def test_get_celery_ssl_options_with_cert_optional(self):
|
||||||
|
"""Test SSL options with CERT_OPTIONAL requirement."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.REDIS_USE_SSL = True
|
||||||
|
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||||
|
mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL"
|
||||||
|
mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt"
|
||||||
|
mock_config.REDIS_SSL_CERTFILE = None
|
||||||
|
mock_config.REDIS_SSL_KEYFILE = None
|
||||||
|
|
||||||
|
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||||
|
from extensions.ext_celery import _get_celery_ssl_options
|
||||||
|
|
||||||
|
result = _get_celery_ssl_options()
|
||||||
|
assert result is not None
|
||||||
|
assert result["ssl_cert_reqs"] == ssl.CERT_OPTIONAL
|
||||||
|
assert result["ssl_ca_certs"] == "/path/to/ca.crt"
|
||||||
|
|
||||||
|
def test_get_celery_ssl_options_with_invalid_cert_reqs(self):
|
||||||
|
"""Test SSL options with invalid cert requirement defaults to CERT_NONE."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.REDIS_USE_SSL = True
|
||||||
|
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||||
|
mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE"
|
||||||
|
mock_config.REDIS_SSL_CA_CERTS = None
|
||||||
|
mock_config.REDIS_SSL_CERTFILE = None
|
||||||
|
mock_config.REDIS_SSL_KEYFILE = None
|
||||||
|
|
||||||
|
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||||
|
from extensions.ext_celery import _get_celery_ssl_options
|
||||||
|
|
||||||
|
result = _get_celery_ssl_options()
|
||||||
|
assert result is not None
|
||||||
|
assert result["ssl_cert_reqs"] == ssl.CERT_NONE # Should default to CERT_NONE
|
||||||
|
|
||||||
|
def test_celery_init_applies_ssl_to_broker_and_backend(self):
|
||||||
|
"""Test that SSL options are applied to both broker and backend when using Redis."""
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.REDIS_USE_SSL = True
|
||||||
|
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||||
|
mock_config.CELERY_BACKEND = "redis"
|
||||||
|
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
|
||||||
|
mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE"
|
||||||
|
mock_config.REDIS_SSL_CA_CERTS = None
|
||||||
|
mock_config.REDIS_SSL_CERTFILE = None
|
||||||
|
mock_config.REDIS_SSL_KEYFILE = None
|
||||||
|
mock_config.CELERY_USE_SENTINEL = False
|
||||||
|
mock_config.LOG_FORMAT = "%(message)s"
|
||||||
|
mock_config.LOG_TZ = "UTC"
|
||||||
|
mock_config.LOG_FILE = None
|
||||||
|
|
||||||
|
# Mock all the scheduler configs
|
||||||
|
mock_config.CELERY_BEAT_SCHEDULER_TIME = 1
|
||||||
|
mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False
|
||||||
|
mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False
|
||||||
|
mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False
|
||||||
|
mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False
|
||||||
|
mock_config.ENABLE_CLEAN_MESSAGES = False
|
||||||
|
mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False
|
||||||
|
mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False
|
||||||
|
mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False
|
||||||
|
|
||||||
|
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||||
|
from dify_app import DifyApp
|
||||||
|
from extensions.ext_celery import init_app
|
||||||
|
|
||||||
|
app = DifyApp(__name__)
|
||||||
|
celery_app = init_app(app)
|
||||||
|
|
||||||
|
# Check that SSL options were applied
|
||||||
|
assert "broker_use_ssl" in celery_app.conf
|
||||||
|
assert celery_app.conf["broker_use_ssl"] is not None
|
||||||
|
assert celery_app.conf["broker_use_ssl"]["ssl_cert_reqs"] == ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Check that SSL is also applied to Redis backend
|
||||||
|
assert "redis_backend_use_ssl" in celery_app.conf
|
||||||
|
assert celery_app.conf["redis_backend_use_ssl"] is not None
|
||||||
@@ -264,6 +264,15 @@ REDIS_PORT=6379
|
|||||||
REDIS_USERNAME=
|
REDIS_USERNAME=
|
||||||
REDIS_PASSWORD=difyai123456
|
REDIS_PASSWORD=difyai123456
|
||||||
REDIS_USE_SSL=false
|
REDIS_USE_SSL=false
|
||||||
|
# SSL configuration for Redis (when REDIS_USE_SSL=true)
|
||||||
|
REDIS_SSL_CERT_REQS=CERT_NONE
|
||||||
|
# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
|
||||||
|
REDIS_SSL_CA_CERTS=
|
||||||
|
# Path to CA certificate file for SSL verification
|
||||||
|
REDIS_SSL_CERTFILE=
|
||||||
|
# Path to client certificate file for SSL authentication
|
||||||
|
REDIS_SSL_KEYFILE=
|
||||||
|
# Path to client private key file for SSL authentication
|
||||||
REDIS_DB=0
|
REDIS_DB=0
|
||||||
|
|
||||||
# Whether to use Redis Sentinel mode.
|
# Whether to use Redis Sentinel mode.
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ x-shared-env: &shared-api-worker-env
|
|||||||
REDIS_USERNAME: ${REDIS_USERNAME:-}
|
REDIS_USERNAME: ${REDIS_USERNAME:-}
|
||||||
REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456}
|
REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456}
|
||||||
REDIS_USE_SSL: ${REDIS_USE_SSL:-false}
|
REDIS_USE_SSL: ${REDIS_USE_SSL:-false}
|
||||||
|
REDIS_SSL_CERT_REQS: ${REDIS_SSL_CERT_REQS:-CERT_NONE}
|
||||||
|
REDIS_SSL_CA_CERTS: ${REDIS_SSL_CA_CERTS:-}
|
||||||
|
REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-}
|
||||||
|
REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-}
|
||||||
REDIS_DB: ${REDIS_DB:-0}
|
REDIS_DB: ${REDIS_DB:-0}
|
||||||
REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false}
|
REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false}
|
||||||
REDIS_SENTINELS: ${REDIS_SENTINELS:-}
|
REDIS_SENTINELS: ${REDIS_SENTINELS:-}
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ import { convertUTCDaySecondsToLocalSeconds, timeOfDayToDayjs } from '../referen
|
|||||||
import useReferenceSetting from '../plugin-page/use-reference-setting'
|
import useReferenceSetting from '../plugin-page/use-reference-setting'
|
||||||
import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types'
|
import { AUTO_UPDATE_MODE } from '../reference-setting-modal/auto-update-setting/types'
|
||||||
import { useAppContext } from '@/context/app-context'
|
import { useAppContext } from '@/context/app-context'
|
||||||
|
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||||
|
|
||||||
const i18nPrefix = 'plugin.action'
|
const i18nPrefix = 'plugin.action'
|
||||||
|
|
||||||
@@ -69,6 +70,7 @@ const DetailHeader = ({
|
|||||||
const { setShowUpdatePluginModal } = useModalContext()
|
const { setShowUpdatePluginModal } = useModalContext()
|
||||||
const { refreshModelProviders } = useProviderContext()
|
const { refreshModelProviders } = useProviderContext()
|
||||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||||
|
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
installation_id,
|
installation_id,
|
||||||
@@ -122,6 +124,8 @@ const DetailHeader = ({
|
|||||||
const { referenceSetting } = useReferenceSetting()
|
const { referenceSetting } = useReferenceSetting()
|
||||||
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
const { auto_upgrade: autoUpgradeInfo } = referenceSetting || {}
|
||||||
const isAutoUpgradeEnabled = useMemo(() => {
|
const isAutoUpgradeEnabled = useMemo(() => {
|
||||||
|
if (!enable_marketplace)
|
||||||
|
return false
|
||||||
if (!autoUpgradeInfo || !isFromMarketplace)
|
if (!autoUpgradeInfo || !isFromMarketplace)
|
||||||
return false
|
return false
|
||||||
if(autoUpgradeInfo.strategy_setting === 'disabled')
|
if(autoUpgradeInfo.strategy_setting === 'disabled')
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import { PermissionType } from '@/app/components/plugins/types'
|
|||||||
import type { AutoUpdateConfig } from './auto-update-setting/types'
|
import type { AutoUpdateConfig } from './auto-update-setting/types'
|
||||||
import AutoUpdateSetting from './auto-update-setting'
|
import AutoUpdateSetting from './auto-update-setting'
|
||||||
import { defaultValue as autoUpdateDefaultValue } from './auto-update-setting/config'
|
import { defaultValue as autoUpdateDefaultValue } from './auto-update-setting/config'
|
||||||
|
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||||
import Label from './label'
|
import Label from './label'
|
||||||
|
|
||||||
const i18nPrefix = 'plugin.privilege'
|
const i18nPrefix = 'plugin.privilege'
|
||||||
@@ -28,6 +29,7 @@ const PluginSettingModal: FC<Props> = ({
|
|||||||
const { auto_upgrade: autoUpdateConfig, permission: privilege } = payload || {}
|
const { auto_upgrade: autoUpdateConfig, permission: privilege } = payload || {}
|
||||||
const [tempPrivilege, setTempPrivilege] = useState<Permissions>(privilege)
|
const [tempPrivilege, setTempPrivilege] = useState<Permissions>(privilege)
|
||||||
const [tempAutoUpdateConfig, setTempAutoUpdateConfig] = useState<AutoUpdateConfig>(autoUpdateConfig || autoUpdateDefaultValue)
|
const [tempAutoUpdateConfig, setTempAutoUpdateConfig] = useState<AutoUpdateConfig>(autoUpdateConfig || autoUpdateDefaultValue)
|
||||||
|
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||||
const handlePrivilegeChange = useCallback((key: string) => {
|
const handlePrivilegeChange = useCallback((key: string) => {
|
||||||
return (value: PermissionType) => {
|
return (value: PermissionType) => {
|
||||||
setTempPrivilege({
|
setTempPrivilege({
|
||||||
@@ -77,8 +79,11 @@ const PluginSettingModal: FC<Props> = ({
|
|||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
{
|
||||||
<AutoUpdateSetting payload={tempAutoUpdateConfig} onChange={setTempAutoUpdateConfig} />
|
enable_marketplace && (
|
||||||
|
<AutoUpdateSetting payload={tempAutoUpdateConfig} onChange={setTempAutoUpdateConfig} />
|
||||||
|
)
|
||||||
|
}
|
||||||
<div className='flex h-[76px] items-center justify-end gap-2 self-stretch p-6 pt-5'>
|
<div className='flex h-[76px] items-center justify-end gap-2 self-stretch p-6 pt-5'>
|
||||||
<Button
|
<Button
|
||||||
className='min-w-[72px]'
|
className='min-w-[72px]'
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ export const checkOrSetAccessToken = async (appCode?: string | null) => {
|
|||||||
[userId || 'DEFAULT']: res.access_token,
|
[userId || 'DEFAULT']: res.access_token,
|
||||||
}
|
}
|
||||||
localStorage.setItem('token', JSON.stringify(accessTokenJson))
|
localStorage.setItem('token', JSON.stringify(accessTokenJson))
|
||||||
|
localStorage.removeItem(CONVERSATION_ID_INFO)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import type { FC, PropsWithChildren } from 'react'
|
|||||||
import { useEffect } from 'react'
|
import { useEffect } from 'react'
|
||||||
import { useState } from 'react'
|
import { useState } from 'react'
|
||||||
import { create } from 'zustand'
|
import { create } from 'zustand'
|
||||||
|
import { useGlobalPublicStore } from './global-public-context'
|
||||||
|
|
||||||
type WebAppStore = {
|
type WebAppStore = {
|
||||||
shareCode: string | null
|
shareCode: string | null
|
||||||
@@ -56,6 +57,7 @@ const getShareCodeFromPathname = (pathname: string): string | null => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
||||||
|
const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
|
||||||
const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode)
|
const updateWebAppAccessMode = useWebAppStore(state => state.updateWebAppAccessMode)
|
||||||
const updateShareCode = useWebAppStore(state => state.updateShareCode)
|
const updateShareCode = useWebAppStore(state => state.updateShareCode)
|
||||||
const pathname = usePathname()
|
const pathname = usePathname()
|
||||||
@@ -67,7 +69,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
|||||||
updateShareCode(shareCode)
|
updateShareCode(shareCode)
|
||||||
|
|
||||||
const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode)
|
const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode)
|
||||||
const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(false)
|
const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(true)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (accessModeResult?.accessMode) {
|
if (accessModeResult?.accessMode) {
|
||||||
@@ -84,7 +86,7 @@ const WebAppStoreProvider: FC<PropsWithChildren> = ({ children }) => {
|
|||||||
}
|
}
|
||||||
}, [accessModeResult, updateWebAppAccessMode, shareCode])
|
}, [accessModeResult, updateWebAppAccessMode, shareCode])
|
||||||
|
|
||||||
if (isFetching || isFetchingAccessToken) {
|
if (isGlobalPending || isFetching || isFetchingAccessToken) {
|
||||||
return <div className='flex h-full w-full items-center justify-center'>
|
return <div className='flex h-full w-full items-center justify-center'>
|
||||||
<Loading />
|
<Loading />
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -398,9 +398,7 @@ export const ssePost = async (
|
|||||||
.then((res) => {
|
.then((res) => {
|
||||||
if (!/^[23]\d{2}$/.test(String(res.status))) {
|
if (!/^[23]\d{2}$/.test(String(res.status))) {
|
||||||
if (res.status === 401) {
|
if (res.status === 401) {
|
||||||
refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
|
if (isPublicAPI) {
|
||||||
ssePost(url, fetchOptions, otherOptions)
|
|
||||||
}).catch(() => {
|
|
||||||
res.json().then((data: any) => {
|
res.json().then((data: any) => {
|
||||||
if (isPublicAPI) {
|
if (isPublicAPI) {
|
||||||
if (data.code === 'web_app_access_denied')
|
if (data.code === 'web_app_access_denied')
|
||||||
@@ -417,7 +415,14 @@ export const ssePost = async (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
}
|
||||||
|
else {
|
||||||
|
refreshAccessTokenOrRelogin(TIME_OUT).then(() => {
|
||||||
|
ssePost(url, fetchOptions, otherOptions)
|
||||||
|
}).catch((err) => {
|
||||||
|
console.error(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
res.json().then((data) => {
|
res.json().then((data) => {
|
||||||
|
|||||||
@@ -1,20 +1,12 @@
|
|||||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
|
||||||
import { AccessMode } from '@/models/access-control'
|
|
||||||
import { useQuery } from '@tanstack/react-query'
|
import { useQuery } from '@tanstack/react-query'
|
||||||
import { fetchAppInfo, fetchAppMeta, fetchAppParams, getAppAccessModeByAppCode } from './share'
|
import { fetchAppInfo, fetchAppMeta, fetchAppParams, getAppAccessModeByAppCode } from './share'
|
||||||
|
|
||||||
const NAME_SPACE = 'webapp'
|
const NAME_SPACE = 'webapp'
|
||||||
|
|
||||||
export const useGetWebAppAccessModeByCode = (code: string | null) => {
|
export const useGetWebAppAccessModeByCode = (code: string | null) => {
|
||||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
|
||||||
return useQuery({
|
return useQuery({
|
||||||
queryKey: [NAME_SPACE, 'appAccessMode', code],
|
queryKey: [NAME_SPACE, 'appAccessMode', code],
|
||||||
queryFn: () => {
|
queryFn: () => {
|
||||||
if (systemFeatures.webapp_auth.enabled === false) {
|
|
||||||
return {
|
|
||||||
accessMode: AccessMode.PUBLIC,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!code || code.length === 0)
|
if (!code || code.length === 0)
|
||||||
return Promise.reject(new Error('App code is required to get access mode'))
|
return Promise.reject(new Error('App code is required to get access mode'))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user