Compare commits

...

9 Commits

Author SHA1 Message Date
Harry
d2a0e498ea fix: change default value of expires_at field in tool_builtin_providers to -1
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2025-07-22 15:03:56 +08:00
Harry
3b44f11439 fix: update redirect URI and system credentials retrieval in tool manager 2025-07-22 14:57:55 +08:00
Harry
9bef8d3856 Merge remote-tracking branch 'origin/main' into feat/oauth 2025-07-22 13:01:11 +08:00
Harry
c538c9f127 Merge remote-tracking branch 'origin/main' into feat/oauth 2025-07-22 10:53:54 +08:00
Yeuoly
fe1a3ca943 fix: adjust expires_at check to allow for a 60-second buffer 2025-07-22 01:14:26 +08:00
Yeuoly
20ca2033ce fix: mypy 2025-07-22 01:10:27 +08:00
Yeuoly
de6708382b fix: circular import 2025-07-22 01:02:28 +08:00
Yeuoly
7fa952b1a2 feat: implement OAuth credentials refresh mechanism and update expires_at handling 2025-07-22 00:58:19 +08:00
Harry
5d986c2cdf feat: add expires_at field to OAuth credentials and default value for builtin tool provider 2025-07-22 00:53:03 +08:00
7 changed files with 125 additions and 4 deletions

View File

@@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials( credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=plugin_id, plugin_id=plugin_id,
@@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=oauth_client_params, system_credentials=oauth_client_params,
request=request, request=request,
).credentials )
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials: if not credentials:
raise Exception("the plugin credentials failed") raise Exception("the plugin credentials failed")
@@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
credentials=dict(credentials), credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2, api_type=CredentialType.OAUTH2,
) )
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

View File

@@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):
class PluginOAuthCredentialsResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel):
metadata: Mapping[str, Any] = Field(
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
)
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")

View File

@@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
except Exception as e: except Exception as e:
raise ValueError(f"Error getting credentials: {e}") raise ValueError(f"Error getting credentials: {e}")
def refresh_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
credentials: Mapping[str, Any],
) -> PluginOAuthCredentialsResponse:
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh credentials request.")
except Exception as e:
raise ValueError(f"Error refreshing credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes: def _convert_request_to_raw_data(self, request: Request) -> bytes:
""" """
Convert a Request object to raw HTTP data. Convert a Request object to raw HTTP data.

View File

@@ -1,16 +1,19 @@
import json import json
import logging import logging
import mimetypes import mimetypes
from collections.abc import Generator import time
from collections.abc import Generator, Mapping
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import TypeAdapter
from yarl import URL from yarl import URL
import contexts import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
@@ -244,12 +247,47 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
), ),
) )
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials), credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type), credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={}, runtime_parameters={},
invoke_from=invoke_from, invoke_from=invoke_from,

View File

@@ -0,0 +1,34 @@
"""oauth_refresh_token
Revision ID: 375fe79ead14
Revises: 1a83934ad6d1
Create Date: 2025-07-22 00:19:45.599636
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '375fe79ead14'
down_revision = '1a83934ad6d1'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_column('expires_at')
# ### end Alembic commands ###

View File

@@ -93,6 +93,7 @@ class BuiltinToolProvider(Base):
credential_type: Mapped[str] = mapped_column( credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
) )
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:

View File

@@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService: class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
__DEFAULT_EXPIRES_AT__ = 2147483647
@staticmethod @staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str): def delete_custom_oauth_client_params(tenant_id: str, provider: str):
@@ -212,6 +213,7 @@ class BuiltinToolManageService:
tenant_id: str, tenant_id: str,
provider: str, provider: str,
credentials: dict, credentials: dict,
expires_at: int = -1,
name: str | None = None, name: str | None = None,
): ):
""" """
@@ -269,6 +271,9 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value, credential_type=api_type.value,
name=name, name=name,
expires_at=expires_at
if expires_at is not None
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
) )
session.add(db_provider) session.add(db_provider)