mirror of
https://github.com/langgenius/dify.git
synced 2026-02-28 04:15:10 +00:00
fix(rag): support legacy model attr in mock model instances
This commit is contained in:
@@ -25,9 +25,22 @@ class CacheEmbedding(Embeddings):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_name(model_instance: ModelInstance) -> str:
|
||||
model_name = getattr(model_instance, "model_name", None)
|
||||
if isinstance(model_name, str):
|
||||
return model_name
|
||||
|
||||
legacy_model_name = getattr(model_instance, "model", None)
|
||||
if isinstance(legacy_model_name, str):
|
||||
return legacy_model_name
|
||||
|
||||
raise ValueError("Model instance does not include a valid model name.")
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
# use doc embedding cache or store if not exists
|
||||
model_name = self._resolve_model_name(self._model_instance)
|
||||
text_embeddings: list[Any] = [None for _ in range(len(texts))]
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
@@ -35,7 +48,9 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name, hash=hash, provider_name=self._model_instance.provider
|
||||
model_name=model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -51,9 +66,7 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_embeddings = []
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, self._model_instance.credentials)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||
@@ -87,7 +100,7 @@ class CacheEmbedding(Embeddings):
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model_name,
|
||||
model_name=model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@@ -107,6 +120,7 @@ class CacheEmbedding(Embeddings):
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
|
||||
"""Embed file documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
model_name = self._resolve_model_name(self._model_instance)
|
||||
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]
|
||||
embedding_queue_indices = []
|
||||
for i, multimodel_document in enumerate(multimodel_documents):
|
||||
@@ -114,7 +128,7 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name,
|
||||
model_name=model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
@@ -132,9 +146,7 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_embeddings = []
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, self._model_instance.credentials)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||
@@ -170,7 +182,7 @@ class CacheEmbedding(Embeddings):
|
||||
file_id = multimodel_documents[i]["file_id"]
|
||||
if file_id not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model_name,
|
||||
model_name=model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@@ -191,8 +203,9 @@ class CacheEmbedding(Embeddings):
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
model_name = self._resolve_model_name(self._model_instance)
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{model_name}_{hash}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
@@ -234,8 +247,9 @@ class CacheEmbedding(Embeddings):
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
model_name = self._resolve_model_name(self._model_instance)
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{model_name}_{file_id}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
|
||||
@@ -16,6 +16,18 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
def __init__(self, rerank_model_instance: ModelInstance):
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_name(model_instance: ModelInstance) -> str:
|
||||
model_name = getattr(model_instance, "model_name", None)
|
||||
if isinstance(model_name, str):
|
||||
return model_name
|
||||
|
||||
legacy_model_name = getattr(model_instance, "model", None)
|
||||
if isinstance(legacy_model_name, str):
|
||||
return legacy_model_name
|
||||
|
||||
raise ValueError("Model instance does not include a valid model name.")
|
||||
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
@@ -34,11 +46,12 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
model_name = self._resolve_model_name(self.rerank_model_instance)
|
||||
model_manager = ModelManager()
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
model=self.rerank_model_instance.model_name,
|
||||
model=model_name,
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if not is_support_vision:
|
||||
|
||||
Reference in New Issue
Block a user