Compare commits

...

5 Commits

Author SHA1 Message Date
Yanli 盐粒
42e8af65d8 fix module_loading cached import fallback 2026-03-15 02:27:19 +08:00
Yanli 盐粒
7b3a6b953c document explicit class attribute typing 2026-03-15 02:07:59 +08:00
Yanli 盐粒
31b60a4966 finish phase 1 oauth contract cleanup 2026-03-15 02:05:23 +08:00
Yanli 盐粒
00eb30ddb8 shrink more phase 1 pyrefly excludes 2026-03-13 18:44:12 +08:00
Yanli 盐粒
57a8f00b1e fix phase 1 shared type contracts 2026-03-13 18:20:16 +08:00
14 changed files with 313 additions and 196 deletions

View File

@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
```python
from datetime import datetime

View File

@@ -1,4 +1,4 @@
from typing import Literal, Protocol
from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
@@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
def _redis_defaults(config: object) -> RedisConfigDefaults:
return cast(RedisConfigDefaults, config)
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
class RedisPubSubConfig(BaseSettings):
"""
Configuration settings for event transport between API and workers.
@@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
@@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{host}:{port}"
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from dify_graph.file.models import File
if TYPE_CHECKING:
pass
from dify_graph.variables.segments import Segment
class ArrayValidation(StrEnum):
@@ -219,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: SegmentType):
def get_zero_value(t: SegmentType) -> Segment:
# Lazy import to avoid circular dependency
from factories import variable_factory

View File

@@ -1,3 +1,5 @@
from typing import Protocol, cast
from fastopenapi.routers import FlaskRouter
from flask_cors import CORS
@@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
DOCS_PREFIX = "/fastopenapi"
class SupportsIncludeRouter(Protocol):
def include_router(self, router: object, *, prefix: str = "") -> None: ...
def init_app(app: DifyApp) -> None:
docs_enabled = dify_config.SWAGGER_UI_ENABLED
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
@@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
_ = remote_files
_ = setup
router.include_router(console_router, prefix="/console/api")
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
CORS(
app,
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},

View File

@@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
ArrayAnySegment: ArrayAnyVariable,
ArrayBooleanSegment: ArrayBooleanVariable,
ArrayFileSegment: ArrayFileVariable,
@@ -296,13 +296,11 @@ def segment_to_variable(
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
VariableBase,
variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=list(selector),
),
return variable_class(
id=id,
name=name,
description=description,
value_type=segment.value_type,
value=segment.value,
selector=list(selector),
)

View File

@@ -32,6 +32,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _stream_with_request_context(response: object) -> Any:
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
return cast(Any, stream_with_context)(response)
def escape_like_pattern(pattern: str) -> str:
"""
Escape special characters in a string for safe use in SQL LIKE patterns.
@@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
def compact_generate_response(
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
) -> Response:
if isinstance(response, Mapping):
return Response(
response=json.dumps(jsonable_encoder(response)),
status=200,
content_type="application/json; charset=utf-8",
)
else:
stream_response = response
def generate() -> Generator:
yield from response
def generate() -> Generator[str, None, None]:
yield from stream_response
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
return Response(
_stream_with_request_context(generate()),
status=200,
mimetype="text/event-stream",
)
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
def length_prefixed_response(
magic_number: int,
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
) -> Response:
"""
This function is used to return a response with a length prefix.
Magic number is a one byte number that indicates the type of the response.
@@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
if isinstance(response, dict):
if isinstance(response, Mapping):
return Response(
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
status=200,
@@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
mimetype="application/json",
)
def generate() -> Generator:
for chunk in response:
stream_response = response
def generate() -> Generator[bytes, None, None]:
for chunk in stream_response:
if isinstance(chunk, str):
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
else:
yield pack_response_with_length_prefix(chunk)
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
return Response(
_stream_with_request_context(generate()),
status=200,
mimetype="text/event-stream",
)
class TokenManager:

View File

@@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif current_user is not None and not current_user.is_authenticated:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _get_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, current_user.id)
check_csrf_token(request, user.id)
return current_app.ensure_sync(func)(*args, **kwargs)
return decorated_view

View File

@@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
import sys
from importlib import import_module
from typing import Any
def cached_import(module_path: str, class_name: str):
def cached_import(module_path: str, class_name: str) -> Any:
"""
Import a module and return the named attribute/class from it, with caching.
@@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
Returns:
The imported attribute/class
"""
if not (
(module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None))
and getattr(spec, "_initializing", False) is False
):
module = sys.modules.get(module_path)
spec = getattr(module, "__spec__", None) if module is not None else None
if module is None or spec is None or getattr(spec, "_initializing", False):
module = import_module(module_path)
return getattr(module, class_name)
def import_string(dotted_path: str):
def import_string(dotted_path: str) -> Any:
"""
Import a dotted module path and return the attribute/class designated by
the last name in the path. Raise ImportError if the import failed.

View File

@@ -1,7 +1,48 @@
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
import httpx
from pydantic import TypeAdapter
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
class AccessTokenResponse(TypedDict, total=False):
access_token: str
class GitHubEmailRecord(TypedDict, total=False):
email: str
primary: bool
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str]
email: NotRequired[str]
class GoogleRawUserInfo(TypedDict):
sub: str
email: str
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
@dataclass
@@ -11,26 +52,38 @@ class OAuthUserInfo:
email: str
def _json_object(response: httpx.Response) -> JsonObject:
return JSON_OBJECT_ADAPTER.validate_python(response.json())
def _json_list(response: httpx.Response) -> JsonObjectList:
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
class OAuth:
client_id: str
client_secret: str
redirect_uri: str
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
def get_authorization_url(self, invite_token: str | None = None) -> str:
raise NotImplementedError()
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
raise NotImplementedError()
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
raise NotImplementedError()
def get_user_info(self, token: str) -> OAuthUserInfo:
raw_info = self.get_raw_user_info(token)
return self._transform_user_info(raw_info)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
raise NotImplementedError()
@@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: str | None = None):
def get_authorization_url(self, invite_token: str | None = None) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
@@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
@@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
if not access_token:
@@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
return access_token
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"token {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
return {**user_info, "email": primary_email.get("email", "")}
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get("email")
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
email = payload.get("email")
if not email:
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
class GoogleOAuth(OAuth):
@@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: str | None = None):
def get_authorization_url(self, invite_token: str | None = None) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
@@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
@@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
if not access_token:
@@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
return access_token
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"Bearer {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
return _json_object(response)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])

View File

@@ -1,25 +1,57 @@
import sys
import urllib.parse
from typing import Any
from typing import Any, Literal
import httpx
from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
class NotionPageSummary(TypedDict):
page_id: str
page_name: str
page_icon: dict[str, str] | None
parent_id: str
type: Literal["page", "database"]
class NotionSourceInfo(TypedDict):
workspace_name: str | None
workspace_icon: str | None
workspace_id: str | None
pages: list[NotionPageSummary]
total: int
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
class OAuthDataSource:
client_id: str
client_secret: str
redirect_uri: str
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
def get_authorization_url(self) -> str:
raise NotImplementedError()
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> None:
raise NotImplementedError()
@@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
def get_authorization_url(self):
def get_authorization_url(self) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
@@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> None:
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
@@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
workspace_id = response_json.get("workspace_id")
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
source_info = self._build_source_info(
workspace_name=workspace_name,
workspace_icon=workspace_icon,
workspace_id=workspace_id,
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
@@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def save_internal_access_token(self, access_token: str):
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
workspace_icon = None
workspace_id = current_user.current_tenant_id
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
source_info = self._build_source_info(
workspace_name=workspace_name,
workspace_icon=workspace_icon,
workspace_id=workspace_id,
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
@@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def sync_data_source(self, binding_id: str):
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = data_source_binding.source_info
new_source_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
"total": len(pages),
}
data_source_binding.source_info = new_source_info
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str):
pages = []
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []
page_results = self.notion_page_search(access_token)
database_results = self.notion_database_search(access_token)
# get page detail
@@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
"parent_id": parent_id,
"type": "page",
}
pages.append(page)
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
# get database detail
for database_result in database_results:
page_id = database_result["id"]
@@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
"parent_id": parent_id,
"type": "database",
}
pages.append(page)
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
return pages
def notion_page_search(self, access_token: str):
results = []
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
next_cursor = None
has_more = True
@@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
return results
def notion_block_parent_page_id(self, access_token: str, block_id: str):
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
@@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
return self.notion_block_parent_page_id(access_token, parent[parent_type])
return parent[parent_type]
def notion_workspace_name(self, access_token: str):
def notion_workspace_name(self, access_token: str) -> str:
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
@@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
return user_info["workspace_name"]
return "workspace"
def notion_database_search(self, access_token: str):
results = []
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
next_cursor = None
has_more = True
@@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
next_cursor = response_json.get("next_cursor", None)
return results
@staticmethod
def _build_source_info(
*,
workspace_name: str | None,
workspace_icon: str | None,
workspace_id: str | None,
pages: list[NotionPageSummary],
) -> NotionSourceInfo:
return {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}

View File

@@ -23,6 +23,9 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
from .model import Account
from .types import EnumText, LongText, StringUUID
TriggerJsonObject = dict[str, object]
TriggerCredentials = dict[str, str]
class TriggerSubscription(TypeBase):
"""
@@ -51,10 +54,14 @@ class TriggerSubscription(TypeBase):
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
)
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
parameters: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription parameters JSON"
)
properties: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription properties JSON"
)
credentials: Mapped[dict[str, Any]] = mapped_column(
credentials: Mapped[TriggerCredentials] = mapped_column(
sa.JSON, nullable=False, comment="Subscription credentials JSON"
)
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
@@ -162,8 +169,8 @@ class TriggerOAuthTenantClient(TypeBase):
)
@property
def oauth_params(self) -> Mapping[str, Any]:
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
def oauth_params(self) -> Mapping[str, object]:
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
class WorkflowTriggerLog(TypeBase):

View File

@@ -19,7 +19,7 @@ from sqlalchemy import (
orm,
select,
)
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from dify_graph.constants import (
@@ -32,7 +32,7 @@ from dify_graph.enums import NodeType, WorkflowExecutionStatus
from dify_graph.file.constants import maybe_file_object
from dify_graph.file.models import File
from dify_graph.variables import utils as variable_utils
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
@@ -58,6 +58,9 @@ from .types import EnumText, LongText, StringUUID
logger = logging.getLogger(__name__)
SerializedWorkflowValue = dict[str, Any]
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
class WorkflowType(StrEnum):
"""
@@ -390,7 +393,7 @@ class Workflow(Base): # bug
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
return variables
@@ -433,17 +436,13 @@ class Workflow(Base): # bug
def environment_variables(
self,
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
results = [
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
]
@@ -521,11 +520,7 @@ class Workflow(Base): # bug
@property
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
return results
@@ -537,19 +532,20 @@ class Workflow(Base): # bug
)
@property
def rag_pipeline_variables(self) -> list[dict]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = list(variables_dict.values())
return results
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: list[dict]) -> None:
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values},
{
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
for rag_pipeline_variable in (
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
for item in values
)
},
ensure_ascii=False,
)
@@ -787,44 +783,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
# The first argument is the index name,
# which we leave as `None`` to allow auto-generation by the ORM.
None,
cls.tenant_id,
cls.workflow_id,
cls.node_id,
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(),
),
)
__table_args__ = (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
None,
"tenant_id",
"workflow_id",
"node_id",
sa.desc("created_at"),
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID)

View File

@@ -1,4 +1,3 @@
configs/middleware/cache/redis_pubsub_config.py
controllers/console/app/annotation.py
controllers/console/app/app.py
controllers/console/app/app_import.py
@@ -138,8 +137,6 @@ dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
dify_graph/variables/types.py
extensions/ext_fastopenapi.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
@@ -156,19 +153,7 @@ extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
factories/variable_factory.py
libs/external_api.py
libs/gmpy2_pkcs10aep_cipher.py
libs/helper.py
libs/login.py
libs/module_loading.py
libs/oauth.py
libs/oauth_data_source.py
models/trigger.py
models/workflow.py
repositories/sqlalchemy_api_workflow_node_execution_repository.py
repositories/sqlalchemy_api_workflow_run_repository.py
repositories/sqlalchemy_execution_extra_content_repository.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py

View File

@@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
import json
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from typing import Protocol, cast
from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
@@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
)
class _WorkflowNodeExecutionSnapshotRow(Protocol):
id: str
node_execution_id: str | None
node_id: str
node_type: str
title: str
index: int
status: WorkflowNodeExecutionStatus
elapsed_time: float | None
created_at: datetime
finished_at: datetime | None
execution_metadata: str | None
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
@@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
- Thread-safe database operations using session-per-request pattern
"""
_session_maker: sessionmaker[Session]
def __init__(self, session_maker: sessionmaker[Session]):
"""
Initialize the repository with a sessionmaker.
@@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
)
with self._session_maker() as session:
rows = session.execute(stmt).all()
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
return [self._row_to_snapshot(row) for row in rows]
@staticmethod
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
metadata: dict[str, object] = {}
execution_metadata = getattr(row, "execution_metadata", None)
if execution_metadata: