fix RetrievalMethod StrEnum (#26768)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-10-13 11:29:37 +09:00
committed by GitHub
parent d299e75e1b
commit 24cd7bbc62
25 changed files with 65 additions and 43 deletions

View File

@@ -1,10 +1,12 @@
import os
from pytest_mock import MockerFixture
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
def test_firecrawl_web_extractor_crawl_mode(mocker):
def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture):
url = "https://firecrawl.dev"
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"

View File

@@ -1,5 +1,7 @@
from unittest import mock
from pytest_mock import MockerFixture
from core.rag.extractor import notion_extractor
user_id = "user1"
@@ -57,7 +59,7 @@ def _remove_multiple_new_lines(text):
return text.strip()
def test_notion_page(mocker):
def test_notion_page(mocker: MockerFixture):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
@@ -77,7 +79,7 @@ def test_notion_page(mocker):
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
def test_notion_database(mocker: MockerFixture):
page_title_list = ["page1", "page2", "page3"]
mocked_notion_database = {
"object": "list",

View File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
import redis
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
@@ -39,7 +40,7 @@ def lb_model_manager():
return lb_model_manager
def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())

View File

@@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.model_runtime.entities.model_entities import ModelType
@@ -7,19 +8,25 @@ from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker):
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
mock_entity.model_credential_schema = mocker.Mock()
mock_entity.model_credential_schema.credential_form_schemas = []
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker, mock_provider_entity):
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@@ -79,7 +86,7 @@ def test__to_model_settings(mocker, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(
@@ -127,7 +134,7 @@ def test__to_model_settings_only_one_lb(mocker, mock_provider_entity):
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker, mock_provider_entity):
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
# Mocking the inputs
provider_model_settings = [
ProviderModelSetting(