Use hook to get userid (#26839)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-10-14 10:20:37 +09:00
committed by GitHub
parent 56ee8f7d64
commit 0a6b78f883
28 changed files with 503 additions and 495 deletions

View File

@@ -5,18 +5,10 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_user, login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@@ -41,7 +33,7 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@@ -87,7 +79,7 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@@ -130,7 +122,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@@ -172,7 +164,7 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@@ -212,7 +204,7 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@@ -255,7 +247,7 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@@ -288,7 +280,7 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user, tenant_id = _current_account_with_tenant()
user, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)

View File

@@ -25,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_user, login_required
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@@ -41,8 +41,7 @@ class MemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@@ -69,9 +68,7 @@ class MemberInviteEmailApi(Resource):
interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@@ -120,8 +117,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.query(Account).where(Account.id == str(member_id)).first()
@@ -160,9 +156,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@@ -189,8 +183,7 @@ class DatasetOperatorMemberListApi(Resource):
@account_initialization_required
@marshal_with(account_with_role_list_fields)
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@@ -212,10 +205,8 @@ class SendOwnerTransferEmailApi(Resource):
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@@ -250,8 +241,7 @@ class OwnerTransferCheckApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@@ -296,8 +286,7 @@ class OwnerTransfer(Resource):
args = parser.parse_args()
# check if the current user is the owner of the workspace
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@@ -1,7 +1,6 @@
import io
from flask import send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -11,8 +10,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from models.account import Account
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@@ -23,11 +21,8 @@ class ModelProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@@ -52,11 +47,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
@@ -73,8 +65,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -85,11 +76,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
@@ -103,8 +92,7 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -116,11 +104,9 @@ class ModelProviderCredentialApi(Resource):
model_provider_service = ModelProviderService()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
try:
model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
@@ -135,19 +121,16 @@ class ModelProviderCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
)
return {"result": "success"}, 204
@@ -159,19 +142,16 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
)
@@ -184,15 +164,12 @@ class ModelProviderValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
_, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
model_provider_service = ModelProviderService()
@@ -240,14 +217,11 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument(
@@ -276,14 +250,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
def get(self, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
if not isinstance(current_user, Account):
raise ValueError("Invalid user account")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
if not current_user.current_tenant_id:
raise ValueError("No current tenant")
data = BillingService.get_model_provider_payment_link(
provider_name=provider,
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
account_id=current_user.id,
prefilled_email=current_user.email,
)

View File

@@ -1,6 +1,5 @@
import logging
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -10,7 +9,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@@ -23,6 +22,8 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument(
"model_type",
@@ -34,8 +35,6 @@ class DefaultModelApi(Resource):
)
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args["model_type"]
@@ -47,15 +46,14 @@ class DefaultModelApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
model_settings = args["model_settings"]
for model_setting in model_settings:
@@ -92,7 +90,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
@@ -104,11 +102,11 @@ class ModelProviderModelApi(Resource):
@account_initialization_required
def post(self, provider: str):
# To save the model's load balance configs
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@@ -129,7 +127,7 @@ class ModelProviderModelApi(Resource):
raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService()
service.switch_active_custom_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -164,11 +162,11 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
parser.add_argument(
@@ -195,7 +193,7 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
@@ -257,6 +255,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -274,7 +274,6 @@ class ModelProviderModelCredentialApi(Resource):
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
try:
@@ -301,6 +300,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def put(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -323,7 +324,7 @@ class ModelProviderModelCredentialApi(Resource):
try:
model_provider_service.update_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -340,6 +341,8 @@ class ModelProviderModelCredentialApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -357,7 +360,7 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -373,6 +376,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
@@ -390,7 +395,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
service = ModelProviderService()
service.add_model_credential_to_model_list(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider=provider,
model_type=args["model_type"],
model=args["model"],
@@ -407,7 +412,7 @@ class ModelProviderModelEnableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -437,7 +442,7 @@ class ModelProviderModelDisableApi(Resource):
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -465,7 +470,7 @@ class ModelProviderModelValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@@ -514,8 +519,7 @@ class ModelProviderModelParameterRuleApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules(
@@ -531,8 +535,7 @@ class ModelProviderAvailableModelApi(Resource):
@login_required
@account_initialization_required
def get(self, model_type):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@@ -1,7 +1,6 @@
import io
from flask import request, send_file
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
@@ -11,7 +10,7 @@ from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
@@ -26,7 +25,7 @@ class PluginDebuggingKeyApi(Resource):
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {
@@ -44,7 +43,7 @@ class PluginListApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=False, location="args", default=1)
parser.add_argument("page_size", type=int, required=False, location="args", default=256)
@@ -81,7 +80,7 @@ class PluginListInstallationsFromIdsApi(Resource):
@login_required
@account_initialization_required
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_ids", type=list, required=True, location="json")
@@ -120,7 +119,7 @@ class PluginUploadFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
@@ -144,7 +143,7 @@ class PluginUploadFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@@ -167,7 +166,7 @@ class PluginUploadFromBundleApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
@@ -191,7 +190,7 @@ class PluginInstallFromPkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@@ -217,7 +216,7 @@ class PluginInstallFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@@ -247,7 +246,7 @@ class PluginInstallFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@@ -273,7 +272,7 @@ class PluginFetchMarketplacePkgApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@@ -299,7 +298,7 @@ class PluginFetchManifestApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
@@ -324,7 +323,7 @@ class PluginFetchInstallTasksApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@@ -346,7 +345,7 @@ class PluginFetchInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@@ -361,7 +360,7 @@ class PluginDeleteInstallTaskApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@@ -376,7 +375,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
@@ -391,7 +390,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@@ -406,7 +405,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@@ -430,7 +429,7 @@ class PluginUpgradeFromGithubApi(Resource):
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@@ -466,7 +465,7 @@ class PluginUninstallApi(Resource):
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
@@ -480,6 +479,7 @@ class PluginChangePermissionApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
@@ -492,7 +492,7 @@ class PluginChangePermissionApi(Resource):
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
tenant_id = current_tenant_id
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
@@ -503,7 +503,7 @@ class PluginFetchPermissionApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
@@ -529,10 +529,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@account_initialization_required
def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id
parser = reqparse.RequestParser()
@@ -565,7 +565,7 @@ class PluginChangePreferencesApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@@ -574,8 +574,6 @@ class PluginChangePreferencesApi(Resource):
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
@@ -621,7 +619,7 @@ class PluginFetchPreferencesApi(Resource):
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
@@ -661,7 +659,7 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")

View File

@@ -2,7 +2,6 @@ import io
from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restx import (
Resource,
reqparse,
@@ -24,7 +23,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import ToolProviderID
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
@@ -53,10 +52,9 @@ class ToolProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument(
@@ -78,9 +76,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
@@ -96,9 +92,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@@ -109,11 +103,10 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
@@ -131,10 +124,9 @@ class ToolBuiltinProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -161,13 +153,12 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
@@ -193,7 +184,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
@@ -218,13 +209,12 @@ class ToolApiProviderAddApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -258,10 +248,9 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -282,10 +271,9 @@ class ToolApiProviderListToolsApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -308,13 +296,12 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
@@ -350,13 +337,12 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -377,10 +363,9 @@ class ToolApiProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
@@ -401,8 +386,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
@@ -444,9 +428,9 @@ class ToolApiProviderPreviousTestApi(Resource):
parser.add_argument("schema", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
current_tenant_id,
args["provider_name"] or "",
args["tool_name"],
args["credentials"],
@@ -462,13 +446,12 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -502,13 +485,12 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -545,13 +527,12 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
@@ -571,10 +552,9 @@ class ToolWorkflowProviderGetApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
@@ -606,10 +586,9 @@ class ToolWorkflowProviderListToolApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
@@ -631,10 +610,9 @@ class ToolBuiltinListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -653,8 +631,7 @@ class ToolApiListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
[
@@ -672,10 +649,9 @@ class ToolWorkflowListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = current_account_with_tenant()
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
@@ -709,19 +685,18 @@ class ToolPluginOAuthApi(Resource):
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
user_id=user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
@@ -800,11 +775,12 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
@@ -819,13 +795,13 @@ class ToolOAuthCustomClient(Resource):
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
@@ -835,20 +811,18 @@ class ToolOAuthCustomClient(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@@ -858,9 +832,10 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
tenant_id=current_tenant_id, provider_name=provider
)
)
@@ -871,7 +846,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
@@ -900,12 +875,12 @@ class ToolProviderMCPApi(Resource):
)
parser.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
args = parser.parse_args()
user = current_user
user, tenant_id = current_account_with_tenant()
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
@@ -940,8 +915,9 @@ class ToolProviderMCPApi(Resource):
pass
else:
raise ValueError("Server URL is not valid.")
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
@@ -962,7 +938,8 @@ class ToolProviderMCPApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
_, current_tenant_id = current_account_with_tenant()
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
@@ -977,7 +954,7 @@ class ToolMCPAuthApi(Resource):
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
@@ -1018,8 +995,8 @@ class ToolMCPDetailApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
_, tenant_id = current_account_with_tenant()
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@@ -1029,8 +1006,7 @@ class ToolMCPListAllApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
@@ -1043,7 +1019,7 @@ class ToolMCPUpdateApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
_, tenant_id = current_account_with_tenant()
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,