Fixed the issue of asynchronous call failure for MosecEmbeddings (#871)

* Fixed the issue of asynchronous call failure for MosecEmbeddings

Signed-off-by: Yao, Qing <qing.yao@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add import asyncio

Signed-off-by: Yao, Qing <qing.yao@intel.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Yao, Qing <qing.yao@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yao Qing
2024-11-08 15:54:16 +08:00
committed by GitHub
parent ef507ce6fa
commit 46ff36c008

View File

@@ -1,6 +1,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
import time
from typing import List, Optional
@@ -23,7 +24,7 @@ logflag = os.getenv("LOGFLAG", False)
class MosecEmbeddings(OpenAIEmbeddings):
def _get_len_safe_embeddings(
async def _aget_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
_chunk_size = chunk_size or self.chunk_size
@@ -35,7 +36,7 @@ class MosecEmbeddings(OpenAIEmbeddings):
_cached_empty_embedding: Optional[List[float]] = None
def empty_embedding() -> List[float]:
async def empty_embedding() -> List[float]:
nonlocal _cached_empty_embedding
if _cached_empty_embedding is None:
average_embedded = self.client.create(input="", **self._invocation_params)
@@ -44,7 +45,11 @@ class MosecEmbeddings(OpenAIEmbeddings):
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
return _cached_empty_embedding
return [e if e is not None else empty_embedding() for e in batched_embeddings]
async def get_embedding(e: Optional[List[float]]) -> List[float]:
return e if e is not None else await empty_embedding()
embeddings = await asyncio.gather(*[get_embedding(e) for e in batched_embeddings])
return embeddings
@register_microservice(