feat(oauth): refactor session management in tool provider operations

This commit is contained in:
Harry
2025-07-09 14:44:36 +08:00
parent ef330fec2c
commit f35b8d6245

View File

@@ -7,6 +7,7 @@ from typing import Any, Optional
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
@@ -114,52 +115,65 @@ class BuiltinToolManageService:
"""
update builtin tool provider
"""
# get if the provider exists
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
try:
if CredentialType.of(db_provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
with Session(db.engine) as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
# Decrypt and restore original credentials for masked values
original_credentials = encrypter.decrypt(db_provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
try:
if CredentialType.of(db_provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
# check if the credential has changed, save the original credential
for key, value in credentials.items():
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, HIDDEN_VALUE)
for key, value in credentials.items()
}
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, new_credentials)
cache.delete()
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
# update name if provided
if name is not None and db_provider.name != name:
db_provider.name = name
cache.delete()
db.session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
db.session.rollback()
raise ValueError(str(e))
# update name if provided
if name is not None and db_provider.name != name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
db_provider.name = name
session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@@ -175,59 +189,69 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
try:
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
with Session(db.engine) as session:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# check if the provider count is reached the limit
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
# validate credentials if allowed
if CredentialType.of(api_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# generate name if not provided
if name is None:
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id=tenant_id, provider=provider, credential_type=api_type
provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# create encrypter
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
)
# check if the provider count is reached the limit
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
)
# validate credentials if allowed
if CredentialType.of(api_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
db.session.add(db_provider)
db.session.commit()
# generate name if not provided
if name is None:
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
)
session.add(db_provider)
session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
db.session.rollback()
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@@ -249,10 +273,12 @@ class BuiltinToolManageService:
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
db.session.query(BuiltinToolProvider)
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
@@ -308,7 +334,7 @@ class BuiltinToolManageService:
default_provider = providers[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, default_provider, default_provider.provider, provider_controller
)
@@ -343,20 +369,28 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if tool_provider is None:
raise ValueError(f"you have not added provider {provider}")
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
db.session.delete(tool_provider)
db.session.commit()
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, tool_provider, provider, provider_controller
)
cache.delete()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
cache.delete()
return {"result": "success"}
@@ -507,18 +541,6 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
return provider
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""