mirror of
https://github.com/langgenius/dify.git
synced 2026-04-08 01:12:43 +00:00
Compare commits
6 Commits
feat/plugi
...
fix/unsele
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb97a93508 | ||
|
|
407d66c62d | ||
|
|
9d75d3d04c | ||
|
|
02852ee543 | ||
|
|
32b2d19622 | ||
|
|
874406d934 |
@@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
@@ -105,6 +105,12 @@ class ChatMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
|
||||
@@ -172,7 +172,7 @@ class MessageAnnotationApi(Resource):
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@@ -2,8 +2,8 @@ import json
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -13,7 +13,8 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@@ -25,6 +26,13 @@ class ModelConfigResource(Resource):
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
||||
@@ -69,7 +69,7 @@ class DraftWorkflowApi(Resource):
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
@@ -92,7 +92,7 @@ class DraftWorkflowApi(Resource):
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
@@ -170,7 +170,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
@@ -220,7 +220,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -256,7 +256,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -293,7 +293,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -330,7 +330,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -367,7 +367,7 @@ class DraftWorkflowRunApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -406,7 +406,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
@@ -428,7 +428,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -476,7 +476,7 @@ class PublishedWorkflowApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch published workflow by app_model
|
||||
@@ -497,7 +497,7 @@ class PublishedWorkflowApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -547,7 +547,7 @@ class DefaultBlockConfigsApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
@@ -567,7 +567,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -602,7 +602,7 @@ class ConvertToWorkflowApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
@@ -651,7 +651,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -702,7 +702,7 @@ class WorkflowByIdApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -758,7 +758,7 @@ class WorkflowByIdApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@@ -137,7 +137,7 @@ def _api_prerequisite(f):
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -175,6 +175,22 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderCredentialCancelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.cancel_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -289,6 +305,9 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
|
||||
api.add_resource(
|
||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderCredentialCancelApi, "/workspaces/current/model-providers/<path:provider>/credentials/cancel"
|
||||
)
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
|
||||
api.add_resource(
|
||||
|
||||
@@ -165,7 +165,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
def put(self, app_model: App, annotation_id):
|
||||
"""Update an existing annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
@@ -189,7 +189,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
"""Delete an annotation."""
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
annotation_id = str(annotation_id)
|
||||
|
||||
@@ -559,7 +559,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_create_parser.parse_args()
|
||||
@@ -583,7 +583,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_update_parser.parse_args()
|
||||
@@ -610,7 +610,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
@@ -634,7 +634,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_binding_parser.parse_args()
|
||||
@@ -660,7 +660,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_unbinding_parser.parse_args()
|
||||
|
||||
@@ -33,6 +33,7 @@ from core.plugin.entities.plugin import ModelProviderID
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import (
|
||||
CredentialStatus,
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderCredential,
|
||||
@@ -43,6 +44,7 @@ from models.provider import (
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.entities.model_provider_entities import CustomConfigurationStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -189,6 +191,22 @@ class ProviderConfiguration(BaseModel):
|
||||
else SystemConfigurationStatus.QUOTA_EXCEEDED
|
||||
)
|
||||
|
||||
def get_custom_configuration_status(self) -> Optional[CustomConfigurationStatus]:
|
||||
"""
|
||||
Get custom configuration status.
|
||||
:return:
|
||||
"""
|
||||
if self.is_custom_configuration_available():
|
||||
return CustomConfigurationStatus.ACTIVE
|
||||
|
||||
provider = self.custom_configuration.provider
|
||||
if provider and hasattr(provider, "current_credential_status"):
|
||||
status = provider.current_credential_status
|
||||
if status:
|
||||
return status
|
||||
|
||||
return CustomConfigurationStatus.NO_CONFIGURE
|
||||
|
||||
def is_custom_configuration_available(self) -> bool:
|
||||
"""
|
||||
Check custom configuration available.
|
||||
@@ -643,6 +661,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
|
||||
elif provider_record and provider_record.credential_id == credential_id:
|
||||
provider_record.credential_id = None
|
||||
provider_record.credential_status = CredentialStatus.REMOVED.value
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
@@ -681,6 +700,34 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
try:
|
||||
provider_record.credential_id = credential_record.id
|
||||
provider_record.credential_status = CredentialStatus.ACTIVE.value
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
def cancel_provider_credential(self):
|
||||
"""
|
||||
Cancel select the active provider credential.
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
if not provider_record:
|
||||
raise ValueError("Provider record not found.")
|
||||
|
||||
try:
|
||||
provider_record.credential_id = None
|
||||
provider_record.credential_status = CredentialStatus.CANCELED.value
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.entities.parameter_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from models.provider import CredentialStatus
|
||||
|
||||
|
||||
class ProviderQuotaType(Enum):
|
||||
@@ -97,6 +98,7 @@ class CustomProviderConfiguration(BaseModel):
|
||||
credentials: dict
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
current_credential_status: Optional[CredentialStatus] = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
|
||||
|
||||
@@ -711,6 +711,7 @@ class ProviderManager:
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
current_credential_status=custom_provider_record.credential_status,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Add credential status for provider table
|
||||
|
||||
Revision ID: cf7c38a32b2d
|
||||
Revises: c20211f18133
|
||||
Create Date: 2025-09-11 15:37:17.771298
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'cf7c38a32b2d'
|
||||
down_revision = 'c20211f18133'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_status', sa.String(length=20), server_default=sa.text("'active'::character varying"), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_status')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -7,6 +7,7 @@ import sqlalchemy as sa
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from models.base import Base
|
||||
|
||||
@@ -187,7 +188,28 @@ class Account(UserMixin, Base):
|
||||
return TenantAccountRole.is_admin_role(self.role)
|
||||
|
||||
@property
|
||||
@deprecated("Use has_edit_permission instead.")
|
||||
def is_editor(self):
|
||||
"""Determines if the account has edit permissions in their current tenant (workspace).
|
||||
|
||||
This property checks if the current role has editing privileges, which includes:
|
||||
- `OWNER`
|
||||
- `ADMIN`
|
||||
- `EDITOR`
|
||||
|
||||
Note: This checks for any role with editing permission, not just the 'EDITOR' role specifically.
|
||||
"""
|
||||
return self.has_edit_permission
|
||||
|
||||
@property
|
||||
def has_edit_permission(self):
|
||||
"""Determines if the account has editing permissions in their current tenant (workspace).
|
||||
|
||||
This property checks if the current role has editing privileges, which includes:
|
||||
- `OWNER`
|
||||
- `ADMIN`
|
||||
- `EDITOR`
|
||||
"""
|
||||
return TenantAccountRole.is_editing_role(self.role)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
@@ -42,6 +42,12 @@ class ProviderQuotaType(Enum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class CredentialStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
CANCELED = "canceled"
|
||||
REMOVED = "removed"
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
"""
|
||||
Provider model representing the API providers and their configurations.
|
||||
@@ -65,6 +71,9 @@ class Provider(Base):
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
|
||||
last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
|
||||
credential_status: Mapped[Optional[str]] = mapped_column(
|
||||
String(20), nullable=True, server_default=text("'active'::character varying")
|
||||
)
|
||||
|
||||
quota_type: Mapped[Optional[str]] = mapped_column(
|
||||
String(40), nullable=True, server_default=text("''::character varying")
|
||||
|
||||
@@ -34,6 +34,8 @@ class CustomConfigurationStatus(Enum):
|
||||
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
CANCELED = "canceled"
|
||||
REMOVED = "removed"
|
||||
|
||||
|
||||
class CustomConfigurationResponse(BaseModel):
|
||||
|
||||
@@ -89,9 +89,7 @@ class ModelProviderService:
|
||||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||
preferred_provider_type=provider_configuration.preferred_provider_type,
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
if provider_configuration.is_custom_configuration_available()
|
||||
else CustomConfigurationStatus.NO_CONFIGURE,
|
||||
status=provider_configuration.get_custom_configuration_status(),
|
||||
current_credential_id=getattr(provider_config, "current_credential_id", None),
|
||||
current_credential_name=getattr(provider_config, "current_credential_name", None),
|
||||
available_credentials=getattr(provider_config, "available_credentials", []),
|
||||
@@ -214,6 +212,16 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.switch_active_provider_credential(credential_id=credential_id)
|
||||
|
||||
def cancel_provider_credential(self, tenant_id: str, provider: str):
|
||||
"""
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.cancel_provider_credential()
|
||||
|
||||
def get_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
|
||||
) -> Optional[dict]:
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Integration tests for ChatMessageApi permission verification."""
|
||||
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import completion as completion_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
class TestChatMessageApiPermissions:
|
||||
"""Test permission verification for ChatMessageApi endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access chat-messages endpoint."""
|
||||
|
||||
"""Setup common mocks for testing."""
|
||||
# Mock app loading
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(completion_api, "current_user", mock_account)
|
||||
|
||||
mock_generate = mock.Mock(return_value={"message": "Test response"})
|
||||
monkeypatch.setattr(AppGenerateService, "generate", mock_generate)
|
||||
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"inputs": {},
|
||||
"query": "Hello, world!",
|
||||
"model_config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}
|
||||
},
|
||||
"response_mode": "blocking",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Integration tests for ModelConfigResource permission verification."""
|
||||
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import model_config as model_config_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
class TestModelConfigResourcePermissions:
|
||||
"""Test permission verification for ModelConfigResource endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.app_model_config_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access model-config endpoint."""
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
# Mock app loading
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(model_config_api, "current_user", mock_account)
|
||||
|
||||
# Mock AccountService.load_user to prevent authentication issues
|
||||
from services.account_service import AccountService
|
||||
|
||||
mock_load_user = mock.Mock(return_value=mock_account)
|
||||
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
|
||||
|
||||
mock_validate_config = mock.Mock(
|
||||
return_value={
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}},
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config)
|
||||
|
||||
# Mock database operations
|
||||
mock_db_session = mock.Mock()
|
||||
mock_db_session.add = mock.Mock()
|
||||
mock_db_session.flush = mock.Mock()
|
||||
mock_db_session.commit = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api.db, "session", mock_db_session)
|
||||
|
||||
# Mock app_model_config_was_updated event
|
||||
mock_event = mock.Mock()
|
||||
mock_event.send = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event)
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/model-config",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.7, "max_tokens": 1000},
|
||||
},
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
@@ -962,7 +962,8 @@ class TestAccountService:
|
||||
Test getting user through non-existent email.
|
||||
"""
|
||||
fake = Faker()
|
||||
non_existent_email = fake.email()
|
||||
domain = f"test-{fake.random_letters(10)}.com"
|
||||
non_existent_email = fake.email(domain=domain)
|
||||
found_user = AccountService.get_user_through_email(non_existent_email)
|
||||
assert found_user is None
|
||||
|
||||
|
||||
@@ -107,6 +107,8 @@ export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = {
|
||||
export enum CustomConfigurationStatusEnum {
|
||||
active = 'active',
|
||||
noConfigure = 'no-configure',
|
||||
canceled = 'canceled',
|
||||
removed = 'removed',
|
||||
}
|
||||
|
||||
export type FormShowOnObject = {
|
||||
|
||||
@@ -46,6 +46,7 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
providers.forEach((provider) => {
|
||||
if (
|
||||
provider.custom_configuration.status === CustomConfigurationStatusEnum.active
|
||||
|| provider.custom_configuration.status === CustomConfigurationStatusEnum.removed
|
||||
|| (
|
||||
provider.system_configuration.enabled === true
|
||||
&& provider.system_configuration.quota_configurations.find(item => item.quota_type === provider.system_configuration.current_quota_type)
|
||||
|
||||
@@ -64,6 +64,7 @@ type AuthorizedProps = {
|
||||
showModelTitle?: boolean
|
||||
disableDeleteButShowAction?: boolean
|
||||
disableDeleteTip?: string
|
||||
showDeselect?: boolean
|
||||
}
|
||||
const Authorized = ({
|
||||
provider,
|
||||
@@ -88,6 +89,7 @@ const Authorized = ({
|
||||
showModelTitle,
|
||||
disableDeleteButShowAction,
|
||||
disableDeleteTip,
|
||||
showDeselect,
|
||||
}: AuthorizedProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [isLocalOpen, setIsLocalOpen] = useState(false)
|
||||
@@ -112,6 +114,7 @@ const Authorized = ({
|
||||
handleConfirmDelete,
|
||||
deleteCredentialId,
|
||||
handleOpenModal,
|
||||
handleDeselect,
|
||||
} = useAuth(
|
||||
provider,
|
||||
configurationMethod,
|
||||
@@ -171,8 +174,15 @@ const Authorized = ({
|
||||
)}>
|
||||
{
|
||||
popupTitle && (
|
||||
<div className='system-xs-medium px-3 pb-0.5 pt-[10px] text-text-tertiary'>
|
||||
<div className='system-xs-medium flex items-center justify-between px-3 pb-0.5 pt-[10px] text-text-tertiary'>
|
||||
{popupTitle}
|
||||
{
|
||||
showDeselect && (
|
||||
<div onClick={() => handleDeselect()} className='cursor-pointer'>
|
||||
{t('common.modelProvider.auth.deselect')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,10 +29,11 @@ const ConfigProvider = ({
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
hasCredential,
|
||||
authorized,
|
||||
current_credential_id,
|
||||
current_credential_name,
|
||||
available_credentials,
|
||||
authorized,
|
||||
unAuthorized,
|
||||
} = useCredentialStatus(provider)
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
|
||||
@@ -41,11 +42,11 @@ const ConfigProvider = ({
|
||||
<Button
|
||||
className='grow'
|
||||
size='small'
|
||||
variant={!authorized ? 'secondary-accent' : 'secondary'}
|
||||
variant={unAuthorized ? 'secondary-accent' : 'secondary'}
|
||||
>
|
||||
<RiEqualizer2Line className='mr-1 h-3.5 w-3.5' />
|
||||
{hasCredential && t('common.operation.config')}
|
||||
{!hasCredential && t('common.operation.setup')}
|
||||
{!unAuthorized && t('common.operation.config')}
|
||||
{unAuthorized && t('common.operation.setup')}
|
||||
</Button>
|
||||
)
|
||||
if (notAllowCustomCredential && !hasCredential) {
|
||||
@@ -59,7 +60,7 @@ const ConfigProvider = ({
|
||||
)
|
||||
}
|
||||
return Item
|
||||
}, [authorized, hasCredential, notAllowCustomCredential, t])
|
||||
}, [hasCredential, notAllowCustomCredential, t])
|
||||
|
||||
return (
|
||||
<Authorized
|
||||
@@ -68,7 +69,6 @@ const ConfigProvider = ({
|
||||
currentCustomConfigurationModelFixedFields={currentCustomConfigurationModelFixedFields}
|
||||
items={[
|
||||
{
|
||||
title: t('common.modelProvider.auth.apiKeys'),
|
||||
credentials: available_credentials ?? [],
|
||||
selectedCredential: {
|
||||
credential_id: current_credential_id ?? '',
|
||||
@@ -77,9 +77,10 @@ const ConfigProvider = ({
|
||||
},
|
||||
]}
|
||||
showItemSelectedIcon
|
||||
showModelTitle
|
||||
renderTrigger={renderTrigger}
|
||||
triggerOnlyOpenModal={!hasCredential && !notAllowCustomCredential}
|
||||
showDeselect={authorized}
|
||||
popupTitle={t('common.modelProvider.auth.apiKeys')}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,10 @@ import {
|
||||
useModelModalHandler,
|
||||
useRefreshModel,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useDeleteModel } from '@/service/use-models'
|
||||
import {
|
||||
useDeleteModel,
|
||||
useDeselectModelCredential,
|
||||
} from '@/service/use-models'
|
||||
|
||||
export const useAuth = (
|
||||
provider: ModelProvider,
|
||||
@@ -46,6 +49,7 @@ export const useAuth = (
|
||||
getAddCredentialService,
|
||||
} = useAuthService(provider.provider)
|
||||
const { mutateAsync: deleteModelService } = useDeleteModel(provider.provider)
|
||||
const { mutateAsync: deselectModelCredentialService } = useDeselectModelCredential(provider.provider)
|
||||
const handleOpenModelModal = useModelModalHandler()
|
||||
const { handleRefreshModel } = useRefreshModel()
|
||||
const pendingOperationCredentialId = useRef<string | null>(null)
|
||||
@@ -177,6 +181,11 @@ export const useAuth = (
|
||||
mode,
|
||||
])
|
||||
|
||||
const handleDeselect = useCallback(async () => {
|
||||
await deselectModelCredentialService()
|
||||
handleRefreshModel(provider, configurationMethod, undefined)
|
||||
}, [deselectModelCredentialService, handleRefreshModel, provider, configurationMethod])
|
||||
|
||||
return {
|
||||
pendingOperationCredentialId,
|
||||
pendingOperationModel,
|
||||
@@ -189,5 +198,6 @@ export const useAuth = (
|
||||
deleteModel,
|
||||
handleSaveCredential,
|
||||
handleOpenModal,
|
||||
handleDeselect,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,16 +2,19 @@ import { useMemo } from 'react'
|
||||
import type {
|
||||
ModelProvider,
|
||||
} from '../../declarations'
|
||||
import { CustomConfigurationStatusEnum } from '../../declarations'
|
||||
|
||||
export const useCredentialStatus = (provider: ModelProvider) => {
|
||||
const {
|
||||
current_credential_id,
|
||||
current_credential_name,
|
||||
available_credentials,
|
||||
status: customConfigurationStatus,
|
||||
} = provider.custom_configuration
|
||||
const hasCredential = !!available_credentials?.length
|
||||
const authorized = current_credential_id && current_credential_name
|
||||
const authRemoved = hasCredential && !current_credential_id && !current_credential_name
|
||||
const authorized = customConfigurationStatus === CustomConfigurationStatusEnum.active
|
||||
const authRemoved = customConfigurationStatus === CustomConfigurationStatusEnum.removed
|
||||
const unAuthorized = customConfigurationStatus === CustomConfigurationStatusEnum.noConfigure || customConfigurationStatus === CustomConfigurationStatusEnum.canceled
|
||||
const currentCredential = available_credentials?.find(credential => credential.credential_id === current_credential_id)
|
||||
|
||||
return useMemo(() => ({
|
||||
@@ -21,6 +24,8 @@ export const useCredentialStatus = (provider: ModelProvider) => {
|
||||
current_credential_id,
|
||||
current_credential_name,
|
||||
available_credentials,
|
||||
customConfigurationStatus,
|
||||
notAllowedToUse: currentCredential?.not_allowed_to_use,
|
||||
unAuthorized,
|
||||
}), [hasCredential, authorized, authRemoved, current_credential_id, current_credential_name, available_credentials])
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ const CredentialPanel = ({
|
||||
authRemoved,
|
||||
current_credential_name,
|
||||
notAllowedToUse,
|
||||
unAuthorized,
|
||||
} = useCredentialStatus(provider)
|
||||
|
||||
const handleChangePriority = async (key: PreferredProviderTypeEnum) => {
|
||||
@@ -70,23 +71,23 @@ const CredentialPanel = ({
|
||||
}
|
||||
}
|
||||
const credentialLabel = useMemo(() => {
|
||||
if (!hasCredential)
|
||||
if (unAuthorized)
|
||||
return t('common.modelProvider.auth.unAuthorized')
|
||||
if (authorized)
|
||||
return current_credential_name
|
||||
if (authRemoved)
|
||||
return t('common.modelProvider.auth.authRemoved')
|
||||
if (authorized)
|
||||
return current_credential_name
|
||||
|
||||
return ''
|
||||
}, [authorized, authRemoved, current_credential_name, hasCredential])
|
||||
}, [authorized, authRemoved, current_credential_name, unAuthorized])
|
||||
|
||||
const color = useMemo(() => {
|
||||
if (authRemoved || !hasCredential)
|
||||
if (authRemoved || !hasCredential || unAuthorized)
|
||||
return 'red'
|
||||
if (notAllowedToUse)
|
||||
return 'gray'
|
||||
return 'green'
|
||||
}, [authRemoved, notAllowedToUse, hasCredential])
|
||||
}, [authRemoved, notAllowedToUse, hasCredential, unAuthorized])
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -523,6 +523,7 @@ const translation = {
|
||||
removeModel: 'Remove Model',
|
||||
selectModelCredential: 'Select a model credential',
|
||||
customModelCredentialsDeleteTip: 'Credential is in use and cannot be deleted',
|
||||
deselect: 'Deselect',
|
||||
},
|
||||
},
|
||||
dataSource: {
|
||||
|
||||
@@ -517,6 +517,7 @@ const translation = {
|
||||
removeModel: '移除模型',
|
||||
selectModelCredential: '选择模型凭据',
|
||||
customModelCredentialsDeleteTip: '模型凭据正在使用中,无法删除',
|
||||
deselect: '取消选择',
|
||||
},
|
||||
},
|
||||
dataSource: {
|
||||
|
||||
@@ -153,3 +153,9 @@ export const useUpdateModelLoadBalancingConfig = (provider: string) => {
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
export const useDeselectModelCredential = (provider: string) => {
|
||||
return useMutation({
|
||||
mutationFn: () => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials/cancel`),
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user