diff --git a/comps/cores/mega/orchestrator.py b/comps/cores/mega/orchestrator.py index 851fb67d0..405398856 100644 --- a/comps/cores/mega/orchestrator.py +++ b/comps/cores/mega/orchestrator.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import contextlib import copy import json import os @@ -16,6 +17,7 @@ from prometheus_client import Gauge, Histogram from pydantic import BaseModel from ..proto.docarray import LLMParams +from ..telemetry.opea_telemetry import opea_telemetry, tracer from .constants import ServiceType from .dag import DAG from .logger import CustomLogger @@ -80,6 +82,7 @@ class ServiceOrchestrator(DAG): logger.error(e) return False + @opea_telemetry async def schedule(self, initial_inputs: Dict | BaseModel, llm_parameters: LLMParams = LLMParams(), **kwargs): req_start = time.time() self.metrics.pending_update(True) @@ -166,6 +169,26 @@ class ServiceOrchestrator(DAG): all_outputs.update(result_dict[prev_node]) return all_outputs + def wrap_iterable(self, iterable, is_first=True): + + with tracer.start_as_current_span("llm_generate_stream"): + while True: + with ( + tracer.start_as_current_span("llm_generate_stream_first_token") + if is_first + else contextlib.nullcontext() + ): # else tracer.start_as_current_span(f"llm_generate_stream_next_token") + try: + token = next(iterable) + yield token + is_first = False + except StopIteration: + # Exiting the iterable loop cleanly + break + except Exception as e: + raise e + + @opea_telemetry async def execute( self, session: aiohttp.client.ClientSession, @@ -193,14 +216,15 @@ class ServiceOrchestrator(DAG): # Still leave to sync requests.post for StreamingResponse if LOGFLAG: logger.info(inputs) - response = requests.post( - url=endpoint, - data=json.dumps(inputs), - headers={"Content-type": "application/json"}, - proxies={"http": None}, - stream=True, - timeout=1000, - ) + with tracer.start_as_current_span(f"{cur_node}_asyn_generate"): + response = requests.post( + url=endpoint, + data=json.dumps(inputs), + headers={"Content-type": "application/json"}, + proxies={"http": None}, + stream=True, + timeout=1000, + ) downstream = runtime_graph.downstream(cur_node) if downstream: assert len(downstream) == 1, "Not supported multiple stream downstreams yet!" @@ -214,7 +238,9 @@ class ServiceOrchestrator(DAG): # response.elapsed = time until first headers received buffered_chunk_str = "" is_first = True - for chunk in response.iter_content(chunk_size=None): + + for chunk in self.wrap_iterable(response.iter_content(chunk_size=None)): + if chunk: if downstream: chunk = chunk.decode("utf-8") @@ -240,6 +266,7 @@ class ServiceOrchestrator(DAG): token_start = self.metrics.token_update(token_start, is_first) yield chunk is_first = False + self.metrics.request_update(req_start) self.metrics.pending_update(False) @@ -256,19 +283,18 @@ class ServiceOrchestrator(DAG): input_data = {k: v for k, v in input_data.items() if v is not None} else: input_data = inputs - async with session.post(endpoint, json=input_data) as response: - if response.content_type == "audio/wav": - audio_data = await response.read() - data = self.align_outputs( - audio_data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs - ) - else: - # Parse as JSON - data = await response.json() - # post process - data = self.align_outputs(data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs) + with tracer.start_as_current_span(f"{cur_node}_generate"): + response = await session.post(endpoint, json=input_data) + if response.content_type == "audio/wav": + audio_data = await response.read() + data = self.align_outputs(audio_data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs) + else: + # Parse as JSON + data = await response.json() + # post process + data = self.align_outputs(data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs) - return data, cur_node + return data, cur_node def align_inputs(self, inputs, *args, **kwargs): """Override this method in megaservice definition.""" diff --git a/comps/cores/telemetry/opea_telemetry.py b/comps/cores/telemetry/opea_telemetry.py index 4d66b9c16..5f08d4bd4 100644 --- a/comps/cores/telemetry/opea_telemetry.py +++ b/comps/cores/telemetry/opea_telemetry.py @@ -6,12 +6,29 @@ import os from functools import wraps from opentelemetry import trace +from opentelemetry.context.contextvars_context import ContextVarsRuntimeContext from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter from opentelemetry.sdk.resources import SERVICE_NAME, Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +def detach_ignore_err(self, token: object) -> None: + """Resets Context to a previous value. + + Args: + token: A reference to a previous Context. + """ + try: + self._current_context.reset(token) # type: ignore + except Exception as e: + pass + + +# bypass the ValueError that ContextVar context was created in a different Context from StreamingResponse +ContextVarsRuntimeContext.detach = detach_ignore_err + telemetry_endpoint = os.environ.get("TELEMETRY_ENDPOINT", "http://localhost:4318/v1/traces") resource = Resource.create({SERVICE_NAME: "opea"}) @@ -26,7 +43,6 @@ tracer = trace.get_tracer(__name__) def opea_telemetry(func): - print(f"[*** telemetry ***] {func.__name__} under telemetry.") if inspect.iscoroutinefunction(func): @wraps(func) diff --git a/comps/embeddings/src/opea_embedding_microservice.py b/comps/embeddings/src/opea_embedding_microservice.py index 335173127..ac6788510 100644 --- a/comps/embeddings/src/opea_embedding_microservice.py +++ b/comps/embeddings/src/opea_embedding_microservice.py @@ -17,6 +17,7 @@ from comps import ( statistics_dict, ) from comps.cores.proto.api_protocol import EmbeddingRequest, EmbeddingResponse +from comps.cores.telemetry.opea_telemetry import opea_telemetry logger = CustomLogger("opea_embedding_microservice") logflag = os.getenv("LOGFLAG", False) @@ -36,6 +37,7 @@ loader = OpeaComponentLoader( host="0.0.0.0", port=6000, ) +@opea_telemetry @register_statistics(names=["opea_service@embedding"]) async def embedding(input: EmbeddingRequest) -> EmbeddingResponse: start = time.time() diff --git a/comps/llms/src/text-generation/opea_llm_microservice.py b/comps/llms/src/text-generation/opea_llm_microservice.py index 013a46238..d430d2acb 100644 --- a/comps/llms/src/text-generation/opea_llm_microservice.py +++ b/comps/llms/src/text-generation/opea_llm_microservice.py @@ -17,6 +17,7 @@ from comps import ( statistics_dict, ) from comps.cores.proto.api_protocol import ChatCompletionRequest +from comps.cores.telemetry.opea_telemetry import opea_telemetry logger = CustomLogger("llm") logflag = os.getenv("LOGFLAG", False) @@ -42,6 +43,7 @@ loader = OpeaComponentLoader(llm_component_name, description=f"OPEA LLM Componen host="0.0.0.0", port=9000, ) +@opea_telemetry @register_statistics(names=["opea_service@llm"]) async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]): start = time.time() diff --git a/comps/rerankings/src/opea_reranking_microservice.py b/comps/rerankings/src/opea_reranking_microservice.py index 4cf94c2cd..7cf073ef7 100644 --- a/comps/rerankings/src/opea_reranking_microservice.py +++ b/comps/rerankings/src/opea_reranking_microservice.py @@ -19,6 +19,7 @@ from comps import ( ) from comps.cores.proto.api_protocol import ChatCompletionRequest, RerankingRequest, RerankingResponse from comps.cores.proto.docarray import LLMParamsDoc, LVMVideoDoc, RerankedDoc, SearchedDoc, SearchedMultimodalDoc +from comps.cores.telemetry.opea_telemetry import opea_telemetry logger = CustomLogger("opea_reranking_microservice") logflag = os.getenv("LOGFLAG", False) @@ -35,6 +36,7 @@ loader = OpeaComponentLoader(rerank_component_name, description=f"OPEA RERANK Co host="0.0.0.0", port=8000, ) +@opea_telemetry @register_statistics(names=["opea_service@reranking"]) async def reranking( input: Union[SearchedMultimodalDoc, SearchedDoc, RerankingRequest, ChatCompletionRequest] diff --git a/comps/retrievers/src/opea_retrievers_microservice.py b/comps/retrievers/src/opea_retrievers_microservice.py index 1592440c8..3bf0d91e2 100644 --- a/comps/retrievers/src/opea_retrievers_microservice.py +++ b/comps/retrievers/src/opea_retrievers_microservice.py @@ -37,6 +37,7 @@ from comps.cores.proto.api_protocol import ( RetrievalResponse, RetrievalResponseData, ) +from comps.cores.telemetry.opea_telemetry import opea_telemetry logger = CustomLogger("opea_retrievers_microservice") logflag = os.getenv("LOGFLAG", False) @@ -56,6 +57,7 @@ loader = OpeaComponentLoader( host="0.0.0.0", port=7000, ) +@opea_telemetry @register_statistics(names=["opea_service@retrievers"]) async def ingest_files( input: Union[EmbedDoc, EmbedMultimodalDoc, RetrievalRequest, ChatCompletionRequest]