mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 09:45:11 +00:00
Compare commits
17 Commits
dependabot
...
copilot/su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e4f5cb38c | ||
|
|
c13d1872d4 | ||
|
|
c911de6a6c | ||
|
|
968bf10e1c | ||
|
|
3d77a5ec08 | ||
|
|
41af72449d | ||
|
|
de72bdef71 | ||
|
|
f97ade7053 | ||
|
|
a0dcd04546 | ||
|
|
b0138316f0 | ||
|
|
099568f3da | ||
|
|
0623522d04 | ||
|
|
a25d48c5bd | ||
|
|
4f3a020670 | ||
|
|
d2e1177478 | ||
|
|
8a21fd88fd | ||
|
|
1c1bcc67da |
@@ -1,16 +1,38 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check:
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@@ -31,6 +53,39 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/")
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status (cached — see EnterpriseService for TTL details)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
if license_status is None:
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
|
||||
@@ -5,89 +5,89 @@ requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
"aliyun-log-python-sdk~=0.9.37",
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.2",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.63",
|
||||
"arize-phoenix-otel~=0.9.2",
|
||||
"azure-identity==1.16.1",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"boto3==1.35.99",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
"celery~=5.5.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.24",
|
||||
"flask-compress>=1.17,<1.18",
|
||||
"flask-cors~=6.0.0",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.1.0",
|
||||
"flask-migrate~=4.0.7",
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.192.0",
|
||||
"google-api-python-client==2.189.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-auth-httplib2==0.3.0",
|
||||
"google-auth-httplib2==0.2.0",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
"googleapis-common-protos>=1.65.0",
|
||||
"gunicorn~=25.1.0",
|
||||
"gunicorn~=23.0.0",
|
||||
"httpx[socks]~=0.28.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.7.14",
|
||||
"markdown~=3.10.2",
|
||||
"langsmith~=0.1.77",
|
||||
"markdown~=3.8.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=2.4.2",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.29",
|
||||
"litellm==1.82.0", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.40.0",
|
||||
"opentelemetry-instrumentation==0.61b0",
|
||||
"opentelemetry-instrumentation-celery==0.61b0",
|
||||
"opentelemetry-instrumentation-flask==0.61b0",
|
||||
"opentelemetry-instrumentation-httpx==0.61b0",
|
||||
"opentelemetry-instrumentation-redis==0.61b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
|
||||
"opentelemetry-propagator-b3==1.40.0",
|
||||
"opentelemetry-proto==1.40.0",
|
||||
"opentelemetry-sdk==1.40.0",
|
||||
"opentelemetry-semantic-conventions==0.61b0",
|
||||
"opentelemetry-util-http==0.61b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"opik~=1.8.72",
|
||||
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.28.0",
|
||||
"opentelemetry-instrumentation==0.49b0",
|
||||
"opentelemetry-instrumentation-celery==0.49b0",
|
||||
"opentelemetry-instrumentation-flask==0.49b0",
|
||||
"opentelemetry-instrumentation-httpx==0.49b0",
|
||||
"opentelemetry-instrumentation-redis==0.49b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
|
||||
"opentelemetry-propagator-b3==1.28.0",
|
||||
"opentelemetry-proto==1.28.0",
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
"opentelemetry-util-http==0.49b0",
|
||||
"pandas[excel,output-formatting,performance]~=2.2.2",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.23.0",
|
||||
"pydantic~=2.12.5",
|
||||
"pydantic-extra-types~=2.11.0",
|
||||
"pydantic-settings~=2.13.1",
|
||||
"pydantic-extra-types~=2.10.3",
|
||||
"pydantic-settings~=2.12.0",
|
||||
"pyjwt~=2.11.0",
|
||||
"pypdfium2==5.6.0",
|
||||
"pypdfium2==5.2.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.2.2",
|
||||
"python-dotenv==1.0.1",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"redis[hiredis]~=7.2.0",
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.52.1",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"yarl~=1.23.0",
|
||||
"starlette==0.49.1",
|
||||
"tiktoken~=0.9.0",
|
||||
"transformers~=4.56.1",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"yarl~=1.18.3",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.9.0",
|
||||
"sseclient-py~=1.8.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"flask-restx~=1.3.2",
|
||||
"packaging~=23.2",
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.20.3",
|
||||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
@@ -109,46 +109,46 @@ package = false
|
||||
# Required for development and running tests
|
||||
############################################################
|
||||
dev = [
|
||||
"coverage~=7.13.4",
|
||||
"dotenv-linter~=0.7.0",
|
||||
"faker~=40.8.0",
|
||||
"coverage~=7.2.4",
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=38.2.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.38.2",
|
||||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-env~=1.5.0",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.14.1",
|
||||
"ruff~=0.14.0",
|
||||
"pytest~=8.3.2",
|
||||
"pytest-benchmark~=4.0.0",
|
||||
"pytest-cov~=4.1.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.14.0",
|
||||
"testcontainers~=4.13.2",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=6.2.0",
|
||||
"types-cachetools~=5.5.0",
|
||||
"types-colorama~=0.4.15",
|
||||
"types-defusedxml~=0.7.0",
|
||||
"types-deprecated~=1.3.1",
|
||||
"types-docutils~=0.22.3",
|
||||
"types-jsonschema~=4.26.0",
|
||||
"types-flask-cors~=6.0.0",
|
||||
"types-deprecated~=1.2.15",
|
||||
"types-docutils~=0.21.0",
|
||||
"types-jsonschema~=4.23.0",
|
||||
"types-flask-cors~=5.0.0",
|
||||
"types-flask-migrate~=4.1.0",
|
||||
"types-gevent~=25.9.0",
|
||||
"types-greenlet~=3.3.0",
|
||||
"types-html5lib~=1.1.11",
|
||||
"types-markdown~=3.10.2",
|
||||
"types-oauthlib~=3.3.0",
|
||||
"types-oauthlib~=3.2.0",
|
||||
"types-objgraph~=3.6.0",
|
||||
"types-olefile~=0.47.0",
|
||||
"types-openpyxl~=3.1.5",
|
||||
"types-pexpect~=4.9.0",
|
||||
"types-protobuf~=6.32.1",
|
||||
"types-protobuf~=5.29.1",
|
||||
"types-psutil~=7.2.2",
|
||||
"types-psycopg2~=2.9.21",
|
||||
"types-pygments~=2.19.0",
|
||||
"types-pymysql~=1.1.0",
|
||||
"types-python-dateutil~=2.9.0",
|
||||
"types-pywin32~=311.0.0",
|
||||
"types-pywin32~=310.0.0",
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2026.2.28",
|
||||
"types-regex~=2024.11.6",
|
||||
"types-shapely~=2.1.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
@@ -161,7 +161,7 @@ dev = [
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=1.17.0",
|
||||
"types_setuptools>=80.9.0",
|
||||
"pandas-stubs~=3.0.0",
|
||||
"pandas-stubs~=2.2.3",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20240910",
|
||||
"import-linter>=2.3",
|
||||
@@ -180,13 +180,13 @@ dev = [
|
||||
# Required for storage clients
|
||||
############################################################
|
||||
storage = [
|
||||
"azure-storage-blob==12.28.0",
|
||||
"azure-storage-blob==12.26.0",
|
||||
"bce-python-sdk~=0.9.23",
|
||||
"cos-python-sdk-v5==1.9.41",
|
||||
"esdk-obs-python==3.26.2",
|
||||
"cos-python-sdk-v5==1.9.38",
|
||||
"esdk-obs-python==3.25.8",
|
||||
"google-cloud-storage>=3.0.0",
|
||||
"opendal~=0.46.0",
|
||||
"oss2==2.19.1",
|
||||
"oss2==2.18.5",
|
||||
"supabase~=2.18.1",
|
||||
"tos~=2.9.0",
|
||||
]
|
||||
@@ -201,29 +201,29 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=5.0.2",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_tea_openapi~=0.3.9",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.13.0",
|
||||
"clickhouse-connect~=0.10.0",
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
"couchbase~=4.5.0",
|
||||
"elasticsearch==9.3.0",
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==3.1.0",
|
||||
"oracledb==3.4.2",
|
||||
"oracledb==3.3.0",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvector==0.4.2",
|
||||
"pymilvus~=2.6.9",
|
||||
"pymochow==2.3.6",
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
"pymochow==2.2.9",
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.17.0",
|
||||
"qdrant-client==1.9.0",
|
||||
"intersystems-irispython>=5.1.0",
|
||||
"tablestore==6.4.1",
|
||||
"tcvectordb~=2.0.0",
|
||||
"tidb-vector==0.0.15",
|
||||
"upstash-vector==0.8.0",
|
||||
"tablestore==6.3.7",
|
||||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
"volcengine-compat~=1.0.0",
|
||||
"weaviate-client==4.20.3",
|
||||
"xinference-client~=2.2.0",
|
||||
"weaviate-client==4.17.0",
|
||||
"xinference-client~=1.2.2",
|
||||
"mo-vector~=0.1.13",
|
||||
"mysql-connector-python>=9.3.0",
|
||||
]
|
||||
|
||||
@@ -6,6 +6,13 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIBadRequestError,
|
||||
EnterpriseAPIError,
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,7 +48,6 @@ class BaseRequest:
|
||||
params: Mapping[str, Any] | None = None,
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
raise_for_status: bool = False,
|
||||
) -> Any:
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
@@ -64,10 +70,51 @@ class BaseRequest:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate HTTP status and raise domain-specific errors
|
||||
if not response.is_success:
|
||||
cls._handle_error_response(response)
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _handle_error_response(cls, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
||||
|
||||
Attempts to extract error message from JSON response body,
|
||||
falls back to status text if parsing fails.
|
||||
"""
|
||||
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
||||
|
||||
# Try to extract error message from JSON response
|
||||
try:
|
||||
error_data = response.json()
|
||||
if isinstance(error_data, dict):
|
||||
# Common error response formats:
|
||||
# {"error": "...", "message": "..."}
|
||||
# {"message": "..."}
|
||||
# {"detail": "..."}
|
||||
error_message = (
|
||||
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
||||
)
|
||||
except Exception:
|
||||
# If JSON parsing fails, use the default message
|
||||
logger.debug(
|
||||
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
||||
)
|
||||
|
||||
# Raise specific error based on status code
|
||||
if response.status_code == 400:
|
||||
raise EnterpriseAPIBadRequestError(error_message)
|
||||
elif response.status_code == 401:
|
||||
raise EnterpriseAPIUnauthorizedError(error_message)
|
||||
elif response.status_code == 403:
|
||||
raise EnterpriseAPIForbiddenError(error_message)
|
||||
elif response.status_code == 404:
|
||||
raise EnterpriseAPINotFoundError(error_message)
|
||||
else:
|
||||
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
||||
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
# License status cache configuration
|
||||
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
|
||||
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
||||
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
@@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
|
||||
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
|
||||
if self.joined and not self.workspace_id:
|
||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||
return self
|
||||
@@ -115,7 +126,6 @@ class EnterpriseService:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||
raise_for_status=True,
|
||||
)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||
@@ -223,3 +233,64 @@ class EnterpriseService:
|
||||
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
|
||||
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
|
||||
balances prompt license-fix detection against DoS mitigation — without
|
||||
caching, every request on an expired license would hit the enterprise API.
|
||||
|
||||
Returns:
|
||||
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return None
|
||||
|
||||
cached = cls._read_cached_license_status()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
return cls._fetch_and_cache_license_status()
|
||||
|
||||
@classmethod
|
||||
def _read_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Read license status from Redis cache, returning None on miss or failure."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||
if raw:
|
||||
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
return LicenseStatus(value)
|
||||
except Exception:
|
||||
logger.warning("Failed to read license status from cache", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
|
||||
"""Fetch license status from enterprise API and cache the result."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
info = cls.get_info()
|
||||
license_info = info.get("License")
|
||||
if not license_info:
|
||||
return None
|
||||
|
||||
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
ttl = (
|
||||
VALID_LICENSE_CACHE_TTL
|
||||
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
|
||||
else INVALID_LICENSE_CACHE_TTL
|
||||
)
|
||||
try:
|
||||
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
|
||||
except Exception:
|
||||
logger.warning("Failed to cache license status", exc_info=True)
|
||||
return status
|
||||
except Exception:
|
||||
logger.exception("Failed to get enterprise license status")
|
||||
return None
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import (
|
||||
conversation,
|
||||
dataset,
|
||||
document,
|
||||
enterprise,
|
||||
file,
|
||||
index,
|
||||
message,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"enterprise",
|
||||
"file",
|
||||
"index",
|
||||
"message",
|
||||
|
||||
45
api/services/errors/enterprise.py
Normal file
45
api/services/errors/enterprise.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Enterprise service errors."""
|
||||
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class EnterpriseServiceError(BaseServiceError):
|
||||
"""Base exception for enterprise service errors."""
|
||||
|
||||
def __init__(self, description: str | None = None, status_code: int | None = None):
|
||||
super().__init__(description)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class EnterpriseAPIError(EnterpriseServiceError):
|
||||
"""Generic enterprise API error (non-2xx response)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EnterpriseAPINotFoundError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 404 Not Found."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=404)
|
||||
|
||||
|
||||
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 403 Forbidden."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=403)
|
||||
|
||||
|
||||
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 401 Unauthorized."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 400 Bad Request."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=400)
|
||||
@@ -379,14 +379,19 @@ class FeatureService:
|
||||
)
|
||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||
|
||||
if is_authenticated and (license_info := enterprise_info.get("License")):
|
||||
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
|
||||
# so the login page can detect an expired/inactive license after force-logout.
|
||||
# All other license details (expiry date, workspace usage) remain auth-gated.
|
||||
# This behavior reflects prior internal review of information-leakage risks.
|
||||
if license_info := enterprise_info.get("License"):
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
if is_authenticated:
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
|
||||
if "PluginInstallationPermission" in enterprise_info:
|
||||
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
|
||||
|
||||
@@ -358,10 +358,9 @@ class TestFeatureService:
|
||||
assert result is not None
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
|
||||
# Ensure only essential UI flags are returned to unauthenticated clients
|
||||
# to keep the payload lightweight and adhere to architectural boundaries.
|
||||
assert result.license.status == LicenseStatus.NONE
|
||||
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
|
||||
# Detailed license info (expiry, workspaces) remains auth-gated.
|
||||
assert result.license.status == LicenseStatus.ACTIVE
|
||||
assert result.license.expired_at == ""
|
||||
assert result.license.workspaces.enabled is False
|
||||
assert result.license.workspaces.limit == 0
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Unit tests for enterprise service integrations.
|
||||
|
||||
This module covers the enterprise-only default workspace auto-join behavior:
|
||||
- Enterprise mode disabled: no external calls
|
||||
- Successful join / skipped join: no errors
|
||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||
Covers:
|
||||
- Default workspace auto-join behavior
|
||||
- License status caching (get_cached_license_status)
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
@@ -11,6 +10,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from services.enterprise.enterprise_service import (
|
||||
INVALID_LICENSE_CACHE_TTL,
|
||||
LICENSE_STATUS_CACHE_KEY,
|
||||
VALID_LICENSE_CACHE_TTL,
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
try_join_default_workspace,
|
||||
@@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=1.0,
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||
@@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
|
||||
|
||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||
try_join_default_workspace("not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_cached_license_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EE_SVC = "services.enterprise.enterprise_service"
|
||||
|
||||
|
||||
class TestGetCachedLicenseStatus:
|
||||
"""Tests for EnterpriseService.get_cached_license_status."""
|
||||
|
||||
def test_returns_none_when_enterprise_disabled(self):
|
||||
with patch(f"{_EE_SVC}.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_cache_hit_returns_license_status_enum(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = b"active"
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
assert isinstance(result, LicenseStatus)
|
||||
mock_get_info.assert_not_called()
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_valid_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
|
||||
)
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "expired"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRED
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
|
||||
)
|
||||
|
||||
def test_redis_read_failure_falls_through_to_api(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_get_info.assert_called_once()
|
||||
|
||||
def test_redis_write_failure_still_returns_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "expiring"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRING
|
||||
|
||||
def test_api_failure_returns_none(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.side_effect = Exception("network failure")
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_api_returns_no_license_info(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {} # no "License" key
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
1992
api/uv.lock
generated
1992
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user