From 47da565e30d7e8a2c12a221c22cf3c772c15080b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 25 Feb 2026 16:17:56 +0800 Subject: [PATCH] fix(rag): support legacy model attr in mock model instances --- api/core/rag/embedding/cached_embedding.py | 38 +++++++++++++++------- api/core/rag/rerank/rerank_model.py | 15 ++++++++- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index de1122ef80..8b43eefe85 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -25,9 +25,22 @@ class CacheEmbedding(Embeddings): self._model_instance = model_instance self._user = user + @staticmethod + def _resolve_model_name(model_instance: ModelInstance) -> str: + model_name = getattr(model_instance, "model_name", None) + if isinstance(model_name, str): + return model_name + + legacy_model_name = getattr(model_instance, "model", None) + if isinstance(legacy_model_name, str): + return legacy_model_name + + raise ValueError("Model instance does not include a valid model name.") + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists + model_name = self._resolve_model_name(self._model_instance) text_embeddings: list[Any] = [None for _ in range(len(texts))] embedding_queue_indices = [] for i, text in enumerate(texts): @@ -35,7 +48,9 @@ class CacheEmbedding(Embeddings): embedding = ( db.session.query(Embedding) .filter_by( - model_name=self._model_instance.model_name, hash=hash, provider_name=self._model_instance.provider + model_name=model_name, + hash=hash, + provider_name=self._model_instance.provider, ) .first() ) @@ -51,9 +66,7 @@ class CacheEmbedding(Embeddings): embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema( - self._model_instance.model_name, self._model_instance.credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, self._model_instance.credentials) max_chunks = ( model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties @@ -87,7 +100,7 @@ class CacheEmbedding(Embeddings): hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: embedding_cache = Embedding( - model_name=self._model_instance.model_name, + model_name=model_name, hash=hash, provider_name=self._model_instance.provider, embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), @@ -107,6 +120,7 @@ class CacheEmbedding(Embeddings): def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: """Embed file documents.""" # use doc embedding cache or store if not exists + model_name = self._resolve_model_name(self._model_instance) multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] embedding_queue_indices = [] for i, multimodel_document in enumerate(multimodel_documents): @@ -114,7 +128,7 @@ class CacheEmbedding(Embeddings): embedding = ( db.session.query(Embedding) .filter_by( - model_name=self._model_instance.model_name, + model_name=model_name, hash=file_id, provider_name=self._model_instance.provider, ) @@ -132,9 +146,7 @@ class CacheEmbedding(Embeddings): embedding_queue_embeddings = [] try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema( - self._model_instance.model_name, self._model_instance.credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, self._model_instance.credentials) max_chunks = ( model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties @@ -170,7 +182,7 @@ class CacheEmbedding(Embeddings): file_id = multimodel_documents[i]["file_id"] if file_id not in cache_embeddings: embedding_cache = Embedding( - model_name=self._model_instance.model_name, + model_name=model_name, hash=file_id, provider_name=self._model_instance.provider, embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), @@ -191,8 +203,9 @@ class CacheEmbedding(Embeddings): def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists + model_name = self._resolve_model_name(self._model_instance) hash = helper.generate_text_hash(text) - embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}" + embedding_cache_key = f"{self._model_instance.provider}_{model_name}_{hash}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) @@ -234,8 +247,9 @@ class CacheEmbedding(Embeddings): def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: """Embed multimodal documents.""" # use doc embedding cache or store if not exists + model_name = self._resolve_model_name(self._model_instance) file_id = multimodel_document["file_id"] - embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}" + embedding_cache_key = f"{self._model_instance.provider}_{model_name}_{file_id}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 690e780921..ceea83bb89 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance): self.rerank_model_instance = rerank_model_instance + @staticmethod + def _resolve_model_name(model_instance: ModelInstance) -> str: + model_name = getattr(model_instance, "model_name", None) + if isinstance(model_name, str): + return model_name + + legacy_model_name = getattr(model_instance, "model", None) + if isinstance(legacy_model_name, str): + return legacy_model_name + + raise ValueError("Model instance does not include a valid model name.") + def run( self, query: str, @@ -34,11 +46,12 @@ class RerankModelRunner(BaseRerankRunner): :param user: unique user id if needed :return: """ + model_name = self._resolve_model_name(self.rerank_model_instance) model_manager = ModelManager() is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, - model=self.rerank_model_instance.model_name, + model=model_name, model_type=ModelType.RERANK, ) if not is_support_vision: