Fixed the issue of asynchronous call failure for MosecEmbeddings (#871)
* Fixed the issue of asynchronous call failure for MosecEmbeddings Signed-off-by: Yao, Qing <qing.yao@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add import asyncio Signed-off-by: Yao, Qing <qing.yao@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Yao, Qing <qing.yao@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# Copyright (C) 2024 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional
|
||||
@@ -23,7 +24,7 @@ logflag = os.getenv("LOGFLAG", False)
|
||||
|
||||
|
||||
class MosecEmbeddings(OpenAIEmbeddings):
|
||||
def _get_len_safe_embeddings(
|
||||
async def _aget_len_safe_embeddings(
|
||||
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
|
||||
) -> List[List[float]]:
|
||||
_chunk_size = chunk_size or self.chunk_size
|
||||
@@ -35,7 +36,7 @@ class MosecEmbeddings(OpenAIEmbeddings):
|
||||
|
||||
_cached_empty_embedding: Optional[List[float]] = None
|
||||
|
||||
def empty_embedding() -> List[float]:
|
||||
async def empty_embedding() -> List[float]:
|
||||
nonlocal _cached_empty_embedding
|
||||
if _cached_empty_embedding is None:
|
||||
average_embedded = self.client.create(input="", **self._invocation_params)
|
||||
@@ -44,7 +45,11 @@ class MosecEmbeddings(OpenAIEmbeddings):
|
||||
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
|
||||
return _cached_empty_embedding
|
||||
|
||||
return [e if e is not None else empty_embedding() for e in batched_embeddings]
|
||||
async def get_embedding(e: Optional[List[float]]) -> List[float]:
|
||||
return e if e is not None else await empty_embedding()
|
||||
|
||||
embeddings = await asyncio.gather(*[get_embedding(e) for e in batched_embeddings])
|
||||
return embeddings
|
||||
|
||||
|
||||
@register_microservice(
|
||||
|
||||
Reference in New Issue
Block a user