test: added tests for backend core.ops module (#32639)

Co-authored-by: rajatagarwal-oss <rajat.agarwal@infocusp.com>
This commit is contained in:
mahammadasim
2026-03-12 13:03:15 +05:30
committed by GitHub
parent ed5511ce28
commit 3dabdc8282
18 changed files with 8485 additions and 1 deletions

View File

@@ -0,0 +1,326 @@
import time
import uuid
from datetime import datetime
from unittest.mock import MagicMock, patch
import httpx
import pytest
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import SpanKind, Status, StatusCode
from core.ops.aliyun_trace.data_exporter.traceclient import (
INVALID_SPAN_ID,
SpanBuilder,
TraceClient,
build_endpoint,
convert_datetime_to_nanoseconds,
convert_string_to_id,
convert_to_span_id,
convert_to_trace_id,
create_link,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
@pytest.fixture
def trace_client_factory():
"""Factory fixture for creating TraceClient instances with automatic cleanup."""
clients_to_shutdown = []
def _factory(**kwargs):
client = TraceClient(**kwargs)
clients_to_shutdown.append(client)
return client
yield _factory
# Cleanup: shutdown all created clients
for client in clients_to_shutdown:
client.shutdown()
class TestTraceClient:
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
mock_gethostname.return_value = "test-host"
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.endpoint == "http://test-endpoint"
assert client.max_queue_size == 1000
assert client.schedule_delay_sec == 5
assert client.done is False
assert client.worker_thread.is_alive()
client.shutdown()
assert client.done is True
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_export(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
spans = [MagicMock(spec=ReadableSpan)]
client.export(spans)
mock_exporter.export.assert_called_once_with(spans)
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 405
mock_head.return_value = mock_response
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is True
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 500
mock_head.return_value = mock_response
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is False
@patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
mock_head.side_effect = httpx.RequestError("Connection error")
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
client.api_check()
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_add_span(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(
service_name="test-service",
endpoint="http://test-endpoint",
max_export_batch_size=2,
)
# Test add None
client.add_span(None)
assert len(client.queue) == 0
# Test add valid SpanData
span_data = SpanData(
name="test-span",
trace_id=123,
span_id=456,
parent_span_id=None,
start_time=1000,
end_time=2000,
status=Status(StatusCode.OK),
span_kind=SpanKind.INTERNAL,
)
mock_span = MagicMock(spec=ReadableSpan)
client.span_builder.build_span = MagicMock(return_value=mock_span)
with patch.object(client.condition, "notify") as mock_notify:
client.add_span(span_data)
assert len(client.queue) == 1
mock_notify.assert_not_called()
client.add_span(span_data)
assert len(client.queue) == 2
mock_notify.assert_called_once()
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
span_data = SpanData(
name="test-span",
trace_id=123,
span_id=456,
parent_span_id=None,
start_time=1000,
end_time=2000,
status=Status(StatusCode.OK),
span_kind=SpanKind.INTERNAL,
)
mock_span = MagicMock(spec=ReadableSpan)
client.span_builder.build_span = MagicMock(return_value=mock_span)
client.add_span(span_data)
assert len(client.queue) == 1
client.add_span(span_data)
assert len(client.queue) == 1
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
mock_exporter.export.side_effect = Exception("Export failed")
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
mock_span = MagicMock(spec=ReadableSpan)
client.queue.append(mock_span)
with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
client._export_batch()
mock_logger.warning.assert_called()
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
# We need to test the wait timeout in _worker
# But _worker runs in a thread. Let's mock condition.wait.
client = trace_client_factory(
service_name="test-service",
endpoint="http://test-endpoint",
schedule_delay_sec=0.1,
)
with patch.object(client.condition, "wait") as mock_wait:
# Let it run for a bit then shut down
time.sleep(0.2)
client.shutdown()
# mock_wait might have been called
assert mock_wait.called or client.done
@patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
mock_span = MagicMock(spec=ReadableSpan)
client.queue.append(mock_span)
client.shutdown()
# Should have called export twice (once in worker/export_batch, once in shutdown)
# or at least once if worker was waiting
assert mock_exporter.export.called
assert mock_exporter.shutdown.called
class TestSpanBuilder:
def test_build_span(self):
resource = MagicMock()
builder = SpanBuilder(resource)
span_data = SpanData(
name="test-span",
trace_id=123,
span_id=456,
parent_span_id=789,
start_time=1000,
end_time=2000,
status=Status(StatusCode.OK),
span_kind=SpanKind.INTERNAL,
attributes={"attr1": "val1"},
events=[],
links=[],
)
span = builder.build_span(span_data)
assert isinstance(span, ReadableSpan)
assert span.name == "test-span"
assert span.context.trace_id == 123
assert span.context.span_id == 456
assert span.parent.span_id == 789
assert span.resource == resource
assert span.attributes == {"attr1": "val1"}
def test_build_span_no_parent(self):
resource = MagicMock()
builder = SpanBuilder(resource)
span_data = SpanData(
name="test-span",
trace_id=123,
span_id=456,
parent_span_id=None,
start_time=1000,
end_time=2000,
status=Status(StatusCode.OK),
span_kind=SpanKind.INTERNAL,
)
span = builder.build_span(span_data)
assert span.parent is None
def test_create_link():
trace_id_str = "0123456789abcdef0123456789abcdef"
link = create_link(trace_id_str)
assert link.context.trace_id == int(trace_id_str, 16)
assert link.context.span_id == INVALID_SPAN_ID
with pytest.raises(ValueError, match="Invalid trace ID format"):
create_link("invalid-hex")
def test_generate_span_id():
# Test normal generation
span_id = generate_span_id()
assert isinstance(span_id, int)
assert span_id != INVALID_SPAN_ID
# Test retry loop
with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
span_id = generate_span_id()
assert span_id == 999
assert mock_rand.call_count == 2
def test_convert_to_trace_id():
uid = str(uuid.uuid4())
trace_id = convert_to_trace_id(uid)
assert trace_id == uuid.UUID(uid).int
with pytest.raises(ValueError, match="UUID cannot be None"):
convert_to_trace_id(None)
with pytest.raises(ValueError, match="Invalid UUID input"):
convert_to_trace_id("not-a-uuid")
def test_convert_string_to_id():
assert convert_string_to_id("test") > 0
# Test with None string
with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
mock_gen.return_value = 12345
assert convert_string_to_id(None) == 12345
def test_convert_to_span_id():
uid = str(uuid.uuid4())
span_id = convert_to_span_id(uid, "test-type")
assert isinstance(span_id, int)
with pytest.raises(ValueError, match="UUID cannot be None"):
convert_to_span_id(None, "test")
with pytest.raises(ValueError, match="Invalid UUID input"):
convert_to_span_id("not-a-uuid", "test")
def test_convert_datetime_to_nanoseconds():
dt = datetime(2023, 1, 1, 12, 0, 0)
ns = convert_datetime_to_nanoseconds(dt)
assert ns == int(dt.timestamp() * 1e9)
assert convert_datetime_to_nanoseconds(None) is None
def test_build_endpoint():
license_key = "abc"
# CMS 2.0 endpoint
url1 = "https://log.aliyuncs.com"
assert build_endpoint(url1, license_key) == "https://log.aliyuncs.com/adapt_abc/api/v1/traces"
# XTrace endpoint
url2 = "https://example.com"
assert build_endpoint(url2, license_key) == "https://example.com/adapt_abc/api/otlp/traces"

View File

@@ -0,0 +1,88 @@
import pytest
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import ValidationError
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
class TestTraceMetadata:
def test_trace_metadata_init(self):
links = [trace_api.Link(context=trace_api.SpanContext(0, 0, False))]
metadata = TraceMetadata(
trace_id=123, workflow_span_id=456, session_id="session_1", user_id="user_1", links=links
)
assert metadata.trace_id == 123
assert metadata.workflow_span_id == 456
assert metadata.session_id == "session_1"
assert metadata.user_id == "user_1"
assert metadata.links == links
class TestSpanData:
def test_span_data_init_required_fields(self):
span_data = SpanData(trace_id=123, span_id=456, name="test_span", start_time=1000, end_time=2000)
assert span_data.trace_id == 123
assert span_data.span_id == 456
assert span_data.name == "test_span"
assert span_data.start_time == 1000
assert span_data.end_time == 2000
# Check defaults
assert span_data.parent_span_id is None
assert span_data.attributes == {}
assert span_data.events == []
assert span_data.links == []
assert span_data.status.status_code == StatusCode.UNSET
assert span_data.span_kind == SpanKind.INTERNAL
def test_span_data_with_optional_fields(self):
event = Event(name="event_1", timestamp=1500)
link = trace_api.Link(context=trace_api.SpanContext(0, 0, False))
status = Status(StatusCode.OK)
span_data = SpanData(
trace_id=123,
parent_span_id=111,
span_id=456,
name="test_span",
attributes={"key": "value"},
events=[event],
links=[link],
status=status,
start_time=1000,
end_time=2000,
span_kind=SpanKind.SERVER,
)
assert span_data.parent_span_id == 111
assert span_data.attributes == {"key": "value"}
assert span_data.events == [event]
assert span_data.links == [link]
assert span_data.status.status_code == status.status_code
assert span_data.span_kind == SpanKind.SERVER
def test_span_data_missing_required_fields(self):
with pytest.raises(ValidationError):
SpanData(
trace_id=123,
# span_id missing
name="test_span",
start_time=1000,
end_time=2000,
)
def test_span_data_arbitrary_types_allowed(self):
# opentelemetry.trace.Status and Event are "arbitrary types" for Pydantic
# This test ensures they are accepted thanks to model_config
status = Status(StatusCode.ERROR, description="error occurred")
event = Event(name="exception", timestamp=1234, attributes={"exception.type": "ValueError"})
span_data = SpanData(
trace_id=123, span_id=456, name="test_span", status=status, events=[event], start_time=1000, end_time=2000
)
assert span_data.status.status_code == status.status_code
assert span_data.status.description == status.description
assert span_data.events == [event]

View File

@@ -0,0 +1,68 @@
from core.ops.aliyun_trace.entities.semconv import (
ACS_ARMS_SERVICE_FEATURE,
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
GEN_AI_PROMPT,
GEN_AI_PROVIDER_NAME,
GEN_AI_REQUEST_MODEL,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
GEN_AI_USER_NAME,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
def test_constants():
assert ACS_ARMS_SERVICE_FEATURE == "acs.arms.service.feature"
assert GEN_AI_SESSION_ID == "gen_ai.session.id"
assert GEN_AI_USER_ID == "gen_ai.user.id"
assert GEN_AI_USER_NAME == "gen_ai.user.name"
assert GEN_AI_SPAN_KIND == "gen_ai.span.kind"
assert GEN_AI_FRAMEWORK == "gen_ai.framework"
assert INPUT_VALUE == "input.value"
assert OUTPUT_VALUE == "output.value"
assert RETRIEVAL_QUERY == "retrieval.query"
assert RETRIEVAL_DOCUMENT == "retrieval.document"
assert GEN_AI_REQUEST_MODEL == "gen_ai.request.model"
assert GEN_AI_PROVIDER_NAME == "gen_ai.provider.name"
assert GEN_AI_USAGE_INPUT_TOKENS == "gen_ai.usage.input_tokens"
assert GEN_AI_USAGE_OUTPUT_TOKENS == "gen_ai.usage.output_tokens"
assert GEN_AI_USAGE_TOTAL_TOKENS == "gen_ai.usage.total_tokens"
assert GEN_AI_PROMPT == "gen_ai.prompt"
assert GEN_AI_COMPLETION == "gen_ai.completion"
assert GEN_AI_RESPONSE_FINISH_REASON == "gen_ai.response.finish_reason"
assert GEN_AI_INPUT_MESSAGE == "gen_ai.input.messages"
assert GEN_AI_OUTPUT_MESSAGE == "gen_ai.output.messages"
assert TOOL_NAME == "tool.name"
assert TOOL_DESCRIPTION == "tool.description"
assert TOOL_PARAMETERS == "tool.parameters"
def test_gen_ai_span_kind_enum():
assert GenAISpanKind.CHAIN == "CHAIN"
assert GenAISpanKind.RETRIEVER == "RETRIEVER"
assert GenAISpanKind.RERANKER == "RERANKER"
assert GenAISpanKind.LLM == "LLM"
assert GenAISpanKind.EMBEDDING == "EMBEDDING"
assert GenAISpanKind.TOOL == "TOOL"
assert GenAISpanKind.AGENT == "AGENT"
assert GenAISpanKind.TASK == "TASK"
# Verify iteration works (covers the class definition)
kinds = list(GenAISpanKind)
assert len(kinds) == 8
assert "LLM" in kinds

View File

@@ -0,0 +1,647 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
GEN_AI_PROMPT,
GEN_AI_REQUEST_MODEL,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_USAGE_TOTAL_TOKENS,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
class RecordingTraceClient:
def __init__(self, service_name: str = "service", endpoint: str = "endpoint"):
self.service_name = service_name
self.endpoint = endpoint
self.added_spans: list[object] = []
def add_span(self, span) -> None:
self.added_spans.append(span)
def api_check(self) -> bool:
return True
def get_project_url(self) -> str:
return "project-url"
def _dt() -> datetime:
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
def _make_link(trace_id: int = 1, span_id: int = 2) -> Link:
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags.SAMPLED,
)
return Link(context)
def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo:
defaults = {
"workflow_id": "workflow-id",
"tenant_id": "tenant-id",
"workflow_run_id": "00000000-0000-0000-0000-000000000001",
"workflow_run_elapsed_time": 1.0,
"workflow_run_status": "succeeded",
"workflow_run_inputs": {"sys.query": "hello"},
"workflow_run_outputs": {"answer": "world"},
"workflow_run_version": "v1",
"total_tokens": 1,
"file_list": [],
"query": "hello",
"metadata": {"conversation_id": "conv", "user_id": "u", "app_id": "app"},
"message_id": None,
"start_time": _dt(),
"end_time": _dt(),
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
}
defaults.update(overrides)
return WorkflowTraceInfo(**defaults)
def _make_message_trace_info(**overrides) -> MessageTraceInfo:
defaults = {
"conversation_model": "chat",
"message_tokens": 1,
"answer_tokens": 2,
"total_tokens": 3,
"conversation_mode": "chat",
"metadata": {"conversation_id": "conv", "ls_model_name": "m", "ls_provider": "p"},
"message_id": "00000000-0000-0000-0000-000000000002",
"message_data": SimpleNamespace(from_account_id="acc", from_end_user_id=None),
"inputs": {"prompt": "hi"},
"outputs": "ok",
"start_time": _dt(),
"end_time": _dt(),
"error": None,
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
}
defaults.update(overrides)
return MessageTraceInfo(**defaults)
def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo:
defaults = {
"metadata": {"conversation_id": "conv", "user_id": "u"},
"message_id": "00000000-0000-0000-0000-000000000003",
"message_data": SimpleNamespace(),
"inputs": "q",
"documents": [SimpleNamespace()],
"start_time": _dt(),
"end_time": _dt(),
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
}
defaults.update(overrides)
return DatasetRetrievalTraceInfo(**defaults)
def _make_tool_trace_info(**overrides) -> ToolTraceInfo:
defaults = {
"tool_name": "tool",
"tool_inputs": {"x": 1},
"tool_outputs": "out",
"tool_config": {"desc": "d"},
"tool_parameters": {},
"time_cost": 0.1,
"metadata": {"conversation_id": "conv", "user_id": "u"},
"message_id": "00000000-0000-0000-0000-000000000004",
"message_data": SimpleNamespace(),
"inputs": {"i": "v"},
"outputs": {"o": "v"},
"start_time": _dt(),
"end_time": _dt(),
"error": None,
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
}
defaults.update(overrides)
return ToolTraceInfo(**defaults)
def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo:
defaults = {
"suggested_question": ["q1", "q2"],
"level": "info",
"total_tokens": 1,
"metadata": {"conversation_id": "conv", "user_id": "u", "ls_model_name": "m", "ls_provider": "p"},
"message_id": "00000000-0000-0000-0000-000000000005",
"inputs": {"i": 1},
"start_time": _dt(),
"end_time": _dt(),
"error": None,
"trace_id": "550e8400-e29b-41d4-a716-446655440000",
}
defaults.update(overrides)
return SuggestedQuestionTraceInfo(**defaults)
@pytest.fixture
def trace_instance(monkeypatch: pytest.MonkeyPatch) -> AliyunDataTrace:
monkeypatch.setattr(aliyun_trace_module, "build_endpoint", lambda base_url, license_key: "built-endpoint")
monkeypatch.setattr(aliyun_trace_module, "TraceClient", RecordingTraceClient)
# Mock get_service_account_with_tenant to avoid DB errors
monkeypatch.setattr(AliyunDataTrace, "get_service_account_with_tenant", lambda self, app_id: MagicMock())
config = AliyunConfig(app_name="app", license_key="k", endpoint="https://example.com")
trace = AliyunDataTrace(config)
return trace
def test_init_builds_endpoint_and_client(monkeypatch: pytest.MonkeyPatch):
build_endpoint = MagicMock(return_value="built")
trace_client_cls = MagicMock()
monkeypatch.setattr(aliyun_trace_module, "build_endpoint", build_endpoint)
monkeypatch.setattr(aliyun_trace_module, "TraceClient", trace_client_cls)
config = AliyunConfig(app_name="my-app", license_key="license", endpoint="https://example.com")
trace = AliyunDataTrace(config)
build_endpoint.assert_called_once_with("https://example.com", "license")
trace_client_cls.assert_called_once_with(service_name="my-app", endpoint="built")
assert trace.trace_config == config
def test_trace_dispatches_to_correct_methods(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
workflow_trace = MagicMock()
message_trace = MagicMock()
suggested_question_trace = MagicMock()
dataset_retrieval_trace = MagicMock()
tool_trace = MagicMock()
monkeypatch.setattr(trace_instance, "workflow_trace", workflow_trace)
monkeypatch.setattr(trace_instance, "message_trace", message_trace)
monkeypatch.setattr(trace_instance, "suggested_question_trace", suggested_question_trace)
monkeypatch.setattr(trace_instance, "dataset_retrieval_trace", dataset_retrieval_trace)
monkeypatch.setattr(trace_instance, "tool_trace", tool_trace)
trace_instance.trace(_make_workflow_trace_info())
workflow_trace.assert_called_once()
trace_instance.trace(_make_message_trace_info())
message_trace.assert_called_once()
trace_instance.trace(_make_suggested_question_trace_info())
suggested_question_trace.assert_called_once()
trace_instance.trace(_make_dataset_retrieval_trace_info())
dataset_retrieval_trace.assert_called_once()
trace_instance.trace(_make_tool_trace_info())
tool_trace.assert_called_once()
# Branches that do nothing but should be covered
trace_instance.trace(ModerationTraceInfo(flagged=False, action="allow", preset_response="", query="", metadata={}))
trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={}))
def test_api_check_delegates(trace_instance: AliyunDataTrace):
trace_instance.trace_client.api_check = MagicMock(return_value=False)
assert trace_instance.api_check() is False
def test_get_project_url_success(trace_instance: AliyunDataTrace):
assert trace_instance.get_project_url() == "project-url"
def test_get_project_url_error(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(trace_instance.trace_client, "get_project_url", MagicMock(side_effect=Exception("boom")))
logger_mock = MagicMock()
monkeypatch.setattr(aliyun_trace_module, "logger", logger_mock)
with pytest.raises(ValueError, match=r"Aliyun get project url failed: boom"):
trace_instance.get_project_url()
logger_mock.info.assert_called()
def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 111)
monkeypatch.setattr(
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"workflow": 222}.get(span_type, 0)
)
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
add_workflow_span = MagicMock()
get_workflow_node_executions = MagicMock(return_value=[MagicMock(), MagicMock()])
build_workflow_node_span = MagicMock(side_effect=["span-1", "span-2"])
monkeypatch.setattr(trace_instance, "add_workflow_span", add_workflow_span)
monkeypatch.setattr(trace_instance, "get_workflow_node_executions", get_workflow_node_executions)
monkeypatch.setattr(trace_instance, "build_workflow_node_span", build_workflow_node_span)
trace_info = _make_workflow_trace_info(
trace_id="abcd", metadata={"conversation_id": "c", "user_id": "u", "app_id": "app"}
)
trace_instance.workflow_trace(trace_info)
add_workflow_span.assert_called_once()
passed_trace_metadata = add_workflow_span.call_args.args[1]
assert passed_trace_metadata.trace_id == 111
assert passed_trace_metadata.workflow_span_id == 222
assert passed_trace_metadata.session_id == "c"
assert passed_trace_metadata.user_id == "u"
assert passed_trace_metadata.links == []
assert trace_instance.trace_client.added_spans == ["span-1", "span-2"]
def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_message_trace_info(message_data=None)
trace_instance.message_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
monkeypatch.setattr(
aliyun_trace_module,
"convert_to_span_id",
lambda _, span_type: {"message": 20, "llm": 30}.get(span_type, 0),
)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
monkeypatch.setattr(aliyun_trace_module, "get_user_id_from_message_data", lambda _: "user")
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_info = _make_message_trace_info(
metadata={"conversation_id": "conv", "ls_model_name": "model", "ls_provider": "provider"},
message_tokens=7,
answer_tokens=11,
total_tokens=18,
outputs="completion",
)
trace_instance.message_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 2
message_span, llm_span = trace_instance.trace_client.added_spans
assert message_span.name == "message"
assert message_span.trace_id == 10
assert message_span.parent_span_id is None
assert message_span.span_id == 20
assert message_span.span_kind == SpanKind.SERVER
assert message_span.status == status
assert message_span.attributes["gen_ai.span.kind"] == GenAISpanKind.CHAIN
assert llm_span.name == "llm"
assert llm_span.parent_span_id == 20
assert llm_span.span_id == 30
assert llm_span.status == status
assert llm_span.attributes[GEN_AI_REQUEST_MODEL] == "model"
assert llm_span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "18"
def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_dataset_retrieval_trace_info(message_data=None)
trace_instance.dataset_retrieval_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 1)
monkeypatch.setattr(
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 2}.get(span_type, 0)
)
monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 3)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}])
trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query"))
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
assert span.name == "dataset_retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]'
def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace):
trace_info = _make_tool_trace_info(message_data=None)
trace_instance.tool_trace(trace_info)
assert trace_instance.trace_client.added_spans == []
def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
monkeypatch.setattr(
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0)
)
monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 30)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_instance.tool_trace(
_make_tool_trace_info(
tool_name="my-tool",
tool_inputs={"a": 1},
tool_config={"description": "x"},
inputs={"i": 1},
)
)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
assert span.name == "my-tool"
assert span.status == status
assert span.attributes[TOOL_NAME] == "my-tool"
assert span.attributes[TOOL_DESCRIPTION] == '{"description": "x"}'
def test_get_workflow_node_executions_requires_app_id(trace_instance: AliyunDataTrace):
trace_info = _make_workflow_trace_info(metadata={"conversation_id": "c"})
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.get_workflow_node_executions(trace_info)
def test_get_workflow_node_executions_builds_repo_and_fetches(
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch
):
trace_info = _make_workflow_trace_info(metadata={"app_id": "app", "conversation_id": "c", "user_id": "u"})
account = object()
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", MagicMock(return_value=account))
monkeypatch.setattr(aliyun_trace_module, "sessionmaker", MagicMock())
monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_run.return_value = ["node1"]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory)
result = trace_instance.get_workflow_node_executions(trace_info)
assert result == ["node1"]
repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id)
def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm"))
node_execution.node_type = NodeType.LLM
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "llm"
def test_build_workflow_node_span_routes_knowledge_retrieval_type(
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval"))
node_execution.node_type = NodeType.KNOWLEDGE_RETRIEVAL
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "retrieval"
def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool"))
node_execution.node_type = NodeType.TOOL
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "tool"
def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task"))
node_execution.node_type = NodeType.CODE
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "task"
def test_build_workflow_node_span_handles_errors(
trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
):
node_execution = MagicMock(spec=WorkflowNodeExecution)
trace_info = _make_workflow_trace_info()
trace_metadata = MagicMock()
monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom")))
node_execution.node_type = NodeType.CODE
assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) is None
assert "Error occurred in build_workflow_node_span" in caplog.text
def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "title"
node_execution.inputs = {"a": 1}
node_execution.outputs = {"b": 2}
node_execution.created_at = _dt()
node_execution.finished_at = _dt()
span = trace_instance.build_workflow_task_span(_make_workflow_trace_info(), node_execution, trace_metadata)
assert span.trace_id == 1
assert span.span_id == 9
assert span.status.status_code == StatusCode.OK
assert span.attributes["gen_ai.span.kind"] == GenAISpanKind.TASK
def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "my-tool"
node_execution.inputs = {"a": 1}
node_execution.outputs = {"b": 2}
node_execution.created_at = _dt()
node_execution.finished_at = _dt()
node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"k": "v"}}
span = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata)
assert span.attributes[TOOL_NAME] == "my-tool"
assert span.attributes[TOOL_DESCRIPTION] == '{"k": "v"}'
assert span.attributes[TOOL_PARAMETERS] == '{"a": 1}'
assert span.status.status_code == StatusCode.OK
# Cover metadata is None and inputs is None
node_execution.metadata = None
node_execution.inputs = None
span2 = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata)
assert span2.attributes[TOOL_DESCRIPTION] == "{}"
assert span2.attributes[TOOL_PARAMETERS] == "{}"
def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
monkeypatch.setattr(
aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else []
)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "retrieval"
node_execution.inputs = {"query": "q"}
node_execution.outputs = {"result": [{"doc": "d"}]}
node_execution.created_at = _dt()
node_execution.finished_at = _dt()
span = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata)
assert span.attributes[RETRIEVAL_QUERY] == "q"
assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"formatted": true}]'
# Cover empty inputs/outputs
node_execution.inputs = None
node_execution.outputs = None
span2 = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata)
assert span2.attributes[RETRIEVAL_QUERY] == ""
assert span2.attributes[RETRIEVAL_DOCUMENT] == "[]"
def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status)
monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in")
monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out")
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node-id"
node_execution.title = "llm"
node_execution.process_data = {
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
"prompts": ["p"],
"model_name": "m",
"model_provider": "p1",
}
node_execution.outputs = {"text": "t", "finish_reason": "stop"}
node_execution.created_at = _dt()
node_execution.finished_at = _dt()
span = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata)
assert span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "3"
assert span.attributes[GEN_AI_REQUEST_MODEL] == "m"
assert span.attributes[GEN_AI_PROMPT] == '["p"]'
assert span.attributes[GEN_AI_COMPLETION] == "t"
assert span.attributes[GEN_AI_RESPONSE_FINISH_REASON] == "stop"
assert span.attributes[GEN_AI_INPUT_MESSAGE] == "in"
assert span.attributes[GEN_AI_OUTPUT_MESSAGE] == "out"
# Cover usage from outputs if not in process_data
node_execution.process_data = {"prompts": []}
node_execution.outputs = {"usage": {"total_tokens": 10}, "text": ""}
span2 = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata)
assert span2.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "10"
def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0)
)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[])
# CASE 1: With message_id
trace_info = _make_workflow_trace_info(
message_id="msg-1", workflow_run_inputs={"sys.query": "hi"}, workflow_run_outputs={"ans": "ok"}
)
trace_instance.add_workflow_span(trace_info, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 2
message_span = trace_instance.trace_client.added_spans[0]
workflow_span = trace_instance.trace_client.added_spans[1]
assert message_span.name == "message"
assert message_span.span_kind == SpanKind.SERVER
assert message_span.parent_span_id is None
assert workflow_span.name == "workflow"
assert workflow_span.span_kind == SpanKind.INTERNAL
assert workflow_span.parent_span_id == 20
trace_instance.trace_client.added_spans.clear()
# CASE 2: Without message_id
trace_info_no_msg = _make_workflow_trace_info(message_id=None)
trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
assert span.name == "workflow"
assert span.span_kind == SpanKind.SERVER
assert span.parent_span_id is None
def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10)
monkeypatch.setattr(
aliyun_trace_module,
"convert_to_span_id",
lambda _, span_type: {"message": 20, "suggested_question": 21}.get(span_type, 0),
)
monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123)
monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: [])
status = Status(StatusCode.OK)
monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status)
trace_info = _make_suggested_question_trace_info(suggested_question=["how?"])
trace_instance.suggested_question_trace(trace_info)
assert len(trace_instance.trace_client.added_spans) == 1
span = trace_instance.trace_client.added_spans[0]
assert span.name == "suggested_question"
assert span.attributes[GEN_AI_COMPLETION] == '["how?"]'

View File

@@ -0,0 +1,275 @@
import json
from unittest.mock import MagicMock
from opentelemetry.trace import Link, StatusCode
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
)
from core.ops.aliyun_trace.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
extract_retrieval_documents,
format_input_messages,
format_output_messages,
format_retrieval_documents,
get_user_id_from_message_data,
get_workflow_node_status,
serialize_json_data,
)
from core.rag.models.document import Document
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import WorkflowNodeExecutionStatus
from models import EndUser
def test_get_user_id_from_message_data_no_end_user(monkeypatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = None
assert get_user_id_from_message_data(message_data) == "account_id"
def test_get_user_id_from_message_data_with_end_user(monkeypatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = "end_user_id"
end_user_data = MagicMock(spec=EndUser)
end_user_data.session_id = "session_id"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = end_user_data
mock_session = MagicMock()
mock_session.query.return_value = mock_query
from core.ops.aliyun_trace.utils import db
monkeypatch.setattr(db, "session", mock_session)
assert get_user_id_from_message_data(message_data) == "session_id"
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = "end_user_id"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = None
mock_session = MagicMock()
mock_session.query.return_value = mock_query
from core.ops.aliyun_trace.utils import db
monkeypatch.setattr(db, "session", mock_session)
assert get_user_id_from_message_data(message_data) == "account_id"
def test_create_status_from_error():
# Case OK
status_ok = create_status_from_error(None)
assert status_ok.status_code == StatusCode.OK
# Case Error
status_err = create_status_from_error("some error")
assert status_err.status_code == StatusCode.ERROR
assert status_err.description == "some error"
def test_get_workflow_node_status():
node_execution = MagicMock(spec=WorkflowNodeExecution)
# SUCCEEDED
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
status = get_workflow_node_status(node_execution)
assert status.status_code == StatusCode.OK
# FAILED
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = "node fail"
status = get_workflow_node_status(node_execution)
assert status.status_code == StatusCode.ERROR
assert status.description == "node fail"
# EXCEPTION
node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
node_execution.error = "node exception"
status = get_workflow_node_status(node_execution)
assert status.status_code == StatusCode.ERROR
assert status.description == "node exception"
# UNSET/OTHER
node_execution.status = WorkflowNodeExecutionStatus.RUNNING
status = get_workflow_node_status(node_execution)
assert status.status_code == StatusCode.UNSET
def test_create_links_from_trace_id(monkeypatch):
# Mock create_link
mock_link = MagicMock(spec=Link)
import core.ops.aliyun_trace.data_exporter.traceclient
monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
# Trace ID None
assert create_links_from_trace_id(None) == []
# Trace ID Present
links = create_links_from_trace_id("trace_id")
assert len(links) == 1
assert links[0] == mock_link
def test_extract_retrieval_documents():
doc1 = MagicMock(spec=Document)
doc1.page_content = "content1"
doc1.metadata = {"dataset_id": "ds1", "doc_id": "di1", "document_id": "dd1", "score": 0.9}
doc2 = MagicMock(spec=Document)
doc2.page_content = "content2"
doc2.metadata = {"dataset_id": "ds2"} # Missing some keys
documents = [doc1, doc2]
extracted = extract_retrieval_documents(documents)
assert len(extracted) == 2
assert extracted[0]["content"] == "content1"
assert extracted[0]["metadata"]["dataset_id"] == "ds1"
assert extracted[0]["score"] == 0.9
assert extracted[1]["content"] == "content2"
assert extracted[1]["metadata"]["dataset_id"] == "ds2"
assert extracted[1]["metadata"]["doc_id"] is None
assert extracted[1]["score"] is None
def test_serialize_json_data():
data = {"a": 1}
# Test ensure_ascii default (False)
assert serialize_json_data(data) == json.dumps(data, ensure_ascii=False)
# Test ensure_ascii True
assert serialize_json_data(data, ensure_ascii=True) == json.dumps(data, ensure_ascii=True)
def test_create_common_span_attributes():
attrs = create_common_span_attributes(
session_id="s1", user_id="u1", span_kind="kind1", framework="fw1", inputs="in1", outputs="out1"
)
assert attrs[GEN_AI_SESSION_ID] == "s1"
assert attrs[GEN_AI_USER_ID] == "u1"
assert attrs[GEN_AI_SPAN_KIND] == "kind1"
assert attrs[GEN_AI_FRAMEWORK] == "fw1"
assert attrs[INPUT_VALUE] == "in1"
assert attrs[OUTPUT_VALUE] == "out1"
def test_format_retrieval_documents():
# Not a list
assert format_retrieval_documents("not a list") == []
# Valid list
docs = [
{"metadata": {"score": 0.8, "document_id": "doc1", "source": "src1"}, "content": "c1", "title": "t1"},
{
"metadata": {"_source": "src2", "doc_metadata": {"extra": "val"}},
"content": "c2",
# Missing title
},
"not a dict", # Should be skipped
]
formatted = format_retrieval_documents(docs)
assert len(formatted) == 2
assert formatted[0]["document"]["content"] == "c1"
assert formatted[0]["document"]["metadata"]["title"] == "t1"
assert formatted[0]["document"]["metadata"]["source"] == "src1"
assert formatted[0]["document"]["score"] == 0.8
assert formatted[0]["document"]["id"] == "doc1"
assert formatted[1]["document"]["content"] == "c2"
assert formatted[1]["document"]["metadata"]["source"] == "src2"
assert formatted[1]["document"]["metadata"]["extra"] == "val"
assert "title" not in formatted[1]["document"]["metadata"]
assert formatted[1]["document"]["score"] == 0.0 # Default
# Exception handling
# We can trigger an exception by passing something that causes an error in the loop logic,
# but the try/except covers the whole function.
# Passing a list that contains something that throws when calling .get() - though dicts won't.
# Let's mock a dict that raises on get.
class BadDict:
def get(self, *args, **kwargs):
raise Exception("boom")
assert format_retrieval_documents([BadDict()]) == []
def test_format_input_messages():
# Not a dict
assert format_input_messages(None) == serialize_json_data([])
# No prompts
assert format_input_messages({}) == serialize_json_data([])
# Valid prompts
process_data = {
"prompts": [
{"role": "user", "text": "hello"},
{"role": "assistant", "text": "hi"},
{"role": "system", "text": "be helpful"},
{"role": "tool", "text": "result"},
{"role": "invalid", "text": "skip me"},
"not a dict",
{"role": "user", "text": ""}, # Empty text, should be skipped? Code says `if text: message = ...`
]
}
result = format_input_messages(process_data)
result_list = json.loads(result)
assert len(result_list) == 4
assert result_list[0]["role"] == "user"
assert result_list[0]["parts"][0]["content"] == "hello"
assert result_list[1]["role"] == "assistant"
assert result_list[2]["role"] == "system"
assert result_list[3]["role"] == "tool"
# Exception path
assert format_input_messages({"prompts": [None]}) == serialize_json_data([])
def test_format_output_messages():
# Not a dict
assert format_output_messages(None) == serialize_json_data([])
# No text
assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([])
# Valid
outputs = {"text": "done", "finish_reason": "length"}
result = format_output_messages(outputs)
result_list = json.loads(result)
assert len(result_list) == 1
assert result_list[0]["role"] == "assistant"
assert result_list[0]["parts"][0]["content"] == "done"
assert result_list[0]["finish_reason"] == "length"
# Invalid finish reason
outputs2 = {"text": "done", "finish_reason": "unknown"}
result2 = format_output_messages(outputs2)
result_list2 = json.loads(result2)
assert result_list2[0]["finish_reason"] == "stop"
# Exception path
# Trigger exception in serialize_json_data by passing non-serializable
assert format_output_messages({"text": MagicMock()}) == serialize_json_data([])

View File

@@ -0,0 +1,398 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from opentelemetry.sdk.trace import Tracer
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import StatusCode
from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
ArizePhoenixDataTrace,
datetime_to_nanos,
error_to_string,
safe_json_dumps,
set_span_status,
setup_tracer,
wrap_span_metadata,
)
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
# --- Helpers ---
def _dt():
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
def _make_workflow_info(**kwargs):
defaults = {
"workflow_id": "w1",
"tenant_id": "t1",
"workflow_run_id": "r1",
"workflow_run_elapsed_time": 1.0,
"workflow_run_status": "succeeded",
"workflow_run_inputs": {"in": "val"},
"workflow_run_outputs": {"out": "val"},
"workflow_run_version": "1.0",
"total_tokens": 10,
"file_list": ["f1"],
"query": "hi",
"metadata": {"app_id": "app1"},
"start_time": _dt(),
"end_time": _dt() + timedelta(seconds=1),
}
defaults.update(kwargs)
return WorkflowTraceInfo(**defaults)
def _make_message_info(**kwargs):
defaults = {
"conversation_model": "chat",
"message_tokens": 5,
"answer_tokens": 5,
"total_tokens": 10,
"conversation_mode": "chat",
"metadata": {"app_id": "app1"},
"inputs": {"in": "val"},
"outputs": "val",
"start_time": _dt(),
"end_time": _dt(),
"message_id": "m1",
}
defaults.update(kwargs)
return MessageTraceInfo(**defaults)
# --- Utility Function Tests ---
def test_datetime_to_nanos():
dt = _dt()
expected = int(dt.timestamp() * 1_000_000_000)
assert datetime_to_nanos(dt) == expected
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
mock_now = MagicMock()
mock_now.timestamp.return_value = 1704110400.0
mock_dt.now.return_value = mock_now
assert datetime_to_nanos(None) == 1704110400000000000
def test_error_to_string():
try:
raise ValueError("boom")
except ValueError as e:
err = e
res = error_to_string(err)
assert "ValueError: boom" in res
assert "traceback" in res.lower() or "line" in res.lower()
assert error_to_string("str error") == "str error"
assert error_to_string(None) == "Empty Stack Trace"
def test_set_span_status():
span = MagicMock()
# OK
set_span_status(span, None)
span.set_status.assert_called()
assert span.set_status.call_args[0][0].status_code == StatusCode.OK
# Error Exception
span.reset_mock()
set_span_status(span, ValueError("fail"))
assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR
span.record_exception.assert_called()
# Error String
span.reset_mock()
set_span_status(span, "fail-str")
assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR
span.add_event.assert_called()
# repr branch
class SilentError:
def __str__(self):
return ""
def __repr__(self):
return "SilentErrorRepr"
span.reset_mock()
set_span_status(span, SilentError())
assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr"
def test_safe_json_dumps():
assert safe_json_dumps({"a": _dt()}) == '{"a": "2024-01-01 00:00:00+00:00"}'
def test_wrap_span_metadata():
res = wrap_span_metadata({"a": 1}, b=2)
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_arize(mock_provider, mock_exporter):
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
setup_tracer(config)
mock_exporter.assert_called_once()
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
config = PhoenixConfig(endpoint="http://p.com", project="p")
setup_tracer(config)
mock_exporter.assert_called_once()
assert mock_exporter.call_args[1]["endpoint"] == "http://p.com/v1/traces"
def test_setup_tracer_exception():
config = ArizeConfig(endpoint="http://a.com", project="p")
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
with pytest.raises(Exception, match="boom"):
setup_tracer(config)
# --- ArizePhoenixDataTrace Class Tests ---
@pytest.fixture
def trace_instance():
with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
mock_tracer = MagicMock(spec=Tracer)
mock_processor = MagicMock()
mock_setup.return_value = (mock_tracer, mock_processor)
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
return ArizePhoenixDataTrace(config)
def test_trace_dispatch(trace_instance):
with (
patch.object(trace_instance, "workflow_trace") as m1,
patch.object(trace_instance, "message_trace") as m2,
patch.object(trace_instance, "moderation_trace") as m3,
patch.object(trace_instance, "suggested_question_trace") as m4,
patch.object(trace_instance, "dataset_retrieval_trace") as m5,
patch.object(trace_instance, "tool_trace") as m6,
patch.object(trace_instance, "generate_name_trace") as m7,
):
trace_instance.trace(_make_workflow_info())
m1.assert_called()
trace_instance.trace(_make_message_info())
m2.assert_called()
trace_instance.trace(ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={}))
m3.assert_called()
trace_instance.trace(SuggestedQuestionTraceInfo(suggested_question=[], total_tokens=0, level="i", metadata={}))
m4.assert_called()
trace_instance.trace(DatasetRetrievalTraceInfo(metadata={}))
m5.assert_called()
trace_instance.trace(
ToolTraceInfo(
tool_name="t",
tool_inputs={},
tool_outputs="o",
metadata={},
tool_config={},
time_cost=1,
tool_parameters={},
)
)
m6.assert_called()
trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={}))
m7.assert_called()
def test_trace_exception(trace_instance):
with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("fail")):
with pytest.raises(RuntimeError):
trace_instance.trace(_make_workflow_info())
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
repo = MagicMock()
mock_repo_factory.create_workflow_node_execution_repository.return_value = repo
node1 = MagicMock()
node1.node_type = "llm"
node1.status = "succeeded"
node1.inputs = {"q": "hi"}
node1.outputs = {"a": "bye", "usage": {"total_tokens": 5}}
node1.created_at = _dt()
node1.elapsed_time = 1.0
node1.process_data = {
"prompts": [{"role": "user", "content": "hi"}],
"model_provider": "openai",
"model_name": "gpt-4",
}
node1.metadata = {"k": "v"}
node1.title = "title"
node1.id = "n1"
node1.error = None
repo.get_by_workflow_run.return_value = [node1]
with patch.object(trace_instance, "get_service_account_with_tenant"):
trace_instance.workflow_trace(info)
assert trace_instance.tracer.start_span.call_count >= 2
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
def test_workflow_trace_no_app_id(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
info.metadata = {}
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(info)
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
def test_message_trace_success(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()
info.message_data = MagicMock()
info.message_data.from_account_id = "acc1"
info.message_data.from_end_user_id = None
info.message_data.query = "q"
info.message_data.answer = "a"
info.message_data.status = "s"
info.message_data.model_id = "m"
info.message_data.model_provider = "p"
info.message_data.message_metadata = "{}"
info.message_data.error = None
info.error = None
trace_instance.message_trace(info)
assert trace_instance.tracer.start_span.call_count >= 1
@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
def test_message_trace_with_error(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()
info.message_data = MagicMock()
info.message_data.from_account_id = "acc1"
info.message_data.from_end_user_id = None
info.message_data.query = "q"
info.message_data.answer = "a"
info.message_data.status = "s"
info.message_data.model_id = "m"
info.message_data.model_provider = "p"
info.message_data.message_metadata = "{}"
info.message_data.error = "processing failed"
info.error = "message error"
trace_instance.message_trace(info)
assert trace_instance.tracer.start_span.call_count >= 1
def test_trace_methods_return_early_with_no_message_data(trace_instance):
info = MagicMock()
info.message_data = None
trace_instance.moderation_trace(info)
trace_instance.suggested_question_trace(info)
trace_instance.dataset_retrieval_trace(info)
trace_instance.tool_trace(info)
trace_instance.generate_name_trace(info)
assert trace_instance.tracer.start_span.call_count == 0
def test_moderation_trace_ok(trace_instance):
info = ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={})
info.message_data = MagicMock()
info.message_data.error = None
trace_instance.moderation_trace(info)
# root span (1) + moderation span (1) = 2
assert trace_instance.tracer.start_span.call_count >= 1
def test_suggested_question_trace_ok(trace_instance):
info = SuggestedQuestionTraceInfo(suggested_question=["?"], total_tokens=1, level="i", metadata={})
info.message_data = MagicMock()
info.error = None
trace_instance.suggested_question_trace(info)
assert trace_instance.tracer.start_span.call_count >= 1
def test_dataset_retrieval_trace_ok(trace_instance):
info = DatasetRetrievalTraceInfo(documents=[], metadata={})
info.message_data = MagicMock()
info.error = None
trace_instance.dataset_retrieval_trace(info)
assert trace_instance.tracer.start_span.call_count >= 1
def test_tool_trace_ok(trace_instance):
info = ToolTraceInfo(
tool_name="t", tool_inputs={}, tool_outputs="o", metadata={}, tool_config={}, time_cost=1, tool_parameters={}
)
info.message_data = MagicMock()
info.error = None
trace_instance.tool_trace(info)
assert trace_instance.tracer.start_span.call_count >= 1
def test_generate_name_trace_ok(trace_instance):
info = GenerateNameTraceInfo(tenant_id="t", metadata={})
info.message_data = MagicMock()
info.message_data.error = None
trace_instance.generate_name_trace(info)
assert trace_instance.tracer.start_span.call_count >= 1
def test_get_project_url_phoenix(trace_instance):
trace_instance.arize_phoenix_config = PhoenixConfig(endpoint="http://p.com", project="p")
assert "p.com/projects/" in trace_instance.get_project_url()
def test_set_attribute_none_logic(trace_instance):
# Test role can be None
attrs = trace_instance._construct_llm_attributes([{"role": None, "content": "hi"}])
assert "llm.input_messages.0.message.role" not in attrs
# Test tool call id can be None
tool_call_none_id = {"id": None, "function": {"name": "f1"}}
attrs = trace_instance._construct_llm_attributes([{"role": "assistant", "tool_calls": [tool_call_none_id]}])
assert "llm.input_messages.0.message.tool_calls.0.tool_call.id" not in attrs
def test_construct_llm_attributes_dict_branch(trace_instance):
attrs = trace_instance._construct_llm_attributes({"prompt": "hi"})
assert '"prompt": "hi"' in attrs["llm.input_messages.0.message.content"]
assert attrs["llm.input_messages.0.message.role"] == "user"
def test_api_check_success(trace_instance):
assert trace_instance.api_check() is True
def test_ensure_root_span_basic(trace_instance):
trace_instance.ensure_root_span("tid")
assert "tid" in trace_instance.dify_trace_ids

View File

@@ -0,0 +1,698 @@
import collections
import logging
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
LangfuseGeneration,
LangfuseSpan,
LangfuseTrace,
LevelEnum,
UnitEnum,
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from dify_graph.enums import NodeType
from models import EndUser
from models.enums import MessageStatus
def _dt() -> datetime:
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
@pytest.fixture
def langfuse_config():
return LangfuseConfig(public_key="pk-123", secret_key="sk-123", host="https://cloud.langfuse.com")
@pytest.fixture
def trace_instance(langfuse_config, monkeypatch):
# Mock Langfuse client to avoid network calls
mock_client = MagicMock()
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
instance = LangFuseDataTrace(langfuse_config)
return instance
def test_init(langfuse_config, monkeypatch):
mock_langfuse = MagicMock()
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangFuseDataTrace(langfuse_config)
mock_langfuse.assert_called_once_with(
public_key=langfuse_config.public_key,
secret_key=langfuse_config.secret_key,
host=langfuse_config.host,
)
assert instance.file_base_url == "http://test.url"
def test_trace_dispatch(trace_instance, monkeypatch):
methods = [
"workflow_trace",
"message_trace",
"moderation_trace",
"suggested_question_trace",
"dataset_retrieval_trace",
"tool_trace",
"generate_name_trace",
]
mocks = {method: MagicMock() for method in methods}
for method, m in mocks.items():
monkeypatch.setattr(trace_instance, method, m)
# WorkflowTraceInfo
info = MagicMock(spec=WorkflowTraceInfo)
trace_instance.trace(info)
mocks["workflow_trace"].assert_called_once_with(info)
# MessageTraceInfo
info = MagicMock(spec=MessageTraceInfo)
trace_instance.trace(info)
mocks["message_trace"].assert_called_once_with(info)
# ModerationTraceInfo
info = MagicMock(spec=ModerationTraceInfo)
trace_instance.trace(info)
mocks["moderation_trace"].assert_called_once_with(info)
# SuggestedQuestionTraceInfo
info = MagicMock(spec=SuggestedQuestionTraceInfo)
trace_instance.trace(info)
mocks["suggested_question_trace"].assert_called_once_with(info)
# DatasetRetrievalTraceInfo
info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_instance.trace(info)
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
# ToolTraceInfo
info = MagicMock(spec=ToolTraceInfo)
trace_instance.trace(info)
mocks["tool_trace"].assert_called_once_with(info)
# GenerateNameTraceInfo
info = MagicMock(spec=GenerateNameTraceInfo)
trace_instance.trace(info)
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
# Setup trace info
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="tenant-1",
workflow_run_id="run-1",
workflow_run_elapsed_time=1.0,
workflow_run_status="succeeded",
workflow_run_inputs={"input": "hi"},
workflow_run_outputs={"output": "hello"},
workflow_run_version="1.0",
message_id="msg-1",
conversation_id="conv-1",
total_tokens=100,
file_list=[],
query="hi",
start_time=_dt(),
end_time=_dt() + timedelta(seconds=1),
trace_id="trace-1",
metadata={"app_id": "app-1", "user_id": "user-1"},
workflow_app_log_id="log-1",
error="",
)
# Mock DB and Repositories
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
node_llm.id = "node-llm"
node_llm.title = "LLM Node"
node_llm.node_type = NodeType.LLM
node_llm.status = "succeeded"
node_llm.process_data = {
"model_mode": "chat",
"model_name": "gpt-4",
"model_provider": "openai",
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
node_llm.inputs = {"prompts": "p"}
node_llm.outputs = {"text": "t"}
node_llm.created_at = _dt()
node_llm.elapsed_time = 0.5
node_llm.metadata = {"foo": "bar"}
node_other = MagicMock()
node_other.id = "node-other"
node_other.title = "Other Node"
node_other.node_type = NodeType.CODE
node_other.status = "failed"
node_other.process_data = None
node_other.inputs = {"code": "print"}
node_other.outputs = {"result": "ok"}
node_other.created_at = None # Trigger datetime.now() branch
node_other.elapsed_time = 0.2
node_other.metadata = None
repo = MagicMock()
repo.get_by_workflow_run.return_value = [node_llm, node_other]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
# Track calls to add_trace, add_span, add_generation
trace_instance.add_trace = MagicMock()
trace_instance.add_span = MagicMock()
trace_instance.add_generation = MagicMock()
trace_instance.workflow_trace(trace_info)
# Verify add_trace (Workflow Level)
trace_instance.add_trace.assert_called_once()
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
assert trace_data.id == "trace-1"
assert trace_data.name == TraceTaskName.MESSAGE_TRACE
assert "message" in trace_data.tags
assert "workflow" in trace_data.tags
# Verify add_span (Workflow Run Span)
assert trace_instance.add_span.call_count >= 1
# First span should be workflow run span because message_id is present
workflow_span = trace_instance.add_span.call_args_list[0][1]["langfuse_span_data"]
assert workflow_span.id == "run-1"
assert workflow_span.name == TraceTaskName.WORKFLOW_TRACE
# Verify Generation for LLM node
trace_instance.add_generation.assert_called_once()
gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"]
assert gen_data.id == "node-llm"
assert gen_data.usage.input == 10
assert gen_data.usage.output == 20
# Verify normal span for Other node
# Second add_span call
other_span = trace_instance.add_span.call_args_list[1][1]["langfuse_span_data"]
assert other_span.id == "node-other"
assert other_span.level == LevelEnum.ERROR
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="tenant-1",
workflow_run_id="run-1",
workflow_run_elapsed_time=1.0,
workflow_run_status="succeeded",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_version="1.0",
total_tokens=0,
file_list=[],
query="",
message_id=None,
conversation_id="conv-1",
start_time=_dt(),
end_time=_dt(),
trace_id=None, # Should fallback to workflow_run_id
metadata={"app_id": "app-1"},
workflow_app_log_id="log-1",
error="",
)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_run.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
trace_instance.workflow_trace(trace_info)
trace_instance.add_trace.assert_called_once()
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
assert trace_data.id == "run-1"
assert trace_data.name == TraceTaskName.WORKFLOW_TRACE
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="tenant-1",
workflow_run_id="run-1",
workflow_run_elapsed_time=1.0,
workflow_run_status="succeeded",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_version="1.0",
total_tokens=0,
file_list=[],
query="",
message_id=None,
conversation_id="conv-1",
start_time=_dt(),
end_time=_dt(),
metadata={}, # Missing app_id
workflow_app_log_id="log-1",
error="",
)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
def test_message_trace_basic(trace_instance, monkeypatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
message_data.from_end_user_id = None
message_data.provider_response_latency = 0.5
message_data.conversation_id = "conv-1"
message_data.total_price = 0.01
message_data.model_id = "gpt-4"
message_data.answer = "hello"
message_data.status = MessageStatus.NORMAL
message_data.error = None
trace_info = MessageTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs={"query": "hi"},
outputs={"answer": "hello"},
message_tokens=10,
answer_tokens=20,
total_tokens=30,
start_time=_dt(),
end_time=_dt() + timedelta(seconds=1),
trace_id="trace-1",
metadata={"foo": "bar"},
conversation_mode="chat",
conversation_model="gpt-4",
file_list=[],
error=None,
)
trace_instance.add_trace = MagicMock()
trace_instance.add_generation = MagicMock()
trace_instance.message_trace(trace_info)
trace_instance.add_trace.assert_called_once()
trace_instance.add_generation.assert_called_once()
gen_data = trace_instance.add_generation.call_args[0][0]
assert gen_data.name == "llm"
assert gen_data.usage.total == 30
def test_message_trace_with_end_user(trace_instance, monkeypatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
message_data.from_end_user_id = "end-user-1"
message_data.conversation_id = "conv-1"
message_data.status = MessageStatus.NORMAL
message_data.model_id = "gpt-4"
message_data.error = ""
message_data.answer = "hello"
message_data.total_price = 0.0
message_data.provider_response_latency = 0.1
trace_info = MessageTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs={},
outputs={},
message_tokens=0,
answer_tokens=0,
total_tokens=0,
start_time=_dt(),
end_time=_dt(),
metadata={},
conversation_mode="chat",
conversation_model="gpt-4",
file_list=[],
error=None,
)
# Mock DB session for EndUser lookup
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_end_user
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query)
trace_instance.add_trace = MagicMock()
trace_instance.add_generation = MagicMock()
trace_instance.message_trace(trace_info)
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
assert trace_data.user_id == "session-id-123"
assert trace_data.metadata["user_id"] == "session-id-123"
def test_message_trace_none_data(trace_instance):
trace_info = SimpleNamespace(message_data=None, file_list=[], metadata={})
trace_instance.add_trace = MagicMock()
trace_instance.message_trace(trace_info)
trace_instance.add_trace.assert_not_called()
def test_moderation_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
trace_info = ModerationTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs={"q": "hi"},
action="stop",
flagged=True,
preset_response="blocked",
start_time=None,
end_time=None,
metadata={"foo": "bar"},
trace_id="trace-1",
query="hi",
)
trace_instance.add_span = MagicMock()
trace_instance.moderation_trace(trace_info)
trace_instance.add_span.assert_called_once()
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
assert span_data.name == TraceTaskName.MODERATION_TRACE
assert span_data.output["flagged"] is True
def test_suggested_question_trace(trace_instance):
message_data = MagicMock()
message_data.status = MessageStatus.NORMAL
message_data.error = None
trace_info = SuggestedQuestionTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs="hi",
suggested_question=["q1"],
total_tokens=10,
level="info",
start_time=_dt(),
end_time=_dt(),
metadata={},
trace_id="trace-1",
)
trace_instance.add_generation = MagicMock()
trace_instance.suggested_question_trace(trace_info)
trace_instance.add_generation.assert_called_once()
gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"]
assert gen_data.name == TraceTaskName.SUGGESTED_QUESTION_TRACE
assert gen_data.usage.unit == UnitEnum.CHARACTERS
def test_dataset_retrieval_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
message_data.updated_at = _dt()
trace_info = DatasetRetrievalTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs="query",
documents=[{"id": "doc1"}],
start_time=None,
end_time=None,
metadata={},
trace_id="trace-1",
)
trace_instance.add_span = MagicMock()
trace_instance.dataset_retrieval_trace(trace_info)
trace_instance.add_span.assert_called_once()
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
assert span_data.name == TraceTaskName.DATASET_RETRIEVAL_TRACE
assert span_data.output["documents"] == [{"id": "doc1"}]
def test_tool_trace(trace_instance):
trace_info = ToolTraceInfo(
message_id="msg-1",
message_data=MagicMock(),
inputs={},
outputs={},
tool_name="my_tool",
tool_inputs={"a": 1},
tool_outputs="result_string",
time_cost=0.1,
start_time=_dt(),
end_time=_dt(),
metadata={},
trace_id="trace-1",
tool_config={},
tool_parameters={},
error="some error",
)
trace_instance.add_span = MagicMock()
trace_instance.tool_trace(trace_info)
trace_instance.add_span.assert_called_once()
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
assert span_data.name == "my_tool"
assert span_data.level == LevelEnum.ERROR
def test_generate_name_trace(trace_instance):
trace_info = GenerateNameTraceInfo(
inputs={"q": "hi"},
outputs={"name": "new"},
tenant_id="tenant-1",
conversation_id="conv-1",
start_time=_dt(),
end_time=_dt(),
metadata={"m": 1},
)
trace_instance.add_trace = MagicMock()
trace_instance.add_span = MagicMock()
trace_instance.generate_name_trace(trace_info)
trace_instance.add_trace.assert_called_once()
trace_instance.add_span.assert_called_once()
trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"]
assert trace_data.name == TraceTaskName.GENERATE_NAME_TRACE
assert trace_data.user_id == "tenant-1"
span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"]
assert span_data.trace_id == "conv-1"
def test_add_trace_success(trace_instance):
data = LangfuseTrace(id="t1", name="trace")
trace_instance.add_trace(data)
trace_instance.langfuse_client.trace.assert_called_once()
def test_add_trace_error(trace_instance):
trace_instance.langfuse_client.trace.side_effect = Exception("error")
data = LangfuseTrace(id="t1", name="trace")
with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"):
trace_instance.add_trace(data)
def test_add_span_success(trace_instance):
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
trace_instance.add_span(data)
trace_instance.langfuse_client.span.assert_called_once()
def test_add_span_error(trace_instance):
trace_instance.langfuse_client.span.side_effect = Exception("error")
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
with pytest.raises(ValueError, match="LangFuse Failed to create span: error"):
trace_instance.add_span(data)
def test_update_span(trace_instance):
span = MagicMock()
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
trace_instance.update_span(span, data)
span.end.assert_called_once()
def test_add_generation_success(trace_instance):
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
trace_instance.add_generation(data)
trace_instance.langfuse_client.generation.assert_called_once()
def test_add_generation_error(trace_instance):
trace_instance.langfuse_client.generation.side_effect = Exception("error")
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"):
trace_instance.add_generation(data)
def test_update_generation(trace_instance):
gen = MagicMock()
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
trace_instance.update_generation(gen, data)
gen.end.assert_called_once()
def test_api_check_success(trace_instance):
trace_instance.langfuse_client.auth_check.return_value = True
assert trace_instance.api_check() is True
def test_api_check_error(trace_instance):
trace_instance.langfuse_client.auth_check.side_effect = Exception("fail")
with pytest.raises(ValueError, match="LangFuse API check failed: fail"):
trace_instance.api_check()
def test_get_project_key_success(trace_instance):
mock_data = MagicMock()
mock_data.id = "proj-1"
trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data])
assert trace_instance.get_project_key() == "proj-1"
def test_get_project_key_error(trace_instance):
trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail")
with pytest.raises(ValueError, match="LangFuse get project key failed: fail"):
trace_instance.get_project_key()
def test_moderation_trace_none(trace_instance):
trace_info = ModerationTraceInfo(
message_id="m",
message_data=None,
inputs={},
action="s",
flagged=False,
preset_response="",
query="",
metadata={},
)
trace_instance.add_span = MagicMock()
trace_instance.moderation_trace(trace_info)
trace_instance.add_span.assert_not_called()
def test_suggested_question_trace_none(trace_instance):
trace_info = SuggestedQuestionTraceInfo(
message_id="m", message_data=None, inputs={}, suggested_question=[], total_tokens=0, level="i", metadata={}
)
trace_instance.add_generation = MagicMock()
trace_instance.suggested_question_trace(trace_info)
trace_instance.add_generation.assert_not_called()
def test_dataset_retrieval_trace_none(trace_instance):
trace_info = DatasetRetrievalTraceInfo(message_id="m", message_data=None, inputs={}, documents=[], metadata={})
trace_instance.add_span = MagicMock()
trace_instance.dataset_retrieval_trace(trace_info)
trace_instance.add_span.assert_not_called()
def test_langfuse_trace_entity_with_list_dict_input():
# To cover lines 29-31 in langfuse_trace_entity.py
# We need to mock replace_text_with_content or just check if it works
# Actually replace_text_with_content is imported from core.ops.utils
data = LangfuseTrace(id="t1", name="n", input=[{"text": "hello"}])
assert isinstance(data.input, list)
assert data.input[0]["content"] == "hello"
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog):
# Setup trace info to trigger LLM node usage extraction
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="t",
workflow_run_id="r",
workflow_run_elapsed_time=1.0,
workflow_run_status="s",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_version="1",
total_tokens=0,
file_list=[],
query="",
message_id=None,
conversation_id="c",
start_time=_dt(),
end_time=_dt(),
metadata={"app_id": "app-1"},
workflow_app_log_id="l",
error="",
)
node = MagicMock()
node.id = "n1"
node.title = "LLM Node"
node.node_type = NodeType.LLM
node.status = "succeeded"
class BadDict(collections.UserDict):
def get(self, key, default=None):
if key == "usage":
raise Exception("Usage extraction failed")
return super().get(key, default)
node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
node.created_at = _dt()
node.elapsed_time = 0.1
node.metadata = {}
node.outputs = {}
repo = MagicMock()
repo.get_by_workflow_run.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
trace_instance.add_generation = MagicMock()
with caplog.at_level(logging.ERROR):
trace_instance.workflow_trace(trace_info)
assert "Failed to extract usage" in caplog.text
trace_instance.add_generation.assert_called_once()

View File

@@ -0,0 +1,608 @@
import collections
from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from models import EndUser
def _dt() -> datetime:
return datetime(2024, 1, 1, 0, 0, 0)
@pytest.fixture
def langsmith_config():
return LangSmithConfig(api_key="ls-123", project="default", endpoint="https://api.smith.langchain.com")
@pytest.fixture
def trace_instance(langsmith_config, monkeypatch):
# Mock LangSmith client
mock_client = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
instance = LangSmithDataTrace(langsmith_config)
return instance
def test_init(langsmith_config, monkeypatch):
mock_client_class = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangSmithDataTrace(langsmith_config)
mock_client_class.assert_called_once_with(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
assert instance.langsmith_key == langsmith_config.api_key
assert instance.project_name == langsmith_config.project
assert instance.file_base_url == "http://test.url"
def test_trace_dispatch(trace_instance, monkeypatch):
methods = [
"workflow_trace",
"message_trace",
"moderation_trace",
"suggested_question_trace",
"dataset_retrieval_trace",
"tool_trace",
"generate_name_trace",
]
mocks = {method: MagicMock() for method in methods}
for method, m in mocks.items():
monkeypatch.setattr(trace_instance, method, m)
# WorkflowTraceInfo
info = MagicMock(spec=WorkflowTraceInfo)
trace_instance.trace(info)
mocks["workflow_trace"].assert_called_once_with(info)
# MessageTraceInfo
info = MagicMock(spec=MessageTraceInfo)
trace_instance.trace(info)
mocks["message_trace"].assert_called_once_with(info)
# ModerationTraceInfo
info = MagicMock(spec=ModerationTraceInfo)
trace_instance.trace(info)
mocks["moderation_trace"].assert_called_once_with(info)
# SuggestedQuestionTraceInfo
info = MagicMock(spec=SuggestedQuestionTraceInfo)
trace_instance.trace(info)
mocks["suggested_question_trace"].assert_called_once_with(info)
# DatasetRetrievalTraceInfo
info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_instance.trace(info)
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
# ToolTraceInfo
info = MagicMock(spec=ToolTraceInfo)
trace_instance.trace(info)
mocks["tool_trace"].assert_called_once_with(info)
# GenerateNameTraceInfo
info = MagicMock(spec=GenerateNameTraceInfo)
trace_instance.trace(info)
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace(trace_instance, monkeypatch):
# Setup trace info
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)
trace_info = WorkflowTraceInfo(
tenant_id="tenant-1",
workflow_id="wf-1",
workflow_run_id="run-1",
workflow_run_inputs={"input": "hi"},
workflow_run_outputs={"output": "hello"},
workflow_run_status="succeeded",
workflow_run_version="1.0",
workflow_run_elapsed_time=1.0,
total_tokens=100,
file_list=[],
query="hi",
message_id="msg-1",
conversation_id="conv-1",
start_time=_dt(),
end_time=_dt() + timedelta(seconds=1),
trace_id="trace-1",
metadata={"app_id": "app-1"},
workflow_app_log_id="log-1",
error="",
workflow_data=workflow_data,
)
# Mock dependencies
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
node_llm.id = "node-llm"
node_llm.title = "LLM Node"
node_llm.node_type = NodeType.LLM
node_llm.status = "succeeded"
node_llm.process_data = {
"model_mode": "chat",
"model_name": "gpt-4",
"model_provider": "openai",
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
}
node_llm.inputs = {"prompts": "p"}
node_llm.outputs = {"text": "t"}
node_llm.created_at = _dt()
node_llm.elapsed_time = 0.5
node_llm.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 30}
node_other = MagicMock()
node_other.id = "node-other"
node_other.title = "Tool Node"
node_other.node_type = NodeType.TOOL
node_other.status = "succeeded"
node_other.process_data = None
node_other.inputs = {"tool_input": "val"}
node_other.outputs = {"tool_output": "val"}
node_other.created_at = None # Trigger datetime.now()
node_other.elapsed_time = 0.2
node_other.metadata = {}
node_retrieval = MagicMock()
node_retrieval.id = "node-retrieval"
node_retrieval.title = "Retrieval Node"
node_retrieval.node_type = NodeType.KNOWLEDGE_RETRIEVAL
node_retrieval.status = "succeeded"
node_retrieval.process_data = None
node_retrieval.inputs = {"query": "val"}
node_retrieval.outputs = {"results": "val"}
node_retrieval.created_at = _dt()
node_retrieval.elapsed_time = 0.2
node_retrieval.metadata = {}
repo = MagicMock()
repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
trace_instance.workflow_trace(trace_info)
# Verify add_run calls
# 1. message run (id="msg-1")
# 2. workflow run (id="run-1")
# 3. node llm run (id="node-llm")
# 4. node other run (id="node-other")
# 5. node retrieval run (id="node-retrieval")
assert trace_instance.add_run.call_count == 5
call_args = [call[0][0] for call in trace_instance.add_run.call_args_list]
assert call_args[0].id == "msg-1"
assert call_args[0].name == TraceTaskName.MESSAGE_TRACE
assert call_args[1].id == "run-1"
assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE
assert call_args[1].parent_run_id == "msg-1"
assert call_args[2].id == "node-llm"
assert call_args[2].run_type == LangSmithRunType.llm
assert call_args[3].id == "node-other"
assert call_args[3].run_type == LangSmithRunType.tool
assert call_args[4].id == "node-retrieval"
assert call_args[4].run_type == LangSmithRunType.retriever
def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)
trace_info = WorkflowTraceInfo(
tenant_id="tenant-1",
workflow_id="wf-1",
workflow_run_id="run-1",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_status="succeeded",
workflow_run_version="1.0",
workflow_run_elapsed_time=1.0,
total_tokens=10,
file_list=[],
query="hi",
message_id="msg-1",
conversation_id="conv-1",
start_time=None,
end_time=None,
trace_id="trace-1",
metadata={"app_id": "app-1"},
workflow_app_log_id="log-1",
error="",
workflow_data=workflow_data,
)
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_run.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
trace_instance.workflow_trace(trace_info)
assert trace_instance.add_run.called
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.trace_id = "trace-1"
trace_info.message_id = None
trace_info.workflow_run_id = "run-1"
trace_info.start_time = None
trace_info.workflow_data = MagicMock()
trace_info.workflow_data.created_at = _dt()
trace_info.metadata = {} # Empty metadata
trace_info.workflow_app_log_id = "log-1"
trace_info.file_list = []
trace_info.total_tokens = 0
trace_info.workflow_run_inputs = {}
trace_info.workflow_run_outputs = {}
trace_info.error = ""
mock_session = MagicMock()
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
def test_message_trace(trace_instance, monkeypatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
message_data.from_end_user_id = "end-user-1"
message_data.answer = "hello answer"
trace_info = MessageTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs={"input": "hi"},
outputs={"answer": "hello"},
message_tokens=10,
answer_tokens=20,
total_tokens=30,
start_time=_dt(),
end_time=_dt() + timedelta(seconds=1),
trace_id="trace-1",
metadata={"foo": "bar"},
conversation_mode="chat",
conversation_model="gpt-4",
file_list=[],
error=None,
message_file_data=MagicMock(url="file-url"),
)
# Mock EndUser lookup
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_end_user
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query)
trace_instance.add_run = MagicMock()
trace_instance.message_trace(trace_info)
# 1. message run
# 2. llm run
assert trace_instance.add_run.call_count == 2
call_args = [call[0][0] for call in trace_instance.add_run.call_args_list]
assert call_args[0].id == "msg-1"
assert call_args[0].extra["metadata"]["end_user_id"] == "session-id-123"
assert call_args[1].parent_run_id == "msg-1"
assert call_args[1].name == "llm"
def test_message_trace_no_data(trace_instance):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.message_data = None
trace_info.file_list = []
trace_info.message_file_data = None
trace_info.metadata = {}
trace_instance.add_run = MagicMock()
trace_instance.message_trace(trace_info)
trace_instance.add_run.assert_not_called()
def test_moderation_trace_no_data(trace_instance):
trace_info = MagicMock(spec=ModerationTraceInfo)
trace_info.message_data = None
trace_instance.add_run = MagicMock()
trace_instance.moderation_trace(trace_info)
trace_instance.add_run.assert_not_called()
def test_suggested_question_trace_no_data(trace_instance):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
trace_info.message_data = None
trace_instance.add_run = MagicMock()
trace_instance.suggested_question_trace(trace_info)
trace_instance.add_run.assert_not_called()
def test_dataset_retrieval_trace_no_data(trace_instance):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_data = None
trace_instance.add_run = MagicMock()
trace_instance.dataset_retrieval_trace(trace_info)
trace_instance.add_run.assert_not_called()
def test_moderation_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
message_data.updated_at = _dt()
trace_info = ModerationTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs={"q": "hi"},
action="stop",
flagged=True,
preset_response="blocked",
start_time=None,
end_time=None,
metadata={},
trace_id="trace-1",
query="hi",
)
trace_instance.add_run = MagicMock()
trace_instance.moderation_trace(trace_info)
trace_instance.add_run.assert_called_once()
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.MODERATION_TRACE
def test_suggested_question_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
message_data.updated_at = _dt()
trace_info = SuggestedQuestionTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs="hi",
suggested_question=["q1"],
total_tokens=10,
level="info",
start_time=None,
end_time=None,
metadata={},
trace_id="trace-1",
)
trace_instance.add_run = MagicMock()
trace_instance.suggested_question_trace(trace_info)
trace_instance.add_run.assert_called_once()
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.SUGGESTED_QUESTION_TRACE
def test_dataset_retrieval_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
message_data.updated_at = _dt()
trace_info = DatasetRetrievalTraceInfo(
message_id="msg-1",
message_data=message_data,
inputs="query",
documents=[{"id": "doc1"}],
start_time=None,
end_time=None,
metadata={},
trace_id="trace-1",
)
trace_instance.add_run = MagicMock()
trace_instance.dataset_retrieval_trace(trace_info)
trace_instance.add_run.assert_called_once()
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.DATASET_RETRIEVAL_TRACE
def test_tool_trace(trace_instance):
trace_info = ToolTraceInfo(
message_id="msg-1",
message_data=MagicMock(),
inputs={},
outputs={},
tool_name="my_tool",
tool_inputs={"a": 1},
tool_outputs="result",
time_cost=0.1,
start_time=_dt(),
end_time=_dt(),
metadata={},
trace_id="trace-1",
tool_config={},
tool_parameters={},
file_url="http://file",
)
trace_instance.add_run = MagicMock()
trace_instance.tool_trace(trace_info)
trace_instance.add_run.assert_called_once()
assert trace_instance.add_run.call_args[0][0].name == "my_tool"
def test_generate_name_trace(trace_instance):
trace_info = GenerateNameTraceInfo(
inputs={"q": "hi"},
outputs={"name": "new"},
tenant_id="tenant-1",
conversation_id="conv-1",
start_time=None,
end_time=None,
metadata={},
trace_id="trace-1",
)
trace_instance.add_run = MagicMock()
trace_instance.generate_name_trace(trace_info)
trace_instance.add_run.assert_called_once()
assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.GENERATE_NAME_TRACE
def test_add_run_success(trace_instance):
run_data = LangSmithRunModel(
id="run-1", name="test", inputs={}, outputs={}, run_type=LangSmithRunType.tool, start_time=_dt()
)
trace_instance.project_id = "proj-1"
trace_instance.add_run(run_data)
trace_instance.langsmith_client.create_run.assert_called_once()
args, kwargs = trace_instance.langsmith_client.create_run.call_args
assert kwargs["session_id"] == "proj-1"
def test_add_run_error(trace_instance):
run_data = LangSmithRunModel(id="run-1", name="test", run_type=LangSmithRunType.tool, start_time=_dt())
trace_instance.langsmith_client.create_run.side_effect = Exception("failed")
with pytest.raises(ValueError, match="LangSmith Failed to create run: failed"):
trace_instance.add_run(run_data)
def test_update_run_success(trace_instance):
update_data = LangSmithRunUpdateModel(run_id="run-1", outputs={"out": "val"})
trace_instance.update_run(update_data)
trace_instance.langsmith_client.update_run.assert_called_once()
def test_update_run_error(trace_instance):
update_data = LangSmithRunUpdateModel(run_id="run-1")
trace_instance.langsmith_client.update_run.side_effect = Exception("failed")
with pytest.raises(ValueError, match="LangSmith Failed to update run: failed"):
trace_instance.update_run(update_data)
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog):
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)
trace_info = WorkflowTraceInfo(
tenant_id="tenant-1",
workflow_id="wf-1",
workflow_run_id="run-1",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_status="succeeded",
workflow_run_version="1.0",
workflow_run_elapsed_time=1.0,
total_tokens=100,
file_list=[],
query="hi",
message_id="msg-1",
conversation_id="conv-1",
start_time=_dt(),
end_time=_dt(),
trace_id="trace-1",
metadata={"app_id": "app-1"},
workflow_app_log_id="log-1",
error="",
workflow_data=workflow_data,
)
class BadDict(collections.UserDict):
def get(self, key, default=None):
if key == "usage":
raise Exception("Usage extraction failed")
return super().get(key, default)
node_llm = MagicMock()
node_llm.id = "node-llm"
node_llm.title = "LLM Node"
node_llm.node_type = NodeType.LLM
node_llm.status = "succeeded"
node_llm.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
node_llm.inputs = {}
node_llm.outputs = {}
node_llm.created_at = _dt()
node_llm.elapsed_time = 0.5
node_llm.metadata = {}
repo = MagicMock()
repo.get_by_workflow_run.return_value = [node_llm]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
import logging
with caplog.at_level(logging.ERROR):
trace_instance.workflow_trace(trace_info)
assert "Failed to extract usage" in caplog.text
def test_api_check_success(trace_instance):
assert trace_instance.api_check() is True
assert trace_instance.langsmith_client.create_project.called
assert trace_instance.langsmith_client.delete_project.called
def test_api_check_error(trace_instance):
trace_instance.langsmith_client.create_project.side_effect = Exception("error")
with pytest.raises(ValueError, match="LangSmith API check failed: error"):
trace_instance.api_check()
def test_get_project_url_success(trace_instance):
trace_instance.langsmith_client.get_run_url.return_value = "https://smith.langchain.com/o/org/p/proj/r/run"
url = trace_instance.get_project_url()
assert url == "https://smith.langchain.com/o/org/p/proj"
def test_get_project_url_error(trace_instance):
trace_instance.langsmith_client.get_run_url.side_effect = Exception("error")
with pytest.raises(ValueError, match="LangSmith get run url failed: error"):
trace_instance.get_project_url()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,678 @@
import collections
import logging
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from models import EndUser
from models.enums import MessageStatus
def _dt() -> datetime:
return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
@pytest.fixture
def opik_config():
return OpikConfig(
project="test-project", workspace="test-workspace", url="https://cloud.opik.com/api/", api_key="api-key-123"
)
@pytest.fixture
def trace_instance(opik_config, monkeypatch):
mock_client = MagicMock()
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client)
instance = OpikDataTrace(opik_config)
return instance
def test_wrap_dict():
assert wrap_dict("input", {"a": 1}) == {"a": 1}
assert wrap_dict("input", "hello") == {"input": "hello"}
def test_wrap_metadata():
assert wrap_metadata({"a": 1}, b=2) == {"a": 1, "b": 2, "created_from": "dify"}
def test_prepare_opik_uuid():
# Test with valid datetime and uuid string
dt = datetime(2024, 1, 1)
uuid_str = "b3e8e918-472e-4b69-8051-12502c34fc07"
result = prepare_opik_uuid(dt, uuid_str)
assert result is not None
# We won't test the exact uuid7 value but just that it returns a string id
# Test with None dt and uuid_str
result = prepare_opik_uuid(None, None)
assert result is not None
def test_init(opik_config, monkeypatch):
mock_opik = MagicMock()
monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = OpikDataTrace(opik_config)
mock_opik.assert_called_once_with(
project_name=opik_config.project,
workspace=opik_config.workspace,
host=opik_config.url,
api_key=opik_config.api_key,
)
assert instance.file_base_url == "http://test.url"
assert instance.project == opik_config.project
def test_trace_dispatch(trace_instance, monkeypatch):
methods = [
"workflow_trace",
"message_trace",
"moderation_trace",
"suggested_question_trace",
"dataset_retrieval_trace",
"tool_trace",
"generate_name_trace",
]
mocks = {method: MagicMock() for method in methods}
for method, m in mocks.items():
monkeypatch.setattr(trace_instance, method, m)
# WorkflowTraceInfo
info = MagicMock(spec=WorkflowTraceInfo)
trace_instance.trace(info)
mocks["workflow_trace"].assert_called_once_with(info)
# MessageTraceInfo
info = MagicMock(spec=MessageTraceInfo)
trace_instance.trace(info)
mocks["message_trace"].assert_called_once_with(info)
# ModerationTraceInfo
info = MagicMock(spec=ModerationTraceInfo)
trace_instance.trace(info)
mocks["moderation_trace"].assert_called_once_with(info)
# SuggestedQuestionTraceInfo
info = MagicMock(spec=SuggestedQuestionTraceInfo)
trace_instance.trace(info)
mocks["suggested_question_trace"].assert_called_once_with(info)
# DatasetRetrievalTraceInfo
info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_instance.trace(info)
mocks["dataset_retrieval_trace"].assert_called_once_with(info)
# ToolTraceInfo
info = MagicMock(spec=ToolTraceInfo)
trace_instance.trace(info)
mocks["tool_trace"].assert_called_once_with(info)
# GenerateNameTraceInfo
info = MagicMock(spec=GenerateNameTraceInfo)
trace_instance.trace(info)
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
# Define constants for better readability
WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce"
WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef"
MESSAGE_ID = "04ec3956-85f3-488a-8539-1017251dc8c6"
CONVERSATION_ID = "d3d01066-23ae-4830-9ce4-eb5640b42a7e"
TRACE_ID = "bf26d929-6f15-4c2f-9abc-761c217056f3"
WORKFLOW_APP_LOG_ID = "ca0e018e-edd4-43fb-a05a-ea001ca8ef4b"
LLM_NODE_ID = "80d7dfa8-08f4-4ab7-aa37-0ca7d27207e3"
CODE_NODE_ID = "b9cd9a7b-c534-4aa9-b5da-efd454140900"
trace_info = WorkflowTraceInfo(
workflow_id=WORKFLOW_ID,
tenant_id="tenant-1",
workflow_run_id=WORKFLOW_RUN_ID,
workflow_run_elapsed_time=1.0,
workflow_run_status="succeeded",
workflow_run_inputs={"input": "hi"},
workflow_run_outputs={"output": "hello"},
workflow_run_version="1.0",
message_id=MESSAGE_ID,
conversation_id=CONVERSATION_ID,
total_tokens=100,
file_list=[],
query="hi",
start_time=_dt(),
end_time=_dt() + timedelta(seconds=1),
trace_id=TRACE_ID,
metadata={"app_id": "app-1", "user_id": "user-1"},
workflow_app_log_id=WORKFLOW_APP_LOG_ID,
error="",
)
mock_session = MagicMock()
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
node_llm = MagicMock()
node_llm.id = LLM_NODE_ID
node_llm.title = "LLM Node"
node_llm.node_type = NodeType.LLM
node_llm.status = "succeeded"
node_llm.process_data = {
"model_mode": "chat",
"model_name": "gpt-4",
"model_provider": "openai",
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
node_llm.inputs = {"prompts": "p"}
node_llm.outputs = {"text": "t"}
node_llm.created_at = _dt()
node_llm.elapsed_time = 0.5
node_llm.metadata = {"foo": "bar"}
node_other = MagicMock()
node_other.id = CODE_NODE_ID
node_other.title = "Other Node"
node_other.node_type = NodeType.CODE
node_other.status = "failed"
node_other.process_data = None
node_other.inputs = {"code": "print"}
node_other.outputs = {"result": "ok"}
node_other.created_at = None
node_other.elapsed_time = 0.2
node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10}
repo = MagicMock()
repo.get_by_workflow_run.return_value = [node_llm, node_other]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
trace_instance.add_span = MagicMock()
trace_instance.workflow_trace(trace_info)
trace_instance.add_trace.assert_called_once()
trace_data = trace_instance.add_trace.call_args[1].get("opik_trace_data", trace_instance.add_trace.call_args[0][0])
assert trace_data["name"] == TraceTaskName.MESSAGE_TRACE
assert "message" in trace_data["tags"]
assert "workflow" in trace_data["tags"]
assert trace_instance.add_span.call_count >= 1
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
# Define constants for better readability
WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3"
WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac"
CONVERSATION_ID = "88a17f2e-9436-4472-bab9-4b1601d5af3c"
WORKFLOW_APP_LOG_ID = "41780d0d-ffba-4220-bc0c-401e4c89cdfb"
trace_info = WorkflowTraceInfo(
workflow_id=WORKFLOW_ID,
tenant_id="tenant-1",
workflow_run_id=WORKFLOW_RUN_ID,
workflow_run_elapsed_time=1.0,
workflow_run_status="succeeded",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_version="1.0",
total_tokens=0,
file_list=[],
query="",
message_id=None,
conversation_id=CONVERSATION_ID,
start_time=_dt(),
end_time=_dt(),
trace_id=None,
metadata={"app_id": "app-1"},
workflow_app_log_id=WORKFLOW_APP_LOG_ID,
error="",
)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_run.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
trace_instance.workflow_trace(trace_info)
trace_instance.add_trace.assert_called_once()
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_info = WorkflowTraceInfo(
workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a",
tenant_id="tenant-1",
workflow_run_id="46f53304-1659-464b-bee5-116585f0bec8",
workflow_run_elapsed_time=1.0,
workflow_run_status="succeeded",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_version="1.0",
total_tokens=0,
file_list=[],
query="",
message_id=None,
conversation_id="83f86b89-caef-4de8-a0f9-f164eddae1ea",
start_time=_dt(),
end_time=_dt(),
metadata={},
workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e",
error="",
)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
def test_message_trace_basic(trace_instance, monkeypatch):
# Define constants for better readability
MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab"
CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3"
MESSAGE_TRACE_ID = "710ace2f-bca8-41be-858c-54da42742a77"
OPIT_TRACE_ID = "f7dfd978-0d10-4549-8abf-00f2cbc49d2c"
message_data = MagicMock()
message_data.id = MESSAGE_DATA_ID
message_data.from_account_id = "acc-1"
message_data.from_end_user_id = None
message_data.provider_response_latency = 0.5
message_data.conversation_id = CONVERSATION_ID
message_data.total_price = 0.01
message_data.model_id = "gpt-4"
message_data.answer = "hello"
message_data.status = MessageStatus.NORMAL
message_data.error = None
trace_info = MessageTraceInfo(
message_id=MESSAGE_TRACE_ID,
message_data=message_data,
inputs={"query": "hi"},
outputs={"answer": "hello"},
message_tokens=10,
answer_tokens=20,
total_tokens=30,
start_time=_dt(),
end_time=_dt() + timedelta(seconds=1),
trace_id=OPIT_TRACE_ID,
metadata={"foo": "bar"},
conversation_mode="chat",
conversation_model="gpt-4",
file_list=[],
error=None,
message_file_data=MagicMock(url="test.png"),
)
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_1"))
trace_instance.add_span = MagicMock()
trace_instance.message_trace(trace_info)
trace_instance.add_trace.assert_called_once()
trace_instance.add_span.assert_called_once()
def test_message_trace_with_end_user(trace_instance, monkeypatch):
message_data = MagicMock()
message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e"
message_data.from_account_id = "acc-1"
message_data.from_end_user_id = "end-user-1"
message_data.conversation_id = "7d9f96d8-3be2-4e93-9c0e-922ff98dccc6"
message_data.status = MessageStatus.NORMAL
message_data.model_id = "gpt-4"
message_data.error = ""
message_data.answer = "hello"
message_data.total_price = 0.0
message_data.provider_response_latency = 0.1
trace_info = MessageTraceInfo(
message_id="6bff35c7-33b7-4acb-ba21-44569a0327d0",
message_data=message_data,
inputs={},
outputs={},
message_tokens=0,
answer_tokens=0,
total_tokens=0,
start_time=_dt(),
end_time=_dt(),
metadata={},
conversation_mode="chat",
conversation_model="gpt-4",
file_list=["url1"],
error=None,
)
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
mock_query = MagicMock()
mock_query.where.return_value.first.return_value = mock_end_user
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query)
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
trace_instance.add_span = MagicMock()
trace_instance.message_trace(trace_info)
trace_data = trace_instance.add_trace.call_args[0][0]
assert trace_data["metadata"]["user_id"] == "acc-1"
assert trace_data["metadata"]["end_user_id"] == "session-id-123"
def test_message_trace_none_data(trace_instance):
trace_info = SimpleNamespace(message_data=None, file_list=[], message_file_data=None, metadata={})
trace_instance.add_trace = MagicMock()
trace_instance.message_trace(trace_info)
trace_instance.add_trace.assert_not_called()
def test_moderation_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
message_data.updated_at = _dt()
trace_info = ModerationTraceInfo(
message_id="489d0dfd-065c-4106-8f9c-daded296c92d",
message_data=message_data,
inputs={"q": "hi"},
action="stop",
flagged=True,
preset_response="blocked",
start_time=None,
end_time=None,
metadata={"foo": "bar"},
trace_id="6f16cf18-9f4b-4955-8b6b-43cfa10978fc",
query="hi",
)
trace_instance.add_span = MagicMock()
trace_instance.moderation_trace(trace_info)
trace_instance.add_span.assert_called_once()
span_data = trace_instance.add_span.call_args[0][0]
assert span_data["name"] == TraceTaskName.MODERATION_TRACE
assert span_data["output"]["flagged"] is True
def test_moderation_trace_none(trace_instance):
trace_info = ModerationTraceInfo(
message_id="cd732e4e-37f1-4c7e-8c64-820308bedcbf",
message_data=None,
inputs={},
action="s",
flagged=False,
preset_response="",
query="",
metadata={},
)
trace_instance.add_span = MagicMock()
trace_instance.moderation_trace(trace_info)
trace_instance.add_span.assert_not_called()
def test_suggested_question_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
message_data.updated_at = _dt()
trace_info = SuggestedQuestionTraceInfo(
message_id="7de55bda-a91d-477e-98ab-85c53c438469",
message_data=message_data,
inputs="hi",
suggested_question=["q1"],
total_tokens=10,
level="info",
start_time=_dt(),
end_time=_dt(),
metadata={},
trace_id="a6687292-68c7-42ba-ae51-285579944d7b",
)
trace_instance.add_span = MagicMock()
trace_instance.suggested_question_trace(trace_info)
trace_instance.add_span.assert_called_once()
span_data = trace_instance.add_span.call_args[0][0]
assert span_data["name"] == TraceTaskName.SUGGESTED_QUESTION_TRACE
def test_suggested_question_trace_none(trace_instance):
trace_info = SuggestedQuestionTraceInfo(
message_id="23696fc5-7e7f-46ec-bce8-1adc3c7f297d",
message_data=None,
inputs={},
suggested_question=[],
total_tokens=0,
level="i",
metadata={},
)
trace_instance.add_span = MagicMock()
trace_instance.suggested_question_trace(trace_info)
trace_instance.add_span.assert_not_called()
def test_dataset_retrieval_trace(trace_instance):
message_data = MagicMock()
message_data.created_at = _dt()
message_data.updated_at = _dt()
trace_info = DatasetRetrievalTraceInfo(
message_id="3e1a819f-c391-4950-adfd-96f82e5419a1",
message_data=message_data,
inputs="query",
documents=[{"id": "doc1"}],
start_time=None,
end_time=None,
metadata={},
trace_id="41361000-e9be-4d11-b5e4-ab27ce0817d6",
)
trace_instance.add_span = MagicMock()
trace_instance.dataset_retrieval_trace(trace_info)
trace_instance.add_span.assert_called_once()
span_data = trace_instance.add_span.call_args[0][0]
assert span_data["name"] == TraceTaskName.DATASET_RETRIEVAL_TRACE
def test_dataset_retrieval_trace_none(trace_instance):
trace_info = DatasetRetrievalTraceInfo(
message_id="35d6d44c-bccb-4e6e-8bd8-859257723ea8", message_data=None, inputs={}, documents=[], metadata={}
)
trace_instance.add_span = MagicMock()
trace_instance.dataset_retrieval_trace(trace_info)
trace_instance.add_span.assert_not_called()
def test_tool_trace(trace_instance):
trace_info = ToolTraceInfo(
message_id="99db92c4-2254-496a-b5cc-18153315ce35",
message_data=MagicMock(),
inputs={},
outputs={},
tool_name="my_tool",
tool_inputs={"a": 1},
tool_outputs="result_string",
time_cost=0.1,
start_time=_dt(),
end_time=_dt(),
metadata={},
trace_id="a15a5fcb-7ffd-4458-8330-208f4cb1f796",
tool_config={},
tool_parameters={},
error="some error",
)
trace_instance.add_span = MagicMock()
trace_instance.tool_trace(trace_info)
trace_instance.add_span.assert_called_once()
span_data = trace_instance.add_span.call_args[0][0]
assert span_data["name"] == "my_tool"
def test_generate_name_trace(trace_instance):
trace_info = GenerateNameTraceInfo(
inputs={"q": "hi"},
outputs={"name": "new"},
tenant_id="tenant-1",
conversation_id="271fe28f-6b86-416b-8d6b-bbbbfa9db791",
start_time=_dt(),
end_time=_dt(),
metadata={"921f010e-6878-4831-ae6b-271bf68c56fb": 1},
)
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_3"))
trace_instance.add_span = MagicMock()
trace_instance.generate_name_trace(trace_info)
trace_instance.add_trace.assert_called_once()
trace_instance.add_span.assert_called_once()
trace_data = trace_instance.add_trace.call_args[0][0]
assert trace_data["name"] == TraceTaskName.GENERATE_NAME_TRACE
span_data = trace_instance.add_span.call_args[0][0]
assert span_data["trace_id"] == "trace_id_3"
def test_add_trace_success(trace_instance):
trace_data = {"id": "t1", "name": "trace"}
trace_instance.opik_client.trace.return_value = MagicMock(id="t1")
trace = trace_instance.add_trace(trace_data)
trace_instance.opik_client.trace.assert_called_once()
assert trace.id == "t1"
def test_add_trace_error(trace_instance):
trace_instance.opik_client.trace.side_effect = Exception("error")
trace_data = {"id": "t1", "name": "trace"}
with pytest.raises(ValueError, match="Opik Failed to create trace: error"):
trace_instance.add_trace(trace_data)
def test_add_span_success(trace_instance):
span_data = {"id": "s1", "name": "span", "trace_id": "t1"}
trace_instance.add_span(span_data)
trace_instance.opik_client.span.assert_called_once()
def test_add_span_error(trace_instance):
trace_instance.opik_client.span.side_effect = Exception("error")
span_data = {"id": "s1", "name": "span", "trace_id": "t1"}
with pytest.raises(ValueError, match="Opik Failed to create span: error"):
trace_instance.add_span(span_data)
def test_api_check_success(trace_instance):
trace_instance.opik_client.auth_check.return_value = True
assert trace_instance.api_check() is True
def test_api_check_error(trace_instance):
trace_instance.opik_client.auth_check.side_effect = Exception("fail")
with pytest.raises(ValueError, match="Opik API check failed: fail"):
trace_instance.api_check()
def test_get_project_url_success(trace_instance):
trace_instance.opik_client.get_project_url.return_value = "http://project.url"
assert trace_instance.get_project_url() == "http://project.url"
trace_instance.opik_client.get_project_url.assert_called_once_with(project_name=trace_instance.project)
def test_get_project_url_error(trace_instance):
trace_instance.opik_client.get_project_url.side_effect = Exception("fail")
with pytest.raises(ValueError, match="Opik get run url failed: fail"):
trace_instance.get_project_url()
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog):
trace_info = WorkflowTraceInfo(
workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de",
tenant_id="66e8e918-472e-4b69-8051-12502c34fc07",
workflow_run_id="8403965c-3344-4d22-a8fe-d8d55cee64d9",
workflow_run_elapsed_time=1.0,
workflow_run_status="s",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_version="1",
total_tokens=0,
file_list=[],
query="",
message_id=None,
conversation_id="7a02cb9d-6949-4c59-a89d-f25bbc881e0e",
start_time=_dt(),
end_time=_dt(),
metadata={"app_id": "77e8e918-472e-4b69-8051-12502c34fc07"},
workflow_app_log_id="82268424-e193-476c-a6db-f473388ee5fe",
error="",
)
node = MagicMock()
node.id = "88e8e918-472e-4b69-8051-12502c34fc07"
node.title = "LLM Node"
node.node_type = NodeType.LLM
node.status = "succeeded"
class BadDict(collections.UserDict):
def get(self, key, default=None):
if key == "usage":
raise Exception("Usage extraction failed")
return super().get(key, default)
node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]})
node.created_at = _dt()
node.elapsed_time = 0.1
node.metadata = {}
node.outputs = {}
repo = MagicMock()
repo.get_by_workflow_run.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
trace_instance.add_span = MagicMock()
with caplog.at_level(logging.ERROR):
trace_instance.workflow_trace(trace_info)
assert "Failed to extract usage" in caplog.text
assert trace_instance.add_span.call_count >= 1
# Verify that at least one of the spans is for the LLM Node
span_names = [call.args[0]["name"] for call in trace_instance.add_span.call_args_list]
assert "LLM Node" in span_names

View File

@@ -0,0 +1,583 @@
"""Tests for the TencentTraceClient helpers that drive tracing and metrics."""
from __future__ import annotations
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from core.ops.tencent_trace import client as client_module
from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version
from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
class DummyHistogram:
"""Placeholder histogram type used by the stubbed metric stack."""
class AggregationTemporality:
DELTA = "delta"
class DummyMeter:
def __init__(self) -> None:
self.created: list[tuple[dict[str, object], MagicMock]] = []
def create_histogram(self, **kwargs: object) -> MagicMock:
hist = MagicMock(name=f"hist-{kwargs.get('name')}")
self.created.append((kwargs, hist))
return hist
class DummyMeterProvider:
def __init__(self, resource: object, metric_readers: list[object]) -> None:
self.resource = resource
self.metric_readers = metric_readers
self.meter = DummyMeter()
self.shutdown = MagicMock(name="meter_provider_shutdown")
meter_provider_instances.append(self)
def get_meter(self, name: str, version: str) -> DummyMeter:
return self.meter
class DummyMetricReader:
def __init__(self, exporter: object, export_interval_millis: int) -> None:
self.exporter = exporter
self.export_interval_millis = export_interval_millis
self.shutdown = MagicMock(name="metric_reader_shutdown")
metric_reader_instances.append(self)
class DummyGrpcMetricExporter:
def __init__(self, **kwargs: object) -> None:
self.kwargs = kwargs
class DummyHttpMetricExporter:
def __init__(self, **kwargs: object) -> None:
self.kwargs = kwargs
class DummyJsonMetricExporter:
def __init__(self, **kwargs: object) -> None:
self.kwargs = kwargs
class DummyJsonMetricExporterNoTemporality:
"""Exporter that rejects preferred_temporality to exercise fallback."""
def __init__(self, **kwargs: object) -> None:
if "preferred_temporality" in kwargs:
raise RuntimeError("unsupported preferred_temporality")
self.kwargs = kwargs
def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None:
"""Drop fake metric modules into sys.modules so the client imports resolve."""
metrics_module = types.ModuleType("opentelemetry.sdk.metrics")
metrics_module.Histogram = DummyHistogram
metrics_module.MeterProvider = DummyMeterProvider
monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics", metrics_module)
metrics_export_module = types.ModuleType("opentelemetry.sdk.metrics.export")
metrics_export_module.AggregationTemporality = AggregationTemporality
metrics_export_module.PeriodicExportingMetricReader = DummyMetricReader
monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics.export", metrics_export_module)
grpc_module = types.ModuleType("opentelemetry.exporter.otlp.proto.grpc.metric_exporter")
grpc_module.OTLPMetricExporter = DummyGrpcMetricExporter
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.grpc.metric_exporter", grpc_module)
http_module = types.ModuleType("opentelemetry.exporter.otlp.proto.http.metric_exporter")
http_module.OTLPMetricExporter = DummyHttpMetricExporter
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.http.metric_exporter", http_module)
http_json_module = types.ModuleType("opentelemetry.exporter.otlp.http.json.metric_exporter")
http_json_module.OTLPMetricExporter = DummyJsonMetricExporter
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.http.json.metric_exporter", http_json_module)
legacy_json_module = types.ModuleType("opentelemetry.exporter.otlp.json.metric_exporter")
legacy_json_module.OTLPMetricExporter = DummyJsonMetricExporter
monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.json.metric_exporter", legacy_json_module)
@pytest.fixture(autouse=True)
def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None:
metric_reader_instances.clear()
meter_provider_instances.clear()
_add_stub_modules(monkeypatch)
@pytest.fixture(autouse=True)
def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]:
span_exporter = MagicMock(name="span_exporter")
monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter))
span_processor = MagicMock(name="span_processor")
monkeypatch.setattr(client_module, "BatchSpanProcessor", MagicMock(return_value=span_processor))
tracer = MagicMock(name="tracer")
span = MagicMock(name="span")
tracer.start_span.return_value = span
tracer_provider = MagicMock(name="tracer_provider")
tracer_provider.get_tracer.return_value = tracer
tracer_provider.shutdown = MagicMock(name="tracer_provider_shutdown")
monkeypatch.setattr(client_module, "TracerProvider", MagicMock(return_value=tracer_provider))
resource = MagicMock(name="resource")
monkeypatch.setattr(client_module, "Resource", MagicMock(return_value=resource))
logger_mock = MagicMock(name="tencent_logger")
monkeypatch.setattr(client_module, "logger", logger_mock)
trace_api_stub = SimpleNamespace(
set_span_in_context=MagicMock(name="set_span_in_context", return_value="trace-context"),
NonRecordingSpan=MagicMock(name="non_recording_span", side_effect=lambda ctx: f"non-{ctx}"),
)
monkeypatch.setattr(client_module, "trace_api", trace_api_stub)
fake_config = SimpleNamespace(
project=SimpleNamespace(version="test"),
COMMIT_SHA="sha",
DEPLOY_ENV="dev",
EDITION="cloud",
)
monkeypatch.setattr(client_module, "dify_config", fake_config)
monkeypatch.setattr(client_module.socket, "gethostname", lambda: "fake-host")
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "")
return {
"span_exporter": span_exporter,
"span_processor": span_processor,
"tracer": tracer,
"span": span,
"tracer_provider": tracer_provider,
"logger": logger_mock,
"trace_api": trace_api_stub,
}
def _build_client() -> TencentTraceClient:
return TencentTraceClient(
service_name="service",
endpoint="https://trace.example.com:4317",
token="token",
)
def test_get_opentelemetry_sdk_version_reads_install(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(client_module, "version", lambda pkg: "2.0.0")
assert _get_opentelemetry_sdk_version() == "2.0.0"
def test_get_opentelemetry_sdk_version_falls_back(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(client_module, "version", MagicMock(side_effect=RuntimeError("boom")))
assert _get_opentelemetry_sdk_version() == "1.27.0"
@pytest.mark.parametrize(
("endpoint", "expected"),
[
(
"https://example.com:9090",
("example.com:9090", False, "example.com", 9090),
),
(
"http://localhost",
("localhost:4317", True, "localhost", 4317),
),
(
"example.com:bad",
("example.com:4317", False, "example.com", 4317),
),
],
)
def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[str, bool, str, int]) -> None:
assert TencentTraceClient._resolve_grpc_target(endpoint) == expected
def test_resolve_grpc_target_handles_errors() -> None:
assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317)
@pytest.mark.parametrize(
("method", "attr_name", "args"),
[
("record_llm_duration", "hist_llm_duration", (0.3, {"foo": object()})),
("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")),
("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")),
("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")),
("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})),
],
)
def test_record_methods_call_histograms(method: str, attr_name: str, args: tuple[object, ...]) -> None:
client = _build_client()
hist_mock = MagicMock(name=attr_name)
setattr(client, attr_name, hist_mock)
getattr(client, method)(*args)
hist_mock.record.assert_called_once()
def test_record_methods_skip_when_histogram_missing() -> None:
client = _build_client()
client.hist_llm_duration = None
client.record_llm_duration(0.1)
client.hist_token_usage = None
client.record_token_usage(1, "go", "chat", "model", "model", "addr", "provider")
client.hist_time_to_first_token = None
client.record_time_to_first_token(0.2, "prov", "model")
client.hist_time_to_generate = None
client.record_time_to_generate(0.3, "prov", "model")
client.hist_trace_duration = None
client.record_trace_duration(0.5)
def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None:
client = _build_client()
client.hist_llm_duration = MagicMock(name="hist_llm_duration")
client.hist_llm_duration.record.side_effect = RuntimeError("boom")
client.record_llm_duration(0.2)
logger = patch_core_components["logger"]
logger.debug.assert_called()
def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
data = SpanData(
trace_id=1,
parent_span_id=None,
span_id=2,
name="span",
attributes={"key": "value"},
events=[Event(name="evt", attributes={"k": "v"}, timestamp=123)],
status=Status(StatusCode.OK),
start_time=10,
end_time=20,
)
client._create_and_export_span(data)
span.set_attributes.assert_called_once()
span.add_event.assert_called_once()
span.set_status.assert_called_once()
span.end.assert_called_once_with(end_time=20)
assert client.span_contexts[2] == "ctx"
def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None:
client = _build_client()
client.span_contexts[10] = "existing"
span = patch_core_components["span"]
span.get_span_context.return_value = "child"
data = SpanData(
trace_id=1,
parent_span_id=10,
span_id=11,
name="span",
attributes={},
events=[],
start_time=0,
end_time=1,
)
client._create_and_export_span(data)
trace_api = patch_core_components["trace_api"]
trace_api.NonRecordingSpan.assert_called_once_with("existing")
trace_api.set_span_in_context.assert_called_once()
def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
client.tracer.start_span.side_effect = RuntimeError("boom")
client._create_and_export_span(
SpanData(
trace_id=1,
parent_span_id=None,
span_id=2,
name="span",
attributes={},
events=[],
start_time=0,
end_time=1,
)
)
logger = patch_core_components["logger"]
logger.exception.assert_called_once()
def test_api_check_connects_successfully(monkeypatch: pytest.MonkeyPatch) -> None:
client = _build_client()
monkeypatch.setattr(
TencentTraceClient,
"_resolve_grpc_target",
MagicMock(return_value=("host:123", False, "host", 123)),
)
socket_mock = MagicMock()
socket_instance = MagicMock()
socket_instance.connect_ex.return_value = 0
socket_mock.return_value = socket_instance
monkeypatch.setattr(client_module.socket, "socket", socket_mock)
assert client.api_check()
socket_instance.connect_ex.assert_called_once()
def test_api_check_returns_false_and_handles_local(monkeypatch: pytest.MonkeyPatch) -> None:
client = _build_client()
monkeypatch.setattr(
TencentTraceClient,
"_resolve_grpc_target",
MagicMock(return_value=("host:123", False, "host", 123)),
)
socket_mock = MagicMock()
socket_instance = MagicMock()
socket_instance.connect_ex.return_value = 1
socket_mock.return_value = socket_instance
monkeypatch.setattr(client_module.socket, "socket", socket_mock)
assert not client.api_check()
monkeypatch.setattr(
TencentTraceClient,
"_resolve_grpc_target",
MagicMock(return_value=("localhost:4317", True, "localhost", 4317)),
)
socket_instance.connect_ex.return_value = 1
assert client.api_check()
def test_api_check_handles_exceptions(monkeypatch: pytest.MonkeyPatch) -> None:
client = TencentTraceClient("svc", "https://localhost", "token")
monkeypatch.setattr(client_module.socket, "socket", MagicMock(side_effect=RuntimeError("boom")))
assert client.api_check()
def test_get_project_url() -> None:
client = _build_client()
assert client.get_project_url() == "https://console.cloud.tencent.com/apm"
def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None:
client = _build_client()
span_processor = patch_core_components["span_processor"]
tracer_provider = patch_core_components["tracer_provider"]
client.shutdown()
span_processor.force_flush.assert_called_once()
span_processor.shutdown.assert_called_once()
tracer_provider.shutdown.assert_called_once()
meter_provider = meter_provider_instances[-1]
metric_reader = metric_reader_instances[-1]
meter_provider.shutdown.assert_called_once()
metric_reader.shutdown.assert_called_once()
def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None:
client = _build_client()
meter_provider = meter_provider_instances[-1]
meter_provider.shutdown.side_effect = RuntimeError("boom")
client.metric_reader.shutdown.side_effect = RuntimeError("boom")
client.shutdown()
logger = patch_core_components["logger"]
logger.debug.assert_any_call(
"[Tencent APM] Error shutting down meter provider",
exc_info=True,
)
logger.debug.assert_any_call(
"[Tencent APM] Error shutting down metric reader",
exc_info=True,
)
def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(DummyMeterProvider, "__init__", MagicMock(side_effect=RuntimeError("err")))
client = _build_client()
assert client.meter is None
assert client.meter_provider is None
assert client.hist_llm_duration is None
assert client.hist_token_usage is None
assert client.hist_time_to_first_token is None
assert client.hist_time_to_generate is None
assert client.hist_trace_duration is None
assert client.metric_reader is None
def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None:
client = _build_client()
monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom")))
client.add_span(
SpanData(
trace_id=1,
parent_span_id=None,
span_id=2,
name="span",
attributes={},
events=[],
start_time=0,
end_time=1,
)
)
logger = patch_core_components["logger"]
logger.exception.assert_called_once()
def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None:
client = _build_client()
span = patch_core_components["span"]
span.get_span_context.return_value = "ctx"
data = SpanData.model_construct(
trace_id=1,
parent_span_id=None,
span_id=2,
name="span",
attributes={"num": 5, "flag": True, "pi": 3.14, "text": "value"},
events=[],
links=[],
status=Status(StatusCode.OK),
start_time=0,
end_time=1,
)
client._create_and_export_span(data)
(attrs,) = span.set_attributes.call_args.args
assert attrs["num"] == 5
assert attrs["flag"] is True
assert attrs["pi"] == 3.14
assert attrs["text"] == "value"
def test_record_llm_duration_converts_attributes() -> None:
client = _build_client()
hist_mock = MagicMock(name="hist_llm_duration")
client.hist_llm_duration = hist_mock
client.record_llm_duration(0.3, {"foo": object(), "bar": 2})
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["foo"], str)
assert attrs["bar"] == 2
def test_record_trace_duration_converts_attributes() -> None:
client = _build_client()
hist_mock = MagicMock(name="hist_trace_duration")
client.hist_trace_duration = hist_mock
client.record_trace_duration(1.0, {"meta": object(), "ok": True})
_, attrs = hist_mock.record.call_args.args
assert isinstance(attrs["meta"], str)
assert attrs["ok"] is True
@pytest.mark.parametrize(
("method", "attr_name", "args"),
[
("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")),
("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")),
("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")),
("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})),
],
)
def test_record_methods_handle_exceptions(
method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object]
) -> None:
client = _build_client()
hist_mock = MagicMock(name=attr_name)
hist_mock.record.side_effect = RuntimeError("boom")
setattr(client, attr_name, hist_mock)
getattr(client, method)(*args)
logger = patch_core_components["logger"]
logger.debug.assert_called()
def test_metrics_initializes_grpc_metric_exporter() -> None:
client = _build_client()
metric_reader = metric_reader_instances[-1]
assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317"
assert metric_reader.exporter.kwargs["insecure"] is False
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf")
client = _build_client()
metric_reader = metric_reader_instances[-1]
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
client = _build_client()
metric_reader = metric_reader_instances[-1]
assert isinstance(metric_reader.exporter, DummyJsonMetricExporter)
assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000
assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint
assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token"
assert "preferred_temporality" in metric_reader.exporter.kwargs
def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
exporter_module = sys.modules["opentelemetry.exporter.otlp.http.json.metric_exporter"]
monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality)
_ = _build_client()
metric_reader = metric_reader_instances[-1]
assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality)
assert "preferred_temporality" not in metric_reader.exporter.kwargs
def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json")
def _fail_import(mod_path: str) -> types.ModuleType:
raise ModuleNotFoundError(mod_path)
monkeypatch.setattr(client_module.importlib, "import_module", _fail_import)
_ = _build_client()
metric_reader = metric_reader_instances[-1]
assert isinstance(metric_reader.exporter, DummyHttpMetricExporter)

View File

@@ -0,0 +1,359 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
from opentelemetry.trace import StatusCode
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
MessageTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.entities.semconv import (
GEN_AI_IS_ENTRY,
GEN_AI_IS_STREAMING_REQUEST,
GEN_AI_MODEL_NAME,
GEN_AI_SPAN_KIND,
GEN_AI_USAGE_INPUT_TOKENS,
INPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.tencent_trace.span_builder import TencentSpanBuilder
from core.rag.models.document import Document
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class TestTencentSpanBuilder:
def test_get_time_nanoseconds(self):
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
mock_convert.return_value = 123456789
dt = datetime.now()
result = TencentSpanBuilder._get_time_nanoseconds(dt)
assert result == 123456789
mock_convert.assert_called_once_with(dt)
def test_build_workflow_spans(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.workflow_run_id = "run_id"
trace_info.error = None
trace_info.start_time = datetime.now()
trace_info.end_time = datetime.now()
trace_info.workflow_run_inputs = {"sys.query": "hello"}
trace_info.workflow_run_outputs = {"answer": "world"}
trace_info.metadata = {"conversation_id": "conv_id"}
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
assert len(spans) == 2
assert spans[0].name == "message"
assert spans[0].span_id == 2
assert spans[1].name == "workflow"
assert spans[1].span_id == 1
assert spans[1].parent_span_id == 2
def test_build_workflow_spans_no_message(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.workflow_run_id = "run_id"
trace_info.error = "some error"
trace_info.start_time = datetime.now()
trace_info.end_time = datetime.now()
trace_info.workflow_run_inputs = {}
trace_info.workflow_run_outputs = {}
trace_info.metadata = {} # No conversation_id
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 1
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
assert len(spans) == 1
assert spans[0].name == "workflow"
assert spans[0].status.status_code == StatusCode.ERROR
assert spans[0].status.description == "some error"
assert spans[0].attributes[GEN_AI_IS_ENTRY] == "true"
def test_build_workflow_llm_span(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"conversation_id": "conv_id"}
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node_id"
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.process_data = {
"model_name": "gpt-4",
"model_provider": "openai",
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "time_to_first_token": 0.5},
"prompts": ["hello"],
}
node_execution.outputs = {"text": "world"}
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
assert span.name == "GENERATION"
assert span.attributes[GEN_AI_MODEL_NAME] == "gpt-4"
assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true"
assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "10"
def test_build_workflow_llm_span_usage_in_outputs(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node_id"
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.process_data = {}
node_execution.outputs = {
"text": "world",
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
}
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "15"
assert GEN_AI_IS_STREAMING_REQUEST not in span.attributes
def test_build_message_span_standalone(self):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.message_id = "msg_id"
trace_info.error = None
trace_info.start_time = datetime.now()
trace_info.end_time = datetime.now()
trace_info.inputs = {"q": "hi"}
trace_info.outputs = "hello"
trace_info.metadata = {"conversation_id": "conv_id"}
trace_info.is_streaming_request = True
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
assert span.name == "message"
assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true"
assert span.attributes[INPUT_VALUE] == str(trace_info.inputs)
def test_build_message_span_standalone_with_error(self):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.message_id = "msg_id"
trace_info.error = "some error"
trace_info.start_time = datetime.now()
trace_info.end_time = datetime.now()
trace_info.inputs = None
trace_info.outputs = None
trace_info.metadata = {}
trace_info.is_streaming_request = False
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
assert span.status.status_code == StatusCode.ERROR
assert span.status.description == "some error"
assert span.attributes[INPUT_VALUE] == ""
def test_build_tool_span(self):
trace_info = MagicMock(spec=ToolTraceInfo)
trace_info.message_id = "msg_id"
trace_info.tool_name = "search"
trace_info.error = "tool error"
trace_info.start_time = datetime.now()
trace_info.end_time = datetime.now()
trace_info.tool_parameters = {"p": 1}
trace_info.tool_inputs = {"i": 2}
trace_info.tool_outputs = "result"
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 101
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1)
assert span.name == "search"
assert span.status.status_code == StatusCode.ERROR
assert span.attributes[TOOL_NAME] == "search"
def test_build_retrieval_span(self):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_id = "msg_id"
trace_info.inputs = "query"
trace_info.error = None
trace_info.start_time = datetime.now()
trace_info.end_time = datetime.now()
doc = Document(
page_content="content", metadata={"dataset_id": "d1", "doc_id": "di1", "document_id": "du1", "score": 0.9}
)
trace_info.documents = [doc]
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
assert span.name == "retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "query"
assert "content" in span.attributes[RETRIEVAL_DOCUMENT]
def test_build_retrieval_span_with_error(self):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_id = "msg_id"
trace_info.inputs = ""
trace_info.error = "retrieval failed"
trace_info.start_time = datetime.now()
trace_info.end_time = datetime.now()
trace_info.documents = []
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
assert span.status.status_code == StatusCode.ERROR
assert span.status.description == "retrieval failed"
def test_get_workflow_node_status(self):
node = MagicMock(spec=WorkflowNodeExecution)
node.status = WorkflowNodeExecutionStatus.SUCCEEDED
assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.OK
node.status = WorkflowNodeExecutionStatus.FAILED
node.error = "fail"
status = TencentSpanBuilder._get_workflow_node_status(node)
assert status.status_code == StatusCode.ERROR
assert status.description == "fail"
node.status = WorkflowNodeExecutionStatus.EXCEPTION
node.error = "exc"
status = TencentSpanBuilder._get_workflow_node_status(node)
assert status.status_code == StatusCode.ERROR
assert status.description == "exc"
node.status = WorkflowNodeExecutionStatus.RUNNING
assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.UNSET
def test_build_workflow_retrieval_span(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"conversation_id": "conv_id"}
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node_id"
node_execution.title = "my retrieval"
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.inputs = {"query": "q1"}
node_execution.outputs = {"result": [{"content": "c1"}]}
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
assert span.name == "my retrieval"
assert span.attributes[RETRIEVAL_QUERY] == "q1"
assert "c1" in span.attributes[RETRIEVAL_DOCUMENT]
def test_build_workflow_retrieval_span_empty(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node_id"
node_execution.title = "my retrieval"
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.inputs = {}
node_execution.outputs = {}
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
assert span.attributes[RETRIEVAL_QUERY] == ""
assert span.attributes[RETRIEVAL_DOCUMENT] == ""
def test_build_workflow_tool_span(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node_id"
node_execution.title = "my tool"
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"info": "some"}}
node_execution.inputs = {"param": "val"}
node_execution.outputs = {"res": "ok"}
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
assert span.name == "my tool"
assert span.attributes[TOOL_NAME] == "my tool"
assert "some" in span.attributes[TOOL_DESCRIPTION]
def test_build_workflow_tool_span_no_metadata(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node_id"
node_execution.title = "my tool"
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.metadata = None
node_execution.inputs = None
node_execution.outputs = {"res": "ok"}
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
assert span.attributes[TOOL_DESCRIPTION] == "{}"
assert span.attributes[TOOL_PARAMETERS] == "{}"
def test_build_workflow_task_span(self):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"conversation_id": "conv_id"}
node_execution = MagicMock(spec=WorkflowNodeExecution)
node_execution.id = "node_id"
node_execution.title = "my task"
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
node_execution.inputs = {"in": 1}
node_execution.outputs = {"out": 2}
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 505
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution)
assert span.name == "my task"
assert span.attributes[GEN_AI_SPAN_KIND] == GenAISpanKind.TASK.value

View File

@@ -0,0 +1,647 @@
import logging
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.ops.tencent_trace.tencent_trace import TencentDataTrace
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import NodeType
from models import Account, App, TenantAccountJoin
logger = logging.getLogger(__name__)
@pytest.fixture
def tencent_config():
return TencentConfig(service_name="test-service", endpoint="https://test-endpoint", token="test-token")
@pytest.fixture
def mock_trace_client():
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock:
yield mock
@pytest.fixture
def mock_span_builder():
with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock:
yield mock
@pytest.fixture
def mock_trace_utils():
with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock:
yield mock
@pytest.fixture
def tencent_data_trace(tencent_config, mock_trace_client):
return TencentDataTrace(tencent_config)
class TestTencentDataTrace:
def test_init(self, tencent_config, mock_trace_client):
trace = TencentDataTrace(tencent_config)
mock_trace_client.assert_called_once_with(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
token=tencent_config.token,
metrics_export_interval_sec=5,
)
assert trace.trace_client == mock_trace_client.return_value
def test_trace_dispatch(self, tencent_data_trace):
methods = [
(
WorkflowTraceInfo(
workflow_id="wf",
tenant_id="t",
workflow_run_id="run",
workflow_run_elapsed_time=1.0,
workflow_run_status="s",
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_version="v",
total_tokens=0,
file_list=[],
query="",
metadata={},
),
"workflow_trace",
),
(
MessageTraceInfo(
message_id="msg",
message_data={},
inputs={},
outputs={},
start_time=None,
end_time=None,
conversation_mode="chat",
conversation_model="gpt-3.5-turbo",
message_tokens=0,
answer_tokens=0,
total_tokens=0,
metadata={},
),
"message_trace",
),
(
ModerationTraceInfo(
flagged=False, action="a", preset_response="p", query="q", metadata={}, message_id="m"
),
None,
), # Pass
(
SuggestedQuestionTraceInfo(
suggested_question=[],
level="l",
total_tokens=0,
metadata={},
message_id="m",
message_data={},
inputs={},
start_time=None,
end_time=None,
),
"suggested_question_trace",
),
(
DatasetRetrievalTraceInfo(
metadata={},
message_id="m",
message_data={},
inputs={},
documents=[],
start_time=None,
end_time=None,
),
"dataset_retrieval_trace",
),
(
ToolTraceInfo(
tool_name="t",
tool_inputs={},
tool_outputs="",
tool_config={},
tool_parameters={},
time_cost=0,
metadata={},
message_id="m",
inputs={},
outputs={},
start_time=None,
end_time=None,
),
"tool_trace",
),
(
GenerateNameTraceInfo(
tenant_id="t", metadata={}, message_id="m", inputs={}, outputs={}, start_time=None, end_time=None
),
None,
), # Pass
]
for trace_info, method_name in methods:
if method_name:
with patch.object(tencent_data_trace, method_name) as mock_method:
tencent_data_trace.trace(trace_info)
mock_method.assert_called_once_with(trace_info)
else:
tencent_data_trace.trace(trace_info)
def test_api_check(self, tencent_data_trace):
tencent_data_trace.trace_client.api_check.return_value = True
assert tencent_data_trace.api_check() is True
tencent_data_trace.trace_client.api_check.assert_called_once()
def test_get_project_url(self, tencent_data_trace):
tencent_data_trace.trace_client.get_project_url.return_value = "http://url"
assert tencent_data_trace.get_project_url() == "http://url"
tencent_data_trace.trace_client.get_project_url.assert_called_once()
def test_workflow_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.workflow_run_id = "run-id"
trace_info.trace_id = "parent-trace-id"
mock_trace_utils.convert_to_trace_id.return_value = 123
mock_trace_utils.create_link.return_value = "link"
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc:
with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur:
mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()]
tencent_data_trace.workflow_trace(trace_info)
mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id")
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
mock_span_builder.build_workflow_spans.assert_called_once()
assert tencent_data_trace.trace_client.add_span.call_count == 2
mock_proc.assert_called_once_with(trace_info, 123)
mock_dur.assert_called_once_with(trace_info)
def test_workflow_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.workflow_run_id = "run-id"
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.workflow_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.message_id = "msg-id"
trace_info.trace_id = "parent-trace-id"
mock_trace_utils.convert_to_trace_id.return_value = 123
mock_trace_utils.create_link.return_value = "link"
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics:
with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur:
mock_span_builder.build_message_span.return_value = MagicMock()
tencent_data_trace.message_trace(trace_info)
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
mock_span_builder.build_message_span.assert_called_once()
tencent_data_trace.trace_client.add_span.assert_called_once()
mock_metrics.assert_called_once_with(trace_info)
mock_dur.assert_called_once_with(trace_info)
def test_message_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.message_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
trace_info = MagicMock(spec=ToolTraceInfo)
trace_info.message_id = "msg-id"
mock_trace_utils.convert_to_span_id.return_value = 456
mock_trace_utils.convert_to_trace_id.return_value = 123
tencent_data_trace.tool_trace(trace_info)
mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message")
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
mock_span_builder.build_tool_span.assert_called_once_with(trace_info, 123, 456)
tencent_data_trace.trace_client.add_span.assert_called_once()
def test_tool_trace_no_msg_id(self, tencent_data_trace):
trace_info = MagicMock(spec=ToolTraceInfo)
trace_info.message_id = None
tencent_data_trace.tool_trace(trace_info)
tencent_data_trace.trace_client.add_span.assert_not_called()
def test_tool_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=ToolTraceInfo)
trace_info.message_id = "msg-id"
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.tool_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_id = "msg-id"
mock_trace_utils.convert_to_span_id.return_value = 456
mock_trace_utils.convert_to_trace_id.return_value = 123
tencent_data_trace.dataset_retrieval_trace(trace_info)
mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message")
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
mock_span_builder.build_retrieval_span.assert_called_once_with(trace_info, 123, 456)
tencent_data_trace.trace_client.add_span.assert_called_once()
def test_dataset_retrieval_trace_no_msg_id(self, tencent_data_trace):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_id = None
tencent_data_trace.dataset_retrieval_trace(trace_info)
tencent_data_trace.trace_client.add_span.assert_not_called()
def test_dataset_retrieval_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
trace_info.message_id = "msg-id"
with patch(
"core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.dataset_retrieval_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
def test_suggested_question_trace(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
def test_suggested_question_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.workflow_run_id = "run-id"
mock_trace_utils.convert_to_span_id.return_value = 111
node1 = MagicMock(spec=WorkflowNodeExecution)
node1.id = "n1"
node1.node_type = NodeType.LLM
node2 = MagicMock(spec=WorkflowNodeExecution)
node2.id = "n2"
node2.node_type = NodeType.TOOL
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]):
with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
assert tencent_data_trace.trace_client.add_span.call_count == 2
mock_metrics.assert_called_once_with(node1)
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils):
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.return_value = 111
node = MagicMock(spec=WorkflowNodeExecution)
node.id = "n1"
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
# The exception should be caught by the outer handler since convert_to_span_id is called first
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils):
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder):
trace_info = MagicMock(spec=WorkflowTraceInfo)
nodes = [
(NodeType.LLM, mock_span_builder.build_workflow_llm_span),
(NodeType.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span),
(NodeType.TOOL, mock_span_builder.build_workflow_tool_span),
(NodeType.CODE, mock_span_builder.build_workflow_task_span),
]
for node_type, builder_method in nodes:
node = MagicMock(spec=WorkflowNodeExecution)
node.node_type = node_type
builder_method.return_value = "span"
result = tencent_data_trace._build_workflow_node_span(node, 123, trace_info, 456)
assert result == "span"
builder_method.assert_called_once_with(123, 456, trace_info, node)
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder):
node = MagicMock(spec=WorkflowNodeExecution)
node.node_type = NodeType.LLM
node.id = "n1"
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
assert result is None
mock_log.assert_called_once()
def test_get_workflow_node_executions(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"app_id": "app-1"}
trace_info.workflow_run_id = "run-1"
app = MagicMock(spec=App)
app.id = "app-1"
app.created_by = "user-1"
account = MagicMock(spec=Account)
account.id = "user-1"
tenant_join = MagicMock(spec=TenantAccountJoin)
tenant_join.tenant_id = "tenant-1"
mock_executions = [MagicMock()]
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
mock_db.engine = "engine"
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.side_effect = [app, account]
session.query.return_value.filter_by.return_value.first.return_value = tenant_join
with patch(
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
) as mock_repo:
mock_repo.return_value.get_by_workflow_run.return_value = mock_executions
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == mock_executions
account.set_tenant_id.assert_called_once_with("tenant-1")
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"app_id": "app-1"}
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock() # Ensure init_app is mocked
mock_db.engine = "engine"
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.return_value = None
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
def test_get_user_id_workflow(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.tenant_id = "tenant-1"
trace_info.metadata = {"user_id": "user-1"}
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock()
mock_db.engine = MagicMock()
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
def test_get_user_id_only_user_id(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = {"user_id": "user-1"}
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "user-1"
def test_get_user_id_anonymous(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = {}
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "anonymous"
def test_get_user_id_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.tenant_id = "t"
trace_info.metadata = {"user_id": "u"}
with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")):
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace):
node = MagicMock(spec=WorkflowNodeExecution)
node.process_data = {
"usage": {
"latency": 2.5,
"time_to_first_token": 0.5,
"time_to_generate": 2.0,
"prompt_tokens": 10,
"completion_tokens": 20,
},
"model_provider": "openai",
"model_name": "gpt-4",
"model_mode": "chat",
}
node.outputs = {}
tencent_data_trace._record_llm_metrics(node)
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once()
tencent_data_trace.trace_client.record_time_to_generate.assert_called_once()
assert tencent_data_trace.trace_client.record_token_usage.call_count == 2
def test_record_llm_metrics_usage_in_outputs(self, tencent_data_trace):
node = MagicMock(spec=WorkflowNodeExecution)
node.process_data = {}
node.outputs = {"usage": {"latency": 1.0, "prompt_tokens": 5}}
tencent_data_trace._record_llm_metrics(node)
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
tencent_data_trace.trace_client.record_token_usage.assert_called_once()
def test_record_llm_metrics_exception(self, tencent_data_trace):
node = MagicMock(spec=WorkflowNodeExecution)
node.process_data = None
node.outputs = None
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_llm_metrics(node)
# Should not crash
def test_record_message_llm_metrics(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = {"ls_provider": "openai", "ls_model_name": "gpt-4"}
trace_info.message_data = {"provider_response_latency": 1.1}
trace_info.is_streaming_request = True
trace_info.gen_ai_server_time_to_first_token = 0.2
trace_info.llm_streaming_time_to_generate = 0.9
trace_info.message_tokens = 15
trace_info.answer_tokens = 25
tencent_data_trace._record_message_llm_metrics(trace_info)
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once()
tencent_data_trace.trace_client.record_time_to_generate.assert_called_once()
assert tencent_data_trace.trace_client.record_token_usage.call_count == 2
def test_record_message_llm_metrics_object_data(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = {}
msg_data = MagicMock()
msg_data.provider_response_latency = 1.1
msg_data.model_provider = "anthropic"
msg_data.model_id = "claude"
trace_info.message_data = msg_data
trace_info.is_streaming_request = False
tencent_data_trace._record_message_llm_metrics(trace_info)
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
def test_record_message_llm_metrics_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = None
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_llm_metrics(trace_info)
# Should not crash
def test_record_workflow_trace_duration(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
from datetime import datetime, timedelta
now = datetime.now()
trace_info.start_time = now
trace_info.end_time = now + timedelta(seconds=3)
trace_info.workflow_run_status = "succeeded"
trace_info.conversation_id = "conv-1"
# Mock the record_trace_duration method to capture arguments
with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record:
tencent_data_trace._record_workflow_trace_duration(trace_info)
# Assert the method was called once
mock_record.assert_called_once()
# Extract arguments passed to the method
args, kwargs = mock_record.call_args
# Validate the duration argument
assert args[0] == 3.0
# Validate the attributes dict in kwargs
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
assert attributes["conversation_mode"] == "workflow"
assert attributes["has_conversation"] == "true"
def test_record_workflow_trace_duration_fallback(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.start_time = None
trace_info.workflow_run_elapsed_time = 4.5
trace_info.workflow_run_status = "failed"
trace_info.conversation_id = None
with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record:
tencent_data_trace._record_workflow_trace_duration(trace_info)
mock_record.assert_called_once()
args, kwargs = mock_record.call_args
assert args[0] == 4.5
# Check attributes dict (either in kwargs or as second positional arg)
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
assert attributes["has_conversation"] == "false"
def test_record_workflow_trace_duration_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_workflow_trace_duration(trace_info)
def test_record_message_trace_duration(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
from datetime import datetime, timedelta
now = datetime.now()
trace_info.start_time = now
trace_info.end_time = now + timedelta(seconds=2)
trace_info.conversation_mode = "chat"
trace_info.is_streaming_request = True
tencent_data_trace._record_message_trace_duration(trace_info)
tencent_data_trace.trace_client.record_trace_duration.assert_called_once_with(
2.0, {"conversation_mode": "chat", "stream": "true"}
)
def test_record_message_trace_duration_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.start_time = None
with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
def test_del(self, tencent_data_trace):
client = tencent_data_trace.trace_client
tencent_data_trace.__del__()
client.shutdown.assert_called_once()
def test_del_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.__del__()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")

View File

@@ -0,0 +1,106 @@
"""Unit tests for Tencent APM tracing utilities."""
from __future__ import annotations
import hashlib
import uuid
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from opentelemetry.trace import Link, TraceFlags
from core.ops.tencent_trace.utils import TencentTraceUtils
def test_convert_to_trace_id_with_valid_uuid() -> None:
uuid_str = "12345678-1234-5678-1234-567812345678"
assert TencentTraceUtils.convert_to_trace_id(uuid_str) == uuid.UUID(uuid_str).int
def test_convert_to_trace_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int
uuid4_mock.assert_called_once()
def test_convert_to_trace_id_raises_value_error_for_invalid_uuid() -> None:
with pytest.raises(ValueError, match=r"^Invalid UUID input:"):
TencentTraceUtils.convert_to_trace_id("not-a-uuid")
def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None:
uuid_str = "12345678-1234-5678-1234-567812345678"
span_type = "llm"
uuid_obj = uuid.UUID(uuid_str)
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
expected = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
assert TencentTraceUtils.convert_to_span_id(uuid_str, span_type) == expected
assert TencentTraceUtils.convert_to_span_id(uuid_str, "other") != expected
def test_convert_to_span_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
span_id = TencentTraceUtils.convert_to_span_id(None, "workflow")
assert isinstance(span_id, int)
uuid4_mock.assert_called_once()
def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None:
with pytest.raises(ValueError, match=r"^Invalid UUID input:"):
TencentTraceUtils.convert_to_span_id("bad-uuid", "span")
def test_generate_span_id_skips_invalid_span_id() -> None:
with patch(
"core.ops.tencent_trace.utils.random.getrandbits",
side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42],
) as bits_mock:
assert TencentTraceUtils.generate_span_id() == 42
assert bits_mock.call_count == 2
def test_convert_datetime_to_nanoseconds_accepts_datetime() -> None:
start_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
expected = int(start_time.timestamp() * 1e9)
assert TencentTraceUtils.convert_datetime_to_nanoseconds(start_time) == expected
def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None:
fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC)
expected = int(fixed.timestamp() * 1e9)
with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock:
datetime_mock.now.return_value = fixed
assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected
datetime_mock.now.assert_called_once()
@pytest.mark.parametrize(
("trace_id_str", "expected_trace_id"),
[
("0" * 31 + "1", int("0" * 31 + "1", 16)),
(str(uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc")), uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc").int),
],
)
def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: int) -> None:
link = TencentTraceUtils.create_link(trace_id_str)
assert isinstance(link, Link)
assert link.context.trace_id == expected_trace_id
assert link.context.span_id == TencentTraceUtils.INVALID_SPAN_ID
assert link.context.is_remote is False
assert link.context.trace_flags == TraceFlags(TraceFlags.SAMPLED)
@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None])
def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None:
fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd")
with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type]
assert link.context.trace_id == fallback_uuid.int
uuid4_mock.assert_called_once()

View File

@@ -0,0 +1,112 @@
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import Session
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.entities.trace_entity import BaseTraceInfo
from models import Account, App, TenantAccountJoin
class ConcreteTraceInstance(BaseTraceInstance):
def __init__(self, trace_config: BaseTracingConfig):
super().__init__(trace_config)
def trace(self, trace_info: BaseTraceInfo):
super().trace(trace_info)
@pytest.fixture
def mock_db_session(monkeypatch):
mock_session = MagicMock(spec=Session)
mock_session.__enter__.return_value = mock_session
mock_session.__exit__.return_value = None
mock_session_class = MagicMock(return_value=mock_session)
monkeypatch.setattr("core.ops.base_trace_instance.Session", mock_session_class)
monkeypatch.setattr("core.ops.base_trace_instance.db", MagicMock())
return mock_session
def test_get_service_account_with_tenant_app_not_found(mock_db_session):
mock_db_session.scalar.return_value = None
config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config)
with pytest.raises(ValueError, match="App with id some_app_id not found"):
instance.get_service_account_with_tenant("some_app_id")
def test_get_service_account_with_tenant_no_creator(mock_db_session):
mock_app = MagicMock(spec=App)
mock_app.id = "some_app_id"
mock_app.created_by = None
mock_db_session.scalar.return_value = mock_app
config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config)
with pytest.raises(ValueError, match="App with id some_app_id has no creator"):
instance.get_service_account_with_tenant("some_app_id")
def test_get_service_account_with_tenant_creator_not_found(mock_db_session):
mock_app = MagicMock(spec=App)
mock_app.id = "some_app_id"
mock_app.created_by = "creator_id"
# First call to scalar returns app, second returns None (for account)
mock_db_session.scalar.side_effect = [mock_app, None]
config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config)
with pytest.raises(ValueError, match="Creator account with id creator_id not found for app some_app_id"):
instance.get_service_account_with_tenant("some_app_id")
def test_get_service_account_with_tenant_tenant_not_found(mock_db_session):
mock_app = MagicMock(spec=App)
mock_app.id = "some_app_id"
mock_app.created_by = "creator_id"
mock_account = MagicMock(spec=Account)
mock_account.id = "creator_id"
mock_db_session.scalar.side_effect = [mock_app, mock_account]
# session.query(TenantAccountJoin).filter_by(...).first() returns None
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config)
with pytest.raises(ValueError, match="Current tenant not found for account creator_id"):
instance.get_service_account_with_tenant("some_app_id")
def test_get_service_account_with_tenant_success(mock_db_session):
mock_app = MagicMock(spec=App)
mock_app.id = "some_app_id"
mock_app.created_by = "creator_id"
mock_account = MagicMock(spec=Account)
mock_account.id = "creator_id"
mock_account.set_tenant_id = MagicMock()
mock_db_session.scalar.side_effect = [mock_app, mock_account]
mock_tenant_join = MagicMock(spec=TenantAccountJoin)
mock_tenant_join.tenant_id = "tenant_id"
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join
config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config)
result = instance.get_service_account_with_tenant("some_app_id")
assert result == mock_account
mock_account.set_tenant_id.assert_called_once_with("tenant_id")

View File

@@ -0,0 +1,576 @@
import contextlib
import json
import queue
from datetime import datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.ops.ops_trace_manager import (
OpsTraceManager,
TraceQueueManager,
TraceTask,
TraceTaskName,
)
class DummyConfig:
def __init__(self, **kwargs):
self._data = kwargs
def model_dump(self):
return dict(self._data)
class DummyTraceInstance:
instances: list["DummyTraceInstance"] = []
def __init__(self, config):
self.config = config
DummyTraceInstance.instances.append(self)
def api_check(self):
return True
def get_project_key(self):
return "fake-key"
def get_project_url(self):
return "https://project.fake"
FAKE_PROVIDER_ENTRY = {
"config_class": DummyConfig,
"secret_keys": ["secret_value"],
"other_keys": ["other_value"],
"trace_instance": DummyTraceInstance,
}
class FakeProviderMap:
def __init__(self, data):
self._data = data
def __getitem__(self, key):
if key in self._data:
return self._data[key]
raise KeyError(f"Unsupported tracing provider: {key}")
class DummyTimer:
def __init__(self, interval, function):
self.interval = interval
self.function = function
self.name = ""
self.daemon = False
self.started = False
def start(self):
self.started = True
def is_alive(self):
return False
class FakeMessageFile:
def __init__(self):
self.url = "path/to/file"
self.id = "file-id"
self.type = "document"
self.created_by_role = "role"
self.created_by = "user"
def make_message_data(**overrides):
created_at = datetime(2025, 2, 20, 12, 0, 0)
base = {
"id": "msg-id",
"conversation_id": "conv-id",
"created_at": created_at,
"updated_at": created_at + timedelta(seconds=3),
"message": "hello",
"provider_response_latency": 1,
"message_tokens": 5,
"answer_tokens": 7,
"answer": "world",
"error": "",
"status": "complete",
"model_provider": "provider",
"model_id": "model",
"from_end_user_id": "end-user",
"from_account_id": "account",
"agent_based": False,
"workflow_run_id": "workflow-run",
"from_source": "source",
"message_metadata": json.dumps({"usage": {"time_to_first_token": 1, "time_to_generate": 2}}),
"agent_thoughts": [],
"query": "sample-query",
"inputs": "sample-input",
}
base.update(overrides)
class MessageData:
def __init__(self, data):
self.__dict__.update(data)
def to_dict(self):
return dict(self.__dict__)
return MessageData(base)
def make_agent_thought(tool_name, created_at):
return SimpleNamespace(
tools=[tool_name],
created_at=created_at,
tool_meta={
tool_name: {
"tool_config": {"foo": "bar"},
"time_cost": 5,
"error": "",
"tool_parameters": {"x": 1},
}
},
)
def make_workflow_run():
return SimpleNamespace(
workflow_id="wf-1",
tenant_id="tenant",
id="run-id",
elapsed_time=10,
status="finished",
inputs_dict={"sys.file": ["f1"], "query": "search"},
outputs_dict={"out": "value"},
version="3",
error=None,
total_tokens=12,
workflow_run_id="run-id",
created_at=datetime(2025, 2, 20, 10, 0, 0),
finished_at=datetime(2025, 2, 20, 10, 0, 5),
triggered_from="user",
app_id="app-id",
to_dict=lambda self=None: {"run": "value"},
)
def configure_db_query(session, *, message_file=None, workflow_app_log=None):
def _side_effect(model):
query = MagicMock()
query.filter_by.return_value.first.return_value = None
if message_file and model.__name__ == "MessageFile":
query.filter_by.return_value.first.return_value = message_file
if workflow_app_log and model.__name__ == "WorkflowAppLog":
query.filter_by.return_value.first.return_value = workflow_app_log
return query
session.query.side_effect = _side_effect
class DummySessionContext:
scalar_values = []
def __init__(self, engine):
self._values = list(self.scalar_values)
self._index = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def scalar(self, *args, **kwargs):
if self._index >= len(self._values):
return None
value = self._values[self._index]
self._index += 1
return value
@pytest.fixture(autouse=True)
def patch_provider_map(monkeypatch):
monkeypatch.setattr(
"core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({"dummy": FAKE_PROVIDER_ENTRY})
)
OpsTraceManager.ops_trace_instances_cache.clear()
OpsTraceManager.decrypted_configs_cache.clear()
@pytest.fixture(autouse=True)
def patch_timer_and_current_app(monkeypatch):
monkeypatch.setattr("core.ops.ops_trace_manager.threading.Timer", DummyTimer)
monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_queue", queue.Queue())
monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_timer", None)
class FakeApp:
def app_context(self):
return contextlib.nullcontext()
fake_current = MagicMock()
fake_current._get_current_object.return_value = FakeApp()
monkeypatch.setattr("core.ops.ops_trace_manager.current_app", fake_current)
@pytest.fixture(autouse=True)
def patch_sqlalchemy_session(monkeypatch):
monkeypatch.setattr("core.ops.ops_trace_manager.Session", DummySessionContext)
@pytest.fixture
def encryption_mocks(monkeypatch):
encrypt_mock = MagicMock(side_effect=lambda tenant, value: f"enc-{value}")
batch_decrypt_mock = MagicMock(side_effect=lambda tenant, values: [f"dec-{value}" for value in values])
obfuscate_mock = MagicMock(side_effect=lambda value: f"ob-{value}")
monkeypatch.setattr("core.ops.ops_trace_manager.encrypt_token", encrypt_mock)
monkeypatch.setattr("core.ops.ops_trace_manager.batch_decrypt_token", batch_decrypt_mock)
monkeypatch.setattr("core.ops.ops_trace_manager.obfuscated_token", obfuscate_mock)
return encrypt_mock, batch_decrypt_mock, obfuscate_mock
@pytest.fixture
def mock_db(monkeypatch):
session = MagicMock()
session.scalars.return_value.all.return_value = ["chat"]
db_mock = MagicMock()
db_mock.session = session
db_mock.engine = MagicMock()
monkeypatch.setattr("core.ops.ops_trace_manager.db", db_mock)
return session
@pytest.fixture
def workflow_repo_fixture(monkeypatch):
repo = MagicMock()
repo.get_workflow_run_by_id_without_tenant.return_value = make_workflow_run()
monkeypatch.setattr(TraceTask, "_get_workflow_run_repo", classmethod(lambda cls: repo))
return repo
@pytest.fixture
def trace_task_message(monkeypatch, mock_db):
message_data = make_message_data()
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data)
configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id"))
return message_data
def test_encrypt_tracing_config_handles_star_and_encrypt(encryption_mocks):
encrypted = OpsTraceManager.encrypt_tracing_config(
"tenant",
"dummy",
{"secret_value": "value", "other_value": "info"},
current_trace_config={"secret_value": "keep"},
)
assert encrypted["secret_value"] == "enc-value"
assert encrypted["other_value"] == "info"
def test_encrypt_tracing_config_preserves_star(encryption_mocks):
encrypted = OpsTraceManager.encrypt_tracing_config(
"tenant",
"dummy",
{"secret_value": "*", "other_value": "info"},
current_trace_config={"secret_value": "keep"},
)
assert encrypted["secret_value"] == "keep"
def test_decrypt_tracing_config_caches(encryption_mocks):
_, decrypt_mock, _ = encryption_mocks
payload = {"secret_value": "enc", "other_value": "info"}
first = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload)
second = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload)
assert first == second
assert decrypt_mock.call_count == 1
def test_obfuscated_decrypt_token(encryption_mocks):
_, _, obfuscate_mock = encryption_mocks
result = OpsTraceManager.obfuscated_decrypt_token("dummy", {"secret_value": "value", "other_value": "info"})
assert "secret_value" in result
assert result["secret_value"] == "ob-value"
obfuscate_mock.assert_called_once()
def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db):
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"})
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
app = SimpleNamespace(id="app-id", tenant_id="tenant")
mock_db.scalar.return_value = app
decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
assert decrypted["other_value"] == "info"
def test_get_decrypted_tracing_config_missing_trace_config(mock_db):
mock_db.query.return_value.where.return_value.first.return_value = None
assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None
def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db):
trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"})
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
mock_db.scalar.return_value = None
with pytest.raises(ValueError, match="App not found"):
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
def test_get_decrypted_tracing_config_raises_for_none_config(mock_db):
trace_config_data = SimpleNamespace(tracing_config=None)
mock_db.query.return_value.where.return_value.first.return_value = trace_config_data
mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant")
with pytest.raises(ValueError, match="Tracing config cannot be None"):
OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy")
def test_get_ops_trace_instance_handles_none_app(mock_db):
mock_db.query.return_value.where.return_value.first.return_value = None
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch):
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False}))
mock_db.query.return_value.where.return_value.first.return_value = app
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch):
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"}))
mock_db.query.return_value.where.return_value.first.return_value = app
monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({}))
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
def test_get_ops_trace_instance_success(monkeypatch, mock_db):
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"}))
mock_db.query.return_value.where.return_value.first.return_value = app
monkeypatch.setattr(
"core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config",
classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}),
)
instance = OpsTraceManager.get_ops_trace_instance("app-id")
assert instance is not None
cached_instance = OpsTraceManager.get_ops_trace_instance("app-id")
assert instance is cached_instance
def test_get_app_config_through_message_id_returns_none(mock_db):
mock_db.scalar.return_value = None
assert OpsTraceManager.get_app_config_through_message_id("m") is None
def test_get_app_config_through_message_id_prefers_override(mock_db):
message = SimpleNamespace(conversation_id="conv")
conversation = SimpleNamespace(app_model_config_id=None, override_model_configs={"foo": "bar"})
app_config = SimpleNamespace(id="config-id")
mock_db.scalar.side_effect = [message, conversation]
result = OpsTraceManager.get_app_config_through_message_id("m")
assert result == {"foo": "bar"}
def test_get_app_config_through_message_id_app_model_config(mock_db):
message = SimpleNamespace(conversation_id="conv")
conversation = SimpleNamespace(app_model_config_id="cfg", override_model_configs=None)
mock_db.scalar.side_effect = [message, conversation, SimpleNamespace(id="cfg")]
result = OpsTraceManager.get_app_config_through_message_id("m")
assert result.id == "cfg"
def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch):
mock_db.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="Invalid tracing provider"):
OpsTraceManager.update_app_tracing_config("app", True, "bad")
with pytest.raises(ValueError, match="App not found"):
OpsTraceManager.update_app_tracing_config("app", True, None)
def test_update_app_tracing_config_success(mock_db):
app = SimpleNamespace(id="app-id", tracing="{}")
mock_db.query.return_value.where.return_value.first.return_value = app
OpsTraceManager.update_app_tracing_config("app-id", True, "dummy")
assert app.tracing is not None
mock_db.commit.assert_called_once()
def test_get_app_tracing_config_errors_when_missing(mock_db):
mock_db.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="App not found"):
OpsTraceManager.get_app_tracing_config("app")
def test_get_app_tracing_config_returns_defaults(mock_db):
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None)
assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None}
def test_get_app_tracing_config_returns_payload(mock_db):
payload = {"enabled": True, "tracing_provider": "dummy"}
mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload))
assert OpsTraceManager.get_app_tracing_config("app-id") == payload
def test_check_and_project_helpers(monkeypatch):
monkeypatch.setattr(
"core.ops.ops_trace_manager.provider_config_map",
FakeProviderMap(
{
"dummy": {
"config_class": DummyConfig,
"trace_instance": type(
"Trace",
(),
{
"__init__": lambda self, cfg: None,
"api_check": lambda self: True,
"get_project_key": lambda self: "key",
"get_project_url": lambda self: "url",
},
),
"secret_keys": [],
"other_keys": [],
}
}
),
)
assert OpsTraceManager.check_trace_config_is_effective({}, "dummy")
assert OpsTraceManager.get_trace_config_project_key({}, "dummy") == "key"
assert OpsTraceManager.get_trace_config_project_url({}, "dummy") == "url"
def test_trace_task_conversation_and_extract(monkeypatch):
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE, message_id="msg")
assert task.conversation_trace(foo="bar") == {"foo": "bar"}
assert task._extract_streaming_metrics(make_message_data(message_metadata="not json")) == {}
def test_trace_task_message_trace(trace_task_message, mock_db):
task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id")
result = task.message_trace("msg-id")
assert result.message_id == "msg-id"
def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db):
DummySessionContext.scalar_values = ["wf-app-log", "message-ref"]
execution = SimpleNamespace(id_="run-id")
task = TraceTask(
trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user"
)
result = task.workflow_trace(workflow_run_id="run-id", conversation_id="conv", user_id="user")
assert result.workflow_run_id == "run-id"
assert result.workflow_id == "wf-1"
def test_trace_task_moderation_trace(trace_task_message):
task = TraceTask(trace_type=TraceTaskName.MODERATION_TRACE, message_id="msg-id")
moderation_result = SimpleNamespace(action="block", preset_response="no", query="q", flagged=True)
timer = {"start": 1, "end": 2}
result = task.moderation_trace("msg-id", timer, moderation_result=moderation_result, inputs={"src": "payload"})
assert result.flagged is True
assert result.message_id == "log-id"
def test_trace_task_suggested_question_trace(trace_task_message):
task = TraceTask(trace_type=TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id="msg-id")
timer = {"start": 1, "end": 2}
result = task.suggested_question_trace("msg-id", timer, suggested_question=["q1"])
assert result.message_id == "log-id"
assert "suggested_question" in result.__dict__
def test_trace_task_dataset_retrieval_trace(trace_task_message):
task = TraceTask(trace_type=TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id="msg-id")
timer = {"start": 1, "end": 2}
mock_doc = SimpleNamespace(model_dump=lambda: {"doc": "value"})
result = task.dataset_retrieval_trace("msg-id", timer, documents=[mock_doc])
assert result.documents == [{"doc": "value"}]
def test_trace_task_tool_trace(monkeypatch, mock_db):
custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))])
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message)
configure_db_query(mock_db, message_file=FakeMessageFile())
task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id")
timer = {"start": 1, "end": 5}
result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result")
assert result.tool_name == "tool-a"
assert result.time_cost == 5
def test_trace_task_generate_name_trace():
task = TraceTask(trace_type=TraceTaskName.GENERATE_NAME_TRACE, conversation_id="conv-id")
timer = {"start": 1, "end": 2}
assert task.generate_name_trace("conv-id", timer, tenant_id=None) == {}
result = task.generate_name_trace(
"conv-id", timer, tenant_id="tenant", generate_conversation_name="name", inputs="q"
)
assert result.outputs == "name"
assert result.tenant_id == "tenant"
def test_extract_streaming_metrics_invalid_json():
task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id")
fake_message = make_message_data(message_metadata="invalid")
assert task._extract_streaming_metrics(fake_message) == {}
def test_trace_queue_manager_add_and_collect(monkeypatch):
monkeypatch.setattr(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
)
manager = TraceQueueManager(app_id="app-id", user_id="user")
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
manager.add_trace_task(task)
tasks = manager.collect_tasks()
assert tasks == [task]
def test_trace_queue_manager_run_invokes_send(monkeypatch):
monkeypatch.setattr(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
)
manager = TraceQueueManager(app_id="app-id", user_id="user")
task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE)
called = {}
def fake_collect():
return [task]
def fake_send(tasks):
called["tasks"] = tasks
monkeypatch.setattr(TraceQueueManager, "collect_tasks", lambda self: fake_collect())
monkeypatch.setattr(TraceQueueManager, "send_to_celery", lambda self, t: fake_send(t))
manager.run()
assert called["tasks"] == [task]
def test_trace_queue_manager_send_to_celery(monkeypatch):
monkeypatch.setattr(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True)
)
storage_save = MagicMock()
process_delay = MagicMock()
monkeypatch.setattr("core.ops.ops_trace_manager.storage.save", storage_save)
monkeypatch.setattr("core.ops.ops_trace_manager.process_trace_tasks.delay", process_delay)
monkeypatch.setattr("core.ops.ops_trace_manager.uuid4", MagicMock(return_value=SimpleNamespace(hex="file-123")))
manager = TraceQueueManager(app_id="app-id", user_id="user")
class DummyTraceInfo:
def model_dump(self):
return {"trace": "info"}
class DummyTask:
def __init__(self):
self.app_id = "app-id"
def execute(self):
return DummyTraceInfo()
task = DummyTask()
manager.send_to_celery([task])
storage_save.assert_called_once()
process_delay.assert_called_once_with({"file_id": "file-123", "app_id": "app-id"})

View File

@@ -1,9 +1,20 @@
import re
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path
from core.ops.utils import (
filter_none_values,
generate_dotted_order,
get_message_data,
measure_time,
replace_text_with_content,
validate_integer_id,
validate_project_name,
validate_url,
validate_url_with_path,
)
class TestValidateUrl:
@@ -187,3 +198,92 @@ class TestGenerateDottedOrder:
result = generate_dotted_order(run_id, start_time, None)
assert "." not in result
def test_dotted_order_with_string_start_time(self):
"""Test dotted_order generation with string start_time."""
start_time = "2025-12-23T04:19:55.111000"
run_id = "test-run-id"
result = generate_dotted_order(run_id, start_time)
assert result == "20251223T041955111000Ztest-run-id"
class TestFilterNoneValues:
"""Test cases for filter_none_values function"""
def test_filter_none_values(self):
data = {"a": 1, "b": None, "c": "test", "d": datetime(2025, 1, 1, 12, 0, 0)}
result = filter_none_values(data)
assert result == {"a": 1, "c": "test", "d": "2025-01-01T12:00:00"}
def test_filter_none_values_empty(self):
assert filter_none_values({}) == {}
class TestGetMessageData:
"""Test cases for get_message_data function"""
@patch("core.ops.utils.db")
@patch("core.ops.utils.Message")
@patch("core.ops.utils.select")
def test_get_message_data(self, mock_select, mock_message, mock_db):
mock_scalar = mock_db.session.scalar
mock_msg_instance = MagicMock()
mock_scalar.return_value = mock_msg_instance
result = get_message_data("message-id")
assert result == mock_msg_instance
mock_select.assert_called_once()
mock_scalar.assert_called_once()
class TestMeasureTime:
"""Test cases for measure_time function"""
def test_measure_time(self):
with measure_time() as timing_info:
assert "start" in timing_info
assert isinstance(timing_info["start"], datetime)
assert timing_info["end"] is None
assert timing_info["end"] is not None
assert isinstance(timing_info["end"], datetime)
assert timing_info["end"] >= timing_info["start"]
class TestReplaceTextWithContent:
"""Test cases for replace_text_with_content function"""
def test_replace_text_with_content_dict(self):
data = {"text": "hello", "other": "world"}
assert replace_text_with_content(data) == {"content": "hello", "other": "world"}
def test_replace_text_with_content_nested(self):
data = {"text": "v1", "nested": {"text": "v2", "list": [{"text": "v3"}]}}
expected = {"content": "v1", "nested": {"content": "v2", "list": [{"content": "v3"}]}}
assert replace_text_with_content(data) == expected
def test_replace_text_with_content_list(self):
data = [{"text": "v1"}, "v2"]
assert replace_text_with_content(data) == [{"content": "v1"}, "v2"]
def test_replace_text_with_content_primitive(self):
assert replace_text_with_content(123) == 123
assert replace_text_with_content("text") == "text"
class TestValidateIntegerId:
"""Test cases for validate_integer_id function"""
def test_valid_integer_id(self):
assert validate_integer_id("123") == "123"
assert validate_integer_id(" 456 ") == "456"
def test_invalid_integer_id_raises_error(self):
with pytest.raises(ValueError, match="ID must be a valid integer"):
validate_integer_id("abc")
def test_empty_integer_id_raises_error(self):
with pytest.raises(ValueError, match="ID must be a valid integer"):
validate_integer_id("")

File diff suppressed because it is too large Load Diff