mirror of
https://github.com/langgenius/dify.git
synced 2026-03-20 15:17:02 +00:00
Compare commits
1 Commits
1.13.2
...
3-18-dev-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dacba93e00 |
@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
12
api/configs/middleware/cache/redis_config.py
vendored
12
api/configs/middleware/cache/redis_config.py
vendored
@@ -1,4 +1,4 @@
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -116,13 +116,3 @@ class RedisConfig(BaseSettings):
|
||||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Protocol, cast
|
||||
from typing import Literal, Protocol
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
@@ -12,13 +12,16 @@ class RedisConfigDefaults(Protocol):
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||
return cast(RedisConfigDefaults, config)
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings):
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
@@ -71,7 +74,7 @@ class RedisPubSubConfig(BaseSettings):
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
defaults = self._redis_defaults()
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
@@ -88,9 +91,11 @@ class RedisPubSubConfig(BaseSettings):
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
|
||||
@@ -473,21 +473,9 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
else:
|
||||
# some historical data may have a provider record but not be set as valid
|
||||
provider_record.is_valid = True
|
||||
|
||||
if provider_record.credential_id is None:
|
||||
provider_record.credential_id = new_record.id
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
|
||||
@@ -196,8 +196,6 @@ class ProviderManager:
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
|
||||
@@ -5,7 +5,6 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -38,32 +37,6 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@@ -112,6 +85,18 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@@ -101,6 +101,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
||||
@@ -256,13 +256,9 @@ def fetch_prompt_messages(
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if not prompt_message_content:
|
||||
continue
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
if prompt_message_content:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from dify_graph.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.variables.segments import Segment
|
||||
pass
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
@@ -219,7 +219,7 @@ class SegmentType(StrEnum):
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: SegmentType) -> Segment:
|
||||
def get_zero_value(t: SegmentType):
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Protocol, cast
|
||||
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
@@ -11,10 +9,6 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
class SupportsIncludeRouter(Protocol):
|
||||
def include_router(self, router: object, *, prefix: str = "") -> None: ...
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
@@ -42,7 +36,7 @@ def init_app(app: DifyApp) -> None:
|
||||
_ = remote_files
|
||||
_ = setup
|
||||
|
||||
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
|
||||
@@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
|
||||
|
||||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
@@ -296,11 +296,13 @@ def segment_to_variable(
|
||||
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
||||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value_type=segment.value_type,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
return cast(
|
||||
VariableBase,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -32,11 +32,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stream_with_request_context(response: object) -> Any:
|
||||
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
|
||||
return cast(Any, stream_with_context)(response)
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||
@@ -291,32 +286,22 @@ def generate_text_hash(text: str) -> str:
|
||||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(
|
||||
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
if isinstance(response, Mapping):
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator[str, None, None]:
|
||||
yield from stream_response
|
||||
def generate() -> Generator:
|
||||
yield from response
|
||||
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
|
||||
|
||||
def length_prefixed_response(
|
||||
magic_number: int,
|
||||
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
"""
|
||||
This function is used to return a response with a length prefix.
|
||||
Magic number is a one byte number that indicates the type of the response.
|
||||
@@ -347,7 +332,7 @@ def length_prefixed_response(
|
||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||
|
||||
if isinstance(response, Mapping):
|
||||
if isinstance(response, dict):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
@@ -360,20 +345,14 @@ def length_prefixed_response(
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator[bytes, None, None]:
|
||||
for chunk in stream_response:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
if isinstance(chunk, str):
|
||||
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
||||
else:
|
||||
yield pack_response_with_length_prefix(chunk)
|
||||
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
@@ -77,14 +77,12 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, user.id)
|
||||
check_csrf_token(request, current_user.id)
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
@@ -7,10 +7,9 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
|
||||
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
|
||||
def cached_import(module_path: str, class_name: str) -> Any:
|
||||
def cached_import(module_path: str, class_name: str):
|
||||
"""
|
||||
Import a module and return the named attribute/class from it, with caching.
|
||||
|
||||
@@ -21,14 +20,16 @@ def cached_import(module_path: str, class_name: str) -> Any:
|
||||
Returns:
|
||||
The imported attribute/class
|
||||
"""
|
||||
module = sys.modules.get(module_path)
|
||||
spec = getattr(module, "__spec__", None) if module is not None else None
|
||||
if module is None or getattr(spec, "_initializing", False):
|
||||
if not (
|
||||
(module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None))
|
||||
and getattr(spec, "_initializing", False) is False
|
||||
):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
def import_string(dotted_path: str):
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by
|
||||
the last name in the path. Raise ImportError if the import failed.
|
||||
|
||||
@@ -1,48 +1,7 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
|
||||
|
||||
class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
|
||||
|
||||
class GitHubEmailRecord(TypedDict, total=False):
|
||||
email: str
|
||||
primary: bool
|
||||
|
||||
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
sub: str
|
||||
email: str
|
||||
|
||||
|
||||
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -52,38 +11,26 @@ class OAuthUserInfo:
|
||||
email: str
|
||||
|
||||
|
||||
def _json_object(response: httpx.Response) -> JsonObject:
|
||||
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
def _json_list(response: httpx.Response) -> JsonObjectList:
|
||||
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
class OAuth:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
def get_access_token(self, code: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
def get_raw_user_info(self, token: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||
raw_info = self.get_raw_user_info(token)
|
||||
return self._transform_user_info(raw_info)
|
||||
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -93,7 +40,7 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@@ -103,7 +50,7 @@ class GitHubOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@@ -113,7 +60,7 @@ class GitHubOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@@ -121,24 +68,23 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
user_info = response.json()
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
return {**user_info, "email": primary_email.get("email", "")}
|
||||
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
email = raw_info.get("email")
|
||||
if not email:
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
@@ -146,7 +92,7 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@@ -157,7 +103,7 @@ class GoogleOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
def get_access_token(self, code: str):
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@@ -168,7 +114,7 @@ class GoogleOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@@ -176,12 +122,11 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
def get_raw_user_info(self, token: str):
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return _json_object(response)
|
||||
return response.json()
|
||||
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
|
||||
|
||||
@@ -1,57 +1,25 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class NotionPageSummary(TypedDict):
|
||||
page_id: str
|
||||
page_name: str
|
||||
page_icon: dict[str, str] | None
|
||||
parent_id: str
|
||||
type: Literal["page", "database"]
|
||||
|
||||
|
||||
class NotionSourceInfo(TypedDict):
|
||||
workspace_name: str | None
|
||||
workspace_icon: str | None
|
||||
workspace_id: str | None
|
||||
pages: list[NotionPageSummary]
|
||||
total: int
|
||||
|
||||
|
||||
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
|
||||
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
|
||||
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self) -> str:
|
||||
def get_authorization_url(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str) -> None:
|
||||
def get_access_token(self, code: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -62,7 +30,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||
|
||||
def get_authorization_url(self) -> str:
|
||||
def get_authorization_url(self):
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@@ -71,7 +39,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str) -> None:
|
||||
def get_access_token(self, code: str):
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
@@ -86,12 +54,13 @@ class NotionOAuth(OAuthDataSource):
|
||||
workspace_id = response_json.get("workspace_id")
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -101,7 +70,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@@ -109,24 +78,25 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
source_info=source_info,
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str) -> None:
|
||||
def save_internal_access_token(self, access_token: str):
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
workspace_icon = None
|
||||
workspace_id = current_user.current_tenant_id
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -136,7 +106,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@@ -144,13 +114,13 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
source_info=source_info,
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str) -> None:
|
||||
def sync_data_source(self, binding_id: str):
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -164,22 +134,23 @@ class NotionOAuth(OAuthDataSource):
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
source_info = data_source_binding.source_info
|
||||
new_source_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||
pages: list[NotionPageSummary] = []
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
page_results = self.notion_page_search(access_token)
|
||||
database_results = self.notion_database_search(access_token)
|
||||
# get page detail
|
||||
@@ -216,7 +187,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "page",
|
||||
}
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
pages.append(page)
|
||||
# get database detail
|
||||
for database_result in database_results:
|
||||
page_id = database_result["id"]
|
||||
@@ -249,11 +220,11 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "database",
|
||||
}
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
pages.append(page)
|
||||
return pages
|
||||
|
||||
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
def notion_page_search(self, access_token: str):
|
||||
results = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@@ -278,7 +249,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@@ -294,7 +265,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
return parent[parent_type]
|
||||
|
||||
def notion_workspace_name(self, access_token: str) -> str:
|
||||
def notion_workspace_name(self, access_token: str):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@@ -308,8 +279,8 @@ class NotionOAuth(OAuthDataSource):
|
||||
return user_info["workspace_name"]
|
||||
return "workspace"
|
||||
|
||||
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
def notion_database_search(self, access_token: str):
|
||||
results = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@@ -332,19 +303,3 @@ class NotionOAuth(OAuthDataSource):
|
||||
next_cursor = response_json.get("next_cursor", None)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_source_info(
|
||||
*,
|
||||
workspace_name: str | None,
|
||||
workspace_icon: str | None,
|
||||
workspace_id: str | None,
|
||||
pages: list[NotionPageSummary],
|
||||
) -> NotionSourceInfo:
|
||||
return {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
|
||||
@@ -11,13 +11,6 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
||||
@@ -23,9 +23,6 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
|
||||
from .model import Account
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
TriggerJsonObject = dict[str, object]
|
||||
TriggerCredentials = dict[str, str]
|
||||
|
||||
|
||||
class WorkflowTriggerLogDict(TypedDict):
|
||||
id: str
|
||||
@@ -92,14 +89,10 @@ class TriggerSubscription(TypeBase):
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription parameters JSON"
|
||||
)
|
||||
properties: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription properties JSON"
|
||||
)
|
||||
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
|
||||
credentials: Mapped[TriggerCredentials] = mapped_column(
|
||||
credentials: Mapped[dict[str, Any]] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
||||
)
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
@@ -207,8 +200,8 @@ class TriggerOAuthTenantClient(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> Mapping[str, object]:
|
||||
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
def oauth_params(self) -> Mapping[str, Any]:
|
||||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class WorkflowTriggerLog(TypeBase):
|
||||
|
||||
@@ -19,7 +19,7 @@ from sqlalchemy import (
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
@@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -59,9 +59,6 @@ from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SerializedWorkflowValue = dict[str, Any]
|
||||
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
@@ -408,7 +405,7 @@ class Workflow(Base): # bug
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@@ -451,13 +448,17 @@ class Workflow(Base): # bug
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@@ -535,7 +536,11 @@ class Workflow(Base): # bug
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@@ -547,20 +552,19 @@ class Workflow(Base): # bug
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{
|
||||
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||
for rag_pipeline_variable in (
|
||||
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||
for item in values
|
||||
)
|
||||
},
|
||||
{item["variable"]: item for item in values},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -798,36 +802,44 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
None,
|
||||
"tenant_id",
|
||||
"workflow_id",
|
||||
"node_id",
|
||||
sa.desc("created_at"),
|
||||
),
|
||||
)
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.2"
|
||||
version = "1.13.0"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
@@ -137,6 +138,8 @@ dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
@@ -153,7 +156,19 @@ extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
|
||||
@@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol, cast
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
@@ -22,20 +22,6 @@ from repositories.api_workflow_node_execution_repository import (
|
||||
)
|
||||
|
||||
|
||||
class _WorkflowNodeExecutionSnapshotRow(Protocol):
|
||||
id: str
|
||||
node_execution_id: str | None
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
status: WorkflowNodeExecutionStatus
|
||||
elapsed_time: float | None
|
||||
created_at: datetime
|
||||
finished_at: datetime | None
|
||||
execution_metadata: str | None
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
@@ -54,8 +40,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
@@ -172,12 +156,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
|
||||
@@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
@@ -534,13 +534,6 @@ class PluginService:
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
session.execute(
|
||||
delete(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete provider credentials that match this plugin
|
||||
credential_ids = session.scalars(
|
||||
select(ProviderCredential.id).where(
|
||||
|
||||
@@ -734,7 +734,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No
|
||||
def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred")
|
||||
provider_record = SimpleNamespace(is_valid=False)
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
@@ -743,25 +743,6 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id == "existing-cred"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_create_provider_credential_auto_activates_when_no_active_credential() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None)
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}):
|
||||
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record):
|
||||
with patch("core.entities.provider_configuration.ProviderCredentialsCache"):
|
||||
with patch.object(ProviderConfiguration, "switch_preferred_provider_type"):
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id is not None
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
prompt_template = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="You are a classifier.",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
|
||||
return_value=mock.MagicMock(features=[]),
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
|
||||
return_value=[SystemPromptMessage(content=content)],
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=["END"],
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=None,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
]
|
||||
)
|
||||
]
|
||||
@@ -1,19 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def test_creator_user_role_missing_maps_hyphen_to_enum():
|
||||
# given an alias with hyphen
|
||||
value = "end-user"
|
||||
|
||||
# when converting to enum (invokes StrEnum._missing_ override)
|
||||
role = CreatorUserRole(value)
|
||||
|
||||
# then it should map to END_USER
|
||||
assert role is CreatorUserRole.END_USER
|
||||
|
||||
|
||||
def test_creator_user_role_missing_raises_for_unknown():
|
||||
with pytest.raises(ValueError):
|
||||
CreatorUserRole("unknown")
|
||||
2
api/uv.lock
generated
2
api/uv.lock
generated
@@ -1533,7 +1533,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.13.2"
|
||||
version = "1.13.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.2
|
||||
image: langgenius/dify-api:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.2
|
||||
image: langgenius/dify-api:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.2
|
||||
image: langgenius/dify-api:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.2
|
||||
image: langgenius/dify-web:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -728,7 +728,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.2
|
||||
image: langgenius/dify-api:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -770,7 +770,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.2
|
||||
image: langgenius/dify-api:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -809,7 +809,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.2
|
||||
image: langgenius/dify-api:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -839,7 +839,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.2
|
||||
image: langgenius/dify-web:1.13.0
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -12,7 +12,7 @@ NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
|
||||
# console or api domain.
|
||||
# example: http://udify.app/api
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
|
||||
# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly.
|
||||
# Dev-only Hono proxy targets. Set the api prefixes above to https://localhost:5001/... to start the proxy with HTTPS.
|
||||
HONO_PROXY_HOST=127.0.0.1
|
||||
HONO_PROXY_PORT=5001
|
||||
HONO_CONSOLE_API_PROXY_TARGET=
|
||||
|
||||
@@ -328,7 +328,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const setupHandlers = (overrides: { isPublicAPI?: boolean, isTimedOut?: () => boolean } = {}) => {
|
||||
const setupHandlers = (overrides: { isTimedOut?: () => boolean } = {}) => {
|
||||
let completionRes = ''
|
||||
let currentTaskId: string | null = null
|
||||
let isStopping = false
|
||||
@@ -359,7 +359,6 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => completionRes,
|
||||
getWorkflowProcessData: () => workflowProcessData,
|
||||
isPublicAPI: overrides.isPublicAPI ?? false,
|
||||
isTimedOut: overrides.isTimedOut ?? (() => false),
|
||||
markEnded,
|
||||
notify,
|
||||
@@ -392,7 +391,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
}
|
||||
|
||||
it('should process workflow success and paused events', () => {
|
||||
const setup = setupHandlers({ isPublicAPI: true })
|
||||
const setup = setupHandlers()
|
||||
const handlers = setup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onTextChunk' | 'onHumanInputRequired' | 'onHumanInputFormFilled' | 'onHumanInputFormTimeout' | 'onWorkflowPaused' | 'onWorkflowFinished' | 'onNodeStarted' | 'onNodeFinished' | 'onIterationStart' | 'onIterationNext' | 'onIterationFinish' | 'onLoopStart' | 'onLoopNext' | 'onLoopFinish'>>
|
||||
|
||||
act(() => {
|
||||
@@ -547,11 +546,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
resultText: 'Hello',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
}))
|
||||
expect(sseGetMock).toHaveBeenCalledWith(
|
||||
'/workflow/run-1/events',
|
||||
{},
|
||||
expect.objectContaining({ isPublicAPI: true }),
|
||||
)
|
||||
expect(sseGetMock).toHaveBeenCalledWith('/workflow/run-1/events', {}, expect.any(Object))
|
||||
expect(setup.messageId()).toBe('run-1')
|
||||
expect(setup.onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', 3, true)
|
||||
expect(setup.setRespondingFalse).toHaveBeenCalled()
|
||||
@@ -652,7 +647,6 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => '',
|
||||
getWorkflowProcessData: () => existingProcess,
|
||||
isPublicAPI: false,
|
||||
isTimedOut: () => false,
|
||||
markEnded: vi.fn(),
|
||||
notify: setup.notify,
|
||||
|
||||
@@ -351,7 +351,6 @@ describe('useResultSender', () => {
|
||||
await waitFor(() => {
|
||||
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
getCompletionRes: harness.runState.getCompletionRes,
|
||||
isPublicAPI: true,
|
||||
resetRunState: harness.runState.resetRunState,
|
||||
setWorkflowProcessData: harness.runState.setWorkflowProcessData,
|
||||
}))
|
||||
@@ -374,30 +373,6 @@ describe('useResultSender', () => {
|
||||
expect(harness.runState.clearMoreLikeThis).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should configure workflow handlers for installed apps as non-public', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
|
||||
const { result } = renderSender({
|
||||
appSourceType: AppSourceTypeEnum.installedApp,
|
||||
isWorkflow: true,
|
||||
runState: harness.runState,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
expect(await result.current.handleSend()).toBe(true)
|
||||
})
|
||||
|
||||
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
isPublicAPI: false,
|
||||
}))
|
||||
expect(sendWorkflowMessageMock).toHaveBeenCalledWith(
|
||||
{ inputs: { name: 'Alice' } },
|
||||
expect.any(Object),
|
||||
AppSourceTypeEnum.installedApp,
|
||||
'app-1',
|
||||
)
|
||||
})
|
||||
|
||||
it('should stringify non-Error workflow failures', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
sendWorkflowMessageMock.mockRejectedValue('workflow failed')
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import type { ResultInputValue } from '../result-request'
|
||||
import type { ResultRunStateController } from './use-result-run-state'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { AppSourceType } from '@/service/share'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
|
||||
import {
|
||||
AppSourceType,
|
||||
sendCompletionMessage,
|
||||
sendWorkflowMessage,
|
||||
} from '@/service/share'
|
||||
@@ -117,7 +117,6 @@ export const useResultSender = ({
|
||||
const otherOptions = createWorkflowStreamHandlers({
|
||||
getCompletionRes: runState.getCompletionRes,
|
||||
getWorkflowProcessData: runState.getWorkflowProcessData,
|
||||
isPublicAPI: appSourceType === AppSourceType.webApp,
|
||||
isTimedOut: () => isTimeout,
|
||||
markEnded: () => {
|
||||
isEnd = true
|
||||
|
||||
@@ -13,7 +13,6 @@ type Translate = (key: string, options?: Record<string, unknown>) => string
|
||||
type CreateWorkflowStreamHandlersParams = {
|
||||
getCompletionRes: () => string
|
||||
getWorkflowProcessData: () => WorkflowProcess | undefined
|
||||
isPublicAPI: boolean
|
||||
isTimedOut: () => boolean
|
||||
markEnded: () => void
|
||||
notify: Notify
|
||||
@@ -256,7 +255,6 @@ const serializeWorkflowOutputs = (outputs: WorkflowFinishedResponse['data']['out
|
||||
export const createWorkflowStreamHandlers = ({
|
||||
getCompletionRes,
|
||||
getWorkflowProcessData,
|
||||
isPublicAPI,
|
||||
isTimedOut,
|
||||
markEnded,
|
||||
notify,
|
||||
@@ -289,7 +287,6 @@ export const createWorkflowStreamHandlers = ({
|
||||
}
|
||||
|
||||
const otherOptions: IOtherOptions = {
|
||||
isPublicAPI,
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
|
||||
const workflowProcessData = getWorkflowProcessData()
|
||||
if (workflowProcessData?.tracing.length) {
|
||||
@@ -381,7 +378,6 @@ export const createWorkflowStreamHandlers = ({
|
||||
},
|
||||
onWorkflowPaused: ({ data }) => {
|
||||
tempMessageId = data.workflow_run_id
|
||||
// WebApp workflows must keep using the public API namespace after pause/resume.
|
||||
void sseGet(`/workflow/${data.workflow_run_id}/events`, {}, otherOptions)
|
||||
setWorkflowProcessData(applyWorkflowPaused(getWorkflowProcessData()))
|
||||
},
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"type": "module",
|
||||
"version": "1.13.2",
|
||||
"version": "1.13.0",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.32.1",
|
||||
"imports": {
|
||||
@@ -210,6 +210,7 @@
|
||||
"@types/sortablejs": "1.15.9",
|
||||
"@typescript-eslint/parser": "8.57.0",
|
||||
"@typescript/native-preview": "7.0.0-dev.20260312.1",
|
||||
"@vitejs/plugin-basic-ssl": "2.2.0",
|
||||
"@vitejs/plugin-react": "6.0.0",
|
||||
"@vitejs/plugin-rsc": "0.5.21",
|
||||
"@vitest/coverage-v8": "4.1.0",
|
||||
|
||||
@@ -34,7 +34,16 @@ const toUpstreamCookieName = (cookieName: string) => {
|
||||
return `__Host-${cookieName}`
|
||||
}
|
||||
|
||||
const toLocalCookieName = (cookieName: string) => cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')
|
||||
const toLocalCookieName = (cookieName: string, options: LocalCookieRewriteOptions) => {
|
||||
if (options.localSecure)
|
||||
return cookieName
|
||||
|
||||
return cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')
|
||||
}
|
||||
|
||||
type LocalCookieRewriteOptions = {
|
||||
localSecure: boolean
|
||||
}
|
||||
|
||||
export const rewriteCookieHeaderForUpstream = (cookieHeader?: string) => {
|
||||
if (!cookieHeader)
|
||||
@@ -55,7 +64,10 @@ export const rewriteCookieHeaderForUpstream = (cookieHeader?: string) => {
|
||||
.join('; ')
|
||||
}
|
||||
|
||||
const rewriteSetCookieValueForLocal = (setCookieValue: string) => {
|
||||
const rewriteSetCookieValueForLocal = (
|
||||
setCookieValue: string,
|
||||
options: LocalCookieRewriteOptions,
|
||||
) => {
|
||||
const [rawCookiePair, ...rawAttributes] = setCookieValue.split(';')
|
||||
const separatorIndex = rawCookiePair.indexOf('=')
|
||||
|
||||
@@ -68,11 +80,11 @@ const rewriteSetCookieValueForLocal = (setCookieValue: string) => {
|
||||
.map(attribute => attribute.trim())
|
||||
.filter(attribute =>
|
||||
!COOKIE_DOMAIN_PATTERN.test(attribute)
|
||||
&& !COOKIE_SECURE_PATTERN.test(attribute)
|
||||
&& !COOKIE_PARTITIONED_PATTERN.test(attribute),
|
||||
&& (options.localSecure || !COOKIE_SECURE_PATTERN.test(attribute))
|
||||
&& (options.localSecure || !COOKIE_PARTITIONED_PATTERN.test(attribute)),
|
||||
)
|
||||
.map((attribute) => {
|
||||
if (SAME_SITE_NONE_PATTERN.test(attribute))
|
||||
if (!options.localSecure && SAME_SITE_NONE_PATTERN.test(attribute))
|
||||
return 'SameSite=Lax'
|
||||
|
||||
if (COOKIE_PATH_PATTERN.test(attribute))
|
||||
@@ -81,10 +93,13 @@ const rewriteSetCookieValueForLocal = (setCookieValue: string) => {
|
||||
return attribute
|
||||
})
|
||||
|
||||
return [`${toLocalCookieName(cookieName)}=${cookieValue}`, ...rewrittenAttributes].join('; ')
|
||||
return [`${toLocalCookieName(cookieName, options)}=${cookieValue}`, ...rewrittenAttributes].join('; ')
|
||||
}
|
||||
|
||||
export const rewriteSetCookieHeadersForLocal = (setCookieHeaders?: string | string[]): string[] | undefined => {
|
||||
export const rewriteSetCookieHeadersForLocal = (
|
||||
setCookieHeaders: string | string[] | undefined,
|
||||
options: LocalCookieRewriteOptions,
|
||||
): string[] | undefined => {
|
||||
if (!setCookieHeaders)
|
||||
return undefined
|
||||
|
||||
@@ -92,7 +107,7 @@ export const rewriteSetCookieHeadersForLocal = (setCookieHeaders?: string | stri
|
||||
? setCookieHeaders
|
||||
: [setCookieHeaders]
|
||||
|
||||
return normalizedHeaders.map(rewriteSetCookieValueForLocal)
|
||||
return normalizedHeaders.map(setCookieValue => rewriteSetCookieValueForLocal(setCookieValue, options))
|
||||
}
|
||||
|
||||
export { DEFAULT_PROXY_TARGET }
|
||||
|
||||
21
web/plugins/dev-proxy/protocol.ts
Normal file
21
web/plugins/dev-proxy/protocol.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
export type DevProxyProtocolEnv = Partial<Record<
|
||||
| 'NEXT_PUBLIC_API_PREFIX'
|
||||
| 'NEXT_PUBLIC_PUBLIC_API_PREFIX',
|
||||
string
|
||||
>>
|
||||
|
||||
const isHttpsUrl = (value?: string) => {
|
||||
if (!value)
|
||||
return false
|
||||
|
||||
try {
|
||||
return new URL(value).protocol === 'https:'
|
||||
}
|
||||
catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export const shouldUseHttpsForDevProxy = (env: DevProxyProtocolEnv = {}) => {
|
||||
return isHttpsUrl(env.NEXT_PUBLIC_API_PREFIX) || isHttpsUrl(env.NEXT_PUBLIC_PUBLIC_API_PREFIX)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets } from './server'
|
||||
import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets, shouldUseHttpsForDevProxy } from './server'
|
||||
|
||||
describe('dev proxy server', () => {
|
||||
beforeEach(() => {
|
||||
@@ -19,6 +19,21 @@ describe('dev proxy server', () => {
|
||||
expect(targets.publicApiTarget).toBe('https://public.example.com')
|
||||
})
|
||||
|
||||
// Scenario: the local dev proxy should switch to https when api prefixes are configured with https.
|
||||
it('should enable https for the local dev proxy when api prefixes use https', () => {
|
||||
// Assert
|
||||
expect(shouldUseHttpsForDevProxy({
|
||||
NEXT_PUBLIC_API_PREFIX: 'https://localhost:5001/console/api',
|
||||
})).toBe(true)
|
||||
expect(shouldUseHttpsForDevProxy({
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX: 'https://localhost:5001/api',
|
||||
})).toBe(true)
|
||||
expect(shouldUseHttpsForDevProxy({
|
||||
NEXT_PUBLIC_API_PREFIX: 'http://localhost:5001/console/api',
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX: 'http://localhost:5001/api',
|
||||
})).toBe(false)
|
||||
})
|
||||
|
||||
// Scenario: target paths should not be duplicated when the incoming route already includes them.
|
||||
it('should preserve prefixed targets when building upstream URLs', () => {
|
||||
// Act
|
||||
@@ -32,6 +47,7 @@ describe('dev proxy server', () => {
|
||||
it('should only allow local development origins', () => {
|
||||
// Assert
|
||||
expect(isAllowedDevOrigin('http://localhost:3000')).toBe(true)
|
||||
expect(isAllowedDevOrigin('https://localhost:3000')).toBe(true)
|
||||
expect(isAllowedDevOrigin('http://127.0.0.1:3000')).toBe(true)
|
||||
expect(isAllowedDevOrigin('https://example.com')).toBe(false)
|
||||
})
|
||||
@@ -86,6 +102,39 @@ describe('dev proxy server', () => {
|
||||
])
|
||||
})
|
||||
|
||||
// Scenario: secure local proxy responses should keep secure cross-site cookie attributes intact.
|
||||
it('should preserve secure cookie attributes when the local proxy is https', async () => {
|
||||
// Arrange
|
||||
const fetchImpl = vi.fn<typeof fetch>().mockResolvedValue(new Response('ok', {
|
||||
status: 200,
|
||||
headers: [
|
||||
['set-cookie', '__Host-access_token=abc; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None; Partitioned'],
|
||||
['set-cookie', '__Host-csrf_token=csrf; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None'],
|
||||
],
|
||||
}))
|
||||
const app = createDevProxyApp({
|
||||
consoleApiTarget: 'https://cloud.dify.ai',
|
||||
publicApiTarget: 'https://public.dify.ai',
|
||||
fetchImpl,
|
||||
})
|
||||
|
||||
// Act
|
||||
const response = await app.request('https://127.0.0.1:5001/console/api/apps?page=1', {
|
||||
headers: {
|
||||
Origin: 'https://localhost:3000',
|
||||
Cookie: 'access_token=abc',
|
||||
},
|
||||
})
|
||||
|
||||
// Assert
|
||||
expect(response.headers.getSetCookie()).toEqual([
|
||||
'__Host-access_token=abc; Path=/; Secure; SameSite=None; Partitioned',
|
||||
'__Host-csrf_token=csrf; Path=/; Secure; SameSite=None',
|
||||
])
|
||||
expect(response.headers.get('access-control-allow-origin')).toBe('https://localhost:3000')
|
||||
expect(response.headers.get('access-control-allow-credentials')).toBe('true')
|
||||
})
|
||||
|
||||
// Scenario: preflight requests should advertise allowed headers for credentialed cross-origin calls.
|
||||
it('should answer CORS preflight requests', async () => {
|
||||
// Arrange
|
||||
|
||||
@@ -2,10 +2,16 @@ import type { Context, Hono } from 'hono'
|
||||
import { Hono as HonoApp } from 'hono'
|
||||
import { DEFAULT_PROXY_TARGET, rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies'
|
||||
|
||||
export { shouldUseHttpsForDevProxy } from './protocol'
|
||||
|
||||
type DevProxyEnv = Partial<Record<
|
||||
| 'HONO_CONSOLE_API_PROXY_TARGET'
|
||||
| 'HONO_PUBLIC_API_PROXY_TARGET',
|
||||
string
|
||||
> & Record<
|
||||
| 'NEXT_PUBLIC_API_PREFIX'
|
||||
| 'NEXT_PUBLIC_PUBLIC_API_PREFIX',
|
||||
string | undefined
|
||||
>>
|
||||
|
||||
export type DevProxyTargets = {
|
||||
@@ -93,11 +99,15 @@ const createProxyRequestHeaders = (request: Request, targetUrl: URL) => {
|
||||
return headers
|
||||
}
|
||||
|
||||
const createUpstreamResponseHeaders = (response: Response, requestOrigin?: string | null) => {
|
||||
const createUpstreamResponseHeaders = (
|
||||
response: Response,
|
||||
requestOrigin: string | null | undefined,
|
||||
localSecure: boolean,
|
||||
) => {
|
||||
const headers = new Headers(response.headers)
|
||||
RESPONSE_HEADERS_TO_DROP.forEach(header => headers.delete(header))
|
||||
|
||||
const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie())
|
||||
const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie(), { localSecure })
|
||||
rewrittenSetCookies?.forEach((cookie) => {
|
||||
headers.append('set-cookie', cookie)
|
||||
})
|
||||
@@ -126,7 +136,11 @@ const proxyRequest = async (
|
||||
}
|
||||
|
||||
const upstreamResponse = await fetchImpl(targetUrl, requestInit)
|
||||
const responseHeaders = createUpstreamResponseHeaders(upstreamResponse, context.req.header('origin'))
|
||||
const responseHeaders = createUpstreamResponseHeaders(
|
||||
upstreamResponse,
|
||||
context.req.header('origin'),
|
||||
requestUrl.protocol === 'https:',
|
||||
)
|
||||
|
||||
return new Response(upstreamResponse.body, {
|
||||
status: upstreamResponse.status,
|
||||
|
||||
13
web/pnpm-lock.yaml
generated
13
web/pnpm-lock.yaml
generated
@@ -512,6 +512,9 @@ importers:
|
||||
'@typescript/native-preview':
|
||||
specifier: 7.0.0-dev.20260312.1
|
||||
version: 7.0.0-dev.20260312.1
|
||||
'@vitejs/plugin-basic-ssl':
|
||||
specifier: 2.2.0
|
||||
version: 2.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))
|
||||
'@vitejs/plugin-react':
|
||||
specifier: 6.0.0
|
||||
version: 6.0.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))
|
||||
@@ -3603,6 +3606,12 @@ packages:
|
||||
resolution: {integrity: sha512-hBcWIOppZV14bi+eAmCZj8Elj8hVSUZJTpf1lgGBhVD85pervzQ1poM/qYfFUlPraYSZYP+ASg6To5BwYmUSGQ==}
|
||||
engines: {node: '>=16'}
|
||||
|
||||
'@vitejs/plugin-basic-ssl@2.2.0':
|
||||
resolution: {integrity: sha512-nmyQ1HGRkfUxjsv3jw0+hMhEdZdrtkvMTdkzRUaRWfiO6PCWw2V2Pz3gldCq96Tn9S8htcgdTxw/gmbLLEbfYw==}
|
||||
engines: {node: ^18.0.0 || ^20.0.0 || >=22.0.0}
|
||||
peerDependencies:
|
||||
vite: ^6.0.0 || ^7.0.0 || ^8.0.0
|
||||
|
||||
'@vitejs/plugin-react@5.2.0':
|
||||
resolution: {integrity: sha512-YmKkfhOAi3wsB1PhJq5Scj3GXMn3WvtQ/JC0xoopuHoXSdmtdStOpFrYaT1kie2YgFBcIe64ROzMYRjCrYOdYw==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
@@ -11030,6 +11039,10 @@ snapshots:
|
||||
'@resvg/resvg-wasm': 2.4.0
|
||||
satori: 0.16.0
|
||||
|
||||
'@vitejs/plugin-basic-ssl@2.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))':
|
||||
dependencies:
|
||||
vite: '@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)'
|
||||
|
||||
'@vitejs/plugin-react@5.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))':
|
||||
dependencies:
|
||||
'@babel/core': 7.29.0
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import { createSecureServer } from 'node:http2'
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
import { serve } from '@hono/node-server'
|
||||
import { getCertificate } from '@vitejs/plugin-basic-ssl'
|
||||
import { loadEnv } from 'vite'
|
||||
import { createDevProxyApp, resolveDevProxyTargets } from '../plugins/dev-proxy/server'
|
||||
import { createDevProxyApp, resolveDevProxyTargets, shouldUseHttpsForDevProxy } from '../plugins/dev-proxy/server'
|
||||
|
||||
const projectRoot = path.resolve(path.dirname(fileURLToPath(import.meta.url)), '..')
|
||||
const mode = process.env.MODE || process.env.NODE_ENV || 'development'
|
||||
@@ -11,11 +13,33 @@ const env = loadEnv(mode, projectRoot, '')
|
||||
const host = env.HONO_PROXY_HOST || '127.0.0.1'
|
||||
const port = Number(env.HONO_PROXY_PORT || 5001)
|
||||
const app = createDevProxyApp(resolveDevProxyTargets(env))
|
||||
const useHttps = shouldUseHttpsForDevProxy(env)
|
||||
|
||||
serve({
|
||||
fetch: app.fetch,
|
||||
hostname: host,
|
||||
port,
|
||||
})
|
||||
if (useHttps) {
|
||||
const certificate = await getCertificate(
|
||||
path.join(projectRoot, 'node_modules/.vite/basic-ssl'),
|
||||
'localhost',
|
||||
Array.from(new Set(['localhost', '127.0.0.1', host])),
|
||||
)
|
||||
|
||||
console.log(`[dev-hono-proxy] listening on http://${host}:${port}`)
|
||||
serve({
|
||||
fetch: app.fetch,
|
||||
hostname: host,
|
||||
port,
|
||||
createServer: createSecureServer,
|
||||
serverOptions: {
|
||||
allowHTTP1: true,
|
||||
cert: certificate,
|
||||
key: certificate,
|
||||
},
|
||||
})
|
||||
}
|
||||
else {
|
||||
serve({
|
||||
fetch: app.fetch,
|
||||
hostname: host,
|
||||
port,
|
||||
})
|
||||
}
|
||||
|
||||
console.log(`[dev-hono-proxy] listening on ${useHttps ? 'https' : 'http'}://${host}:${port}`)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
import basicSsl from '@vitejs/plugin-basic-ssl'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import vinext from 'vinext'
|
||||
import { loadEnv } from 'vite'
|
||||
import Inspect from 'vite-plugin-inspect'
|
||||
import { defineConfig } from 'vite-plus'
|
||||
import { shouldUseHttpsForDevProxy } from './plugins/dev-proxy/protocol'
|
||||
import { createCodeInspectorPlugin, createForceInspectorClientInjectionPlugin } from './plugins/vite/code-inspector'
|
||||
import { customI18nHmrPlugin } from './plugins/vite/custom-i18n-hmr'
|
||||
import { nextStaticImageTestPlugin } from './plugins/vite/next-static-image-test'
|
||||
@@ -21,6 +24,8 @@ export default defineConfig(({ mode }) => {
|
||||
const isTest = mode === 'test'
|
||||
const isStorybook = process.env.STORYBOOK === 'true'
|
||||
|| process.argv.some(arg => arg.toLowerCase().includes('storybook'))
|
||||
const env = loadEnv(mode, projectRoot, '')
|
||||
const useHttpsForDevServer = shouldUseHttpsForDevProxy(env)
|
||||
const isAppComponentsCoverage = coverageScope === 'app-components'
|
||||
const excludedComponentCoverageFiles = isAppComponentsCoverage
|
||||
? collectComponentCoverageExcludedFiles(path.join(projectRoot, 'app/components'), { pathPrefix: 'app/components' })
|
||||
@@ -57,6 +62,7 @@ export default defineConfig(({ mode }) => {
|
||||
react(),
|
||||
vinext({ react: false }),
|
||||
customI18nHmrPlugin({ injectTarget: browserInitializerInjectTarget }),
|
||||
...(useHttpsForDevServer ? [basicSsl()] : []),
|
||||
// reactGrabOpenFilePlugin({
|
||||
// injectTarget: browserInitializerInjectTarget,
|
||||
// projectRoot,
|
||||
|
||||
Reference in New Issue
Block a user