add more typing (#24949)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-08 11:40:00 +09:00
committed by GitHub
parent ce2281d31b
commit f6059ef389
9 changed files with 97 additions and 74 deletions

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, reqparse
@@ -6,6 +8,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
@@ -14,9 +18,9 @@ from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
def admin_required(view):
def admin_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import cast
from typing import Concatenate, ParamSpec, TypeVar, cast
import flask_login
from flask import jsonify, request
@@ -15,10 +16,14 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import api
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def oauth_server_client_id_required(view):
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
@@ -30,18 +35,15 @@ def oauth_server_client_id_required(view):
if not oauth_provider_app:
raise NotFound("client_id is invalid")
kwargs["oauth_provider_app"] = oauth_provider_app
return view(*args, **kwargs)
return view(self, oauth_provider_app, *args, **kwargs)
return decorated
def oauth_server_access_token_required(view):
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
oauth_provider_app = kwargs.get("oauth_provider_app")
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization")
@@ -79,9 +81,7 @@ def oauth_server_access_token_required(view):
response.headers["WWW-Authenticate"] = "Bearer"
return response
kwargs["account"] = account
return view(*args, **kwargs)
return view(self, oauth_provider_app, account, *args, **kwargs)
return decorated

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource
@@ -13,19 +15,15 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view=None):
def decorator(view):
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(*args, **kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
installed_app = (
db.session.query(InstalledApp)
.where(
@@ -52,10 +50,10 @@ def installed_app_required(view=None):
return decorator
def user_allowed_to_access_app(view=None):
def decorator(view):
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args, **kwargs):
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session
@@ -7,14 +9,17 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
user = current_user
tenant_id = user.current_tenant_id

View File

@@ -2,7 +2,9 @@ import contextlib
import json
import os
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from flask_login import current_user
@@ -19,10 +21,13 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
def account_initialization_required(view):
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = current_user
@@ -34,9 +39,9 @@ def account_initialization_required(view):
return decorated
def only_edition_cloud(view):
def only_edition_cloud(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD":
abort(404)
@@ -45,9 +50,9 @@ def only_edition_cloud(view):
return decorated
def only_edition_enterprise(view):
def only_edition_enterprise(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED:
abort(404)
@@ -56,9 +61,9 @@ def only_edition_enterprise(view):
return decorated
def only_edition_self_hosted(view):
def only_edition_self_hosted(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED":
abort(404)
@@ -67,9 +72,9 @@ def only_edition_self_hosted(view):
return decorated
def cloud_edition_billing_enabled(view):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
@@ -79,9 +84,9 @@ def cloud_edition_billing_enabled(view):
def cloud_edition_billing_resource_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
members = features.members
@@ -120,9 +125,9 @@ def cloud_edition_billing_resource_check(resource: str):
def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
@@ -142,9 +147,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
@@ -176,9 +181,9 @@ def cloud_edition_billing_rate_limit_check(resource: str):
return interceptor
def cloud_utm_record(view):
def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)
@@ -194,9 +199,9 @@ def cloud_utm_record(view):
return decorated
def setup_required(view):
def setup_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"
@@ -212,9 +217,9 @@ def setup_required(view):
return decorated
def enterprise_license_required(view):
def enterprise_license_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]:
raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.")
@@ -224,9 +229,9 @@ def enterprise_license_required(view):
return decorated
def email_password_login_enabled(view):
def email_password_login_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.enable_email_password_login:
return view(*args, **kwargs)
@@ -237,9 +242,9 @@ def email_password_login_enabled(view):
return decorated
def enable_change_email(view):
def enable_change_email(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
if features.enable_change_email:
return view(*args, **kwargs)
@@ -250,9 +255,9 @@ def enable_change_email(view):
return decorated
def is_allow_transfer_owner(view):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)