fix(rag): support legacy model attr in mock model instances

This commit is contained in:
-LAN-
2026-02-25 16:17:56 +08:00
parent ad5c0d7b6b
commit 47da565e30
2 changed files with 40 additions and 13 deletions

View File

@@ -25,9 +25,22 @@ class CacheEmbedding(Embeddings):
self._model_instance = model_instance
self._user = user
@staticmethod
def _resolve_model_name(model_instance: ModelInstance) -> str:
model_name = getattr(model_instance, "model_name", None)
if isinstance(model_name, str):
return model_name
legacy_model_name = getattr(model_instance, "model", None)
if isinstance(legacy_model_name, str):
return legacy_model_name
raise ValueError("Model instance does not include a valid model name.")
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
# use doc embedding cache or store if not exists
model_name = self._resolve_model_name(self._model_instance)
text_embeddings: list[Any] = [None for _ in range(len(texts))]
embedding_queue_indices = []
for i, text in enumerate(texts):
@@ -35,7 +48,9 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name, hash=hash, provider_name=self._model_instance.provider
model_name=model_name,
hash=hash,
provider_name=self._model_instance.provider,
)
.first()
)
@@ -51,9 +66,7 @@ class CacheEmbedding(Embeddings):
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model_name, self._model_instance.credentials
)
model_schema = model_type_instance.get_model_schema(model_name, self._model_instance.credentials)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
@@ -87,7 +100,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model_name,
model_name=model_name,
hash=hash,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@@ -107,6 +120,7 @@ class CacheEmbedding(Embeddings):
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
"""Embed file documents."""
# use doc embedding cache or store if not exists
model_name = self._resolve_model_name(self._model_instance)
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents):
@@ -114,7 +128,7 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name,
model_name=model_name,
hash=file_id,
provider_name=self._model_instance.provider,
)
@@ -132,9 +146,7 @@ class CacheEmbedding(Embeddings):
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model_name, self._model_instance.credentials
)
model_schema = model_type_instance.get_model_schema(model_name, self._model_instance.credentials)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
@@ -170,7 +182,7 @@ class CacheEmbedding(Embeddings):
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model_name,
model_name=model_name,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@@ -191,8 +203,9 @@ class CacheEmbedding(Embeddings):
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
model_name = self._resolve_model_name(self._model_instance)
hash = helper.generate_text_hash(text)
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
embedding_cache_key = f"{self._model_instance.provider}_{model_name}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
@@ -234,8 +247,9 @@ class CacheEmbedding(Embeddings):
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
model_name = self._resolve_model_name(self._model_instance)
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
embedding_cache_key = f"{self._model_instance.provider}_{model_name}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)

View File

@@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner):
def __init__(self, rerank_model_instance: ModelInstance):
self.rerank_model_instance = rerank_model_instance
@staticmethod
def _resolve_model_name(model_instance: ModelInstance) -> str:
model_name = getattr(model_instance, "model_name", None)
if isinstance(model_name, str):
return model_name
legacy_model_name = getattr(model_instance, "model", None)
if isinstance(legacy_model_name, str):
return legacy_model_name
raise ValueError("Model instance does not include a valid model name.")
def run(
self,
query: str,
@@ -34,11 +46,12 @@ class RerankModelRunner(BaseRerankRunner):
:param user: unique user id if needed
:return:
"""
model_name = self._resolve_model_name(self.rerank_model_instance)
model_manager = ModelManager()
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model_name,
model=model_name,
model_type=ModelType.RERANK,
)
if not is_support_vision: