diff --git a/comps/embeddings/mosec/langchain/embedding_mosec.py b/comps/embeddings/mosec/langchain/embedding_mosec.py index bab254ae4..fde9e17af 100644 --- a/comps/embeddings/mosec/langchain/embedding_mosec.py +++ b/comps/embeddings/mosec/langchain/embedding_mosec.py @@ -1,6 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import asyncio import os import time from typing import List, Optional @@ -23,7 +24,7 @@ logflag = os.getenv("LOGFLAG", False) class MosecEmbeddings(OpenAIEmbeddings): - def _get_len_safe_embeddings( + async def _aget_len_safe_embeddings( self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None ) -> List[List[float]]: _chunk_size = chunk_size or self.chunk_size @@ -35,7 +36,7 @@ class MosecEmbeddings(OpenAIEmbeddings): _cached_empty_embedding: Optional[List[float]] = None - def empty_embedding() -> List[float]: + async def empty_embedding() -> List[float]: nonlocal _cached_empty_embedding if _cached_empty_embedding is None: average_embedded = self.client.create(input="", **self._invocation_params) @@ -44,7 +45,11 @@ class MosecEmbeddings(OpenAIEmbeddings): _cached_empty_embedding = average_embedded["data"][0]["embedding"] return _cached_empty_embedding - return [e if e is not None else empty_embedding() for e in batched_embeddings] + async def get_embedding(e: Optional[List[float]]) -> List[float]: + return e if e is not None else await empty_embedding() + + embeddings = await asyncio.gather(*[get_embedding(e) for e in batched_embeddings]) + return embeddings @register_microservice(