From d861cc761f64b2a6c677716b703eb15d07bff3be Mon Sep 17 00:00:00 2001 From: hjlarry Date: Wed, 20 Aug 2025 12:46:10 +0800 Subject: [PATCH] improve load balance logic --- api/core/entities/provider_configuration.py | 20 +++++++++++++++++--- api/core/entities/provider_entities.py | 2 +- api/core/provider_manager.py | 13 ++++++++++--- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 628b282d67..682d401df0 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1591,10 +1591,15 @@ class ProviderConfiguration(BaseModel): if model_setting.enabled is False: status = ModelStatus.DISABLED - if len(model_setting.load_balancing_configs) > 1: + provider_model_lb_configs = [ + config for config in model_setting.load_balancing_configs + if config.credential_source_type != "custom_model" + ] + + if len(provider_model_lb_configs) > 1: load_balancing_enabled = True - if model_setting.has_invalid_load_balancing_configs: + if any(config.name == "__delete__" for config in provider_model_lb_configs): has_invalid_load_balancing_configs = True provider_models.append( @@ -1642,9 +1647,17 @@ class ProviderConfiguration(BaseModel): if model_setting.enabled is False: status = ModelStatus.DISABLED - if len(model_setting.load_balancing_configs) > 1: + custom_model_lb_configs = [ + config for config in model_setting.load_balancing_configs + if config.credential_source_type != "provider" + ] + + if len(custom_model_lb_configs) > 1: load_balancing_enabled = True + if any(config.name == "__delete__" for config in custom_model_lb_configs): + has_invalid_load_balancing_configs = True + if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: status = ModelStatus.CREDENTIAL_REMOVED @@ -1660,6 +1673,7 @@ class ProviderConfiguration(BaseModel): provider=SimpleModelProviderEntity(self.provider), status=status, load_balancing_enabled=load_balancing_enabled, + has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) ) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 98ba625c94..1b87bffe57 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -133,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel): id: str name: str credentials: dict + credential_source_type: str | None = None class ModelSettings(BaseModel): @@ -144,7 +145,6 @@ class ModelSettings(BaseModel): model_type: ModelType enabled: bool = True load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] - has_invalid_load_balancing_configs: bool = False # pydantic configs model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index a99b4777f0..3e5aa63a8a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -963,7 +963,6 @@ class ProviderManager: if not provider_model_settings: return model_settings - has_invalid_load_balancing_configs = False for provider_model_setting in provider_model_settings: load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: @@ -973,7 +972,15 @@ class ProviderManager: and load_balancing_model_config.model_type == provider_model_setting.model_type ): if load_balancing_model_config.name == "__delete__": - has_invalid_load_balancing_configs = True + # to calculate current model whether has invalidate lb configs + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + credential_source_type=load_balancing_model_config.credential_source_type, + ) + ) continue if not load_balancing_model_config.enabled: @@ -1032,6 +1039,7 @@ class ProviderManager: id=load_balancing_model_config.id, name=load_balancing_model_config.name, credentials=provider_model_credentials, + credential_source_type=load_balancing_model_config.credential_source_type, ) ) @@ -1041,7 +1049,6 @@ class ProviderManager: model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], - has_invalid_load_balancing_configs=has_invalid_load_balancing_configs, ) )