mirror of
https://github.com/langgenius/dify.git
synced 2026-03-13 19:32:04 +00:00
test: added tests for backend core.ops module (#32639)
Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
@@ -0,0 +1,326 @@
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
INVALID_SPAN_ID,
|
||||
SpanBuilder,
|
||||
TraceClient,
|
||||
build_endpoint,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_string_to_id,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
create_link,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_client_factory():
|
||||
"""Factory fixture for creating TraceClient instances with automatic cleanup."""
|
||||
clients_to_shutdown = []
|
||||
|
||||
def _factory(**kwargs):
|
||||
client = TraceClient(**kwargs)
|
||||
clients_to_shutdown.append(client)
|
||||
return client
|
||||
|
||||
yield _factory
|
||||
|
||||
# Cleanup: shutdown all created clients
|
||||
for client in clients_to_shutdown:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
class TestTraceClient:
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
|
||||
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
|
||||
mock_gethostname.return_value = "test-host"
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
|
||||
assert client.endpoint == "http://test-endpoint"
|
||||
assert client.max_queue_size == 1000
|
||||
assert client.schedule_delay_sec == 5
|
||||
assert client.done is False
|
||||
assert client.worker_thread.is_alive()
|
||||
|
||||
client.shutdown()
|
||||
assert client.done is True
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
spans = [MagicMock(spec=ReadableSpan)]
|
||||
client.export(spans)
|
||||
mock_exporter.export.assert_called_once_with(spans)
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 405
|
||||
mock_head.return_value = mock_response
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is True
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_head.return_value = mock_response
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.api_check() is False
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
|
||||
mock_head.side_effect = httpx.RequestError("Connection error")
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
|
||||
client.api_check()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_add_span(self, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
endpoint="http://test-endpoint",
|
||||
max_export_batch_size=2,
|
||||
)
|
||||
|
||||
# Test add None
|
||||
client.add_span(None)
|
||||
assert len(client.queue) == 0
|
||||
|
||||
# Test add valid SpanData
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=None,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
)
|
||||
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.span_builder.build_span = MagicMock(return_value=mock_span)
|
||||
|
||||
with patch.object(client.condition, "notify") as mock_notify:
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
mock_notify.assert_not_called()
|
||||
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 2
|
||||
mock_notify.assert_called_once()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
|
||||
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
|
||||
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=None,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
)
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.span_builder.build_span = MagicMock(return_value=mock_span)
|
||||
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
|
||||
client.add_span(span_data)
|
||||
assert len(client.queue) == 1
|
||||
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
mock_exporter.export.side_effect = Exception("Export failed")
|
||||
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.queue.append(mock_span)
|
||||
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
|
||||
client._export_batch()
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
|
||||
# We need to test the wait timeout in _worker
|
||||
# But _worker runs in a thread. Let's mock condition.wait.
|
||||
client = trace_client_factory(
|
||||
service_name="test-service",
|
||||
endpoint="http://test-endpoint",
|
||||
schedule_delay_sec=0.1,
|
||||
)
|
||||
|
||||
with patch.object(client.condition, "wait") as mock_wait:
|
||||
# Let it run for a bit then shut down
|
||||
time.sleep(0.2)
|
||||
client.shutdown()
|
||||
# mock_wait might have been called
|
||||
assert mock_wait.called or client.done
|
||||
|
||||
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
|
||||
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
|
||||
mock_exporter = mock_exporter_class.return_value
|
||||
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
|
||||
|
||||
mock_span = MagicMock(spec=ReadableSpan)
|
||||
client.queue.append(mock_span)
|
||||
|
||||
client.shutdown()
|
||||
# Should have called export twice (once in worker/export_batch, once in shutdown)
|
||||
# or at least once if worker was waiting
|
||||
assert mock_exporter.export.called
|
||||
assert mock_exporter.shutdown.called
|
||||
|
||||
|
||||
class TestSpanBuilder:
|
||||
def test_build_span(self):
|
||||
resource = MagicMock()
|
||||
builder = SpanBuilder(resource)
|
||||
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=789,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
attributes={"attr1": "val1"},
|
||||
events=[],
|
||||
links=[],
|
||||
)
|
||||
|
||||
span = builder.build_span(span_data)
|
||||
assert isinstance(span, ReadableSpan)
|
||||
assert span.name == "test-span"
|
||||
assert span.context.trace_id == 123
|
||||
assert span.context.span_id == 456
|
||||
assert span.parent.span_id == 789
|
||||
assert span.resource == resource
|
||||
assert span.attributes == {"attr1": "val1"}
|
||||
|
||||
def test_build_span_no_parent(self):
|
||||
resource = MagicMock()
|
||||
builder = SpanBuilder(resource)
|
||||
|
||||
span_data = SpanData(
|
||||
name="test-span",
|
||||
trace_id=123,
|
||||
span_id=456,
|
||||
parent_span_id=None,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
status=Status(StatusCode.OK),
|
||||
span_kind=SpanKind.INTERNAL,
|
||||
)
|
||||
|
||||
span = builder.build_span(span_data)
|
||||
assert span.parent is None
|
||||
|
||||
|
||||
def test_create_link():
|
||||
trace_id_str = "0123456789abcdef0123456789abcdef"
|
||||
link = create_link(trace_id_str)
|
||||
assert link.context.trace_id == int(trace_id_str, 16)
|
||||
assert link.context.span_id == INVALID_SPAN_ID
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid trace ID format"):
|
||||
create_link("invalid-hex")
|
||||
|
||||
|
||||
def test_generate_span_id():
|
||||
# Test normal generation
|
||||
span_id = generate_span_id()
|
||||
assert isinstance(span_id, int)
|
||||
assert span_id != INVALID_SPAN_ID
|
||||
|
||||
# Test retry loop
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
|
||||
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
|
||||
span_id = generate_span_id()
|
||||
assert span_id == 999
|
||||
assert mock_rand.call_count == 2
|
||||
|
||||
|
||||
def test_convert_to_trace_id():
|
||||
uid = str(uuid.uuid4())
|
||||
trace_id = convert_to_trace_id(uid)
|
||||
assert trace_id == uuid.UUID(uid).int
|
||||
|
||||
with pytest.raises(ValueError, match="UUID cannot be None"):
|
||||
convert_to_trace_id(None)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid UUID input"):
|
||||
convert_to_trace_id("not-a-uuid")
|
||||
|
||||
|
||||
def test_convert_string_to_id():
|
||||
assert convert_string_to_id("test") > 0
|
||||
# Test with None string
|
||||
with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
|
||||
mock_gen.return_value = 12345
|
||||
assert convert_string_to_id(None) == 12345
|
||||
|
||||
|
||||
def test_convert_to_span_id():
|
||||
uid = str(uuid.uuid4())
|
||||
span_id = convert_to_span_id(uid, "test-type")
|
||||
assert isinstance(span_id, int)
|
||||
|
||||
with pytest.raises(ValueError, match="UUID cannot be None"):
|
||||
convert_to_span_id(None, "test")
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid UUID input"):
|
||||
convert_to_span_id("not-a-uuid", "test")
|
||||
|
||||
|
||||
def test_convert_datetime_to_nanoseconds():
|
||||
dt = datetime(2023, 1, 1, 12, 0, 0)
|
||||
ns = convert_datetime_to_nanoseconds(dt)
|
||||
assert ns == int(dt.timestamp() * 1e9)
|
||||
assert convert_datetime_to_nanoseconds(None) is None
|
||||
|
||||
|
||||
def test_build_endpoint():
|
||||
license_key = "abc"
|
||||
|
||||
# CMS 2.0 endpoint
|
||||
url1 = "https://log.aliyuncs.com"
|
||||
assert build_endpoint(url1, license_key) == "https://log.aliyuncs.com/adapt_abc/api/v1/traces"
|
||||
|
||||
# XTrace endpoint
|
||||
url2 = "https://example.com"
|
||||
assert build_endpoint(url2, license_key) == "https://example.com/adapt_abc/api/otlp/traces"
|
||||
@@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
|
||||
|
||||
class TestTraceMetadata:
|
||||
def test_trace_metadata_init(self):
|
||||
links = [trace_api.Link(context=trace_api.SpanContext(0, 0, False))]
|
||||
metadata = TraceMetadata(
|
||||
trace_id=123, workflow_span_id=456, session_id="session_1", user_id="user_1", links=links
|
||||
)
|
||||
assert metadata.trace_id == 123
|
||||
assert metadata.workflow_span_id == 456
|
||||
assert metadata.session_id == "session_1"
|
||||
assert metadata.user_id == "user_1"
|
||||
assert metadata.links == links
|
||||
|
||||
|
||||
class TestSpanData:
|
||||
def test_span_data_init_required_fields(self):
|
||||
span_data = SpanData(trace_id=123, span_id=456, name="test_span", start_time=1000, end_time=2000)
|
||||
assert span_data.trace_id == 123
|
||||
assert span_data.span_id == 456
|
||||
assert span_data.name == "test_span"
|
||||
assert span_data.start_time == 1000
|
||||
assert span_data.end_time == 2000
|
||||
|
||||
# Check defaults
|
||||
assert span_data.parent_span_id is None
|
||||
assert span_data.attributes == {}
|
||||
assert span_data.events == []
|
||||
assert span_data.links == []
|
||||
assert span_data.status.status_code == StatusCode.UNSET
|
||||
assert span_data.span_kind == SpanKind.INTERNAL
|
||||
|
||||
def test_span_data_with_optional_fields(self):
|
||||
event = Event(name="event_1", timestamp=1500)
|
||||
link = trace_api.Link(context=trace_api.SpanContext(0, 0, False))
|
||||
status = Status(StatusCode.OK)
|
||||
|
||||
span_data = SpanData(
|
||||
trace_id=123,
|
||||
parent_span_id=111,
|
||||
span_id=456,
|
||||
name="test_span",
|
||||
attributes={"key": "value"},
|
||||
events=[event],
|
||||
links=[link],
|
||||
status=status,
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
span_kind=SpanKind.SERVER,
|
||||
)
|
||||
|
||||
assert span_data.parent_span_id == 111
|
||||
assert span_data.attributes == {"key": "value"}
|
||||
assert span_data.events == [event]
|
||||
assert span_data.links == [link]
|
||||
assert span_data.status.status_code == status.status_code
|
||||
assert span_data.span_kind == SpanKind.SERVER
|
||||
|
||||
def test_span_data_missing_required_fields(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SpanData(
|
||||
trace_id=123,
|
||||
# span_id missing
|
||||
name="test_span",
|
||||
start_time=1000,
|
||||
end_time=2000,
|
||||
)
|
||||
|
||||
def test_span_data_arbitrary_types_allowed(self):
|
||||
# opentelemetry.trace.Status and Event are "arbitrary types" for Pydantic
|
||||
# This test ensures they are accepted thanks to model_config
|
||||
status = Status(StatusCode.ERROR, description="error occurred")
|
||||
event = Event(name="exception", timestamp=1234, attributes={"exception.type": "ValueError"})
|
||||
|
||||
span_data = SpanData(
|
||||
trace_id=123, span_id=456, name="test_span", status=status, events=[event], start_time=1000, end_time=2000
|
||||
)
|
||||
|
||||
assert span_data.status.status_code == status.status_code
|
||||
assert span_data.status.description == status.description
|
||||
assert span_data.events == [event]
|
||||
@@ -0,0 +1,68 @@
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
ACS_ARMS_SERVICE_FEATURE,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER_NAME,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
GEN_AI_USER_NAME,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
|
||||
|
||||
def test_constants():
|
||||
assert ACS_ARMS_SERVICE_FEATURE == "acs.arms.service.feature"
|
||||
assert GEN_AI_SESSION_ID == "gen_ai.session.id"
|
||||
assert GEN_AI_USER_ID == "gen_ai.user.id"
|
||||
assert GEN_AI_USER_NAME == "gen_ai.user.name"
|
||||
assert GEN_AI_SPAN_KIND == "gen_ai.span.kind"
|
||||
assert GEN_AI_FRAMEWORK == "gen_ai.framework"
|
||||
assert INPUT_VALUE == "input.value"
|
||||
assert OUTPUT_VALUE == "output.value"
|
||||
assert RETRIEVAL_QUERY == "retrieval.query"
|
||||
assert RETRIEVAL_DOCUMENT == "retrieval.document"
|
||||
assert GEN_AI_REQUEST_MODEL == "gen_ai.request.model"
|
||||
assert GEN_AI_PROVIDER_NAME == "gen_ai.provider.name"
|
||||
assert GEN_AI_USAGE_INPUT_TOKENS == "gen_ai.usage.input_tokens"
|
||||
assert GEN_AI_USAGE_OUTPUT_TOKENS == "gen_ai.usage.output_tokens"
|
||||
assert GEN_AI_USAGE_TOTAL_TOKENS == "gen_ai.usage.total_tokens"
|
||||
assert GEN_AI_PROMPT == "gen_ai.prompt"
|
||||
assert GEN_AI_COMPLETION == "gen_ai.completion"
|
||||
assert GEN_AI_RESPONSE_FINISH_REASON == "gen_ai.response.finish_reason"
|
||||
assert GEN_AI_INPUT_MESSAGE == "gen_ai.input.messages"
|
||||
assert GEN_AI_OUTPUT_MESSAGE == "gen_ai.output.messages"
|
||||
assert TOOL_NAME == "tool.name"
|
||||
assert TOOL_DESCRIPTION == "tool.description"
|
||||
assert TOOL_PARAMETERS == "tool.parameters"
|
||||
|
||||
|
||||
def test_gen_ai_span_kind_enum():
|
||||
assert GenAISpanKind.CHAIN == "CHAIN"
|
||||
assert GenAISpanKind.RETRIEVER == "RETRIEVER"
|
||||
assert GenAISpanKind.RERANKER == "RERANKER"
|
||||
assert GenAISpanKind.LLM == "LLM"
|
||||
assert GenAISpanKind.EMBEDDING == "EMBEDDING"
|
||||
assert GenAISpanKind.TOOL == "TOOL"
|
||||
assert GenAISpanKind.AGENT == "AGENT"
|
||||
assert GenAISpanKind.TASK == "TASK"
|
||||
|
||||
# Verify iteration works (covers the class definition)
|
||||
kinds = list(GenAISpanKind)
|
||||
assert len(kinds) == 8
|
||||
assert "LLM" in kinds
|
||||
647
api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
Normal file
647
api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
Normal file
@@ -0,0 +1,647 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
|
||||
|
||||
import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
|
||||
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
|
||||
|
||||
class RecordingTraceClient:
|
||||
def __init__(self, service_name: str = "service", endpoint: str = "endpoint"):
|
||||
self.service_name = service_name
|
||||
self.endpoint = endpoint
|
||||
self.added_spans: list[object] = []
|
||||
|
||||
def add_span(self, span) -> None:
|
||||
self.added_spans.append(span)
|
||||
|
||||
def api_check(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_project_url(self) -> str:
|
||||
return "project-url"
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags.SAMPLED,
|
||||
)
|
||||
return Link(context)
|
||||
|
||||
|
||||
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
|
||||
defaults = {
|
||||
"workflow_id": "workflow-id",
|
||||
"tenant_id": "tenant-id",
|
||||
"workflow_run_id": "00000000-0000-0000-0000-000000000001",
|
||||
"workflow_run_elapsed_time": 1.0,
|
||||
"workflow_run_status": "succeeded",
|
||||
"workflow_run_inputs": {"sys.query": "hello"},
|
||||
"workflow_run_outputs": {"answer": "world"},
|
||||
"workflow_run_version": "v1",
|
||||
"total_tokens": 1,
|
||||
"file_list": [],
|
||||
"query": "hello",
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u", "app_id": "app"},
|
||||
"message_id": None,
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return WorkflowTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_message_trace_info(**overrides) -> MessageTraceInfo:
|
||||
defaults = {
|
||||
"conversation_model": "chat",
|
||||
"message_tokens": 1,
|
||||
"answer_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"conversation_mode": "chat",
|
||||
"metadata": {"conversation_id": "conv", "ls_model_name": "m", "ls_provider": "p"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000002",
|
||||
"message_data": SimpleNamespace(from_account_id="acc", from_end_user_id=None),
|
||||
"inputs": {"prompt": "hi"},
|
||||
"outputs": "ok",
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"error": None,
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MessageTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo:
|
||||
defaults = {
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000003",
|
||||
"message_data": SimpleNamespace(),
|
||||
"inputs": "q",
|
||||
"documents": [SimpleNamespace()],
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return DatasetRetrievalTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_tool_trace_info(**overrides) -> ToolTraceInfo:
|
||||
defaults = {
|
||||
"tool_name": "tool",
|
||||
"tool_inputs": {"x": 1},
|
||||
"tool_outputs": "out",
|
||||
"tool_config": {"desc": "d"},
|
||||
"tool_parameters": {},
|
||||
"time_cost": 0.1,
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000004",
|
||||
"message_data": SimpleNamespace(),
|
||||
"inputs": {"i": "v"},
|
||||
"outputs": {"o": "v"},
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"error": None,
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ToolTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo:
|
||||
defaults = {
|
||||
"suggested_question": ["q1", "q2"],
|
||||
"level": "info",
|
||||
"total_tokens": 1,
|
||||
"metadata": {"conversation_id": "conv", "user_id": "u", "ls_model_name": "m", "ls_provider": "p"},
|
||||
"message_id": "00000000-0000-0000-0000-000000000005",
|
||||
"inputs": {"i": 1},
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"error": None,
|
||||
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return SuggestedQuestionTraceInfo(**defaults)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(monkeypatch: pytest.MonkeyPatch) -> AliyunDataTrace:
|
||||
monkeypatch.setattr(aliyun_trace_module, "build_endpoint", lambda base_url, license_key: "built-endpoint")
|
||||
monkeypatch.setattr(aliyun_trace_module, "TraceClient", RecordingTraceClient)
|
||||
# Mock get_service_account_with_tenant to avoid DB errors
|
||||
monkeypatch.setattr(AliyunDataTrace, "get_service_account_with_tenant", lambda self, app_id: MagicMock())
|
||||
|
||||
config = AliyunConfig(app_name="app", license_key="k", endpoint="https://example.com")
|
||||
trace = AliyunDataTrace(config)
|
||||
return trace
|
||||
|
||||
|
||||
def test_init_builds_endpoint_and_client(monkeypatch: pytest.MonkeyPatch):
|
||||
build_endpoint = MagicMock(return_value="built")
|
||||
trace_client_cls = MagicMock()
|
||||
monkeypatch.setattr(aliyun_trace_module, "build_endpoint", build_endpoint)
|
||||
monkeypatch.setattr(aliyun_trace_module, "TraceClient", trace_client_cls)
|
||||
|
||||
config = AliyunConfig(app_name="my-app", license_key="license", endpoint="https://example.com")
|
||||
trace = AliyunDataTrace(config)
|
||||
|
||||
build_endpoint.assert_called_once_with("https://example.com", "license")
|
||||
trace_client_cls.assert_called_once_with(service_name="my-app", endpoint="built")
|
||||
assert trace.trace_config == config
|
||||
|
||||
|
||||
def test_trace_dispatches_to_correct_methods(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
workflow_trace = MagicMock()
|
||||
message_trace = MagicMock()
|
||||
suggested_question_trace = MagicMock()
|
||||
dataset_retrieval_trace = MagicMock()
|
||||
tool_trace = MagicMock()
|
||||
monkeypatch.setattr(trace_instance, "workflow_trace", workflow_trace)
|
||||
monkeypatch.setattr(trace_instance, "message_trace", message_trace)
|
||||
monkeypatch.setattr(trace_instance, "suggested_question_trace", suggested_question_trace)
|
||||
monkeypatch.setattr(trace_instance, "dataset_retrieval_trace", dataset_retrieval_trace)
|
||||
monkeypatch.setattr(trace_instance, "tool_trace", tool_trace)
|
||||
|
||||
trace_instance.trace(_make_workflow_trace_info())
|
||||
workflow_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_message_trace_info())
|
||||
message_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_suggested_question_trace_info())
|
||||
suggested_question_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_dataset_retrieval_trace_info())
|
||||
dataset_retrieval_trace.assert_called_once()
|
||||
|
||||
trace_instance.trace(_make_tool_trace_info())
|
||||
tool_trace.assert_called_once()
|
||||
|
||||
# Branches that do nothing but should be covered
|
||||
trace_instance.trace(ModerationTraceInfo(flagged=False, action="allow", preset_response="", query="", metadata={}))
|
||||
trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={}))
|
||||
|
||||
|
||||
def test_api_check_delegates(trace_instance: AliyunDataTrace):
|
||||
trace_instance.trace_client.api_check = MagicMock(return_value=False)
|
||||
assert trace_instance.api_check() is False
|
||||
|
||||
|
||||
def test_get_project_url_success(trace_instance: AliyunDataTrace):
|
||||
assert trace_instance.get_project_url() == "project-url"
|
||||
|
||||
|
||||
def test_get_project_url_error(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(trace_instance.trace_client, "get_project_url", MagicMock(side_effect=Exception("boom")))
|
||||
logger_mock = MagicMock()
|
||||
monkeypatch.setattr(aliyun_trace_module, "logger", logger_mock)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Aliyun get project url failed: boom"):
|
||||
trace_instance.get_project_url()
|
||||
logger_mock.info.assert_called()
|
||||
|
||||
|
||||
def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 111)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"workflow": 222}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
|
||||
add_workflow_span = MagicMock()
|
||||
get_workflow_node_executions = MagicMock(return_value=[MagicMock(), MagicMock()])
|
||||
build_workflow_node_span = MagicMock(side_effect=["span-1", "span-2"])
|
||||
monkeypatch.setattr(trace_instance, "add_workflow_span", add_workflow_span)
|
||||
monkeypatch.setattr(trace_instance, "get_workflow_node_executions", get_workflow_node_executions)
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_node_span", build_workflow_node_span)
|
||||
|
||||
trace_info = _make_workflow_trace_info(
|
||||
trace_id="abcd", metadata={"conversation_id": "c", "user_id": "u", "app_id": "app"}
|
||||
)
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
add_workflow_span.assert_called_once()
|
||||
passed_trace_metadata = add_workflow_span.call_args.args[1]
|
||||
assert passed_trace_metadata.trace_id == 111
|
||||
assert passed_trace_metadata.workflow_span_id == 222
|
||||
assert passed_trace_metadata.session_id == "c"
|
||||
assert passed_trace_metadata.user_id == "u"
|
||||
assert passed_trace_metadata.links == []
|
||||
|
||||
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
|
||||
|
||||
|
||||
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_message_trace_info(message_data=None)
|
||||
trace_instance.message_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
|
||||
|
||||
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module,
|
||||
"convert_to_span_id",
|
||||
lambda _, span_type: {"message": 20, "llm": 30}.get(span_type, 0),
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_user_id_from_message_data", lambda _: "user")
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_info = _make_message_trace_info(
|
||||
metadata={"conversation_id": "conv", "ls_model_name": "model", "ls_provider": "provider"},
|
||||
message_tokens=7,
|
||||
answer_tokens=11,
|
||||
total_tokens=18,
|
||||
outputs="completion",
|
||||
)
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span, llm_span = trace_instance.trace_client.added_spans
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.trace_id == 10
|
||||
assert message_span.parent_span_id is None
|
||||
assert message_span.span_id == 20
|
||||
assert message_span.span_kind == SpanKind.SERVER
|
||||
assert message_span.status == status
|
||||
assert message_span.attributes["gen_ai.span.kind"] == GenAISpanKind.CHAIN
|
||||
|
||||
assert llm_span.name == "llm"
|
||||
assert llm_span.parent_span_id == 20
|
||||
assert llm_span.span_id == 30
|
||||
assert llm_span.status == status
|
||||
assert llm_span.attributes[GEN_AI_REQUEST_MODEL] == "model"
|
||||
assert llm_span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "18"
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 1)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 2}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 3)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
|
||||
|
||||
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "dataset_retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
|
||||
|
||||
|
||||
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_tool_trace_info(message_data=None)
|
||||
trace_instance.tool_trace(trace_info)
|
||||
assert trace_instance.trace_client.added_spans == []
|
||||
|
||||
|
||||
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 30)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_instance.tool_trace(
|
||||
_make_tool_trace_info(
|
||||
tool_name="my-tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_config={"description": "x"},
|
||||
inputs={"i": 1},
|
||||
)
|
||||
)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "my-tool"
|
||||
assert span.status == status
|
||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||
assert span.attributes[TOOL_DESCRIPTION] == '{"description": "x"}'
|
||||
|
||||
|
||||
def test_get_workflow_node_executions_requires_app_id(trace_instance: AliyunDataTrace):
|
||||
trace_info = _make_workflow_trace_info(metadata={"conversation_id": "c"})
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.get_workflow_node_executions(trace_info)
|
||||
|
||||
|
||||
def test_get_workflow_node_executions_builds_repo_and_fetches(
|
||||
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
trace_info = _make_workflow_trace_info(metadata={"app_id": "app", "conversation_id": "c", "user_id": "u"})
|
||||
|
||||
account = object()
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", MagicMock(return_value=account))
|
||||
monkeypatch.setattr(aliyun_trace_module, "sessionmaker", MagicMock())
|
||||
monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine"))
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = ["node1"]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
result = trace_instance.get_workflow_node_executions(trace_info)
|
||||
assert result == ["node1"]
|
||||
repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id)
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
|
||||
|
||||
node_execution.node_type = NodeType.LLM
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "llm"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_knowledge_retrieval_type(
|
||||
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
|
||||
|
||||
node_execution.node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "retrieval"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
|
||||
|
||||
node_execution.node_type = NodeType.TOOL
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "tool"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
|
||||
|
||||
node_execution.node_type = NodeType.CODE
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "task"
|
||||
|
||||
|
||||
def test_build_workflow_node_span_handles_errors(
|
||||
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
trace_info = _make_workflow_trace_info()
|
||||
trace_metadata = MagicMock()
|
||||
|
||||
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
node_execution.node_type = NodeType.CODE
|
||||
|
||||
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) is None
|
||||
assert "Error occurred in build_workflow_node_span" in caplog.text
|
||||
|
||||
|
||||
def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "title"
|
||||
node_execution.inputs = {"a": 1}
|
||||
node_execution.outputs = {"b": 2}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
|
||||
span = trace_instance.build_workflow_task_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.trace_id == 1
|
||||
assert span.span_id == 9
|
||||
assert span.status.status_code == StatusCode.OK
|
||||
assert span.attributes["gen_ai.span.kind"] == GenAISpanKind.TASK
|
||||
|
||||
|
||||
def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "my-tool"
|
||||
node_execution.inputs = {"a": 1}
|
||||
node_execution.outputs = {"b": 2}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"k": "v"}}
|
||||
|
||||
span = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.attributes[TOOL_NAME] == "my-tool"
|
||||
assert span.attributes[TOOL_DESCRIPTION] == '{"k": "v"}'
|
||||
assert span.attributes[TOOL_PARAMETERS] == '{"a": 1}'
|
||||
assert span.status.status_code == StatusCode.OK
|
||||
|
||||
# Cover metadata is None and inputs is None
|
||||
node_execution.metadata = None
|
||||
node_execution.inputs = None
|
||||
span2 = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span2.attributes[TOOL_DESCRIPTION] == "{}"
|
||||
assert span2.attributes[TOOL_PARAMETERS] == "{}"
|
||||
|
||||
|
||||
def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
|
||||
)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "retrieval"
|
||||
node_execution.inputs = {"query": "q"}
|
||||
node_execution.outputs = {"result": [{"doc": "d"}]}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
|
||||
span = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "q"
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"formatted": true}]'
|
||||
|
||||
# Cover empty inputs/outputs
|
||||
node_execution.inputs = None
|
||||
node_execution.outputs = None
|
||||
span2 = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span2.attributes[RETRIEVAL_QUERY] == ""
|
||||
assert span2.attributes[RETRIEVAL_DOCUMENT] == "[]"
|
||||
|
||||
|
||||
def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
|
||||
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node-id"
|
||||
node_execution.title = "llm"
|
||||
node_execution.process_data = {
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
|
||||
"prompts": ["p"],
|
||||
"model_name": "m",
|
||||
"model_provider": "p1",
|
||||
}
|
||||
node_execution.outputs = {"text": "t", "finish_reason": "stop"}
|
||||
node_execution.created_at = _dt()
|
||||
node_execution.finished_at = _dt()
|
||||
|
||||
span = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "3"
|
||||
assert span.attributes[GEN_AI_REQUEST_MODEL] == "m"
|
||||
assert span.attributes[GEN_AI_PROMPT] == '["p"]'
|
||||
assert span.attributes[GEN_AI_COMPLETION] == "t"
|
||||
assert span.attributes[GEN_AI_RESPONSE_FINISH_REASON] == "stop"
|
||||
assert span.attributes[GEN_AI_INPUT_MESSAGE] == "in"
|
||||
assert span.attributes[GEN_AI_OUTPUT_MESSAGE] == "out"
|
||||
|
||||
# Cover usage from outputs if not in process_data
|
||||
node_execution.process_data = {"prompts": []}
|
||||
node_execution.outputs = {"usage": {"total_tokens": 10}, "text": ""}
|
||||
span2 = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata)
|
||||
assert span2.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "10"
|
||||
|
||||
|
||||
def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0)
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
|
||||
|
||||
# CASE 1: With message_id
|
||||
trace_info = _make_workflow_trace_info(
|
||||
message_id="msg-1", workflow_run_inputs={"sys.query": "hi"}, workflow_run_outputs={"ans": "ok"}
|
||||
)
|
||||
trace_instance.add_workflow_span(trace_info, trace_metadata)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 2
|
||||
message_span = trace_instance.trace_client.added_spans[0]
|
||||
workflow_span = trace_instance.trace_client.added_spans[1]
|
||||
|
||||
assert message_span.name == "message"
|
||||
assert message_span.span_kind == SpanKind.SERVER
|
||||
assert message_span.parent_span_id is None
|
||||
|
||||
assert workflow_span.name == "workflow"
|
||||
assert workflow_span.span_kind == SpanKind.INTERNAL
|
||||
assert workflow_span.parent_span_id == 20
|
||||
|
||||
trace_instance.trace_client.added_spans.clear()
|
||||
|
||||
# CASE 2: Without message_id
|
||||
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
|
||||
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "workflow"
|
||||
assert span.span_kind == SpanKind.SERVER
|
||||
assert span.parent_span_id is None
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
|
||||
monkeypatch.setattr(
|
||||
aliyun_trace_module,
|
||||
"convert_to_span_id",
|
||||
lambda _, span_type: {"message": 20, "suggested_question": 21}.get(span_type, 0),
|
||||
)
|
||||
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
|
||||
status = Status(StatusCode.OK)
|
||||
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
|
||||
|
||||
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
assert len(trace_instance.trace_client.added_spans) == 1
|
||||
span = trace_instance.trace_client.added_spans[0]
|
||||
assert span.name == "suggested_question"
|
||||
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'
|
||||
@@ -0,0 +1,275 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from opentelemetry.trace import Link, StatusCode
|
||||
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
)
|
||||
from core.ops.aliyun_trace.utils import (
|
||||
create_common_span_attributes,
|
||||
create_links_from_trace_id,
|
||||
create_status_from_error,
|
||||
extract_retrieval_documents,
|
||||
format_input_messages,
|
||||
format_output_messages,
|
||||
format_retrieval_documents,
|
||||
get_user_id_from_message_data,
|
||||
get_workflow_node_status,
|
||||
serialize_json_data,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from models import EndUser
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_no_end_user(monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = None
|
||||
|
||||
assert get_user_id_from_message_data(message_data) == "account_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
|
||||
end_user_data = MagicMock(spec=EndUser)
|
||||
end_user_data.session_id = "session_id"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = end_user_data
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
assert get_user_id_from_message_data(message_data) == "session_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
|
||||
from core.ops.aliyun_trace.utils import db
|
||||
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
|
||||
assert get_user_id_from_message_data(message_data) == "account_id"
|
||||
|
||||
|
||||
def test_create_status_from_error():
|
||||
# Case OK
|
||||
status_ok = create_status_from_error(None)
|
||||
assert status_ok.status_code == StatusCode.OK
|
||||
|
||||
# Case Error
|
||||
status_err = create_status_from_error("some error")
|
||||
assert status_err.status_code == StatusCode.ERROR
|
||||
assert status_err.description == "some error"
|
||||
|
||||
|
||||
def test_get_workflow_node_status():
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
|
||||
# SUCCEEDED
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.OK
|
||||
|
||||
# FAILED
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = "node fail"
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "node fail"
|
||||
|
||||
# EXCEPTION
|
||||
node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
node_execution.error = "node exception"
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "node exception"
|
||||
|
||||
# UNSET/OTHER
|
||||
node_execution.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
status = get_workflow_node_status(node_execution)
|
||||
assert status.status_code == StatusCode.UNSET
|
||||
|
||||
|
||||
def test_create_links_from_trace_id(monkeypatch):
|
||||
# Mock create_link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
import core.ops.aliyun_trace.data_exporter.traceclient
|
||||
|
||||
monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
|
||||
|
||||
# Trace ID None
|
||||
assert create_links_from_trace_id(None) == []
|
||||
|
||||
# Trace ID Present
|
||||
links = create_links_from_trace_id("trace_id")
|
||||
assert len(links) == 1
|
||||
assert links[0] == mock_link
|
||||
|
||||
|
||||
def test_extract_retrieval_documents():
|
||||
doc1 = MagicMock(spec=Document)
|
||||
doc1.page_content = "content1"
|
||||
doc1.metadata = {"dataset_id": "ds1", "doc_id": "di1", "document_id": "dd1", "score": 0.9}
|
||||
|
||||
doc2 = MagicMock(spec=Document)
|
||||
doc2.page_content = "content2"
|
||||
doc2.metadata = {"dataset_id": "ds2"} # Missing some keys
|
||||
|
||||
documents = [doc1, doc2]
|
||||
extracted = extract_retrieval_documents(documents)
|
||||
|
||||
assert len(extracted) == 2
|
||||
assert extracted[0]["content"] == "content1"
|
||||
assert extracted[0]["metadata"]["dataset_id"] == "ds1"
|
||||
assert extracted[0]["score"] == 0.9
|
||||
|
||||
assert extracted[1]["content"] == "content2"
|
||||
assert extracted[1]["metadata"]["dataset_id"] == "ds2"
|
||||
assert extracted[1]["metadata"]["doc_id"] is None
|
||||
assert extracted[1]["score"] is None
|
||||
|
||||
|
||||
def test_serialize_json_data():
|
||||
data = {"a": 1}
|
||||
# Test ensure_ascii default (False)
|
||||
assert serialize_json_data(data) == json.dumps(data, ensure_ascii=False)
|
||||
# Test ensure_ascii True
|
||||
assert serialize_json_data(data, ensure_ascii=True) == json.dumps(data, ensure_ascii=True)
|
||||
|
||||
|
||||
def test_create_common_span_attributes():
|
||||
attrs = create_common_span_attributes(
|
||||
session_id="s1", user_id="u1", span_kind="kind1", framework="fw1", inputs="in1", outputs="out1"
|
||||
)
|
||||
assert attrs[GEN_AI_SESSION_ID] == "s1"
|
||||
assert attrs[GEN_AI_USER_ID] == "u1"
|
||||
assert attrs[GEN_AI_SPAN_KIND] == "kind1"
|
||||
assert attrs[GEN_AI_FRAMEWORK] == "fw1"
|
||||
assert attrs[INPUT_VALUE] == "in1"
|
||||
assert attrs[OUTPUT_VALUE] == "out1"
|
||||
|
||||
|
||||
def test_format_retrieval_documents():
|
||||
# Not a list
|
||||
assert format_retrieval_documents("not a list") == []
|
||||
|
||||
# Valid list
|
||||
docs = [
|
||||
{"metadata": {"score": 0.8, "document_id": "doc1", "source": "src1"}, "content": "c1", "title": "t1"},
|
||||
{
|
||||
"metadata": {"_source": "src2", "doc_metadata": {"extra": "val"}},
|
||||
"content": "c2",
|
||||
# Missing title
|
||||
},
|
||||
"not a dict", # Should be skipped
|
||||
]
|
||||
formatted = format_retrieval_documents(docs)
|
||||
|
||||
assert len(formatted) == 2
|
||||
assert formatted[0]["document"]["content"] == "c1"
|
||||
assert formatted[0]["document"]["metadata"]["title"] == "t1"
|
||||
assert formatted[0]["document"]["metadata"]["source"] == "src1"
|
||||
assert formatted[0]["document"]["score"] == 0.8
|
||||
assert formatted[0]["document"]["id"] == "doc1"
|
||||
|
||||
assert formatted[1]["document"]["content"] == "c2"
|
||||
assert formatted[1]["document"]["metadata"]["source"] == "src2"
|
||||
assert formatted[1]["document"]["metadata"]["extra"] == "val"
|
||||
assert "title" not in formatted[1]["document"]["metadata"]
|
||||
assert formatted[1]["document"]["score"] == 0.0 # Default
|
||||
|
||||
# Exception handling
|
||||
# We can trigger an exception by passing something that causes an error in the loop logic,
|
||||
# but the try/except covers the whole function.
|
||||
# Passing a list that contains something that throws when calling .get() - though dicts won't.
|
||||
# Let's mock a dict that raises on get.
|
||||
class BadDict:
|
||||
def get(self, *args, **kwargs):
|
||||
raise Exception("boom")
|
||||
|
||||
assert format_retrieval_documents([BadDict()]) == []
|
||||
|
||||
|
||||
def test_format_input_messages():
|
||||
# Not a dict
|
||||
assert format_input_messages(None) == serialize_json_data([])
|
||||
|
||||
# No prompts
|
||||
assert format_input_messages({}) == serialize_json_data([])
|
||||
|
||||
# Valid prompts
|
||||
process_data = {
|
||||
"prompts": [
|
||||
{"role": "user", "text": "hello"},
|
||||
{"role": "assistant", "text": "hi"},
|
||||
{"role": "system", "text": "be helpful"},
|
||||
{"role": "tool", "text": "result"},
|
||||
{"role": "invalid", "text": "skip me"},
|
||||
"not a dict",
|
||||
{"role": "user", "text": ""}, # Empty text, should be skipped? Code says `if text: message = ...`
|
||||
]
|
||||
}
|
||||
result = format_input_messages(process_data)
|
||||
result_list = json.loads(result)
|
||||
|
||||
assert len(result_list) == 4
|
||||
assert result_list[0]["role"] == "user"
|
||||
assert result_list[0]["parts"][0]["content"] == "hello"
|
||||
assert result_list[1]["role"] == "assistant"
|
||||
assert result_list[2]["role"] == "system"
|
||||
assert result_list[3]["role"] == "tool"
|
||||
|
||||
# Exception path
|
||||
assert format_input_messages({"prompts": [None]}) == serialize_json_data([])
|
||||
|
||||
|
||||
def test_format_output_messages():
|
||||
# Not a dict
|
||||
assert format_output_messages(None) == serialize_json_data([])
|
||||
|
||||
# No text
|
||||
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
|
||||
|
||||
# Valid
|
||||
outputs = {"text": "done", "finish_reason": "length"}
|
||||
result = format_output_messages(outputs)
|
||||
result_list = json.loads(result)
|
||||
assert len(result_list) == 1
|
||||
assert result_list[0]["role"] == "assistant"
|
||||
assert result_list[0]["parts"][0]["content"] == "done"
|
||||
assert result_list[0]["finish_reason"] == "length"
|
||||
|
||||
# Invalid finish reason
|
||||
outputs2 = {"text": "done", "finish_reason": "unknown"}
|
||||
result2 = format_output_messages(outputs2)
|
||||
result_list2 = json.loads(result2)
|
||||
assert result_list2[0]["finish_reason"] == "stop"
|
||||
|
||||
# Exception path
|
||||
# Trigger exception in serialize_json_data by passing non-serializable
|
||||
assert format_output_messages({"text": MagicMock()}) == serialize_json_data([])
|
||||
@@ -0,0 +1,398 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import Tracer
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
|
||||
ArizePhoenixDataTrace,
|
||||
datetime_to_nanos,
|
||||
error_to_string,
|
||||
safe_json_dumps,
|
||||
set_span_status,
|
||||
setup_tracer,
|
||||
wrap_span_metadata,
|
||||
)
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _dt():
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
def _make_workflow_info(**kwargs):
|
||||
defaults = {
|
||||
"workflow_id": "w1",
|
||||
"tenant_id": "t1",
|
||||
"workflow_run_id": "r1",
|
||||
"workflow_run_elapsed_time": 1.0,
|
||||
"workflow_run_status": "succeeded",
|
||||
"workflow_run_inputs": {"in": "val"},
|
||||
"workflow_run_outputs": {"out": "val"},
|
||||
"workflow_run_version": "1.0",
|
||||
"total_tokens": 10,
|
||||
"file_list": ["f1"],
|
||||
"query": "hi",
|
||||
"metadata": {"app_id": "app1"},
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt() + timedelta(seconds=1),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return WorkflowTraceInfo(**defaults)
|
||||
|
||||
|
||||
def _make_message_info(**kwargs):
|
||||
defaults = {
|
||||
"conversation_model": "chat",
|
||||
"message_tokens": 5,
|
||||
"answer_tokens": 5,
|
||||
"total_tokens": 10,
|
||||
"conversation_mode": "chat",
|
||||
"metadata": {"app_id": "app1"},
|
||||
"inputs": {"in": "val"},
|
||||
"outputs": "val",
|
||||
"start_time": _dt(),
|
||||
"end_time": _dt(),
|
||||
"message_id": "m1",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return MessageTraceInfo(**defaults)
|
||||
|
||||
|
||||
# --- Utility Function Tests ---
|
||||
|
||||
|
||||
def test_datetime_to_nanos():
|
||||
dt = _dt()
|
||||
expected = int(dt.timestamp() * 1_000_000_000)
|
||||
assert datetime_to_nanos(dt) == expected
|
||||
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
|
||||
mock_now = MagicMock()
|
||||
mock_now.timestamp.return_value = 1704110400.0
|
||||
mock_dt.now.return_value = mock_now
|
||||
assert datetime_to_nanos(None) == 1704110400000000000
|
||||
|
||||
|
||||
def test_error_to_string():
|
||||
try:
|
||||
raise ValueError("boom")
|
||||
except ValueError as e:
|
||||
err = e
|
||||
|
||||
res = error_to_string(err)
|
||||
assert "ValueError: boom" in res
|
||||
assert "traceback" in res.lower() or "line" in res.lower()
|
||||
|
||||
assert error_to_string("str error") == "str error"
|
||||
assert error_to_string(None) == "Empty Stack Trace"
|
||||
|
||||
|
||||
def test_set_span_status():
|
||||
span = MagicMock()
|
||||
# OK
|
||||
set_span_status(span, None)
|
||||
span.set_status.assert_called()
|
||||
assert span.set_status.call_args[0][0].status_code == StatusCode.OK
|
||||
|
||||
# Error Exception
|
||||
span.reset_mock()
|
||||
set_span_status(span, ValueError("fail"))
|
||||
assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR
|
||||
span.record_exception.assert_called()
|
||||
|
||||
# Error String
|
||||
span.reset_mock()
|
||||
set_span_status(span, "fail-str")
|
||||
assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR
|
||||
span.add_event.assert_called()
|
||||
|
||||
# repr branch
|
||||
class SilentError:
|
||||
def __str__(self):
|
||||
return ""
|
||||
|
||||
def __repr__(self):
|
||||
return "SilentErrorRepr"
|
||||
|
||||
span.reset_mock()
|
||||
set_span_status(span, SilentError())
|
||||
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
|
||||
|
||||
|
||||
def test_safe_json_dumps():
|
||||
assert safe_json_dumps({"a": _dt()}) == '{"a": "2024-01-01 00:00:00+00:00"}'
|
||||
|
||||
|
||||
def test_wrap_span_metadata():
|
||||
res = wrap_span_metadata({"a": 1}, b=2)
|
||||
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_arize(mock_provider, mock_exporter):
|
||||
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
|
||||
setup_tracer(config)
|
||||
mock_exporter.assert_called_once()
|
||||
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
|
||||
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
|
||||
config = PhoenixConfig(endpoint="http://p.com", project="p")
|
||||
setup_tracer(config)
|
||||
mock_exporter.assert_called_once()
|
||||
assert mock_exporter.call_args[1]["endpoint"] == "http://p.com/v1/traces"
|
||||
|
||||
|
||||
def test_setup_tracer_exception():
|
||||
config = ArizeConfig(endpoint="http://a.com", project="p")
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
|
||||
with pytest.raises(Exception, match="boom"):
|
||||
setup_tracer(config)
|
||||
|
||||
|
||||
# --- ArizePhoenixDataTrace Class Tests ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance():
|
||||
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
|
||||
mock_tracer = MagicMock(spec=Tracer)
|
||||
mock_processor = MagicMock()
|
||||
mock_setup.return_value = (mock_tracer, mock_processor)
|
||||
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
|
||||
return ArizePhoenixDataTrace(config)
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance):
|
||||
with (
|
||||
patch.object(trace_instance, "workflow_trace") as m1,
|
||||
patch.object(trace_instance, "message_trace") as m2,
|
||||
patch.object(trace_instance, "moderation_trace") as m3,
|
||||
patch.object(trace_instance, "suggested_question_trace") as m4,
|
||||
patch.object(trace_instance, "dataset_retrieval_trace") as m5,
|
||||
patch.object(trace_instance, "tool_trace") as m6,
|
||||
patch.object(trace_instance, "generate_name_trace") as m7,
|
||||
):
|
||||
trace_instance.trace(_make_workflow_info())
|
||||
m1.assert_called()
|
||||
|
||||
trace_instance.trace(_make_message_info())
|
||||
m2.assert_called()
|
||||
|
||||
trace_instance.trace(ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={}))
|
||||
m3.assert_called()
|
||||
|
||||
trace_instance.trace(SuggestedQuestionTraceInfo(suggested_question=[], total_tokens=0, level="i", metadata={}))
|
||||
m4.assert_called()
|
||||
|
||||
trace_instance.trace(DatasetRetrievalTraceInfo(metadata={}))
|
||||
m5.assert_called()
|
||||
|
||||
trace_instance.trace(
|
||||
ToolTraceInfo(
|
||||
tool_name="t",
|
||||
tool_inputs={},
|
||||
tool_outputs="o",
|
||||
metadata={},
|
||||
tool_config={},
|
||||
time_cost=1,
|
||||
tool_parameters={},
|
||||
)
|
||||
)
|
||||
m6.assert_called()
|
||||
|
||||
trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={}))
|
||||
m7.assert_called()
|
||||
|
||||
|
||||
def test_trace_exception(trace_instance):
|
||||
with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("fail")):
|
||||
with pytest.raises(RuntimeError):
|
||||
trace_instance.trace(_make_workflow_info())
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
repo = MagicMock()
|
||||
mock_repo_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
|
||||
node1 = MagicMock()
|
||||
node1.node_type = "llm"
|
||||
node1.status = "succeeded"
|
||||
node1.inputs = {"q": "hi"}
|
||||
node1.outputs = {"a": "bye", "usage": {"total_tokens": 5}}
|
||||
node1.created_at = _dt()
|
||||
node1.elapsed_time = 1.0
|
||||
node1.process_data = {
|
||||
"prompts": [{"role": "user", "content": "hi"}],
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
}
|
||||
node1.metadata = {"k": "v"}
|
||||
node1.title = "title"
|
||||
node1.id = "n1"
|
||||
node1.error = None
|
||||
|
||||
repo.get_by_workflow_run.return_value = [node1]
|
||||
|
||||
with patch.object(trace_instance, "get_service_account_with_tenant"):
|
||||
trace_instance.workflow_trace(info)
|
||||
|
||||
assert trace_instance.tracer.start_span.call_count >= 2
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_workflow_trace_no_app_id(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_workflow_info()
|
||||
info.metadata = {}
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(info)
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_message_trace_success(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.from_account_id = "acc1"
|
||||
info.message_data.from_end_user_id = None
|
||||
info.message_data.query = "q"
|
||||
info.message_data.answer = "a"
|
||||
info.message_data.status = "s"
|
||||
info.message_data.model_id = "m"
|
||||
info.message_data.model_provider = "p"
|
||||
info.message_data.message_metadata = "{}"
|
||||
info.message_data.error = None
|
||||
info.error = None
|
||||
|
||||
trace_instance.message_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
|
||||
def test_message_trace_with_error(mock_db, trace_instance):
|
||||
mock_db.engine = MagicMock()
|
||||
info = _make_message_info()
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.from_account_id = "acc1"
|
||||
info.message_data.from_end_user_id = None
|
||||
info.message_data.query = "q"
|
||||
info.message_data.answer = "a"
|
||||
info.message_data.status = "s"
|
||||
info.message_data.model_id = "m"
|
||||
info.message_data.model_provider = "p"
|
||||
info.message_data.message_metadata = "{}"
|
||||
info.message_data.error = "processing failed"
|
||||
info.error = "message error"
|
||||
|
||||
trace_instance.message_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_trace_methods_return_early_with_no_message_data(trace_instance):
|
||||
info = MagicMock()
|
||||
info.message_data = None
|
||||
|
||||
trace_instance.moderation_trace(info)
|
||||
trace_instance.suggested_question_trace(info)
|
||||
trace_instance.dataset_retrieval_trace(info)
|
||||
trace_instance.tool_trace(info)
|
||||
trace_instance.generate_name_trace(info)
|
||||
|
||||
assert trace_instance.tracer.start_span.call_count == 0
|
||||
|
||||
|
||||
def test_moderation_trace_ok(trace_instance):
|
||||
info = ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.error = None
|
||||
trace_instance.moderation_trace(info)
|
||||
# root span (1) + moderation span (1) = 2
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_suggested_question_trace_ok(trace_instance):
|
||||
info = SuggestedQuestionTraceInfo(suggested_question=["?"], total_tokens=1, level="i", metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.error = None
|
||||
trace_instance.suggested_question_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_ok(trace_instance):
|
||||
info = DatasetRetrievalTraceInfo(documents=[], metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.error = None
|
||||
trace_instance.dataset_retrieval_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_tool_trace_ok(trace_instance):
|
||||
info = ToolTraceInfo(
|
||||
tool_name="t", tool_inputs={}, tool_outputs="o", metadata={}, tool_config={}, time_cost=1, tool_parameters={}
|
||||
)
|
||||
info.message_data = MagicMock()
|
||||
info.error = None
|
||||
trace_instance.tool_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_generate_name_trace_ok(trace_instance):
|
||||
info = GenerateNameTraceInfo(tenant_id="t", metadata={})
|
||||
info.message_data = MagicMock()
|
||||
info.message_data.error = None
|
||||
trace_instance.generate_name_trace(info)
|
||||
assert trace_instance.tracer.start_span.call_count >= 1
|
||||
|
||||
|
||||
def test_get_project_url_phoenix(trace_instance):
|
||||
trace_instance.arize_phoenix_config = PhoenixConfig(endpoint="http://p.com", project="p")
|
||||
assert "p.com/projects/" in trace_instance.get_project_url()
|
||||
|
||||
|
||||
def test_set_attribute_none_logic(trace_instance):
|
||||
# Test role can be None
|
||||
attrs = trace_instance._construct_llm_attributes([{"role": None, "content": "hi"}])
|
||||
assert "llm.input_messages.0.message.role" not in attrs
|
||||
|
||||
# Test tool call id can be None
|
||||
tool_call_none_id = {"id": None, "function": {"name": "f1"}}
|
||||
attrs = trace_instance._construct_llm_attributes([{"role": "assistant", "tool_calls": [tool_call_none_id]}])
|
||||
assert "llm.input_messages.0.message.tool_calls.0.tool_call.id" not in attrs
|
||||
|
||||
|
||||
def test_construct_llm_attributes_dict_branch(trace_instance):
|
||||
attrs = trace_instance._construct_llm_attributes({"prompt": "hi"})
|
||||
assert '"prompt": "hi"' in attrs["llm.input_messages.0.message.content"]
|
||||
assert attrs["llm.input_messages.0.message.role"] == "user"
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
assert trace_instance.api_check() is True
|
||||
|
||||
|
||||
def test_ensure_root_span_basic(trace_instance):
|
||||
trace_instance.ensure_root_span("tid")
|
||||
assert "tid" in trace_instance.dify_trace_ids
|
||||
@@ -0,0 +1,698 @@
|
||||
import collections
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
|
||||
LangfuseGeneration,
|
||||
LangfuseSpan,
|
||||
LangfuseTrace,
|
||||
LevelEnum,
|
||||
UnitEnum,
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from dify_graph.enums import NodeType
|
||||
from models import EndUser
|
||||
from models.enums import MessageStatus
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def langfuse_config():
|
||||
return LangfuseConfig(public_key="pk-123", secret_key="sk-123", host="https://cloud.langfuse.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langfuse_config, monkeypatch):
|
||||
# Mock Langfuse client to avoid network calls
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langfuse_config, monkeypatch):
|
||||
mock_langfuse = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangFuseDataTrace(langfuse_config)
|
||||
|
||||
mock_langfuse.assert_called_once_with(
|
||||
public_key=langfuse_config.public_key,
|
||||
secret_key=langfuse_config.secret_key,
|
||||
host=langfuse_config.host,
|
||||
)
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
"moderation_trace",
|
||||
"suggested_question_trace",
|
||||
"dataset_retrieval_trace",
|
||||
"tool_trace",
|
||||
"generate_name_trace",
|
||||
]
|
||||
mocks = {method: MagicMock() for method in methods}
|
||||
for method, m in mocks.items():
|
||||
monkeypatch.setattr(trace_instance, method, m)
|
||||
|
||||
# WorkflowTraceInfo
|
||||
info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["workflow_trace"].assert_called_once_with(info)
|
||||
|
||||
# MessageTraceInfo
|
||||
info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["message_trace"].assert_called_once_with(info)
|
||||
|
||||
# ModerationTraceInfo
|
||||
info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["moderation_trace"].assert_called_once_with(info)
|
||||
|
||||
# SuggestedQuestionTraceInfo
|
||||
info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["suggested_question_trace"].assert_called_once_with(info)
|
||||
|
||||
# DatasetRetrievalTraceInfo
|
||||
info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
|
||||
|
||||
# ToolTraceInfo
|
||||
info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["tool_trace"].assert_called_once_with(info)
|
||||
|
||||
# GenerateNameTraceInfo
|
||||
info = MagicMock(spec=GenerateNameTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
# Setup trace info
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={"input": "hi"},
|
||||
workflow_run_outputs={"output": "hello"},
|
||||
workflow_run_version="1.0",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1", "user_id": "user-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
|
||||
# Mock DB and Repositories
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = "node-llm"
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = {
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||
}
|
||||
node_llm.inputs = {"prompts": "p"}
|
||||
node_llm.outputs = {"text": "t"}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {"foo": "bar"}
|
||||
|
||||
node_other = MagicMock()
|
||||
node_other.id = "node-other"
|
||||
node_other.title = "Other Node"
|
||||
node_other.node_type = NodeType.CODE
|
||||
node_other.status = "failed"
|
||||
node_other.process_data = None
|
||||
node_other.inputs = {"code": "print"}
|
||||
node_other.outputs = {"result": "ok"}
|
||||
node_other.created_at = None # Trigger datetime.now() branch
|
||||
node_other.elapsed_time = 0.2
|
||||
node_other.metadata = None
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm, node_other]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
# Track calls to add_trace, add_span, add_generation
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
# Verify add_trace (Workflow Level)
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.id == "trace-1"
|
||||
assert trace_data.name == TraceTaskName.MESSAGE_TRACE
|
||||
assert "message" in trace_data.tags
|
||||
assert "workflow" in trace_data.tags
|
||||
|
||||
# Verify add_span (Workflow Run Span)
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
# First span should be workflow run span because message_id is present
|
||||
workflow_span = trace_instance.add_span.call_args_list[0][1]["langfuse_span_data"]
|
||||
assert workflow_span.id == "run-1"
|
||||
assert workflow_span.name == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
# Verify Generation for LLM node
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"]
|
||||
assert gen_data.id == "node-llm"
|
||||
assert gen_data.usage.input == 10
|
||||
assert gen_data.usage.output == 20
|
||||
|
||||
# Verify normal span for Other node
|
||||
# Second add_span call
|
||||
other_span = trace_instance.add_span.call_args_list[1][1]["langfuse_span_data"]
|
||||
assert other_span.id == "node-other"
|
||||
assert other_span.level == LevelEnum.ERROR
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
trace_id=None, # Should fallback to workflow_run_id
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.id == "run-1"
|
||||
assert trace_data.name == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={}, # Missing app_id
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = None
|
||||
message_data.provider_response_latency = 0.5
|
||||
message_data.conversation_id = "conv-1"
|
||||
message_data.total_price = 0.01
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.answer = "hello"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.error = None
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"query": "hi"},
|
||||
outputs={"answer": "hello"},
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
total_tokens=30,
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"foo": "bar"},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
|
||||
gen_data = trace_instance.add_generation.call_args[0][0]
|
||||
assert gen_data.name == "llm"
|
||||
assert gen_data.usage.total == 30
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = "end-user-1"
|
||||
message_data.conversation_id = "conv-1"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.error = ""
|
||||
message_data.answer = "hello"
|
||||
message_data.total_price = 0.0
|
||||
message_data.provider_response_latency = 0.1
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={},
|
||||
outputs={},
|
||||
message_tokens=0,
|
||||
answer_tokens=0,
|
||||
total_tokens=0,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
)
|
||||
|
||||
# Mock DB session for EndUser lookup
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.user_id == "session-id-123"
|
||||
assert trace_data.metadata["user_id"] == "session-id-123"
|
||||
|
||||
|
||||
def test_message_trace_none_data(trace_instance):
|
||||
trace_info = SimpleNamespace(message_data=None, file_list=[], metadata={})
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.add_trace.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"q": "hi"},
|
||||
action="stop",
|
||||
flagged=True,
|
||||
preset_response="blocked",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={"foo": "bar"},
|
||||
trace_id="trace-1",
|
||||
query="hi",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.name == TraceTaskName.MODERATION_TRACE
|
||||
assert span_data.output["flagged"] is True
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.error = None
|
||||
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="hi",
|
||||
suggested_question=["q1"],
|
||||
total_tokens=10,
|
||||
level="info",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_generation = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"]
|
||||
assert gen_data.name == TraceTaskName.SUGGESTED_QUESTION_TRACE
|
||||
assert gen_data.usage.unit == UnitEnum.CHARACTERS
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="query",
|
||||
documents=[{"id": "doc1"}],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.name == TraceTaskName.DATASET_RETRIEVAL_TRACE
|
||||
assert span_data.output["documents"] == [{"id": "doc1"}]
|
||||
|
||||
|
||||
def test_tool_trace(trace_instance):
|
||||
trace_info = ToolTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=MagicMock(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
tool_name="my_tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="result_string",
|
||||
time_cost=0.1,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
error="some error",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.tool_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.name == "my_tool"
|
||||
assert span_data.level == LevelEnum.ERROR
|
||||
|
||||
|
||||
def test_generate_name_trace(trace_instance):
|
||||
trace_info = GenerateNameTraceInfo(
|
||||
inputs={"q": "hi"},
|
||||
outputs={"name": "new"},
|
||||
tenant_id="tenant-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"m": 1},
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.generate_name_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
|
||||
assert trace_data.name == TraceTaskName.GENERATE_NAME_TRACE
|
||||
assert trace_data.user_id == "tenant-1"
|
||||
|
||||
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
|
||||
assert span_data.trace_id == "conv-1"
|
||||
|
||||
|
||||
def test_add_trace_success(trace_instance):
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
trace_instance.add_trace(data)
|
||||
trace_instance.langfuse_client.trace.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trace_error(trace_instance):
|
||||
trace_instance.langfuse_client.trace.side_effect = Exception("error")
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"):
|
||||
trace_instance.add_trace(data)
|
||||
|
||||
|
||||
def test_add_span_success(trace_instance):
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
trace_instance.add_span(data)
|
||||
trace_instance.langfuse_client.span.assert_called_once()
|
||||
|
||||
|
||||
def test_add_span_error(trace_instance):
|
||||
trace_instance.langfuse_client.span.side_effect = Exception("error")
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create span: error"):
|
||||
trace_instance.add_span(data)
|
||||
|
||||
|
||||
def test_update_span(trace_instance):
|
||||
span = MagicMock()
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
trace_instance.update_span(span, data)
|
||||
span.end.assert_called_once()
|
||||
|
||||
|
||||
def test_add_generation_success(trace_instance):
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
trace_instance.add_generation(data)
|
||||
trace_instance.langfuse_client.generation.assert_called_once()
|
||||
|
||||
|
||||
def test_add_generation_error(trace_instance):
|
||||
trace_instance.langfuse_client.generation.side_effect = Exception("error")
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"):
|
||||
trace_instance.add_generation(data)
|
||||
|
||||
|
||||
def test_update_generation(trace_instance):
|
||||
gen = MagicMock()
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
trace_instance.update_generation(gen, data)
|
||||
gen.end.assert_called_once()
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
trace_instance.langfuse_client.auth_check.return_value = True
|
||||
assert trace_instance.api_check() is True
|
||||
|
||||
|
||||
def test_api_check_error(trace_instance):
|
||||
trace_instance.langfuse_client.auth_check.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="LangFuse API check failed: fail"):
|
||||
trace_instance.api_check()
|
||||
|
||||
|
||||
def test_get_project_key_success(trace_instance):
|
||||
mock_data = MagicMock()
|
||||
mock_data.id = "proj-1"
|
||||
trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data])
|
||||
assert trace_instance.get_project_key() == "proj-1"
|
||||
|
||||
|
||||
def test_get_project_key_error(trace_instance):
|
||||
trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="LangFuse get project key failed: fail"):
|
||||
trace_instance.get_project_key()
|
||||
|
||||
|
||||
def test_moderation_trace_none(trace_instance):
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="m",
|
||||
message_data=None,
|
||||
inputs={},
|
||||
action="s",
|
||||
flagged=False,
|
||||
preset_response="",
|
||||
query="",
|
||||
metadata={},
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_suggested_question_trace_none(trace_instance):
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="m", message_data=None, inputs={}, suggested_question=[], total_tokens=0, level="i", metadata={}
|
||||
)
|
||||
trace_instance.add_generation = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_generation.assert_not_called()
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_none(trace_instance):
|
||||
trace_info = DatasetRetrievalTraceInfo(message_id="m", message_data=None, inputs={}, documents=[], metadata={})
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_langfuse_trace_entity_with_list_dict_input():
|
||||
# To cover lines 29-31 in langfuse_trace_entity.py
|
||||
# We need to mock replace_text_with_content or just check if it works
|
||||
# Actually replace_text_with_content is imported from core.ops.utils
|
||||
data = LangfuseTrace(id="t1", name="n", input=[{"text": "hello"}])
|
||||
assert isinstance(data.input, list)
|
||||
assert data.input[0]["content"] == "hello"
|
||||
|
||||
|
||||
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
# Setup trace info to trigger LLM node usage extraction
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="t",
|
||||
workflow_run_id="r",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="s",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="c",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="l",
|
||||
error="",
|
||||
)
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "n1"
|
||||
node.title = "LLM Node"
|
||||
node.node_type = NodeType.LLM
|
||||
node.status = "succeeded"
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
return super().get(key, default)
|
||||
|
||||
node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
|
||||
node.created_at = _dt()
|
||||
node.elapsed_time = 0.1
|
||||
node.metadata = {}
|
||||
node.outputs = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_generation = MagicMock()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
assert "Failed to extract usage" in caplog.text
|
||||
trace_instance.add_generation.assert_called_once()
|
||||
@@ -0,0 +1,608 @@
|
||||
import collections
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import LangSmithConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
|
||||
LangSmithRunModel,
|
||||
LangSmithRunType,
|
||||
LangSmithRunUpdateModel,
|
||||
)
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def langsmith_config():
|
||||
return LangSmithConfig(api_key="ls-123", project="default", endpoint="https://api.smith.langchain.com")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langsmith_config, monkeypatch):
|
||||
# Mock LangSmith client
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langsmith_config, monkeypatch):
|
||||
mock_client_class = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = LangSmithDataTrace(langsmith_config)
|
||||
|
||||
mock_client_class.assert_called_once_with(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
|
||||
assert instance.langsmith_key == langsmith_config.api_key
|
||||
assert instance.project_name == langsmith_config.project
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
"moderation_trace",
|
||||
"suggested_question_trace",
|
||||
"dataset_retrieval_trace",
|
||||
"tool_trace",
|
||||
"generate_name_trace",
|
||||
]
|
||||
mocks = {method: MagicMock() for method in methods}
|
||||
for method, m in mocks.items():
|
||||
monkeypatch.setattr(trace_instance, method, m)
|
||||
|
||||
# WorkflowTraceInfo
|
||||
info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["workflow_trace"].assert_called_once_with(info)
|
||||
|
||||
# MessageTraceInfo
|
||||
info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["message_trace"].assert_called_once_with(info)
|
||||
|
||||
# ModerationTraceInfo
|
||||
info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["moderation_trace"].assert_called_once_with(info)
|
||||
|
||||
# SuggestedQuestionTraceInfo
|
||||
info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["suggested_question_trace"].assert_called_once_with(info)
|
||||
|
||||
# DatasetRetrievalTraceInfo
|
||||
info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
|
||||
|
||||
# ToolTraceInfo
|
||||
info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["tool_trace"].assert_called_once_with(info)
|
||||
|
||||
# GenerateNameTraceInfo
|
||||
info = MagicMock(spec=GenerateNameTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace(trace_instance, monkeypatch):
|
||||
# Setup trace info
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_inputs={"input": "hi"},
|
||||
workflow_run_outputs={"output": "hello"},
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_version="1.0",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
workflow_data=workflow_data,
|
||||
)
|
||||
|
||||
# Mock dependencies
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
# Mock node executions
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = "node-llm"
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = {
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||||
}
|
||||
node_llm.inputs = {"prompts": "p"}
|
||||
node_llm.outputs = {"text": "t"}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 30}
|
||||
|
||||
node_other = MagicMock()
|
||||
node_other.id = "node-other"
|
||||
node_other.title = "Tool Node"
|
||||
node_other.node_type = NodeType.TOOL
|
||||
node_other.status = "succeeded"
|
||||
node_other.process_data = None
|
||||
node_other.inputs = {"tool_input": "val"}
|
||||
node_other.outputs = {"tool_output": "val"}
|
||||
node_other.created_at = None # Trigger datetime.now()
|
||||
node_other.elapsed_time = 0.2
|
||||
node_other.metadata = {}
|
||||
|
||||
node_retrieval = MagicMock()
|
||||
node_retrieval.id = "node-retrieval"
|
||||
node_retrieval.title = "Retrieval Node"
|
||||
node_retrieval.node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
node_retrieval.status = "succeeded"
|
||||
node_retrieval.process_data = None
|
||||
node_retrieval.inputs = {"query": "val"}
|
||||
node_retrieval.outputs = {"results": "val"}
|
||||
node_retrieval.created_at = _dt()
|
||||
node_retrieval.elapsed_time = 0.2
|
||||
node_retrieval.metadata = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
# Verify add_run calls
|
||||
# 1. message run (id="msg-1")
|
||||
# 2. workflow run (id="run-1")
|
||||
# 3. node llm run (id="node-llm")
|
||||
# 4. node other run (id="node-other")
|
||||
# 5. node retrieval run (id="node-retrieval")
|
||||
assert trace_instance.add_run.call_count == 5
|
||||
|
||||
call_args = [call[0][0] for call in trace_instance.add_run.call_args_list]
|
||||
|
||||
assert call_args[0].id == "msg-1"
|
||||
assert call_args[0].name == TraceTaskName.MESSAGE_TRACE
|
||||
|
||||
assert call_args[1].id == "run-1"
|
||||
assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE
|
||||
assert call_args[1].parent_run_id == "msg-1"
|
||||
|
||||
assert call_args[2].id == "node-llm"
|
||||
assert call_args[2].run_type == LangSmithRunType.llm
|
||||
|
||||
assert call_args[3].id == "node-other"
|
||||
assert call_args[3].run_type == LangSmithRunType.tool
|
||||
|
||||
assert call_args[4].id == "node-retrieval"
|
||||
assert call_args[4].run_type == LangSmithRunType.retriever
|
||||
|
||||
|
||||
def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_version="1.0",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
total_tokens=10,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
workflow_data=workflow_data,
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
assert trace_instance.add_run.called
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.trace_id = "trace-1"
|
||||
trace_info.message_id = None
|
||||
trace_info.workflow_run_id = "run-1"
|
||||
trace_info.start_time = None
|
||||
trace_info.workflow_data = MagicMock()
|
||||
trace_info.workflow_data.created_at = _dt()
|
||||
trace_info.metadata = {} # Empty metadata
|
||||
trace_info.workflow_app_log_id = "log-1"
|
||||
trace_info.file_list = []
|
||||
trace_info.total_tokens = 0
|
||||
trace_info.workflow_run_inputs = {}
|
||||
trace_info.workflow_run_outputs = {}
|
||||
trace_info.error = ""
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = "end-user-1"
|
||||
message_data.answer = "hello answer"
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"input": "hi"},
|
||||
outputs={"answer": "hello"},
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
total_tokens=30,
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id="trace-1",
|
||||
metadata={"foo": "bar"},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
message_file_data=MagicMock(url="file-url"),
|
||||
)
|
||||
|
||||
# Mock EndUser lookup
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
# 1. message run
|
||||
# 2. llm run
|
||||
assert trace_instance.add_run.call_count == 2
|
||||
|
||||
call_args = [call[0][0] for call in trace_instance.add_run.call_args_list]
|
||||
assert call_args[0].id == "msg-1"
|
||||
assert call_args[0].extra["metadata"]["end_user_id"] == "session-id-123"
|
||||
assert call_args[1].parent_run_id == "msg-1"
|
||||
assert call_args[1].name == "llm"
|
||||
|
||||
|
||||
def test_message_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_info.file_list = []
|
||||
trace_info.message_file_data = None
|
||||
trace_info.metadata = {}
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_suggested_question_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_no_data(trace_instance):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_data = None
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_run.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs={"q": "hi"},
|
||||
action="stop",
|
||||
flagged=True,
|
||||
preset_response="blocked",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
query="hi",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.MODERATION_TRACE
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="hi",
|
||||
suggested_question=["q1"],
|
||||
total_tokens=10,
|
||||
level="info",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.SUGGESTED_QUESTION_TRACE
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=message_data,
|
||||
inputs="query",
|
||||
documents=[{"id": "doc1"}],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.DATASET_RETRIEVAL_TRACE
|
||||
|
||||
|
||||
def test_tool_trace(trace_instance):
|
||||
trace_info = ToolTraceInfo(
|
||||
message_id="msg-1",
|
||||
message_data=MagicMock(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
tool_name="my_tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="result",
|
||||
time_cost=0.1,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
file_url="http://file",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.tool_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == "my_tool"
|
||||
|
||||
|
||||
def test_generate_name_trace(trace_instance):
|
||||
trace_info = GenerateNameTraceInfo(
|
||||
inputs={"q": "hi"},
|
||||
outputs={"name": "new"},
|
||||
tenant_id="tenant-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="trace-1",
|
||||
)
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
trace_instance.generate_name_trace(trace_info)
|
||||
trace_instance.add_run.assert_called_once()
|
||||
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.GENERATE_NAME_TRACE
|
||||
|
||||
|
||||
def test_add_run_success(trace_instance):
|
||||
run_data = LangSmithRunModel(
|
||||
id="run-1", name="test", inputs={}, outputs={}, run_type=LangSmithRunType.tool, start_time=_dt()
|
||||
)
|
||||
trace_instance.project_id = "proj-1"
|
||||
trace_instance.add_run(run_data)
|
||||
trace_instance.langsmith_client.create_run.assert_called_once()
|
||||
args, kwargs = trace_instance.langsmith_client.create_run.call_args
|
||||
assert kwargs["session_id"] == "proj-1"
|
||||
|
||||
|
||||
def test_add_run_error(trace_instance):
|
||||
run_data = LangSmithRunModel(id="run-1", name="test", run_type=LangSmithRunType.tool, start_time=_dt())
|
||||
trace_instance.langsmith_client.create_run.side_effect = Exception("failed")
|
||||
with pytest.raises(ValueError, match="LangSmith Failed to create run: failed"):
|
||||
trace_instance.add_run(run_data)
|
||||
|
||||
|
||||
def test_update_run_success(trace_instance):
|
||||
update_data = LangSmithRunUpdateModel(run_id="run-1", outputs={"out": "val"})
|
||||
trace_instance.update_run(update_data)
|
||||
trace_instance.langsmith_client.update_run.assert_called_once()
|
||||
|
||||
|
||||
def test_update_run_error(trace_instance):
|
||||
update_data = LangSmithRunUpdateModel(run_id="run-1")
|
||||
trace_instance.langsmith_client.update_run.side_effect = Exception("failed")
|
||||
with pytest.raises(ValueError, match="LangSmith Failed to update run: failed"):
|
||||
trace_instance.update_run(update_data)
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_version="1.0",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
message_id="msg-1",
|
||||
conversation_id="conv-1",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
trace_id="trace-1",
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id="log-1",
|
||||
error="",
|
||||
workflow_data=workflow_data,
|
||||
)
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
return super().get(key, default)
|
||||
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = "node-llm"
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
|
||||
node_llm.inputs = {}
|
||||
node_llm.outputs = {}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
assert "Failed to extract usage" in caplog.text
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
assert trace_instance.api_check() is True
|
||||
assert trace_instance.langsmith_client.create_project.called
|
||||
assert trace_instance.langsmith_client.delete_project.called
|
||||
|
||||
|
||||
def test_api_check_error(trace_instance):
|
||||
trace_instance.langsmith_client.create_project.side_effect = Exception("error")
|
||||
with pytest.raises(ValueError, match="LangSmith API check failed: error"):
|
||||
trace_instance.api_check()
|
||||
|
||||
|
||||
def test_get_project_url_success(trace_instance):
|
||||
trace_instance.langsmith_client.get_run_url.return_value = "https://smith.langchain.com/o/org/p/proj/r/run"
|
||||
url = trace_instance.get_project_url()
|
||||
assert url == "https://smith.langchain.com/o/org/p/proj"
|
||||
|
||||
|
||||
def test_get_project_url_error(trace_instance):
|
||||
trace_instance.langsmith_client.get_run_url.side_effect = Exception("error")
|
||||
with pytest.raises(ValueError, match="LangSmith get run url failed: error"):
|
||||
trace_instance.get_project_url()
|
||||
1019
api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
Normal file
1019
api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
Normal file
File diff suppressed because it is too large
Load Diff
678
api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
Normal file
678
api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
Normal file
@@ -0,0 +1,678 @@
|
||||
import collections
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from models import EndUser
|
||||
from models.enums import MessageStatus
|
||||
|
||||
|
||||
def _dt() -> datetime:
|
||||
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opik_config():
|
||||
return OpikConfig(
|
||||
project="test-project", workspace="test-workspace", url="https://cloud.opik.com/api/", api_key="api-key-123"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(opik_config, monkeypatch):
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client)
|
||||
|
||||
instance = OpikDataTrace(opik_config)
|
||||
return instance
|
||||
|
||||
|
||||
def test_wrap_dict():
|
||||
assert wrap_dict("input", {"a": 1}) == {"a": 1}
|
||||
assert wrap_dict("input", "hello") == {"input": "hello"}
|
||||
|
||||
|
||||
def test_wrap_metadata():
|
||||
assert wrap_metadata({"a": 1}, b=2) == {"a": 1, "b": 2, "created_from": "dify"}
|
||||
|
||||
|
||||
def test_prepare_opik_uuid():
|
||||
# Test with valid datetime and uuid string
|
||||
dt = datetime(2024, 1, 1)
|
||||
uuid_str = "b3e8e918-472e-4b69-8051-12502c34fc07"
|
||||
result = prepare_opik_uuid(dt, uuid_str)
|
||||
assert result is not None
|
||||
# We won't test the exact uuid7 value but just that it returns a string id
|
||||
|
||||
# Test with None dt and uuid_str
|
||||
result = prepare_opik_uuid(None, None)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_init(opik_config, monkeypatch):
|
||||
mock_opik = MagicMock()
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
|
||||
instance = OpikDataTrace(opik_config)
|
||||
|
||||
mock_opik.assert_called_once_with(
|
||||
project_name=opik_config.project,
|
||||
workspace=opik_config.workspace,
|
||||
host=opik_config.url,
|
||||
api_key=opik_config.api_key,
|
||||
)
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
assert instance.project == opik_config.project
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
"moderation_trace",
|
||||
"suggested_question_trace",
|
||||
"dataset_retrieval_trace",
|
||||
"tool_trace",
|
||||
"generate_name_trace",
|
||||
]
|
||||
mocks = {method: MagicMock() for method in methods}
|
||||
for method, m in mocks.items():
|
||||
monkeypatch.setattr(trace_instance, method, m)
|
||||
|
||||
# WorkflowTraceInfo
|
||||
info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["workflow_trace"].assert_called_once_with(info)
|
||||
|
||||
# MessageTraceInfo
|
||||
info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["message_trace"].assert_called_once_with(info)
|
||||
|
||||
# ModerationTraceInfo
|
||||
info = MagicMock(spec=ModerationTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["moderation_trace"].assert_called_once_with(info)
|
||||
|
||||
# SuggestedQuestionTraceInfo
|
||||
info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["suggested_question_trace"].assert_called_once_with(info)
|
||||
|
||||
# DatasetRetrievalTraceInfo
|
||||
info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
|
||||
|
||||
# ToolTraceInfo
|
||||
info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["tool_trace"].assert_called_once_with(info)
|
||||
|
||||
# GenerateNameTraceInfo
|
||||
info = MagicMock(spec=GenerateNameTraceInfo)
|
||||
trace_instance.trace(info)
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce"
|
||||
WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef"
|
||||
MESSAGE_ID = "04ec3956-85f3-488a-8539-1017251dc8c6"
|
||||
CONVERSATION_ID = "d3d01066-23ae-4830-9ce4-eb5640b42a7e"
|
||||
TRACE_ID = "bf26d929-6f15-4c2f-9abc-761c217056f3"
|
||||
WORKFLOW_APP_LOG_ID = "ca0e018e-edd4-43fb-a05a-ea001ca8ef4b"
|
||||
LLM_NODE_ID = "80d7dfa8-08f4-4ab7-aa37-0ca7d27207e3"
|
||||
CODE_NODE_ID = "b9cd9a7b-c534-4aa9-b5da-efd454140900"
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id=WORKFLOW_ID,
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id=WORKFLOW_RUN_ID,
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={"input": "hi"},
|
||||
workflow_run_outputs={"output": "hello"},
|
||||
workflow_run_version="1.0",
|
||||
message_id=MESSAGE_ID,
|
||||
conversation_id=CONVERSATION_ID,
|
||||
total_tokens=100,
|
||||
file_list=[],
|
||||
query="hi",
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id=TRACE_ID,
|
||||
metadata={"app_id": "app-1", "user_id": "user-1"},
|
||||
workflow_app_log_id=WORKFLOW_APP_LOG_ID,
|
||||
error="",
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
node_llm = MagicMock()
|
||||
node_llm.id = LLM_NODE_ID
|
||||
node_llm.title = "LLM Node"
|
||||
node_llm.node_type = NodeType.LLM
|
||||
node_llm.status = "succeeded"
|
||||
node_llm.process_data = {
|
||||
"model_mode": "chat",
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
node_llm.inputs = {"prompts": "p"}
|
||||
node_llm.outputs = {"text": "t"}
|
||||
node_llm.created_at = _dt()
|
||||
node_llm.elapsed_time = 0.5
|
||||
node_llm.metadata = {"foo": "bar"}
|
||||
|
||||
node_other = MagicMock()
|
||||
node_other.id = CODE_NODE_ID
|
||||
node_other.title = "Other Node"
|
||||
node_other.node_type = NodeType.CODE
|
||||
node_other.status = "failed"
|
||||
node_other.process_data = None
|
||||
node_other.inputs = {"code": "print"}
|
||||
node_other.outputs = {"result": "ok"}
|
||||
node_other.created_at = None
|
||||
node_other.elapsed_time = 0.2
|
||||
node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node_llm, node_other]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_data = trace_instance.add_trace.call_args[1].get("opik_trace_data", trace_instance.add_trace.call_args[0][0])
|
||||
assert trace_data["name"] == TraceTaskName.MESSAGE_TRACE
|
||||
assert "message" in trace_data["tags"]
|
||||
assert "workflow" in trace_data["tags"]
|
||||
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3"
|
||||
WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac"
|
||||
CONVERSATION_ID = "88a17f2e-9436-4472-bab9-4b1601d5af3c"
|
||||
WORKFLOW_APP_LOG_ID = "41780d0d-ffba-4220-bc0c-401e4c89cdfb"
|
||||
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id=WORKFLOW_ID,
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id=WORKFLOW_RUN_ID,
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id=CONVERSATION_ID,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
trace_id=None,
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id=WORKFLOW_APP_LOG_ID,
|
||||
error="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = []
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a",
|
||||
tenant_id="tenant-1",
|
||||
workflow_run_id="46f53304-1659-464b-bee5-116585f0bec8",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1.0",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="83f86b89-caef-4de8-a0f9-f164eddae1ea",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e",
|
||||
error="",
|
||||
)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
# Define constants for better readability
|
||||
MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab"
|
||||
CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3"
|
||||
MESSAGE_TRACE_ID = "710ace2f-bca8-41be-858c-54da42742a77"
|
||||
OPIT_TRACE_ID = "f7dfd978-0d10-4549-8abf-00f2cbc49d2c"
|
||||
|
||||
message_data = MagicMock()
|
||||
message_data.id = MESSAGE_DATA_ID
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = None
|
||||
message_data.provider_response_latency = 0.5
|
||||
message_data.conversation_id = CONVERSATION_ID
|
||||
message_data.total_price = 0.01
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.answer = "hello"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.error = None
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id=MESSAGE_TRACE_ID,
|
||||
message_data=message_data,
|
||||
inputs={"query": "hi"},
|
||||
outputs={"answer": "hello"},
|
||||
message_tokens=10,
|
||||
answer_tokens=20,
|
||||
total_tokens=30,
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id=OPIT_TRACE_ID,
|
||||
metadata={"foo": "bar"},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=[],
|
||||
error=None,
|
||||
message_file_data=MagicMock(url="test.png"),
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_1"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e"
|
||||
message_data.from_account_id = "acc-1"
|
||||
message_data.from_end_user_id = "end-user-1"
|
||||
message_data.conversation_id = "7d9f96d8-3be2-4e93-9c0e-922ff98dccc6"
|
||||
message_data.status = MessageStatus.NORMAL
|
||||
message_data.model_id = "gpt-4"
|
||||
message_data.error = ""
|
||||
message_data.answer = "hello"
|
||||
message_data.total_price = 0.0
|
||||
message_data.provider_response_latency = 0.1
|
||||
|
||||
trace_info = MessageTraceInfo(
|
||||
message_id="6bff35c7-33b7-4acb-ba21-44569a0327d0",
|
||||
message_data=message_data,
|
||||
inputs={},
|
||||
outputs={},
|
||||
message_tokens=0,
|
||||
answer_tokens=0,
|
||||
total_tokens=0,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-4",
|
||||
file_list=["url1"],
|
||||
error=None,
|
||||
)
|
||||
|
||||
mock_end_user = MagicMock(spec=EndUser)
|
||||
mock_end_user.session_id = "session-id-123"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_end_user
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.message_trace(trace_info)
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[0][0]
|
||||
assert trace_data["metadata"]["user_id"] == "acc-1"
|
||||
assert trace_data["metadata"]["end_user_id"] == "session-id-123"
|
||||
|
||||
|
||||
def test_message_trace_none_data(trace_instance):
|
||||
trace_info = SimpleNamespace(message_data=None, file_list=[], message_file_data=None, metadata={})
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.add_trace.assert_not_called()
|
||||
|
||||
|
||||
def test_moderation_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="489d0dfd-065c-4106-8f9c-daded296c92d",
|
||||
message_data=message_data,
|
||||
inputs={"q": "hi"},
|
||||
action="stop",
|
||||
flagged=True,
|
||||
preset_response="blocked",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={"foo": "bar"},
|
||||
trace_id="6f16cf18-9f4b-4955-8b6b-43cfa10978fc",
|
||||
query="hi",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == TraceTaskName.MODERATION_TRACE
|
||||
assert span_data["output"]["flagged"] is True
|
||||
|
||||
|
||||
def test_moderation_trace_none(trace_instance):
|
||||
trace_info = ModerationTraceInfo(
|
||||
message_id="cd732e4e-37f1-4c7e-8c64-820308bedcbf",
|
||||
message_data=None,
|
||||
inputs={},
|
||||
action="s",
|
||||
flagged=False,
|
||||
preset_response="",
|
||||
query="",
|
||||
metadata={},
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.moderation_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_suggested_question_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="7de55bda-a91d-477e-98ab-85c53c438469",
|
||||
message_data=message_data,
|
||||
inputs="hi",
|
||||
suggested_question=["q1"],
|
||||
total_tokens=10,
|
||||
level="info",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="a6687292-68c7-42ba-ae51-285579944d7b",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == TraceTaskName.SUGGESTED_QUESTION_TRACE
|
||||
|
||||
|
||||
def test_suggested_question_trace_none(trace_instance):
|
||||
trace_info = SuggestedQuestionTraceInfo(
|
||||
message_id="23696fc5-7e7f-46ec-bce8-1adc3c7f297d",
|
||||
message_data=None,
|
||||
inputs={},
|
||||
suggested_question=[],
|
||||
total_tokens=0,
|
||||
level="i",
|
||||
metadata={},
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.suggested_question_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace(trace_instance):
|
||||
message_data = MagicMock()
|
||||
message_data.created_at = _dt()
|
||||
message_data.updated_at = _dt()
|
||||
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="3e1a819f-c391-4950-adfd-96f82e5419a1",
|
||||
message_data=message_data,
|
||||
inputs="query",
|
||||
documents=[{"id": "doc1"}],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
metadata={},
|
||||
trace_id="41361000-e9be-4d11-b5e4-ab27ce0817d6",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == TraceTaskName.DATASET_RETRIEVAL_TRACE
|
||||
|
||||
|
||||
def test_dataset_retrieval_trace_none(trace_instance):
|
||||
trace_info = DatasetRetrievalTraceInfo(
|
||||
message_id="35d6d44c-bccb-4e6e-8bd8-859257723ea8", message_data=None, inputs={}, documents=[], metadata={}
|
||||
)
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.dataset_retrieval_trace(trace_info)
|
||||
trace_instance.add_span.assert_not_called()
|
||||
|
||||
|
||||
def test_tool_trace(trace_instance):
|
||||
trace_info = ToolTraceInfo(
|
||||
message_id="99db92c4-2254-496a-b5cc-18153315ce35",
|
||||
message_data=MagicMock(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
tool_name="my_tool",
|
||||
tool_inputs={"a": 1},
|
||||
tool_outputs="result_string",
|
||||
time_cost=0.1,
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={},
|
||||
trace_id="a15a5fcb-7ffd-4458-8330-208f4cb1f796",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
error="some error",
|
||||
)
|
||||
|
||||
trace_instance.add_span = MagicMock()
|
||||
trace_instance.tool_trace(trace_info)
|
||||
|
||||
trace_instance.add_span.assert_called_once()
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["name"] == "my_tool"
|
||||
|
||||
|
||||
def test_generate_name_trace(trace_instance):
|
||||
trace_info = GenerateNameTraceInfo(
|
||||
inputs={"q": "hi"},
|
||||
outputs={"name": "new"},
|
||||
tenant_id="tenant-1",
|
||||
conversation_id="271fe28f-6b86-416b-8d6b-bbbbfa9db791",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"921f010e-6878-4831-ae6b-271bf68c56fb": 1},
|
||||
)
|
||||
|
||||
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_3"))
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
trace_instance.generate_name_trace(trace_info)
|
||||
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
trace_data = trace_instance.add_trace.call_args[0][0]
|
||||
assert trace_data["name"] == TraceTaskName.GENERATE_NAME_TRACE
|
||||
|
||||
span_data = trace_instance.add_span.call_args[0][0]
|
||||
assert span_data["trace_id"] == "trace_id_3"
|
||||
|
||||
|
||||
def test_add_trace_success(trace_instance):
|
||||
trace_data = {"id": "t1", "name": "trace"}
|
||||
trace_instance.opik_client.trace.return_value = MagicMock(id="t1")
|
||||
trace = trace_instance.add_trace(trace_data)
|
||||
trace_instance.opik_client.trace.assert_called_once()
|
||||
assert trace.id == "t1"
|
||||
|
||||
|
||||
def test_add_trace_error(trace_instance):
|
||||
trace_instance.opik_client.trace.side_effect = Exception("error")
|
||||
trace_data = {"id": "t1", "name": "trace"}
|
||||
with pytest.raises(ValueError, match="Opik Failed to create trace: error"):
|
||||
trace_instance.add_trace(trace_data)
|
||||
|
||||
|
||||
def test_add_span_success(trace_instance):
|
||||
span_data = {"id": "s1", "name": "span", "trace_id": "t1"}
|
||||
trace_instance.add_span(span_data)
|
||||
trace_instance.opik_client.span.assert_called_once()
|
||||
|
||||
|
||||
def test_add_span_error(trace_instance):
|
||||
trace_instance.opik_client.span.side_effect = Exception("error")
|
||||
span_data = {"id": "s1", "name": "span", "trace_id": "t1"}
|
||||
with pytest.raises(ValueError, match="Opik Failed to create span: error"):
|
||||
trace_instance.add_span(span_data)
|
||||
|
||||
|
||||
def test_api_check_success(trace_instance):
|
||||
trace_instance.opik_client.auth_check.return_value = True
|
||||
assert trace_instance.api_check() is True
|
||||
|
||||
|
||||
def test_api_check_error(trace_instance):
|
||||
trace_instance.opik_client.auth_check.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="Opik API check failed: fail"):
|
||||
trace_instance.api_check()
|
||||
|
||||
|
||||
def test_get_project_url_success(trace_instance):
|
||||
trace_instance.opik_client.get_project_url.return_value = "http://project.url"
|
||||
assert trace_instance.get_project_url() == "http://project.url"
|
||||
trace_instance.opik_client.get_project_url.assert_called_once_with(project_name=trace_instance.project)
|
||||
|
||||
|
||||
def test_get_project_url_error(trace_instance):
|
||||
trace_instance.opik_client.get_project_url.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="Opik get run url failed: fail"):
|
||||
trace_instance.get_project_url()
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de",
|
||||
tenant_id="66e8e918-472e-4b69-8051-12502c34fc07",
|
||||
workflow_run_id="8403965c-3344-4d22-a8fe-d8d55cee64d9",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="s",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="1",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
message_id=None,
|
||||
conversation_id="7a02cb9d-6949-4c59-a89d-f25bbc881e0e",
|
||||
start_time=_dt(),
|
||||
end_time=_dt(),
|
||||
metadata={"app_id": "77e8e918-472e-4b69-8051-12502c34fc07"},
|
||||
workflow_app_log_id="82268424-e193-476c-a6db-f473388ee5fe",
|
||||
error="",
|
||||
)
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "88e8e918-472e-4b69-8051-12502c34fc07"
|
||||
node.title = "LLM Node"
|
||||
node.node_type = NodeType.LLM
|
||||
node.status = "succeeded"
|
||||
|
||||
class BadDict(collections.UserDict):
|
||||
def get(self, key, default=None):
|
||||
if key == "usage":
|
||||
raise Exception("Usage extraction failed")
|
||||
return super().get(key, default)
|
||||
|
||||
node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
|
||||
node.created_at = _dt()
|
||||
node.elapsed_time = 0.1
|
||||
node.metadata = {}
|
||||
node.outputs = {}
|
||||
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_run.return_value = [node]
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
|
||||
trace_instance.add_trace = MagicMock()
|
||||
trace_instance.add_span = MagicMock()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
assert "Failed to extract usage" in caplog.text
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
# Verify that at least one of the spans is for the LLM Node
|
||||
span_names = [call.args[0]["name"] for call in trace_instance.add_span.call_args_list]
|
||||
assert "LLM Node" in span_names
|
||||
583
api/tests/unit_tests/core/ops/tencent_trace/test_client.py
Normal file
583
api/tests/unit_tests/core/ops/tencent_trace/test_client.py
Normal file
@@ -0,0 +1,583 @@
|
||||
"""Tests for the TencentTraceClient helpers that drive tracing and metrics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from opentelemetry.sdk.trace import Event
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from core.ops.tencent_trace import client as client_module
|
||||
from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version
|
||||
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
|
||||
|
||||
metric_reader_instances: list[DummyMetricReader] = []
|
||||
meter_provider_instances: list[DummyMeterProvider] = []
|
||||
|
||||
|
||||
class DummyHistogram:
|
||||
"""Placeholder histogram type used by the stubbed metric stack."""
|
||||
|
||||
|
||||
class AggregationTemporality:
|
||||
DELTA = "delta"
|
||||
|
||||
|
||||
class DummyMeter:
|
||||
def __init__(self) -> None:
|
||||
self.created: list[tuple[dict[str, object], MagicMock]] = []
|
||||
|
||||
def create_histogram(self, **kwargs: object) -> MagicMock:
|
||||
hist = MagicMock(name=f"hist-{kwargs.get('name')}")
|
||||
self.created.append((kwargs, hist))
|
||||
return hist
|
||||
|
||||
|
||||
class DummyMeterProvider:
|
||||
def __init__(self, resource: object, metric_readers: list[object]) -> None:
|
||||
self.resource = resource
|
||||
self.metric_readers = metric_readers
|
||||
self.meter = DummyMeter()
|
||||
self.shutdown = MagicMock(name="meter_provider_shutdown")
|
||||
meter_provider_instances.append(self)
|
||||
|
||||
def get_meter(self, name: str, version: str) -> DummyMeter:
|
||||
return self.meter
|
||||
|
||||
|
||||
class DummyMetricReader:
|
||||
def __init__(self, exporter: object, export_interval_millis: int) -> None:
|
||||
self.exporter = exporter
|
||||
self.export_interval_millis = export_interval_millis
|
||||
self.shutdown = MagicMock(name="metric_reader_shutdown")
|
||||
metric_reader_instances.append(self)
|
||||
|
||||
|
||||
class DummyGrpcMetricExporter:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyHttpMetricExporter:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyJsonMetricExporter:
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class DummyJsonMetricExporterNoTemporality:
|
||||
"""Exporter that rejects preferred_temporality to exercise fallback."""
|
||||
|
||||
def __init__(self, **kwargs: object) -> None:
|
||||
if "preferred_temporality" in kwargs:
|
||||
raise RuntimeError("unsupported preferred_temporality")
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Drop fake metric modules into sys.modules so the client imports resolve."""
|
||||
|
||||
metrics_module = types.ModuleType("opentelemetry.sdk.metrics")
|
||||
metrics_module.Histogram = DummyHistogram
|
||||
metrics_module.MeterProvider = DummyMeterProvider
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics", metrics_module)
|
||||
|
||||
metrics_export_module = types.ModuleType("opentelemetry.sdk.metrics.export")
|
||||
metrics_export_module.AggregationTemporality = AggregationTemporality
|
||||
metrics_export_module.PeriodicExportingMetricReader = DummyMetricReader
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics.export", metrics_export_module)
|
||||
|
||||
grpc_module = types.ModuleType("opentelemetry.exporter.otlp.proto.grpc.metric_exporter")
|
||||
grpc_module.OTLPMetricExporter = DummyGrpcMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.grpc.metric_exporter", grpc_module)
|
||||
|
||||
http_module = types.ModuleType("opentelemetry.exporter.otlp.proto.http.metric_exporter")
|
||||
http_module.OTLPMetricExporter = DummyHttpMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.http.metric_exporter", http_module)
|
||||
|
||||
http_json_module = types.ModuleType("opentelemetry.exporter.otlp.http.json.metric_exporter")
|
||||
http_json_module.OTLPMetricExporter = DummyJsonMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.http.json.metric_exporter", http_json_module)
|
||||
|
||||
legacy_json_module = types.ModuleType("opentelemetry.exporter.otlp.json.metric_exporter")
|
||||
legacy_json_module.OTLPMetricExporter = DummyJsonMetricExporter
|
||||
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.json.metric_exporter", legacy_json_module)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
metric_reader_instances.clear()
|
||||
meter_provider_instances.clear()
|
||||
_add_stub_modules(monkeypatch)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
|
||||
span_exporter = MagicMock(name="span_exporter")
|
||||
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
|
||||
|
||||
span_processor = MagicMock(name="span_processor")
|
||||
monkeypatch.setattr(client_module, "BatchSpanProcessor", MagicMock(return_value=span_processor))
|
||||
|
||||
tracer = MagicMock(name="tracer")
|
||||
span = MagicMock(name="span")
|
||||
tracer.start_span.return_value = span
|
||||
|
||||
tracer_provider = MagicMock(name="tracer_provider")
|
||||
tracer_provider.get_tracer.return_value = tracer
|
||||
tracer_provider.shutdown = MagicMock(name="tracer_provider_shutdown")
|
||||
monkeypatch.setattr(client_module, "TracerProvider", MagicMock(return_value=tracer_provider))
|
||||
|
||||
resource = MagicMock(name="resource")
|
||||
monkeypatch.setattr(client_module, "Resource", MagicMock(return_value=resource))
|
||||
|
||||
logger_mock = MagicMock(name="tencent_logger")
|
||||
monkeypatch.setattr(client_module, "logger", logger_mock)
|
||||
|
||||
trace_api_stub = SimpleNamespace(
|
||||
set_span_in_context=MagicMock(name="set_span_in_context", return_value="trace-context"),
|
||||
NonRecordingSpan=MagicMock(name="non_recording_span", side_effect=lambda ctx: f"non-{ctx}"),
|
||||
)
|
||||
monkeypatch.setattr(client_module, "trace_api", trace_api_stub)
|
||||
|
||||
fake_config = SimpleNamespace(
|
||||
project=SimpleNamespace(version="test"),
|
||||
COMMIT_SHA="sha",
|
||||
DEPLOY_ENV="dev",
|
||||
EDITION="cloud",
|
||||
)
|
||||
monkeypatch.setattr(client_module, "dify_config", fake_config)
|
||||
|
||||
monkeypatch.setattr(client_module.socket, "gethostname", lambda: "fake-host")
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "")
|
||||
|
||||
return {
|
||||
"span_exporter": span_exporter,
|
||||
"span_processor": span_processor,
|
||||
"tracer": tracer,
|
||||
"span": span,
|
||||
"tracer_provider": tracer_provider,
|
||||
"logger": logger_mock,
|
||||
"trace_api": trace_api_stub,
|
||||
}
|
||||
|
||||
|
||||
def _build_client() -> TencentTraceClient:
|
||||
return TencentTraceClient(
|
||||
service_name="service",
|
||||
endpoint="https://trace.example.com:4317",
|
||||
token="token",
|
||||
)
|
||||
|
||||
|
||||
def test_get_opentelemetry_sdk_version_reads_install(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(client_module, "version", lambda pkg: "2.0.0")
|
||||
assert _get_opentelemetry_sdk_version() == "2.0.0"
|
||||
|
||||
|
||||
def test_get_opentelemetry_sdk_version_falls_back(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(client_module, "version", MagicMock(side_effect=RuntimeError("boom")))
|
||||
assert _get_opentelemetry_sdk_version() == "1.27.0"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("endpoint", "expected"),
|
||||
[
|
||||
(
|
||||
"https://example.com:9090",
|
||||
("example.com:9090", False, "example.com", 9090),
|
||||
),
|
||||
(
|
||||
"http://localhost",
|
||||
("localhost:4317", True, "localhost", 4317),
|
||||
),
|
||||
(
|
||||
"example.com:bad",
|
||||
("example.com:4317", False, "example.com", 4317),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[str, bool, str, int]) -> None:
|
||||
assert TencentTraceClient._resolve_grpc_target(endpoint) == expected
|
||||
|
||||
|
||||
def test_resolve_grpc_target_handles_errors() -> None:
|
||||
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "attr_name", "args"),
|
||||
[
|
||||
("record_llm_duration", "hist_llm_duration", (0.3, {"foo": object()})),
|
||||
("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")),
|
||||
("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")),
|
||||
("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")),
|
||||
("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})),
|
||||
],
|
||||
)
|
||||
def test_record_methods_call_histograms(method: str, attr_name: str, args: tuple[object, ...]) -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name=attr_name)
|
||||
setattr(client, attr_name, hist_mock)
|
||||
|
||||
getattr(client, method)(*args)
|
||||
hist_mock.record.assert_called_once()
|
||||
|
||||
|
||||
def test_record_methods_skip_when_histogram_missing() -> None:
|
||||
client = _build_client()
|
||||
client.hist_llm_duration = None
|
||||
client.record_llm_duration(0.1)
|
||||
|
||||
client.hist_token_usage = None
|
||||
client.record_token_usage(1, "go", "chat", "model", "model", "addr", "provider")
|
||||
|
||||
client.hist_time_to_first_token = None
|
||||
client.record_time_to_first_token(0.2, "prov", "model")
|
||||
|
||||
client.hist_time_to_generate = None
|
||||
client.record_time_to_generate(0.3, "prov", "model")
|
||||
|
||||
client.hist_trace_duration = None
|
||||
client.record_trace_duration(0.5)
|
||||
|
||||
|
||||
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
|
||||
|
||||
client.record_llm_duration(0.2)
|
||||
logger = patch_core_components["logger"]
|
||||
logger.debug.assert_called()
|
||||
|
||||
|
||||
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={"key": "value"},
|
||||
events=[Event(name="evt", attributes={"k": "v"}, timestamp=123)],
|
||||
status=Status(StatusCode.OK),
|
||||
start_time=10,
|
||||
end_time=20,
|
||||
)
|
||||
|
||||
client._create_and_export_span(data)
|
||||
span.set_attributes.assert_called_once()
|
||||
span.add_event.assert_called_once()
|
||||
span.set_status.assert_called_once()
|
||||
span.end.assert_called_once_with(end_time=20)
|
||||
assert client.span_contexts[2] == "ctx"
|
||||
|
||||
|
||||
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
client.span_contexts[10] = "existing"
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "child"
|
||||
|
||||
data = SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=10,
|
||||
span_id=11,
|
||||
name="span",
|
||||
attributes={},
|
||||
events=[],
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
|
||||
client._create_and_export_span(data)
|
||||
trace_api = patch_core_components["trace_api"]
|
||||
trace_api.NonRecordingSpan.assert_called_once_with("existing")
|
||||
trace_api.set_span_in_context.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
client.tracer.start_span.side_effect = RuntimeError("boom")
|
||||
|
||||
client._create_and_export_span(
|
||||
SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={},
|
||||
events=[],
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
)
|
||||
logger = patch_core_components["logger"]
|
||||
logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_api_check_connects_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = _build_client()
|
||||
|
||||
monkeypatch.setattr(
|
||||
TencentTraceClient,
|
||||
"_resolve_grpc_target",
|
||||
MagicMock(return_value=("host:123", False, "host", 123)),
|
||||
)
|
||||
|
||||
socket_mock = MagicMock()
|
||||
socket_instance = MagicMock()
|
||||
socket_instance.connect_ex.return_value = 0
|
||||
socket_mock.return_value = socket_instance
|
||||
monkeypatch.setattr(client_module.socket, "socket", socket_mock)
|
||||
|
||||
assert client.api_check()
|
||||
socket_instance.connect_ex.assert_called_once()
|
||||
|
||||
|
||||
def test_api_check_returns_false_and_handles_local(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = _build_client()
|
||||
|
||||
monkeypatch.setattr(
|
||||
TencentTraceClient,
|
||||
"_resolve_grpc_target",
|
||||
MagicMock(return_value=("host:123", False, "host", 123)),
|
||||
)
|
||||
|
||||
socket_mock = MagicMock()
|
||||
socket_instance = MagicMock()
|
||||
socket_instance.connect_ex.return_value = 1
|
||||
socket_mock.return_value = socket_instance
|
||||
monkeypatch.setattr(client_module.socket, "socket", socket_mock)
|
||||
|
||||
assert not client.api_check()
|
||||
|
||||
monkeypatch.setattr(
|
||||
TencentTraceClient,
|
||||
"_resolve_grpc_target",
|
||||
MagicMock(return_value=("localhost:4317", True, "localhost", 4317)),
|
||||
)
|
||||
socket_instance.connect_ex.return_value = 1
|
||||
assert client.api_check()
|
||||
|
||||
|
||||
def test_api_check_handles_exceptions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
client = TencentTraceClient("svc", "https://localhost", "token")
|
||||
|
||||
monkeypatch.setattr(client_module.socket, "socket", MagicMock(side_effect=RuntimeError("boom")))
|
||||
assert client.api_check()
|
||||
|
||||
|
||||
def test_get_project_url() -> None:
|
||||
client = _build_client()
|
||||
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
|
||||
|
||||
|
||||
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span_processor = patch_core_components["span_processor"]
|
||||
tracer_provider = patch_core_components["tracer_provider"]
|
||||
|
||||
client.shutdown()
|
||||
span_processor.force_flush.assert_called_once()
|
||||
span_processor.shutdown.assert_called_once()
|
||||
tracer_provider.shutdown.assert_called_once()
|
||||
|
||||
meter_provider = meter_provider_instances[-1]
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
meter_provider.shutdown.assert_called_once()
|
||||
metric_reader.shutdown.assert_called_once()
|
||||
|
||||
|
||||
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
meter_provider = meter_provider_instances[-1]
|
||||
meter_provider.shutdown.side_effect = RuntimeError("boom")
|
||||
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
|
||||
|
||||
client.shutdown()
|
||||
logger = patch_core_components["logger"]
|
||||
logger.debug.assert_any_call(
|
||||
"[Tencent APM] Error shutting down meter provider",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.debug.assert_any_call(
|
||||
"[Tencent APM] Error shutting down metric reader",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(DummyMeterProvider, "__init__", MagicMock(side_effect=RuntimeError("err")))
|
||||
client = _build_client()
|
||||
|
||||
assert client.meter is None
|
||||
assert client.meter_provider is None
|
||||
assert client.hist_llm_duration is None
|
||||
assert client.hist_token_usage is None
|
||||
assert client.hist_time_to_first_token is None
|
||||
assert client.hist_time_to_generate is None
|
||||
assert client.hist_trace_duration is None
|
||||
assert client.metric_reader is None
|
||||
|
||||
|
||||
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
client.add_span(
|
||||
SpanData(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={},
|
||||
events=[],
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
)
|
||||
|
||||
logger = patch_core_components["logger"]
|
||||
logger.exception.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
|
||||
client = _build_client()
|
||||
span = patch_core_components["span"]
|
||||
span.get_span_context.return_value = "ctx"
|
||||
|
||||
data = SpanData.model_construct(
|
||||
trace_id=1,
|
||||
parent_span_id=None,
|
||||
span_id=2,
|
||||
name="span",
|
||||
attributes={"num": 5, "flag": True, "pi": 3.14, "text": "value"},
|
||||
events=[],
|
||||
links=[],
|
||||
status=Status(StatusCode.OK),
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
|
||||
client._create_and_export_span(data)
|
||||
(attrs,) = span.set_attributes.call_args.args
|
||||
assert attrs["num"] == 5
|
||||
assert attrs["flag"] is True
|
||||
assert attrs["pi"] == 3.14
|
||||
assert attrs["text"] == "value"
|
||||
|
||||
|
||||
def test_record_llm_duration_converts_attributes() -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name="hist_llm_duration")
|
||||
client.hist_llm_duration = hist_mock
|
||||
|
||||
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["foo"], str)
|
||||
assert attrs["bar"] == 2
|
||||
|
||||
|
||||
def test_record_trace_duration_converts_attributes() -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name="hist_trace_duration")
|
||||
client.hist_trace_duration = hist_mock
|
||||
|
||||
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
|
||||
_, attrs = hist_mock.record.call_args.args
|
||||
assert isinstance(attrs["meta"], str)
|
||||
assert attrs["ok"] is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method", "attr_name", "args"),
|
||||
[
|
||||
("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")),
|
||||
("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")),
|
||||
("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")),
|
||||
("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})),
|
||||
],
|
||||
)
|
||||
def test_record_methods_handle_exceptions(
|
||||
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
|
||||
) -> None:
|
||||
client = _build_client()
|
||||
hist_mock = MagicMock(name=attr_name)
|
||||
hist_mock.record.side_effect = RuntimeError("boom")
|
||||
setattr(client, attr_name, hist_mock)
|
||||
|
||||
getattr(client, method)(*args)
|
||||
logger = patch_core_components["logger"]
|
||||
logger.debug.assert_called()
|
||||
|
||||
|
||||
def test_metrics_initializes_grpc_metric_exporter() -> None:
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
|
||||
assert metric_reader.exporter.kwargs["insecure"] is False
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
client = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
|
||||
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
|
||||
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
|
||||
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
|
||||
assert "preferred_temporality" in metric_reader.exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
exporter_module = sys.modules["opentelemetry.exporter.otlp.http.json.metric_exporter"]
|
||||
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
|
||||
_ = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
|
||||
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
|
||||
assert "preferred_temporality" not in metric_reader.exporter.kwargs
|
||||
|
||||
|
||||
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
|
||||
|
||||
def _fail_import(mod_path: str) -> types.ModuleType:
|
||||
raise ModuleNotFoundError(mod_path)
|
||||
|
||||
monkeypatch.setattr(client_module.importlib, "import_module", _fail_import)
|
||||
|
||||
_ = _build_client()
|
||||
metric_reader = metric_reader_instances[-1]
|
||||
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
|
||||
359
api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
Normal file
359
api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
Normal file
@@ -0,0 +1,359 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from opentelemetry.trace import StatusCode
|
||||
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.semconv import (
|
||||
GEN_AI_IS_ENTRY,
|
||||
GEN_AI_IS_STREAMING_REQUEST,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
INPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
|
||||
from core.rag.models.document import Document
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestTencentSpanBuilder:
|
||||
def test_get_time_nanoseconds(self):
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
|
||||
mock_convert.return_value = 123456789
|
||||
dt = datetime.now()
|
||||
result = TencentSpanBuilder._get_time_nanoseconds(dt)
|
||||
assert result == 123456789
|
||||
mock_convert.assert_called_once_with(dt)
|
||||
|
||||
def test_build_workflow_spans(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run_id"
|
||||
trace_info.error = None
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.workflow_run_inputs = {"sys.query": "hello"}
|
||||
trace_info.workflow_run_outputs = {"answer": "world"}
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
|
||||
|
||||
assert len(spans) == 2
|
||||
assert spans[0].name == "message"
|
||||
assert spans[0].span_id == 2
|
||||
assert spans[1].name == "workflow"
|
||||
assert spans[1].span_id == 1
|
||||
assert spans[1].parent_span_id == 2
|
||||
|
||||
def test_build_workflow_spans_no_message(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run_id"
|
||||
trace_info.error = "some error"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.workflow_run_inputs = {}
|
||||
trace_info.workflow_run_outputs = {}
|
||||
trace_info.metadata = {} # No conversation_id
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 1
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
|
||||
|
||||
assert len(spans) == 1
|
||||
assert spans[0].name == "workflow"
|
||||
assert spans[0].status.status_code == StatusCode.ERROR
|
||||
assert spans[0].status.description == "some error"
|
||||
assert spans[0].attributes[GEN_AI_IS_ENTRY] == "true"
|
||||
|
||||
def test_build_workflow_llm_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.process_data = {
|
||||
"model_name": "gpt-4",
|
||||
"model_provider": "openai",
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "time_to_first_token": 0.5},
|
||||
"prompts": ["hello"],
|
||||
}
|
||||
node_execution.outputs = {"text": "world"}
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 456
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "GENERATION"
|
||||
assert span.attributes[GEN_AI_MODEL_NAME] == "gpt-4"
|
||||
assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true"
|
||||
assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "10"
|
||||
|
||||
def test_build_workflow_llm_span_usage_in_outputs(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.process_data = {}
|
||||
node_execution.outputs = {
|
||||
"text": "world",
|
||||
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
|
||||
}
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 456
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "15"
|
||||
assert GEN_AI_IS_STREAMING_REQUEST not in span.attributes
|
||||
|
||||
def test_build_message_span_standalone(self):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.error = None
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.inputs = {"q": "hi"}
|
||||
trace_info.outputs = "hello"
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
trace_info.is_streaming_request = True
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 789
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
|
||||
|
||||
assert span.name == "message"
|
||||
assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true"
|
||||
assert span.attributes[INPUT_VALUE] == str(trace_info.inputs)
|
||||
|
||||
def test_build_message_span_standalone_with_error(self):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.error = "some error"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.inputs = None
|
||||
trace_info.outputs = None
|
||||
trace_info.metadata = {}
|
||||
trace_info.is_streaming_request = False
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 789
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
|
||||
|
||||
assert span.status.status_code == StatusCode.ERROR
|
||||
assert span.status.description == "some error"
|
||||
assert span.attributes[INPUT_VALUE] == ""
|
||||
|
||||
def test_build_tool_span(self):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.tool_name = "search"
|
||||
trace_info.error = "tool error"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.tool_parameters = {"p": 1}
|
||||
trace_info.tool_inputs = {"i": 2}
|
||||
trace_info.tool_outputs = "result"
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 101
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1)
|
||||
|
||||
assert span.name == "search"
|
||||
assert span.status.status_code == StatusCode.ERROR
|
||||
assert span.attributes[TOOL_NAME] == "search"
|
||||
|
||||
def test_build_retrieval_span(self):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.inputs = "query"
|
||||
trace_info.error = None
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
|
||||
doc = Document(
|
||||
page_content="content", metadata={"dataset_id": "d1", "doc_id": "di1", "document_id": "du1", "score": 0.9}
|
||||
)
|
||||
trace_info.documents = [doc]
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 202
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
|
||||
|
||||
assert span.name == "retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "query"
|
||||
assert "content" in span.attributes[RETRIEVAL_DOCUMENT]
|
||||
|
||||
def test_build_retrieval_span_with_error(self):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg_id"
|
||||
trace_info.inputs = ""
|
||||
trace_info.error = "retrieval failed"
|
||||
trace_info.start_time = datetime.now()
|
||||
trace_info.end_time = datetime.now()
|
||||
trace_info.documents = []
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 202
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
|
||||
|
||||
assert span.status.status_code == StatusCode.ERROR
|
||||
assert span.status.description == "retrieval failed"
|
||||
|
||||
def test_get_workflow_node_status(self):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.OK
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node.error = "fail"
|
||||
status = TencentSpanBuilder._get_workflow_node_status(node)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "fail"
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.EXCEPTION
|
||||
node.error = "exc"
|
||||
status = TencentSpanBuilder._get_workflow_node_status(node)
|
||||
assert status.status_code == StatusCode.ERROR
|
||||
assert status.description == "exc"
|
||||
|
||||
node.status = WorkflowNodeExecutionStatus.RUNNING
|
||||
assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.UNSET
|
||||
|
||||
def test_build_workflow_retrieval_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my retrieval"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.inputs = {"query": "q1"}
|
||||
node_execution.outputs = {"result": [{"content": "c1"}]}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 303
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "my retrieval"
|
||||
assert span.attributes[RETRIEVAL_QUERY] == "q1"
|
||||
assert "c1" in span.attributes[RETRIEVAL_DOCUMENT]
|
||||
|
||||
def test_build_workflow_retrieval_span_empty(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my retrieval"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.inputs = {}
|
||||
node_execution.outputs = {}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 303
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.attributes[RETRIEVAL_QUERY] == ""
|
||||
assert span.attributes[RETRIEVAL_DOCUMENT] == ""
|
||||
|
||||
def test_build_workflow_tool_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my tool"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"info": "some"}}
|
||||
node_execution.inputs = {"param": "val"}
|
||||
node_execution.outputs = {"res": "ok"}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 404
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "my tool"
|
||||
assert span.attributes[TOOL_NAME] == "my tool"
|
||||
assert "some" in span.attributes[TOOL_DESCRIPTION]
|
||||
|
||||
def test_build_workflow_tool_span_no_metadata(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my tool"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.metadata = None
|
||||
node_execution.inputs = None
|
||||
node_execution.outputs = {"res": "ok"}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 404
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.attributes[TOOL_DESCRIPTION] == "{}"
|
||||
assert span.attributes[TOOL_PARAMETERS] == "{}"
|
||||
|
||||
def test_build_workflow_task_span(self):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"conversation_id": "conv_id"}
|
||||
|
||||
node_execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
node_execution.id = "node_id"
|
||||
node_execution.title = "my task"
|
||||
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
node_execution.inputs = {"in": 1}
|
||||
node_execution.outputs = {"out": 2}
|
||||
node_execution.created_at = datetime.now()
|
||||
node_execution.finished_at = datetime.now()
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
|
||||
mock_convert_id.return_value = 505
|
||||
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
|
||||
span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution)
|
||||
|
||||
assert span.name == "my task"
|
||||
assert span.attributes[GEN_AI_SPAN_KIND] == GenAISpanKind.TASK.value
|
||||
@@ -0,0 +1,647 @@
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.entities.config_entity import TencentConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
|
||||
from dify_graph.entities import WorkflowNodeExecution
|
||||
from dify_graph.enums import NodeType
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tencent_config():
|
||||
return TencentConfig(service_name="test-service", endpoint="https://test-endpoint", token="test-token")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_client():
|
||||
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_span_builder():
|
||||
with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trace_utils():
|
||||
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tencent_data_trace(tencent_config, mock_trace_client):
|
||||
return TencentDataTrace(tencent_config)
|
||||
|
||||
|
||||
class TestTencentDataTrace:
|
||||
def test_init(self, tencent_config, mock_trace_client):
|
||||
trace = TencentDataTrace(tencent_config)
|
||||
mock_trace_client.assert_called_once_with(
|
||||
service_name=tencent_config.service_name,
|
||||
endpoint=tencent_config.endpoint,
|
||||
token=tencent_config.token,
|
||||
metrics_export_interval_sec=5,
|
||||
)
|
||||
assert trace.trace_client == mock_trace_client.return_value
|
||||
|
||||
def test_trace_dispatch(self, tencent_data_trace):
|
||||
methods = [
|
||||
(
|
||||
WorkflowTraceInfo(
|
||||
workflow_id="wf",
|
||||
tenant_id="t",
|
||||
workflow_run_id="run",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
workflow_run_status="s",
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_version="v",
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="",
|
||||
metadata={},
|
||||
),
|
||||
"workflow_trace",
|
||||
),
|
||||
(
|
||||
MessageTraceInfo(
|
||||
message_id="msg",
|
||||
message_data={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
conversation_mode="chat",
|
||||
conversation_model="gpt-3.5-turbo",
|
||||
message_tokens=0,
|
||||
answer_tokens=0,
|
||||
total_tokens=0,
|
||||
metadata={},
|
||||
),
|
||||
"message_trace",
|
||||
),
|
||||
(
|
||||
ModerationTraceInfo(
|
||||
flagged=False, action="a", preset_response="p", query="q", metadata={}, message_id="m"
|
||||
),
|
||||
None,
|
||||
), # Pass
|
||||
(
|
||||
SuggestedQuestionTraceInfo(
|
||||
suggested_question=[],
|
||||
level="l",
|
||||
total_tokens=0,
|
||||
metadata={},
|
||||
message_id="m",
|
||||
message_data={},
|
||||
inputs={},
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
"suggested_question_trace",
|
||||
),
|
||||
(
|
||||
DatasetRetrievalTraceInfo(
|
||||
metadata={},
|
||||
message_id="m",
|
||||
message_data={},
|
||||
inputs={},
|
||||
documents=[],
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
"dataset_retrieval_trace",
|
||||
),
|
||||
(
|
||||
ToolTraceInfo(
|
||||
tool_name="t",
|
||||
tool_inputs={},
|
||||
tool_outputs="",
|
||||
tool_config={},
|
||||
tool_parameters={},
|
||||
time_cost=0,
|
||||
metadata={},
|
||||
message_id="m",
|
||||
inputs={},
|
||||
outputs={},
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
),
|
||||
"tool_trace",
|
||||
),
|
||||
(
|
||||
GenerateNameTraceInfo(
|
||||
tenant_id="t", metadata={}, message_id="m", inputs={}, outputs={}, start_time=None, end_time=None
|
||||
),
|
||||
None,
|
||||
), # Pass
|
||||
]
|
||||
|
||||
for trace_info, method_name in methods:
|
||||
if method_name:
|
||||
with patch.object(tencent_data_trace, method_name) as mock_method:
|
||||
tencent_data_trace.trace(trace_info)
|
||||
mock_method.assert_called_once_with(trace_info)
|
||||
else:
|
||||
tencent_data_trace.trace(trace_info)
|
||||
|
||||
def test_api_check(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.api_check.return_value = True
|
||||
assert tencent_data_trace.api_check() is True
|
||||
tencent_data_trace.trace_client.api_check.assert_called_once()
|
||||
|
||||
def test_get_project_url(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.get_project_url.return_value = "http://url"
|
||||
assert tencent_data_trace.get_project_url() == "http://url"
|
||||
tencent_data_trace.trace_client.get_project_url.assert_called_once()
|
||||
|
||||
def test_workflow_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
trace_info.trace_id = "parent-trace-id"
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
mock_trace_utils.create_link.return_value = "link"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
|
||||
with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc:
|
||||
with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur:
|
||||
mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()]
|
||||
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_workflow_spans.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_proc.assert_called_once_with(trace_info, 123)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_workflow_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
|
||||
|
||||
def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
trace_info.trace_id = "parent-trace-id"
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
mock_trace_utils.create_link.return_value = "link"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
|
||||
with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics:
|
||||
with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur:
|
||||
mock_span_builder.build_message_span.return_value = MagicMock()
|
||||
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_message_span.assert_called_once()
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
mock_metrics.assert_called_once_with(trace_info)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_message_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
|
||||
|
||||
def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
mock_trace_utils.convert_to_span_id.return_value = 456
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message")
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_span_builder.build_tool_span.assert_called_once_with(trace_info, 123, 456)
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
|
||||
def test_tool_trace_no_msg_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = None
|
||||
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_tool_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
|
||||
|
||||
def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
mock_trace_utils.convert_to_span_id.return_value = 456
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message")
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_span_builder.build_retrieval_span.assert_called_once_with(trace_info, 123, 456)
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
|
||||
def test_dataset_retrieval_trace_no_msg_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = None
|
||||
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_dataset_retrieval_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
|
||||
|
||||
def test_suggested_question_trace(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log:
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
|
||||
|
||||
def test_suggested_question_trace_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
|
||||
|
||||
def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
mock_trace_utils.convert_to_span_id.return_value = 111
|
||||
|
||||
node1 = MagicMock(spec=WorkflowNodeExecution)
|
||||
node1.id = "n1"
|
||||
node1.node_type = NodeType.LLM
|
||||
node2 = MagicMock(spec=WorkflowNodeExecution)
|
||||
node2.id = "n2"
|
||||
node2.node_type = NodeType.TOOL
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]):
|
||||
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]):
|
||||
with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_metrics.assert_called_once_with(node1)
|
||||
|
||||
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.return_value = 111
|
||||
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.id = "n1"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
|
||||
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
# The exception should be caught by the outer handler since convert_to_span_id is called first
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
|
||||
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
|
||||
def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
|
||||
nodes = [
|
||||
(NodeType.LLM, mock_span_builder.build_workflow_llm_span),
|
||||
(NodeType.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span),
|
||||
(NodeType.TOOL, mock_span_builder.build_workflow_tool_span),
|
||||
(NodeType.CODE, mock_span_builder.build_workflow_task_span),
|
||||
]
|
||||
|
||||
for node_type, builder_method in nodes:
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.node_type = node_type
|
||||
builder_method.return_value = "span"
|
||||
|
||||
result = tencent_data_trace._build_workflow_node_span(node, 123, trace_info, 456)
|
||||
|
||||
assert result == "span"
|
||||
builder_method.assert_called_once_with(123, 456, trace_info, node)
|
||||
|
||||
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.node_type = NodeType.LLM
|
||||
node.id = "n1"
|
||||
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
|
||||
assert result is None
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_get_workflow_node_executions(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"app_id": "app-1"}
|
||||
trace_info.workflow_run_id = "run-1"
|
||||
|
||||
app = MagicMock(spec=App)
|
||||
app.id = "app-1"
|
||||
app.created_by = "user-1"
|
||||
|
||||
account = MagicMock(spec=Account)
|
||||
account.id = "user-1"
|
||||
|
||||
tenant_join = MagicMock(spec=TenantAccountJoin)
|
||||
tenant_join.tenant_id = "tenant-1"
|
||||
|
||||
mock_executions = [MagicMock()]
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
|
||||
mock_db.engine = "engine"
|
||||
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
|
||||
session = mock_session_ctx.return_value.__enter__.return_value
|
||||
session.scalar.side_effect = [app, account]
|
||||
session.query.return_value.filter_by.return_value.first.return_value = tenant_join
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
|
||||
) as mock_repo:
|
||||
mock_repo.return_value.get_by_workflow_run.return_value = mock_executions
|
||||
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
|
||||
assert results == mock_executions
|
||||
account.set_tenant_id.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"app_id": "app-1"}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
|
||||
mock_db.init_app = MagicMock() # Ensure init_app is mocked
|
||||
mock_db.engine = "engine"
|
||||
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
|
||||
session = mock_session_ctx.return_value.__enter__.return_value
|
||||
session.scalar.return_value = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
|
||||
def test_get_user_id_workflow(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.tenant_id = "tenant-1"
|
||||
trace_info.metadata = {"user_id": "user-1"}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
|
||||
mock_db.init_app = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
|
||||
def test_get_user_id_only_user_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {"user_id": "user-1"}
|
||||
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "user-1"
|
||||
|
||||
def test_get_user_id_anonymous(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "anonymous"
|
||||
|
||||
def test_get_user_id_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.tenant_id = "t"
|
||||
trace_info.metadata = {"user_id": "u"}
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")):
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
|
||||
|
||||
def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = {
|
||||
"usage": {
|
||||
"latency": 2.5,
|
||||
"time_to_first_token": 0.5,
|
||||
"time_to_generate": 2.0,
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
},
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-4",
|
||||
"model_mode": "chat",
|
||||
}
|
||||
node.outputs = {}
|
||||
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_generate.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.record_token_usage.call_count == 2
|
||||
|
||||
def test_record_llm_metrics_usage_in_outputs(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = {}
|
||||
node.outputs = {"usage": {"latency": 1.0, "prompt_tokens": 5}}
|
||||
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_token_usage.assert_called_once()
|
||||
|
||||
def test_record_llm_metrics_exception(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = None
|
||||
node.outputs = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
# Should not crash
|
||||
|
||||
def test_record_message_llm_metrics(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {"ls_provider": "openai", "ls_model_name": "gpt-4"}
|
||||
trace_info.message_data = {"provider_response_latency": 1.1}
|
||||
trace_info.is_streaming_request = True
|
||||
trace_info.gen_ai_server_time_to_first_token = 0.2
|
||||
trace_info.llm_streaming_time_to_generate = 0.9
|
||||
trace_info.message_tokens = 15
|
||||
trace_info.answer_tokens = 25
|
||||
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_time_to_generate.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.record_token_usage.call_count == 2
|
||||
|
||||
def test_record_message_llm_metrics_object_data(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
msg_data = MagicMock()
|
||||
msg_data.provider_response_latency = 1.1
|
||||
msg_data.model_provider = "anthropic"
|
||||
msg_data.model_id = "claude"
|
||||
trace_info.message_data = msg_data
|
||||
trace_info.is_streaming_request = False
|
||||
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
|
||||
def test_record_message_llm_metrics_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
# Should not crash
|
||||
|
||||
def test_record_workflow_trace_duration(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
trace_info.start_time = now
|
||||
trace_info.end_time = now + timedelta(seconds=3)
|
||||
trace_info.workflow_run_status = "succeeded"
|
||||
trace_info.conversation_id = "conv-1"
|
||||
|
||||
# Mock the record_trace_duration method to capture arguments
|
||||
with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record:
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
|
||||
# Assert the method was called once
|
||||
mock_record.assert_called_once()
|
||||
|
||||
# Extract arguments passed to the method
|
||||
args, kwargs = mock_record.call_args
|
||||
|
||||
# Validate the duration argument
|
||||
assert args[0] == 3.0
|
||||
|
||||
# Validate the attributes dict in kwargs
|
||||
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
|
||||
assert attributes["conversation_mode"] == "workflow"
|
||||
assert attributes["has_conversation"] == "true"
|
||||
|
||||
def test_record_workflow_trace_duration_fallback(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.start_time = None
|
||||
trace_info.workflow_run_elapsed_time = 4.5
|
||||
trace_info.workflow_run_status = "failed"
|
||||
trace_info.conversation_id = None
|
||||
|
||||
with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record:
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
mock_record.assert_called_once()
|
||||
args, kwargs = mock_record.call_args
|
||||
assert args[0] == 4.5
|
||||
# Check attributes dict (either in kwargs or as second positional arg)
|
||||
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
|
||||
assert attributes["has_conversation"] == "false"
|
||||
|
||||
def test_record_workflow_trace_duration_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
|
||||
def test_record_message_trace_duration(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
now = datetime.now()
|
||||
trace_info.start_time = now
|
||||
trace_info.end_time = now + timedelta(seconds=2)
|
||||
trace_info.conversation_mode = "chat"
|
||||
trace_info.is_streaming_request = True
|
||||
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
tencent_data_trace.trace_client.record_trace_duration.assert_called_once_with(
|
||||
2.0, {"conversation_mode": "chat", "stream": "true"}
|
||||
)
|
||||
|
||||
def test_record_message_trace_duration_exception(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.start_time = None
|
||||
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
|
||||
def test_del(self, tencent_data_trace):
|
||||
client = tencent_data_trace.trace_client
|
||||
tencent_data_trace.__del__()
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_del_exception(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
|
||||
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.__del__()
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Unit tests for Tencent APM tracing utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from opentelemetry.trace import Link, TraceFlags
|
||||
|
||||
from core.ops.tencent_trace.utils import TencentTraceUtils
|
||||
|
||||
|
||||
def test_convert_to_trace_id_with_valid_uuid() -> None:
|
||||
uuid_str = "12345678-1234-5678-1234-567812345678"
|
||||
assert TencentTraceUtils.convert_to_trace_id(uuid_str) == uuid.UUID(uuid_str).int
|
||||
|
||||
|
||||
def test_convert_to_trace_id_uses_uuid4_when_none() -> None:
|
||||
expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
|
||||
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
|
||||
assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int
|
||||
uuid4_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_convert_to_trace_id_raises_value_error_for_invalid_uuid() -> None:
|
||||
with pytest.raises(ValueError, match=r"^Invalid UUID input:"):
|
||||
TencentTraceUtils.convert_to_trace_id("not-a-uuid")
|
||||
|
||||
|
||||
def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None:
|
||||
uuid_str = "12345678-1234-5678-1234-567812345678"
|
||||
span_type = "llm"
|
||||
|
||||
uuid_obj = uuid.UUID(uuid_str)
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
expected = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
|
||||
assert TencentTraceUtils.convert_to_span_id(uuid_str, span_type) == expected
|
||||
assert TencentTraceUtils.convert_to_span_id(uuid_str, "other") != expected
|
||||
|
||||
|
||||
def test_convert_to_span_id_uses_uuid4_when_none() -> None:
|
||||
expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
|
||||
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
|
||||
span_id = TencentTraceUtils.convert_to_span_id(None, "workflow")
|
||||
assert isinstance(span_id, int)
|
||||
uuid4_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None:
|
||||
with pytest.raises(ValueError, match=r"^Invalid UUID input:"):
|
||||
TencentTraceUtils.convert_to_span_id("bad-uuid", "span")
|
||||
|
||||
|
||||
def test_generate_span_id_skips_invalid_span_id() -> None:
|
||||
with patch(
|
||||
"core.ops.tencent_trace.utils.random.getrandbits",
|
||||
side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42],
|
||||
) as bits_mock:
|
||||
assert TencentTraceUtils.generate_span_id() == 42
|
||||
assert bits_mock.call_count == 2
|
||||
|
||||
|
||||
def test_convert_datetime_to_nanoseconds_accepts_datetime() -> None:
|
||||
start_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
expected = int(start_time.timestamp() * 1e9)
|
||||
assert TencentTraceUtils.convert_datetime_to_nanoseconds(start_time) == expected
|
||||
|
||||
|
||||
def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None:
|
||||
fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
expected = int(fixed.timestamp() * 1e9)
|
||||
|
||||
with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock:
|
||||
datetime_mock.now.return_value = fixed
|
||||
assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected
|
||||
datetime_mock.now.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trace_id_str", "expected_trace_id"),
|
||||
[
|
||||
("0" * 31 + "1", int("0" * 31 + "1", 16)),
|
||||
(str(uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc")), uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc").int),
|
||||
],
|
||||
)
|
||||
def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: int) -> None:
|
||||
link = TencentTraceUtils.create_link(trace_id_str)
|
||||
assert isinstance(link, Link)
|
||||
assert link.context.trace_id == expected_trace_id
|
||||
assert link.context.span_id == TencentTraceUtils.INVALID_SPAN_ID
|
||||
assert link.context.is_remote is False
|
||||
assert link.context.trace_flags == TraceFlags(TraceFlags.SAMPLED)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None])
|
||||
def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None:
|
||||
fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd")
|
||||
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
|
||||
link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type]
|
||||
assert link.context.trace_id == fallback_uuid.int
|
||||
uuid4_mock.assert_called_once()
|
||||
112
api/tests/unit_tests/core/ops/test_base_trace_instance.py
Normal file
112
api/tests/unit_tests/core/ops/test_base_trace_instance.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
from core.ops.entities.trace_entity import BaseTraceInfo
|
||||
from models import Account, App, TenantAccountJoin
|
||||
|
||||
|
||||
class ConcreteTraceInstance(BaseTraceInstance):
|
||||
def __init__(self, trace_config: BaseTracingConfig):
|
||||
super().__init__(trace_config)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
super().trace(trace_info)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(monkeypatch):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.__enter__.return_value = mock_session
|
||||
mock_session.__exit__.return_value = None
|
||||
|
||||
mock_session_class = MagicMock(return_value=mock_session)
|
||||
|
||||
monkeypatch.setattr("core.ops.base_trace_instance.Session", mock_session_class)
|
||||
monkeypatch.setattr("core.ops.base_trace_instance.db", MagicMock())
|
||||
return mock_session
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_app_not_found(mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="App with id some_app_id not found"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_no_creator(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = None
|
||||
mock_db_session.scalar.return_value = mock_app
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="App with id some_app_id has no creator"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_creator_not_found(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = "creator_id"
|
||||
|
||||
# First call to scalar returns app, second returns None (for account)
|
||||
mock_db_session.scalar.side_effect = [mock_app, None]
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="Creator account with id creator_id not found for app some_app_id"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_tenant_not_found(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = "creator_id"
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "creator_id"
|
||||
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
||||
|
||||
# session.query(TenantAccountJoin).filter_by(...).first() returns None
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
with pytest.raises(ValueError, match="Current tenant not found for account creator_id"):
|
||||
instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
|
||||
def test_get_service_account_with_tenant_success(mock_db_session):
|
||||
mock_app = MagicMock(spec=App)
|
||||
mock_app.id = "some_app_id"
|
||||
mock_app.created_by = "creator_id"
|
||||
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "creator_id"
|
||||
mock_account.set_tenant_id = MagicMock()
|
||||
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
||||
|
||||
mock_tenant_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_tenant_join.tenant_id = "tenant_id"
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
result = instance.get_service_account_with_tenant("some_app_id")
|
||||
|
||||
assert result == mock_account
|
||||
mock_account.set_tenant_id.assert_called_once_with("tenant_id")
|
||||
576
api/tests/unit_tests/core/ops/test_ops_trace_manager.py
Normal file
576
api/tests/unit_tests/core/ops/test_ops_trace_manager.py
Normal file
@@ -0,0 +1,576 @@
|
||||
import contextlib
|
||||
import json
|
||||
import queue
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.ops_trace_manager import (
|
||||
OpsTraceManager,
|
||||
TraceQueueManager,
|
||||
TraceTask,
|
||||
TraceTaskName,
|
||||
)
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self._data = kwargs
|
||||
|
||||
def model_dump(self):
|
||||
return dict(self._data)
|
||||
|
||||
|
||||
class DummyTraceInstance:
|
||||
instances: list["DummyTraceInstance"] = []
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
DummyTraceInstance.instances.append(self)
|
||||
|
||||
def api_check(self):
|
||||
return True
|
||||
|
||||
def get_project_key(self):
|
||||
return "fake-key"
|
||||
|
||||
def get_project_url(self):
|
||||
return "https://project.fake"
|
||||
|
||||
|
||||
FAKE_PROVIDER_ENTRY = {
|
||||
"config_class": DummyConfig,
|
||||
"secret_keys": ["secret_value"],
|
||||
"other_keys": ["other_value"],
|
||||
"trace_instance": DummyTraceInstance,
|
||||
}
|
||||
|
||||
|
||||
class FakeProviderMap:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self._data:
|
||||
return self._data[key]
|
||||
raise KeyError(f"Unsupported tracing provider: {key}")
|
||||
|
||||
|
||||
class DummyTimer:
|
||||
def __init__(self, interval, function):
|
||||
self.interval = interval
|
||||
self.function = function
|
||||
self.name = ""
|
||||
self.daemon = False
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def is_alive(self):
|
||||
return False
|
||||
|
||||
|
||||
class FakeMessageFile:
|
||||
def __init__(self):
|
||||
self.url = "path/to/file"
|
||||
self.id = "file-id"
|
||||
self.type = "document"
|
||||
self.created_by_role = "role"
|
||||
self.created_by = "user"
|
||||
|
||||
|
||||
def make_message_data(**overrides):
|
||||
created_at = datetime(2025, 2, 20, 12, 0, 0)
|
||||
base = {
|
||||
"id": "msg-id",
|
||||
"conversation_id": "conv-id",
|
||||
"created_at": created_at,
|
||||
"updated_at": created_at + timedelta(seconds=3),
|
||||
"message": "hello",
|
||||
"provider_response_latency": 1,
|
||||
"message_tokens": 5,
|
||||
"answer_tokens": 7,
|
||||
"answer": "world",
|
||||
"error": "",
|
||||
"status": "complete",
|
||||
"model_provider": "provider",
|
||||
"model_id": "model",
|
||||
"from_end_user_id": "end-user",
|
||||
"from_account_id": "account",
|
||||
"agent_based": False,
|
||||
"workflow_run_id": "workflow-run",
|
||||
"from_source": "source",
|
||||
"message_metadata": json.dumps({"usage": {"time_to_first_token": 1, "time_to_generate": 2}}),
|
||||
"agent_thoughts": [],
|
||||
"query": "sample-query",
|
||||
"inputs": "sample-input",
|
||||
}
|
||||
base.update(overrides)
|
||||
|
||||
class MessageData:
|
||||
def __init__(self, data):
|
||||
self.__dict__.update(data)
|
||||
|
||||
def to_dict(self):
|
||||
return dict(self.__dict__)
|
||||
|
||||
return MessageData(base)
|
||||
|
||||
|
||||
def make_agent_thought(tool_name, created_at):
|
||||
return SimpleNamespace(
|
||||
tools=[tool_name],
|
||||
created_at=created_at,
|
||||
tool_meta={
|
||||
tool_name: {
|
||||
"tool_config": {"foo": "bar"},
|
||||
"time_cost": 5,
|
||||
"error": "",
|
||||
"tool_parameters": {"x": 1},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def make_workflow_run():
|
||||
return SimpleNamespace(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant",
|
||||
id="run-id",
|
||||
elapsed_time=10,
|
||||
status="finished",
|
||||
inputs_dict={"sys.file": ["f1"], "query": "search"},
|
||||
outputs_dict={"out": "value"},
|
||||
version="3",
|
||||
error=None,
|
||||
total_tokens=12,
|
||||
workflow_run_id="run-id",
|
||||
created_at=datetime(2025, 2, 20, 10, 0, 0),
|
||||
finished_at=datetime(2025, 2, 20, 10, 0, 5),
|
||||
triggered_from="user",
|
||||
app_id="app-id",
|
||||
to_dict=lambda self=None: {"run": "value"},
|
||||
)
|
||||
|
||||
|
||||
def configure_db_query(session, *, message_file=None, workflow_app_log=None):
|
||||
def _side_effect(model):
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value.first.return_value = None
|
||||
if message_file and model.__name__ == "MessageFile":
|
||||
query.filter_by.return_value.first.return_value = message_file
|
||||
if workflow_app_log and model.__name__ == "WorkflowAppLog":
|
||||
query.filter_by.return_value.first.return_value = workflow_app_log
|
||||
return query
|
||||
|
||||
session.query.side_effect = _side_effect
|
||||
|
||||
|
||||
class DummySessionContext:
|
||||
scalar_values = []
|
||||
|
||||
def __init__(self, engine):
|
||||
self._values = list(self.scalar_values)
|
||||
self._index = 0
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def scalar(self, *args, **kwargs):
|
||||
if self._index >= len(self._values):
|
||||
return None
|
||||
value = self._values[self._index]
|
||||
self._index += 1
|
||||
return value
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_provider_map(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({"dummy": FAKE_PROVIDER_ENTRY})
|
||||
)
|
||||
OpsTraceManager.ops_trace_instances_cache.clear()
|
||||
OpsTraceManager.decrypted_configs_cache.clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_timer_and_current_app(monkeypatch):
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.threading.Timer", DummyTimer)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_queue", queue.Queue())
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_timer", None)
|
||||
|
||||
class FakeApp:
|
||||
def app_context(self):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
fake_current = MagicMock()
|
||||
fake_current._get_current_object.return_value = FakeApp()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.current_app", fake_current)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_sqlalchemy_session(monkeypatch):
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.Session", DummySessionContext)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def encryption_mocks(monkeypatch):
|
||||
encrypt_mock = MagicMock(side_effect=lambda tenant, value: f"enc-{value}")
|
||||
batch_decrypt_mock = MagicMock(side_effect=lambda tenant, values: [f"dec-{value}" for value in values])
|
||||
obfuscate_mock = MagicMock(side_effect=lambda value: f"ob-{value}")
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.encrypt_token", encrypt_mock)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.batch_decrypt_token", batch_decrypt_mock)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.obfuscated_token", obfuscate_mock)
|
||||
return encrypt_mock, batch_decrypt_mock, obfuscate_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(monkeypatch):
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = ["chat"]
|
||||
db_mock = MagicMock()
|
||||
db_mock.session = session
|
||||
db_mock.engine = MagicMock()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.db", db_mock)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_repo_fixture(monkeypatch):
|
||||
repo = MagicMock()
|
||||
repo.get_workflow_run_by_id_without_tenant.return_value = make_workflow_run()
|
||||
monkeypatch.setattr(TraceTask, "_get_workflow_run_repo", classmethod(lambda cls: repo))
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_task_message(monkeypatch, mock_db):
|
||||
message_data = make_message_data()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data)
|
||||
configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id"))
|
||||
return message_data
|
||||
|
||||
|
||||
def test_encrypt_tracing_config_handles_star_and_encrypt(encryption_mocks):
|
||||
encrypted = OpsTraceManager.encrypt_tracing_config(
|
||||
"tenant",
|
||||
"dummy",
|
||||
{"secret_value": "value", "other_value": "info"},
|
||||
current_trace_config={"secret_value": "keep"},
|
||||
)
|
||||
assert encrypted["secret_value"] == "enc-value"
|
||||
assert encrypted["other_value"] == "info"
|
||||
|
||||
|
||||
def test_encrypt_tracing_config_preserves_star(encryption_mocks):
|
||||
encrypted = OpsTraceManager.encrypt_tracing_config(
|
||||
"tenant",
|
||||
"dummy",
|
||||
{"secret_value": "*", "other_value": "info"},
|
||||
current_trace_config={"secret_value": "keep"},
|
||||
)
|
||||
assert encrypted["secret_value"] == "keep"
|
||||
|
||||
|
||||
def test_decrypt_tracing_config_caches(encryption_mocks):
|
||||
_, decrypt_mock, _ = encryption_mocks
|
||||
payload = {"secret_value": "enc", "other_value": "info"}
|
||||
first = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload)
|
||||
second = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload)
|
||||
assert first == second
|
||||
assert decrypt_mock.call_count == 1
|
||||
|
||||
|
||||
def test_obfuscated_decrypt_token(encryption_mocks):
|
||||
_, _, obfuscate_mock = encryption_mocks
|
||||
result = OpsTraceManager.obfuscated_decrypt_token("dummy", {"secret_value": "value", "other_value": "info"})
|
||||
assert "secret_value" in result
|
||||
assert result["secret_value"] == "ob-value"
|
||||
obfuscate_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"})
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
app = SimpleNamespace(id="app-id", tenant_id="tenant")
|
||||
mock_db.scalar.return_value = app
|
||||
|
||||
decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
assert decrypted["other_value"] == "info"
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_missing_trace_config(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"})
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
mock_db.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
|
||||
|
||||
def test_get_decrypted_tracing_config_raises_for_none_config(mock_db):
|
||||
trace_config_data = SimpleNamespace(tracing_config=None)
|
||||
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
|
||||
mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant")
|
||||
with pytest.raises(ValueError, match="Tracing config cannot be None"):
|
||||
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_handles_none_app(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({}))
|
||||
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
|
||||
|
||||
|
||||
def test_get_ops_trace_instance_success(monkeypatch, mock_db):
|
||||
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"}))
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config",
|
||||
classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}),
|
||||
)
|
||||
instance = OpsTraceManager.get_ops_trace_instance("app-id")
|
||||
assert instance is not None
|
||||
cached_instance = OpsTraceManager.get_ops_trace_instance("app-id")
|
||||
assert instance is cached_instance
|
||||
|
||||
|
||||
def test_get_app_config_through_message_id_returns_none(mock_db):
|
||||
mock_db.scalar.return_value = None
|
||||
assert OpsTraceManager.get_app_config_through_message_id("m") is None
|
||||
|
||||
|
||||
def test_get_app_config_through_message_id_prefers_override(mock_db):
|
||||
message = SimpleNamespace(conversation_id="conv")
|
||||
conversation = SimpleNamespace(app_model_config_id=None, override_model_configs={"foo": "bar"})
|
||||
app_config = SimpleNamespace(id="config-id")
|
||||
mock_db.scalar.side_effect = [message, conversation]
|
||||
result = OpsTraceManager.get_app_config_through_message_id("m")
|
||||
assert result == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_get_app_config_through_message_id_app_model_config(mock_db):
|
||||
message = SimpleNamespace(conversation_id="conv")
|
||||
conversation = SimpleNamespace(app_model_config_id="cfg", override_model_configs=None)
|
||||
mock_db.scalar.side_effect = [message, conversation, SimpleNamespace(id="cfg")]
|
||||
result = OpsTraceManager.get_app_config_through_message_id("m")
|
||||
assert result.id == "cfg"
|
||||
|
||||
|
||||
def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="Invalid tracing provider"):
|
||||
OpsTraceManager.update_app_tracing_config("app", True, "bad")
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.update_app_tracing_config("app", True, None)
|
||||
|
||||
|
||||
def test_update_app_tracing_config_success(mock_db):
|
||||
app = SimpleNamespace(id="app-id", tracing="{}")
|
||||
mock_db.query.return_value.where.return_value.first.return_value = app
|
||||
OpsTraceManager.update_app_tracing_config("app-id", True, "dummy")
|
||||
assert app.tracing is not None
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_get_app_tracing_config_errors_when_missing(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
OpsTraceManager.get_app_tracing_config("app")
|
||||
|
||||
|
||||
def test_get_app_tracing_config_returns_defaults(mock_db):
|
||||
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None)
|
||||
assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None}
|
||||
|
||||
|
||||
def test_get_app_tracing_config_returns_payload(mock_db):
|
||||
payload = {"enabled": True, "tracing_provider": "dummy"}
|
||||
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload))
|
||||
assert OpsTraceManager.get_app_tracing_config("app-id") == payload
|
||||
|
||||
|
||||
def test_check_and_project_helpers(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.provider_config_map",
|
||||
FakeProviderMap(
|
||||
{
|
||||
"dummy": {
|
||||
"config_class": DummyConfig,
|
||||
"trace_instance": type(
|
||||
"Trace",
|
||||
(),
|
||||
{
|
||||
"__init__": lambda self, cfg: None,
|
||||
"api_check": lambda self: True,
|
||||
"get_project_key": lambda self: "key",
|
||||
"get_project_url": lambda self: "url",
|
||||
},
|
||||
),
|
||||
"secret_keys": [],
|
||||
"other_keys": [],
|
||||
}
|
||||
}
|
||||
),
|
||||
)
|
||||
assert OpsTraceManager.check_trace_config_is_effective({}, "dummy")
|
||||
assert OpsTraceManager.get_trace_config_project_key({}, "dummy") == "key"
|
||||
assert OpsTraceManager.get_trace_config_project_url({}, "dummy") == "url"
|
||||
|
||||
|
||||
def test_trace_task_conversation_and_extract(monkeypatch):
|
||||
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE, message_id="msg")
|
||||
assert task.conversation_trace(foo="bar") == {"foo": "bar"}
|
||||
assert task._extract_streaming_metrics(make_message_data(message_metadata="not json")) == {}
|
||||
|
||||
|
||||
def test_trace_task_message_trace(trace_task_message, mock_db):
|
||||
task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id")
|
||||
result = task.message_trace("msg-id")
|
||||
assert result.message_id == "msg-id"
|
||||
|
||||
|
||||
def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db):
|
||||
DummySessionContext.scalar_values = ["wf-app-log", "message-ref"]
|
||||
execution = SimpleNamespace(id_="run-id")
|
||||
task = TraceTask(
|
||||
trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user"
|
||||
)
|
||||
result = task.workflow_trace(workflow_run_id="run-id", conversation_id="conv", user_id="user")
|
||||
assert result.workflow_run_id == "run-id"
|
||||
assert result.workflow_id == "wf-1"
|
||||
|
||||
|
||||
def test_trace_task_moderation_trace(trace_task_message):
|
||||
task = TraceTask(trace_type=TraceTaskName.MODERATION_TRACE, message_id="msg-id")
|
||||
moderation_result = SimpleNamespace(action="block", preset_response="no", query="q", flagged=True)
|
||||
timer = {"start": 1, "end": 2}
|
||||
result = task.moderation_trace("msg-id", timer, moderation_result=moderation_result, inputs={"src": "payload"})
|
||||
assert result.flagged is True
|
||||
assert result.message_id == "log-id"
|
||||
|
||||
|
||||
def test_trace_task_suggested_question_trace(trace_task_message):
|
||||
task = TraceTask(trace_type=TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id="msg-id")
|
||||
timer = {"start": 1, "end": 2}
|
||||
result = task.suggested_question_trace("msg-id", timer, suggested_question=["q1"])
|
||||
assert result.message_id == "log-id"
|
||||
assert "suggested_question" in result.__dict__
|
||||
|
||||
|
||||
def test_trace_task_dataset_retrieval_trace(trace_task_message):
|
||||
task = TraceTask(trace_type=TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id="msg-id")
|
||||
timer = {"start": 1, "end": 2}
|
||||
mock_doc = SimpleNamespace(model_dump=lambda: {"doc": "value"})
|
||||
result = task.dataset_retrieval_trace("msg-id", timer, documents=[mock_doc])
|
||||
assert result.documents == [{"doc": "value"}]
|
||||
|
||||
|
||||
def test_trace_task_tool_trace(monkeypatch, mock_db):
|
||||
custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))])
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message)
|
||||
configure_db_query(mock_db, message_file=FakeMessageFile())
|
||||
task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id")
|
||||
timer = {"start": 1, "end": 5}
|
||||
result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result")
|
||||
assert result.tool_name == "tool-a"
|
||||
assert result.time_cost == 5
|
||||
|
||||
|
||||
def test_trace_task_generate_name_trace():
|
||||
task = TraceTask(trace_type=TraceTaskName.GENERATE_NAME_TRACE, conversation_id="conv-id")
|
||||
timer = {"start": 1, "end": 2}
|
||||
assert task.generate_name_trace("conv-id", timer, tenant_id=None) == {}
|
||||
result = task.generate_name_trace(
|
||||
"conv-id", timer, tenant_id="tenant", generate_conversation_name="name", inputs="q"
|
||||
)
|
||||
assert result.outputs == "name"
|
||||
assert result.tenant_id == "tenant"
|
||||
|
||||
|
||||
def test_extract_streaming_metrics_invalid_json():
|
||||
task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id")
|
||||
fake_message = make_message_data(message_metadata="invalid")
|
||||
assert task._extract_streaming_metrics(fake_message) == {}
|
||||
|
||||
|
||||
def test_trace_queue_manager_add_and_collect(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
||||
)
|
||||
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
|
||||
manager.add_trace_task(task)
|
||||
tasks = manager.collect_tasks()
|
||||
assert tasks == [task]
|
||||
|
||||
|
||||
def test_trace_queue_manager_run_invokes_send(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
||||
)
|
||||
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
|
||||
called = {}
|
||||
|
||||
def fake_collect():
|
||||
return [task]
|
||||
|
||||
def fake_send(tasks):
|
||||
called["tasks"] = tasks
|
||||
|
||||
monkeypatch.setattr(TraceQueueManager, "collect_tasks", lambda self: fake_collect())
|
||||
monkeypatch.setattr(TraceQueueManager, "send_to_celery", lambda self, t: fake_send(t))
|
||||
manager.run()
|
||||
assert called["tasks"] == [task]
|
||||
|
||||
|
||||
def test_trace_queue_manager_send_to_celery(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
|
||||
)
|
||||
storage_save = MagicMock()
|
||||
process_delay = MagicMock()
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.storage.save", storage_save)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.process_trace_tasks.delay", process_delay)
|
||||
monkeypatch.setattr("core.ops.ops_trace_manager.uuid4", MagicMock(return_value=SimpleNamespace(hex="file-123")))
|
||||
|
||||
manager = TraceQueueManager(app_id="app-id", user_id="user")
|
||||
|
||||
class DummyTraceInfo:
|
||||
def model_dump(self):
|
||||
return {"trace": "info"}
|
||||
|
||||
class DummyTask:
|
||||
def __init__(self):
|
||||
self.app_id = "app-id"
|
||||
|
||||
def execute(self):
|
||||
return DummyTraceInfo()
|
||||
|
||||
task = DummyTask()
|
||||
manager.send_to_celery([task])
|
||||
storage_save.assert_called_once()
|
||||
process_delay.assert_called_once_with({"file_id": "file-123", "app_id": "app-id"})
|
||||
@@ -1,9 +1,20 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
|
||||
from core.ops.utils import (
|
||||
filter_none_values,
|
||||
generate_dotted_order,
|
||||
get_message_data,
|
||||
measure_time,
|
||||
replace_text_with_content,
|
||||
validate_integer_id,
|
||||
validate_project_name,
|
||||
validate_url,
|
||||
validate_url_with_path,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
@@ -187,3 +198,92 @@ class TestGenerateDottedOrder:
|
||||
result = generate_dotted_order(run_id, start_time, None)
|
||||
|
||||
assert "." not in result
|
||||
|
||||
def test_dotted_order_with_string_start_time(self):
|
||||
"""Test dotted_order generation with string start_time."""
|
||||
start_time = "2025-12-23T04:19:55.111000"
|
||||
run_id = "test-run-id"
|
||||
result = generate_dotted_order(run_id, start_time)
|
||||
|
||||
assert result == "20251223T041955111000Ztest-run-id"
|
||||
|
||||
|
||||
class TestFilterNoneValues:
|
||||
"""Test cases for filter_none_values function"""
|
||||
|
||||
def test_filter_none_values(self):
|
||||
data = {"a": 1, "b": None, "c": "test", "d": datetime(2025, 1, 1, 12, 0, 0)}
|
||||
result = filter_none_values(data)
|
||||
assert result == {"a": 1, "c": "test", "d": "2025-01-01T12:00:00"}
|
||||
|
||||
def test_filter_none_values_empty(self):
|
||||
assert filter_none_values({}) == {}
|
||||
|
||||
|
||||
class TestGetMessageData:
|
||||
"""Test cases for get_message_data function"""
|
||||
|
||||
@patch("core.ops.utils.db")
|
||||
@patch("core.ops.utils.Message")
|
||||
@patch("core.ops.utils.select")
|
||||
def test_get_message_data(self, mock_select, mock_message, mock_db):
|
||||
mock_scalar = mock_db.session.scalar
|
||||
mock_msg_instance = MagicMock()
|
||||
mock_scalar.return_value = mock_msg_instance
|
||||
|
||||
result = get_message_data("message-id")
|
||||
|
||||
assert result == mock_msg_instance
|
||||
mock_select.assert_called_once()
|
||||
mock_scalar.assert_called_once()
|
||||
|
||||
|
||||
class TestMeasureTime:
|
||||
"""Test cases for measure_time function"""
|
||||
|
||||
def test_measure_time(self):
|
||||
with measure_time() as timing_info:
|
||||
assert "start" in timing_info
|
||||
assert isinstance(timing_info["start"], datetime)
|
||||
assert timing_info["end"] is None
|
||||
|
||||
assert timing_info["end"] is not None
|
||||
assert isinstance(timing_info["end"], datetime)
|
||||
assert timing_info["end"] >= timing_info["start"]
|
||||
|
||||
|
||||
class TestReplaceTextWithContent:
|
||||
"""Test cases for replace_text_with_content function"""
|
||||
|
||||
def test_replace_text_with_content_dict(self):
|
||||
data = {"text": "hello", "other": "world"}
|
||||
assert replace_text_with_content(data) == {"content": "hello", "other": "world"}
|
||||
|
||||
def test_replace_text_with_content_nested(self):
|
||||
data = {"text": "v1", "nested": {"text": "v2", "list": [{"text": "v3"}]}}
|
||||
expected = {"content": "v1", "nested": {"content": "v2", "list": [{"content": "v3"}]}}
|
||||
assert replace_text_with_content(data) == expected
|
||||
|
||||
def test_replace_text_with_content_list(self):
|
||||
data = [{"text": "v1"}, "v2"]
|
||||
assert replace_text_with_content(data) == [{"content": "v1"}, "v2"]
|
||||
|
||||
def test_replace_text_with_content_primitive(self):
|
||||
assert replace_text_with_content(123) == 123
|
||||
assert replace_text_with_content("text") == "text"
|
||||
|
||||
|
||||
class TestValidateIntegerId:
|
||||
"""Test cases for validate_integer_id function"""
|
||||
|
||||
def test_valid_integer_id(self):
|
||||
assert validate_integer_id("123") == "123"
|
||||
assert validate_integer_id(" 456 ") == "456"
|
||||
|
||||
def test_invalid_integer_id_raises_error(self):
|
||||
with pytest.raises(ValueError, match="ID must be a valid integer"):
|
||||
validate_integer_id("abc")
|
||||
|
||||
def test_empty_integer_id_raises_error(self):
|
||||
with pytest.raises(ValueError, match="ID must be a valid integer"):
|
||||
validate_integer_id("")
|
||||
|
||||
1196
api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
Normal file
1196
api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user