diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py new file mode 100644 index 0000000000..acb43d4036 --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py @@ -0,0 +1,326 @@ +import time +import uuid +from datetime import datetime +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode + +from core.ops.aliyun_trace.data_exporter.traceclient import ( + INVALID_SPAN_ID, + SpanBuilder, + TraceClient, + build_endpoint, + convert_datetime_to_nanoseconds, + convert_string_to_id, + convert_to_span_id, + convert_to_trace_id, + create_link, + generate_span_id, +) +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData + + +@pytest.fixture +def trace_client_factory(): + """Factory fixture for creating TraceClient instances with automatic cleanup.""" + clients_to_shutdown = [] + + def _factory(**kwargs): + client = TraceClient(**kwargs) + clients_to_shutdown.append(client) + return client + + yield _factory + + # Cleanup: shutdown all created clients + for client in clients_to_shutdown: + client.shutdown() + + +class TestTraceClient: + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): + mock_gethostname.return_value = "test-host" + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + + assert client.endpoint == "http://test-endpoint" + assert client.max_queue_size == 1000 + assert client.schedule_delay_sec == 5 + assert client.done is False + assert client.worker_thread.is_alive() + + client.shutdown() + assert client.done is True + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_export(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + spans = [MagicMock(spec=ReadableSpan)] + client.export(spans) + mock_exporter.export.assert_called_once_with(spans) + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): + mock_response = MagicMock() + mock_response.status_code = 405 + mock_head.return_value = mock_response + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.api_check() is True + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_head.return_value = mock_response + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.api_check() is False + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): + mock_head.side_effect = httpx.RequestError("Connection error") + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): + client.api_check() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_get_project_url(self, mock_exporter_class, trace_client_factory): + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_add_span(self, mock_exporter_class, trace_client_factory): + client = trace_client_factory( + service_name="test-service", + endpoint="http://test-endpoint", + max_export_batch_size=2, + ) + + # Test add None + client.add_span(None) + assert len(client.queue) == 0 + + # Test add valid SpanData + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + + mock_span = MagicMock(spec=ReadableSpan) + client.span_builder.build_span = MagicMock(return_value=mock_span) + + with patch.object(client.condition, "notify") as mock_notify: + client.add_span(span_data) + assert len(client.queue) == 1 + mock_notify.assert_not_called() + + client.add_span(span_data) + assert len(client.queue) == 2 + mock_notify.assert_called_once() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + mock_span = MagicMock(spec=ReadableSpan) + client.span_builder.build_span = MagicMock(return_value=mock_span) + + client.add_span(span_data) + assert len(client.queue) == 1 + + client.add_span(span_data) + assert len(client.queue) == 1 + mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_export_batch_error(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + mock_exporter.export.side_effect = Exception("Export failed") + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + mock_span = MagicMock(spec=ReadableSpan) + client.queue.append(mock_span) + + with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + client._export_batch() + mock_logger.warning.assert_called() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_worker_loop(self, mock_exporter_class, trace_client_factory): + # We need to test the wait timeout in _worker + # But _worker runs in a thread. Let's mock condition.wait. + client = trace_client_factory( + service_name="test-service", + endpoint="http://test-endpoint", + schedule_delay_sec=0.1, + ) + + with patch.object(client.condition, "wait") as mock_wait: + # Let it run for a bit then shut down + time.sleep(0.2) + client.shutdown() + # mock_wait might have been called + assert mock_wait.called or client.done + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + + mock_span = MagicMock(spec=ReadableSpan) + client.queue.append(mock_span) + + client.shutdown() + # Should have called export twice (once in worker/export_batch, once in shutdown) + # or at least once if worker was waiting + assert mock_exporter.export.called + assert mock_exporter.shutdown.called + + +class TestSpanBuilder: + def test_build_span(self): + resource = MagicMock() + builder = SpanBuilder(resource) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=789, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + attributes={"attr1": "val1"}, + events=[], + links=[], + ) + + span = builder.build_span(span_data) + assert isinstance(span, ReadableSpan) + assert span.name == "test-span" + assert span.context.trace_id == 123 + assert span.context.span_id == 456 + assert span.parent.span_id == 789 + assert span.resource == resource + assert span.attributes == {"attr1": "val1"} + + def test_build_span_no_parent(self): + resource = MagicMock() + builder = SpanBuilder(resource) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + + span = builder.build_span(span_data) + assert span.parent is None + + +def test_create_link(): + trace_id_str = "0123456789abcdef0123456789abcdef" + link = create_link(trace_id_str) + assert link.context.trace_id == int(trace_id_str, 16) + assert link.context.span_id == INVALID_SPAN_ID + + with pytest.raises(ValueError, match="Invalid trace ID format"): + create_link("invalid-hex") + + +def test_generate_span_id(): + # Test normal generation + span_id = generate_span_id() + assert isinstance(span_id, int) + assert span_id != INVALID_SPAN_ID + + # Test retry loop + with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + mock_rand.side_effect = [INVALID_SPAN_ID, 999] + span_id = generate_span_id() + assert span_id == 999 + assert mock_rand.call_count == 2 + + +def test_convert_to_trace_id(): + uid = str(uuid.uuid4()) + trace_id = convert_to_trace_id(uid) + assert trace_id == uuid.UUID(uid).int + + with pytest.raises(ValueError, match="UUID cannot be None"): + convert_to_trace_id(None) + + with pytest.raises(ValueError, match="Invalid UUID input"): + convert_to_trace_id("not-a-uuid") + + +def test_convert_string_to_id(): + assert convert_string_to_id("test") > 0 + # Test with None string + with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + mock_gen.return_value = 12345 + assert convert_string_to_id(None) == 12345 + + +def test_convert_to_span_id(): + uid = str(uuid.uuid4()) + span_id = convert_to_span_id(uid, "test-type") + assert isinstance(span_id, int) + + with pytest.raises(ValueError, match="UUID cannot be None"): + convert_to_span_id(None, "test") + + with pytest.raises(ValueError, match="Invalid UUID input"): + convert_to_span_id("not-a-uuid", "test") + + +def test_convert_datetime_to_nanoseconds(): + dt = datetime(2023, 1, 1, 12, 0, 0) + ns = convert_datetime_to_nanoseconds(dt) + assert ns == int(dt.timestamp() * 1e9) + assert convert_datetime_to_nanoseconds(None) is None + + +def test_build_endpoint(): + license_key = "abc" + + # CMS 2.0 endpoint + url1 = "https://log.aliyuncs.com" + assert build_endpoint(url1, license_key) == "https://log.aliyuncs.com/adapt_abc/api/v1/traces" + + # XTrace endpoint + url2 = "https://example.com" + assert build_endpoint(url2, license_key) == "https://example.com/adapt_abc/api/otlp/traces" diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py new file mode 100644 index 0000000000..2fcb927e0c --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -0,0 +1,88 @@ +import pytest +from opentelemetry import trace as trace_api +from opentelemetry.sdk.trace import Event +from opentelemetry.trace import SpanKind, Status, StatusCode +from pydantic import ValidationError + +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata + + +class TestTraceMetadata: + def test_trace_metadata_init(self): + links = [trace_api.Link(context=trace_api.SpanContext(0, 0, False))] + metadata = TraceMetadata( + trace_id=123, workflow_span_id=456, session_id="session_1", user_id="user_1", links=links + ) + assert metadata.trace_id == 123 + assert metadata.workflow_span_id == 456 + assert metadata.session_id == "session_1" + assert metadata.user_id == "user_1" + assert metadata.links == links + + +class TestSpanData: + def test_span_data_init_required_fields(self): + span_data = SpanData(trace_id=123, span_id=456, name="test_span", start_time=1000, end_time=2000) + assert span_data.trace_id == 123 + assert span_data.span_id == 456 + assert span_data.name == "test_span" + assert span_data.start_time == 1000 + assert span_data.end_time == 2000 + + # Check defaults + assert span_data.parent_span_id is None + assert span_data.attributes == {} + assert span_data.events == [] + assert span_data.links == [] + assert span_data.status.status_code == StatusCode.UNSET + assert span_data.span_kind == SpanKind.INTERNAL + + def test_span_data_with_optional_fields(self): + event = Event(name="event_1", timestamp=1500) + link = trace_api.Link(context=trace_api.SpanContext(0, 0, False)) + status = Status(StatusCode.OK) + + span_data = SpanData( + trace_id=123, + parent_span_id=111, + span_id=456, + name="test_span", + attributes={"key": "value"}, + events=[event], + links=[link], + status=status, + start_time=1000, + end_time=2000, + span_kind=SpanKind.SERVER, + ) + + assert span_data.parent_span_id == 111 + assert span_data.attributes == {"key": "value"} + assert span_data.events == [event] + assert span_data.links == [link] + assert span_data.status.status_code == status.status_code + assert span_data.span_kind == SpanKind.SERVER + + def test_span_data_missing_required_fields(self): + with pytest.raises(ValidationError): + SpanData( + trace_id=123, + # span_id missing + name="test_span", + start_time=1000, + end_time=2000, + ) + + def test_span_data_arbitrary_types_allowed(self): + # opentelemetry.trace.Status and Event are "arbitrary types" for Pydantic + # This test ensures they are accepted thanks to model_config + status = Status(StatusCode.ERROR, description="error occurred") + event = Event(name="exception", timestamp=1234, attributes={"exception.type": "ValueError"}) + + span_data = SpanData( + trace_id=123, span_id=456, name="test_span", status=status, events=[event], start_time=1000, end_time=2000 + ) + + assert span_data.status.status_code == status.status_code + assert span_data.status.description == status.description + assert span_data.events == [event] diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py new file mode 100644 index 0000000000..3961555b9a --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py @@ -0,0 +1,68 @@ +from core.ops.aliyun_trace.entities.semconv import ( + ACS_ARMS_SERVICE_FEATURE, + GEN_AI_COMPLETION, + GEN_AI_FRAMEWORK, + GEN_AI_INPUT_MESSAGE, + GEN_AI_OUTPUT_MESSAGE, + GEN_AI_PROMPT, + GEN_AI_PROVIDER_NAME, + GEN_AI_REQUEST_MODEL, + GEN_AI_RESPONSE_FINISH_REASON, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_USAGE_INPUT_TOKENS, + GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_TOTAL_TOKENS, + GEN_AI_USER_ID, + GEN_AI_USER_NAME, + INPUT_VALUE, + OUTPUT_VALUE, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) + + +def test_constants(): + assert ACS_ARMS_SERVICE_FEATURE == "acs.arms.service.feature" + assert GEN_AI_SESSION_ID == "gen_ai.session.id" + assert GEN_AI_USER_ID == "gen_ai.user.id" + assert GEN_AI_USER_NAME == "gen_ai.user.name" + assert GEN_AI_SPAN_KIND == "gen_ai.span.kind" + assert GEN_AI_FRAMEWORK == "gen_ai.framework" + assert INPUT_VALUE == "input.value" + assert OUTPUT_VALUE == "output.value" + assert RETRIEVAL_QUERY == "retrieval.query" + assert RETRIEVAL_DOCUMENT == "retrieval.document" + assert GEN_AI_REQUEST_MODEL == "gen_ai.request.model" + assert GEN_AI_PROVIDER_NAME == "gen_ai.provider.name" + assert GEN_AI_USAGE_INPUT_TOKENS == "gen_ai.usage.input_tokens" + assert GEN_AI_USAGE_OUTPUT_TOKENS == "gen_ai.usage.output_tokens" + assert GEN_AI_USAGE_TOTAL_TOKENS == "gen_ai.usage.total_tokens" + assert GEN_AI_PROMPT == "gen_ai.prompt" + assert GEN_AI_COMPLETION == "gen_ai.completion" + assert GEN_AI_RESPONSE_FINISH_REASON == "gen_ai.response.finish_reason" + assert GEN_AI_INPUT_MESSAGE == "gen_ai.input.messages" + assert GEN_AI_OUTPUT_MESSAGE == "gen_ai.output.messages" + assert TOOL_NAME == "tool.name" + assert TOOL_DESCRIPTION == "tool.description" + assert TOOL_PARAMETERS == "tool.parameters" + + +def test_gen_ai_span_kind_enum(): + assert GenAISpanKind.CHAIN == "CHAIN" + assert GenAISpanKind.RETRIEVER == "RETRIEVER" + assert GenAISpanKind.RERANKER == "RERANKER" + assert GenAISpanKind.LLM == "LLM" + assert GenAISpanKind.EMBEDDING == "EMBEDDING" + assert GenAISpanKind.TOOL == "TOOL" + assert GenAISpanKind.AGENT == "AGENT" + assert GenAISpanKind.TASK == "TASK" + + # Verify iteration works (covers the class definition) + kinds = list(GenAISpanKind) + assert len(kinds) == 8 + assert "LLM" in kinds diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py new file mode 100644 index 0000000000..fac0597f5a --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + +import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module +from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_COMPLETION, + GEN_AI_INPUT_MESSAGE, + GEN_AI_OUTPUT_MESSAGE, + GEN_AI_PROMPT, + GEN_AI_REQUEST_MODEL, + GEN_AI_RESPONSE_FINISH_REASON, + GEN_AI_USAGE_TOTAL_TOKENS, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) +from core.ops.entities.config_entity import AliyunConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey + + +class RecordingTraceClient: + def __init__(self, service_name: str = "service", endpoint: str = "endpoint"): + self.service_name = service_name + self.endpoint = endpoint + self.added_spans: list[object] = [] + + def add_span(self, span) -> None: + self.added_spans.append(span) + + def api_check(self) -> bool: + return True + + def get_project_url(self) -> str: + return "project-url" + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_link(trace_id: int = 1, span_id: int = 2) -> Link: + context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags.SAMPLED, + ) + return Link(context) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "workflow-id", + "tenant_id": "tenant-id", + "workflow_run_id": "00000000-0000-0000-0000-000000000001", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"sys.query": "hello"}, + "workflow_run_outputs": {"answer": "world"}, + "workflow_run_version": "v1", + "total_tokens": 1, + "file_list": [], + "query": "hello", + "metadata": {"conversation_id": "conv", "user_id": "u", "app_id": "app"}, + "message_id": None, + "start_time": _dt(), + "end_time": _dt(), + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + defaults = { + "conversation_model": "chat", + "message_tokens": 1, + "answer_tokens": 2, + "total_tokens": 3, + "conversation_mode": "chat", + "metadata": {"conversation_id": "conv", "ls_model_name": "m", "ls_provider": "p"}, + "message_id": "00000000-0000-0000-0000-000000000002", + "message_data": SimpleNamespace(from_account_id="acc", from_end_user_id=None), + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults = { + "metadata": {"conversation_id": "conv", "user_id": "u"}, + "message_id": "00000000-0000-0000-0000-000000000003", + "message_data": SimpleNamespace(), + "inputs": "q", + "documents": [SimpleNamespace()], + "start_time": _dt(), + "end_time": _dt(), + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "out", + "tool_config": {"desc": "d"}, + "tool_parameters": {}, + "time_cost": 0.1, + "metadata": {"conversation_id": "conv", "user_id": "u"}, + "message_id": "00000000-0000-0000-0000-000000000004", + "message_data": SimpleNamespace(), + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 1, + "metadata": {"conversation_id": "conv", "user_id": "u", "ls_model_name": "m", "ls_provider": "p"}, + "message_id": "00000000-0000-0000-0000-000000000005", + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +@pytest.fixture +def trace_instance(monkeypatch: pytest.MonkeyPatch) -> AliyunDataTrace: + monkeypatch.setattr(aliyun_trace_module, "build_endpoint", lambda base_url, license_key: "built-endpoint") + monkeypatch.setattr(aliyun_trace_module, "TraceClient", RecordingTraceClient) + # Mock get_service_account_with_tenant to avoid DB errors + monkeypatch.setattr(AliyunDataTrace, "get_service_account_with_tenant", lambda self, app_id: MagicMock()) + + config = AliyunConfig(app_name="app", license_key="k", endpoint="https://example.com") + trace = AliyunDataTrace(config) + return trace + + +def test_init_builds_endpoint_and_client(monkeypatch: pytest.MonkeyPatch): + build_endpoint = MagicMock(return_value="built") + trace_client_cls = MagicMock() + monkeypatch.setattr(aliyun_trace_module, "build_endpoint", build_endpoint) + monkeypatch.setattr(aliyun_trace_module, "TraceClient", trace_client_cls) + + config = AliyunConfig(app_name="my-app", license_key="license", endpoint="https://example.com") + trace = AliyunDataTrace(config) + + build_endpoint.assert_called_once_with("https://example.com", "license") + trace_client_cls.assert_called_once_with(service_name="my-app", endpoint="built") + assert trace.trace_config == config + + +def test_trace_dispatches_to_correct_methods(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + workflow_trace = MagicMock() + message_trace = MagicMock() + suggested_question_trace = MagicMock() + dataset_retrieval_trace = MagicMock() + tool_trace = MagicMock() + monkeypatch.setattr(trace_instance, "workflow_trace", workflow_trace) + monkeypatch.setattr(trace_instance, "message_trace", message_trace) + monkeypatch.setattr(trace_instance, "suggested_question_trace", suggested_question_trace) + monkeypatch.setattr(trace_instance, "dataset_retrieval_trace", dataset_retrieval_trace) + monkeypatch.setattr(trace_instance, "tool_trace", tool_trace) + + trace_instance.trace(_make_workflow_trace_info()) + workflow_trace.assert_called_once() + + trace_instance.trace(_make_message_trace_info()) + message_trace.assert_called_once() + + trace_instance.trace(_make_suggested_question_trace_info()) + suggested_question_trace.assert_called_once() + + trace_instance.trace(_make_dataset_retrieval_trace_info()) + dataset_retrieval_trace.assert_called_once() + + trace_instance.trace(_make_tool_trace_info()) + tool_trace.assert_called_once() + + # Branches that do nothing but should be covered + trace_instance.trace(ModerationTraceInfo(flagged=False, action="allow", preset_response="", query="", metadata={})) + trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={})) + + +def test_api_check_delegates(trace_instance: AliyunDataTrace): + trace_instance.trace_client.api_check = MagicMock(return_value=False) + assert trace_instance.api_check() is False + + +def test_get_project_url_success(trace_instance: AliyunDataTrace): + assert trace_instance.get_project_url() == "project-url" + + +def test_get_project_url_error(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(trace_instance.trace_client, "get_project_url", MagicMock(side_effect=Exception("boom"))) + logger_mock = MagicMock() + monkeypatch.setattr(aliyun_trace_module, "logger", logger_mock) + + with pytest.raises(ValueError, match=r"Aliyun get project url failed: boom"): + trace_instance.get_project_url() + logger_mock.info.assert_called() + + +def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 111) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"workflow": 222}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + + add_workflow_span = MagicMock() + get_workflow_node_executions = MagicMock(return_value=[MagicMock(), MagicMock()]) + build_workflow_node_span = MagicMock(side_effect=["span-1", "span-2"]) + monkeypatch.setattr(trace_instance, "add_workflow_span", add_workflow_span) + monkeypatch.setattr(trace_instance, "get_workflow_node_executions", get_workflow_node_executions) + monkeypatch.setattr(trace_instance, "build_workflow_node_span", build_workflow_node_span) + + trace_info = _make_workflow_trace_info( + trace_id="abcd", metadata={"conversation_id": "c", "user_id": "u", "app_id": "app"} + ) + trace_instance.workflow_trace(trace_info) + + add_workflow_span.assert_called_once() + passed_trace_metadata = add_workflow_span.call_args.args[1] + assert passed_trace_metadata.trace_id == 111 + assert passed_trace_metadata.workflow_span_id == 222 + assert passed_trace_metadata.session_id == "c" + assert passed_trace_metadata.user_id == "u" + assert passed_trace_metadata.links == [] + + assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] + + +def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_message_trace_info(message_data=None) + trace_instance.message_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, + "convert_to_span_id", + lambda _, span_type: {"message": 20, "llm": 30}.get(span_type, 0), + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "get_user_id_from_message_data", lambda _: "user") + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_info = _make_message_trace_info( + metadata={"conversation_id": "conv", "ls_model_name": "model", "ls_provider": "provider"}, + message_tokens=7, + answer_tokens=11, + total_tokens=18, + outputs="completion", + ) + trace_instance.message_trace(trace_info) + + assert len(trace_instance.trace_client.added_spans) == 2 + message_span, llm_span = trace_instance.trace_client.added_spans + + assert message_span.name == "message" + assert message_span.trace_id == 10 + assert message_span.parent_span_id is None + assert message_span.span_id == 20 + assert message_span.span_kind == SpanKind.SERVER + assert message_span.status == status + assert message_span.attributes["gen_ai.span.kind"] == GenAISpanKind.CHAIN + + assert llm_span.name == "llm" + assert llm_span.parent_span_id == 20 + assert llm_span.span_id == 30 + assert llm_span.status == status + assert llm_span.attributes[GEN_AI_REQUEST_MODEL] == "model" + assert llm_span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "18" + + +def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.dataset_retrieval_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 1) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 2}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 3) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) + + trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "dataset_retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "query" + assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' + + +def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_tool_trace_info(message_data=None) + trace_instance.tool_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 30) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_instance.tool_trace( + _make_tool_trace_info( + tool_name="my-tool", + tool_inputs={"a": 1}, + tool_config={"description": "x"}, + inputs={"i": 1}, + ) + ) + + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "my-tool" + assert span.status == status + assert span.attributes[TOOL_NAME] == "my-tool" + assert span.attributes[TOOL_DESCRIPTION] == '{"description": "x"}' + + +def test_get_workflow_node_executions_requires_app_id(trace_instance: AliyunDataTrace): + trace_info = _make_workflow_trace_info(metadata={"conversation_id": "c"}) + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.get_workflow_node_executions(trace_info) + + +def test_get_workflow_node_executions_builds_repo_and_fetches( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch +): + trace_info = _make_workflow_trace_info(metadata={"app_id": "app", "conversation_id": "c", "user_id": "u"}) + + account = object() + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", MagicMock(return_value=account)) + monkeypatch.setattr(aliyun_trace_module, "sessionmaker", MagicMock()) + monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) + + repo = MagicMock() + repo.get_by_workflow_run.return_value = ["node1"] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) + + result = trace_instance.get_workflow_node_executions(trace_info) + assert result == ["node1"] + repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + + +def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) + + node_execution.node_type = NodeType.LLM + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "llm" + + +def test_build_workflow_node_span_routes_knowledge_retrieval_type( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch +): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) + + node_execution.node_type = NodeType.KNOWLEDGE_RETRIEVAL + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "retrieval" + + +def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) + + node_execution.node_type = NodeType.TOOL + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "tool" + + +def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) + + node_execution.node_type = NodeType.CODE + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "task" + + +def test_build_workflow_node_span_handles_errors( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) + node_execution.node_type = NodeType.CODE + + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) is None + assert "Error occurred in build_workflow_node_span" in caplog.text + + +def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "title" + node_execution.inputs = {"a": 1} + node_execution.outputs = {"b": 2} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_task_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.trace_id == 1 + assert span.span_id == 9 + assert span.status.status_code == StatusCode.OK + assert span.attributes["gen_ai.span.kind"] == GenAISpanKind.TASK + + +def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "my-tool" + node_execution.inputs = {"a": 1} + node_execution.outputs = {"b": 2} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"k": "v"}} + + span = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[TOOL_NAME] == "my-tool" + assert span.attributes[TOOL_DESCRIPTION] == '{"k": "v"}' + assert span.attributes[TOOL_PARAMETERS] == '{"a": 1}' + assert span.status.status_code == StatusCode.OK + + # Cover metadata is None and inputs is None + node_execution.metadata = None + node_execution.inputs = None + span2 = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[TOOL_DESCRIPTION] == "{}" + assert span2.attributes[TOOL_PARAMETERS] == "{}" + + +def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + monkeypatch.setattr( + aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] + ) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "retrieval" + node_execution.inputs = {"query": "q"} + node_execution.outputs = {"result": [{"doc": "d"}]} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[RETRIEVAL_QUERY] == "q" + assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"formatted": true}]' + + # Cover empty inputs/outputs + node_execution.inputs = None + node_execution.outputs = None + span2 = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[RETRIEVAL_QUERY] == "" + assert span2.attributes[RETRIEVAL_DOCUMENT] == "[]" + + +def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") + monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "llm" + node_execution.process_data = { + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + "prompts": ["p"], + "model_name": "m", + "model_provider": "p1", + } + node_execution.outputs = {"text": "t", "finish_reason": "stop"} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "3" + assert span.attributes[GEN_AI_REQUEST_MODEL] == "m" + assert span.attributes[GEN_AI_PROMPT] == '["p"]' + assert span.attributes[GEN_AI_COMPLETION] == "t" + assert span.attributes[GEN_AI_RESPONSE_FINISH_REASON] == "stop" + assert span.attributes[GEN_AI_INPUT_MESSAGE] == "in" + assert span.attributes[GEN_AI_OUTPUT_MESSAGE] == "out" + + # Cover usage from outputs if not in process_data + node_execution.process_data = {"prompts": []} + node_execution.outputs = {"usage": {"total_tokens": 10}, "text": ""} + span2 = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "10" + + +def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + + # CASE 1: With message_id + trace_info = _make_workflow_trace_info( + message_id="msg-1", workflow_run_inputs={"sys.query": "hi"}, workflow_run_outputs={"ans": "ok"} + ) + trace_instance.add_workflow_span(trace_info, trace_metadata) + + assert len(trace_instance.trace_client.added_spans) == 2 + message_span = trace_instance.trace_client.added_spans[0] + workflow_span = trace_instance.trace_client.added_spans[1] + + assert message_span.name == "message" + assert message_span.span_kind == SpanKind.SERVER + assert message_span.parent_span_id is None + + assert workflow_span.name == "workflow" + assert workflow_span.span_kind == SpanKind.INTERNAL + assert workflow_span.parent_span_id == 20 + + trace_instance.trace_client.added_spans.clear() + + # CASE 2: Without message_id + trace_info_no_msg = _make_workflow_trace_info(message_id=None) + trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "workflow" + assert span.span_kind == SpanKind.SERVER + assert span.parent_span_id is None + + +def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, + "convert_to_span_id", + lambda _, span_type: {"message": 20, "suggested_question": 21}.get(span_type, 0), + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) + trace_instance.suggested_question_trace(trace_info) + + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "suggested_question" + assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py new file mode 100644 index 0000000000..763fc90710 --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -0,0 +1,275 @@ +import json +from unittest.mock import MagicMock + +from opentelemetry.trace import Link, StatusCode + +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_FRAMEWORK, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_USER_ID, + INPUT_VALUE, + OUTPUT_VALUE, +) +from core.ops.aliyun_trace.utils import ( + create_common_span_attributes, + create_links_from_trace_id, + create_status_from_error, + extract_retrieval_documents, + format_input_messages, + format_output_messages, + format_retrieval_documents, + get_user_id_from_message_data, + get_workflow_node_status, + serialize_json_data, +) +from core.rag.models.document import Document +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionStatus +from models import EndUser + + +def test_get_user_id_from_message_data_no_end_user(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = None + + assert get_user_id_from_message_data(message_data) == "account_id" + + +def test_get_user_id_from_message_data_with_end_user(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = "end_user_id" + + end_user_data = MagicMock(spec=EndUser) + end_user_data.session_id = "session_id" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = end_user_data + + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + from core.ops.aliyun_trace.utils import db + + monkeypatch.setattr(db, "session", mock_session) + + assert get_user_id_from_message_data(message_data) == "session_id" + + +def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = "end_user_id" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = None + + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + from core.ops.aliyun_trace.utils import db + + monkeypatch.setattr(db, "session", mock_session) + + assert get_user_id_from_message_data(message_data) == "account_id" + + +def test_create_status_from_error(): + # Case OK + status_ok = create_status_from_error(None) + assert status_ok.status_code == StatusCode.OK + + # Case Error + status_err = create_status_from_error("some error") + assert status_err.status_code == StatusCode.ERROR + assert status_err.description == "some error" + + +def test_get_workflow_node_status(): + node_execution = MagicMock(spec=WorkflowNodeExecution) + + # SUCCEEDED + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.OK + + # FAILED + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = "node fail" + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.ERROR + assert status.description == "node fail" + + # EXCEPTION + node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION + node_execution.error = "node exception" + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.ERROR + assert status.description == "node exception" + + # UNSET/OTHER + node_execution.status = WorkflowNodeExecutionStatus.RUNNING + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.UNSET + + +def test_create_links_from_trace_id(monkeypatch): + # Mock create_link + mock_link = MagicMock(spec=Link) + import core.ops.aliyun_trace.data_exporter.traceclient + + monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + + # Trace ID None + assert create_links_from_trace_id(None) == [] + + # Trace ID Present + links = create_links_from_trace_id("trace_id") + assert len(links) == 1 + assert links[0] == mock_link + + +def test_extract_retrieval_documents(): + doc1 = MagicMock(spec=Document) + doc1.page_content = "content1" + doc1.metadata = {"dataset_id": "ds1", "doc_id": "di1", "document_id": "dd1", "score": 0.9} + + doc2 = MagicMock(spec=Document) + doc2.page_content = "content2" + doc2.metadata = {"dataset_id": "ds2"} # Missing some keys + + documents = [doc1, doc2] + extracted = extract_retrieval_documents(documents) + + assert len(extracted) == 2 + assert extracted[0]["content"] == "content1" + assert extracted[0]["metadata"]["dataset_id"] == "ds1" + assert extracted[0]["score"] == 0.9 + + assert extracted[1]["content"] == "content2" + assert extracted[1]["metadata"]["dataset_id"] == "ds2" + assert extracted[1]["metadata"]["doc_id"] is None + assert extracted[1]["score"] is None + + +def test_serialize_json_data(): + data = {"a": 1} + # Test ensure_ascii default (False) + assert serialize_json_data(data) == json.dumps(data, ensure_ascii=False) + # Test ensure_ascii True + assert serialize_json_data(data, ensure_ascii=True) == json.dumps(data, ensure_ascii=True) + + +def test_create_common_span_attributes(): + attrs = create_common_span_attributes( + session_id="s1", user_id="u1", span_kind="kind1", framework="fw1", inputs="in1", outputs="out1" + ) + assert attrs[GEN_AI_SESSION_ID] == "s1" + assert attrs[GEN_AI_USER_ID] == "u1" + assert attrs[GEN_AI_SPAN_KIND] == "kind1" + assert attrs[GEN_AI_FRAMEWORK] == "fw1" + assert attrs[INPUT_VALUE] == "in1" + assert attrs[OUTPUT_VALUE] == "out1" + + +def test_format_retrieval_documents(): + # Not a list + assert format_retrieval_documents("not a list") == [] + + # Valid list + docs = [ + {"metadata": {"score": 0.8, "document_id": "doc1", "source": "src1"}, "content": "c1", "title": "t1"}, + { + "metadata": {"_source": "src2", "doc_metadata": {"extra": "val"}}, + "content": "c2", + # Missing title + }, + "not a dict", # Should be skipped + ] + formatted = format_retrieval_documents(docs) + + assert len(formatted) == 2 + assert formatted[0]["document"]["content"] == "c1" + assert formatted[0]["document"]["metadata"]["title"] == "t1" + assert formatted[0]["document"]["metadata"]["source"] == "src1" + assert formatted[0]["document"]["score"] == 0.8 + assert formatted[0]["document"]["id"] == "doc1" + + assert formatted[1]["document"]["content"] == "c2" + assert formatted[1]["document"]["metadata"]["source"] == "src2" + assert formatted[1]["document"]["metadata"]["extra"] == "val" + assert "title" not in formatted[1]["document"]["metadata"] + assert formatted[1]["document"]["score"] == 0.0 # Default + + # Exception handling + # We can trigger an exception by passing something that causes an error in the loop logic, + # but the try/except covers the whole function. + # Passing a list that contains something that throws when calling .get() - though dicts won't. + # Let's mock a dict that raises on get. + class BadDict: + def get(self, *args, **kwargs): + raise Exception("boom") + + assert format_retrieval_documents([BadDict()]) == [] + + +def test_format_input_messages(): + # Not a dict + assert format_input_messages(None) == serialize_json_data([]) + + # No prompts + assert format_input_messages({}) == serialize_json_data([]) + + # Valid prompts + process_data = { + "prompts": [ + {"role": "user", "text": "hello"}, + {"role": "assistant", "text": "hi"}, + {"role": "system", "text": "be helpful"}, + {"role": "tool", "text": "result"}, + {"role": "invalid", "text": "skip me"}, + "not a dict", + {"role": "user", "text": ""}, # Empty text, should be skipped? Code says `if text: message = ...` + ] + } + result = format_input_messages(process_data) + result_list = json.loads(result) + + assert len(result_list) == 4 + assert result_list[0]["role"] == "user" + assert result_list[0]["parts"][0]["content"] == "hello" + assert result_list[1]["role"] == "assistant" + assert result_list[2]["role"] == "system" + assert result_list[3]["role"] == "tool" + + # Exception path + assert format_input_messages({"prompts": [None]}) == serialize_json_data([]) + + +def test_format_output_messages(): + # Not a dict + assert format_output_messages(None) == serialize_json_data([]) + + # No text + assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) + + # Valid + outputs = {"text": "done", "finish_reason": "length"} + result = format_output_messages(outputs) + result_list = json.loads(result) + assert len(result_list) == 1 + assert result_list[0]["role"] == "assistant" + assert result_list[0]["parts"][0]["content"] == "done" + assert result_list[0]["finish_reason"] == "length" + + # Invalid finish reason + outputs2 = {"text": "done", "finish_reason": "unknown"} + result2 = format_output_messages(outputs2) + result_list2 = json.loads(result2) + assert result_list2[0]["finish_reason"] == "stop" + + # Exception path + # Trigger exception in serialize_json_data by passing non-serializable + assert format_output_messages({"text": MagicMock()}) == serialize_json_data([]) diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py new file mode 100644 index 0000000000..1cee2f5b68 --- /dev/null +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -0,0 +1,398 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + +from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( + ArizePhoenixDataTrace, + datetime_to_nanos, + error_to_string, + safe_json_dumps, + set_span_status, + setup_tracer, + wrap_span_metadata, +) +from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) + +# --- Helpers --- + + +def _dt(): + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_workflow_info(**kwargs): + defaults = { + "workflow_id": "w1", + "tenant_id": "t1", + "workflow_run_id": "r1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"in": "val"}, + "workflow_run_outputs": {"out": "val"}, + "workflow_run_version": "1.0", + "total_tokens": 10, + "file_list": ["f1"], + "query": "hi", + "metadata": {"app_id": "app1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(kwargs) + return WorkflowTraceInfo(**defaults) + + +def _make_message_info(**kwargs): + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 5, + "total_tokens": 10, + "conversation_mode": "chat", + "metadata": {"app_id": "app1"}, + "inputs": {"in": "val"}, + "outputs": "val", + "start_time": _dt(), + "end_time": _dt(), + "message_id": "m1", + } + defaults.update(kwargs) + return MessageTraceInfo(**defaults) + + +# --- Utility Function Tests --- + + +def test_datetime_to_nanos(): + dt = _dt() + expected = int(dt.timestamp() * 1_000_000_000) + assert datetime_to_nanos(dt) == expected + + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + mock_now = MagicMock() + mock_now.timestamp.return_value = 1704110400.0 + mock_dt.now.return_value = mock_now + assert datetime_to_nanos(None) == 1704110400000000000 + + +def test_error_to_string(): + try: + raise ValueError("boom") + except ValueError as e: + err = e + + res = error_to_string(err) + assert "ValueError: boom" in res + assert "traceback" in res.lower() or "line" in res.lower() + + assert error_to_string("str error") == "str error" + assert error_to_string(None) == "Empty Stack Trace" + + +def test_set_span_status(): + span = MagicMock() + # OK + set_span_status(span, None) + span.set_status.assert_called() + assert span.set_status.call_args[0][0].status_code == StatusCode.OK + + # Error Exception + span.reset_mock() + set_span_status(span, ValueError("fail")) + assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR + span.record_exception.assert_called() + + # Error String + span.reset_mock() + set_span_status(span, "fail-str") + assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR + span.add_event.assert_called() + + # repr branch + class SilentError: + def __str__(self): + return "" + + def __repr__(self): + return "SilentErrorRepr" + + span.reset_mock() + set_span_status(span, SilentError()) + assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" + + +def test_safe_json_dumps(): + assert safe_json_dumps({"a": _dt()}) == '{"a": "2024-01-01 00:00:00+00:00"}' + + +def test_wrap_span_metadata(): + res = wrap_span_metadata({"a": 1}, b=2) + assert res == {"a": 1, "b": 2, "created_from": "Dify"} + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +def test_setup_tracer_arize(mock_provider, mock_exporter): + config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") + setup_tracer(config) + mock_exporter.assert_called_once() + assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +def test_setup_tracer_phoenix(mock_provider, mock_exporter): + config = PhoenixConfig(endpoint="http://p.com", project="p") + setup_tracer(config) + mock_exporter.assert_called_once() + assert mock_exporter.call_args[1]["endpoint"] == "http://p.com/v1/traces" + + +def test_setup_tracer_exception(): + config = ArizeConfig(endpoint="http://a.com", project="p") + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with pytest.raises(Exception, match="boom"): + setup_tracer(config) + + +# --- ArizePhoenixDataTrace Class Tests --- + + +@pytest.fixture +def trace_instance(): + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + mock_tracer = MagicMock(spec=Tracer) + mock_processor = MagicMock() + mock_setup.return_value = (mock_tracer, mock_processor) + config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") + return ArizePhoenixDataTrace(config) + + +def test_trace_dispatch(trace_instance): + with ( + patch.object(trace_instance, "workflow_trace") as m1, + patch.object(trace_instance, "message_trace") as m2, + patch.object(trace_instance, "moderation_trace") as m3, + patch.object(trace_instance, "suggested_question_trace") as m4, + patch.object(trace_instance, "dataset_retrieval_trace") as m5, + patch.object(trace_instance, "tool_trace") as m6, + patch.object(trace_instance, "generate_name_trace") as m7, + ): + trace_instance.trace(_make_workflow_info()) + m1.assert_called() + + trace_instance.trace(_make_message_info()) + m2.assert_called() + + trace_instance.trace(ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={})) + m3.assert_called() + + trace_instance.trace(SuggestedQuestionTraceInfo(suggested_question=[], total_tokens=0, level="i", metadata={})) + m4.assert_called() + + trace_instance.trace(DatasetRetrievalTraceInfo(metadata={})) + m5.assert_called() + + trace_instance.trace( + ToolTraceInfo( + tool_name="t", + tool_inputs={}, + tool_outputs="o", + metadata={}, + tool_config={}, + time_cost=1, + tool_parameters={}, + ) + ) + m6.assert_called() + + trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={})) + m7.assert_called() + + +def test_trace_exception(trace_instance): + with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("fail")): + with pytest.raises(RuntimeError): + trace_instance.trace(_make_workflow_info()) + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + node1 = MagicMock() + node1.node_type = "llm" + node1.status = "succeeded" + node1.inputs = {"q": "hi"} + node1.outputs = {"a": "bye", "usage": {"total_tokens": 5}} + node1.created_at = _dt() + node1.elapsed_time = 1.0 + node1.process_data = { + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + } + node1.metadata = {"k": "v"} + node1.title = "title" + node1.id = "n1" + node1.error = None + + repo.get_by_workflow_run.return_value = [node1] + + with patch.object(trace_instance, "get_service_account_with_tenant"): + trace_instance.workflow_trace(info) + + assert trace_instance.tracer.start_span.call_count >= 2 + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_workflow_trace_no_app_id(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_workflow_info() + info.metadata = {} + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(info) + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_message_trace_success(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_message_info() + info.message_data = MagicMock() + info.message_data.from_account_id = "acc1" + info.message_data.from_end_user_id = None + info.message_data.query = "q" + info.message_data.answer = "a" + info.message_data.status = "s" + info.message_data.model_id = "m" + info.message_data.model_provider = "p" + info.message_data.message_metadata = "{}" + info.message_data.error = None + info.error = None + + trace_instance.message_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_message_trace_with_error(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_message_info() + info.message_data = MagicMock() + info.message_data.from_account_id = "acc1" + info.message_data.from_end_user_id = None + info.message_data.query = "q" + info.message_data.answer = "a" + info.message_data.status = "s" + info.message_data.model_id = "m" + info.message_data.model_provider = "p" + info.message_data.message_metadata = "{}" + info.message_data.error = "processing failed" + info.error = "message error" + + trace_instance.message_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_trace_methods_return_early_with_no_message_data(trace_instance): + info = MagicMock() + info.message_data = None + + trace_instance.moderation_trace(info) + trace_instance.suggested_question_trace(info) + trace_instance.dataset_retrieval_trace(info) + trace_instance.tool_trace(info) + trace_instance.generate_name_trace(info) + + assert trace_instance.tracer.start_span.call_count == 0 + + +def test_moderation_trace_ok(trace_instance): + info = ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={}) + info.message_data = MagicMock() + info.message_data.error = None + trace_instance.moderation_trace(info) + # root span (1) + moderation span (1) = 2 + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_suggested_question_trace_ok(trace_instance): + info = SuggestedQuestionTraceInfo(suggested_question=["?"], total_tokens=1, level="i", metadata={}) + info.message_data = MagicMock() + info.error = None + trace_instance.suggested_question_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_dataset_retrieval_trace_ok(trace_instance): + info = DatasetRetrievalTraceInfo(documents=[], metadata={}) + info.message_data = MagicMock() + info.error = None + trace_instance.dataset_retrieval_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_tool_trace_ok(trace_instance): + info = ToolTraceInfo( + tool_name="t", tool_inputs={}, tool_outputs="o", metadata={}, tool_config={}, time_cost=1, tool_parameters={} + ) + info.message_data = MagicMock() + info.error = None + trace_instance.tool_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_generate_name_trace_ok(trace_instance): + info = GenerateNameTraceInfo(tenant_id="t", metadata={}) + info.message_data = MagicMock() + info.message_data.error = None + trace_instance.generate_name_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_get_project_url_phoenix(trace_instance): + trace_instance.arize_phoenix_config = PhoenixConfig(endpoint="http://p.com", project="p") + assert "p.com/projects/" in trace_instance.get_project_url() + + +def test_set_attribute_none_logic(trace_instance): + # Test role can be None + attrs = trace_instance._construct_llm_attributes([{"role": None, "content": "hi"}]) + assert "llm.input_messages.0.message.role" not in attrs + + # Test tool call id can be None + tool_call_none_id = {"id": None, "function": {"name": "f1"}} + attrs = trace_instance._construct_llm_attributes([{"role": "assistant", "tool_calls": [tool_call_none_id]}]) + assert "llm.input_messages.0.message.tool_calls.0.tool_call.id" not in attrs + + +def test_construct_llm_attributes_dict_branch(trace_instance): + attrs = trace_instance._construct_llm_attributes({"prompt": "hi"}) + assert '"prompt": "hi"' in attrs["llm.input_messages.0.message.content"] + assert attrs["llm.input_messages.0.message.role"] == "user" + + +def test_api_check_success(trace_instance): + assert trace_instance.api_check() is True + + +def test_ensure_root_span_basic(trace_instance): + trace_instance.ensure_root_span("tid") + assert "tid" in trace_instance.dify_trace_ids diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py new file mode 100644 index 0000000000..8e036e4b52 --- /dev/null +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -0,0 +1,698 @@ +import collections +import logging +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import LangfuseConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace +from dify_graph.enums import NodeType +from models import EndUser +from models.enums import MessageStatus + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +@pytest.fixture +def langfuse_config(): + return LangfuseConfig(public_key="pk-123", secret_key="sk-123", host="https://cloud.langfuse.com") + + +@pytest.fixture +def trace_instance(langfuse_config, monkeypatch): + # Mock Langfuse client to avoid network calls + mock_client = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + + instance = LangFuseDataTrace(langfuse_config) + return instance + + +def test_init(langfuse_config, monkeypatch): + mock_langfuse = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = LangFuseDataTrace(langfuse_config) + + mock_langfuse.assert_called_once_with( + public_key=langfuse_config.public_key, + secret_key=langfuse_config.secret_key, + host=langfuse_config.host, + ) + assert instance.file_base_url == "http://test.url" + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace_with_message_id(trace_instance, monkeypatch): + # Setup trace info + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_version="1.0", + message_id="msg-1", + conversation_id="conv-1", + total_tokens=100, + file_list=[], + query="hi", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"app_id": "app-1", "user_id": "user-1"}, + workflow_app_log_id="log-1", + error="", + ) + + # Mock DB and Repositories + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + + # Mock node executions + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {"foo": "bar"} + + node_other = MagicMock() + node_other.id = "node-other" + node_other.title = "Other Node" + node_other.node_type = NodeType.CODE + node_other.status = "failed" + node_other.process_data = None + node_other.inputs = {"code": "print"} + node_other.outputs = {"result": "ok"} + node_other.created_at = None # Trigger datetime.now() branch + node_other.elapsed_time = 0.2 + node_other.metadata = None + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + # Track calls to add_trace, add_span, add_generation + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.workflow_trace(trace_info) + + # Verify add_trace (Workflow Level) + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.id == "trace-1" + assert trace_data.name == TraceTaskName.MESSAGE_TRACE + assert "message" in trace_data.tags + assert "workflow" in trace_data.tags + + # Verify add_span (Workflow Run Span) + assert trace_instance.add_span.call_count >= 1 + # First span should be workflow run span because message_id is present + workflow_span = trace_instance.add_span.call_args_list[0][1]["langfuse_span_data"] + assert workflow_span.id == "run-1" + assert workflow_span.name == TraceTaskName.WORKFLOW_TRACE + + # Verify Generation for LLM node + trace_instance.add_generation.assert_called_once() + gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"] + assert gen_data.id == "node-llm" + assert gen_data.usage.input == 10 + assert gen_data.usage.output == 20 + + # Verify normal span for Other node + # Second add_span call + other_span = trace_instance.add_span.call_args_list[1][1]["langfuse_span_data"] + assert other_span.id == "node-other" + assert other_span.level == LevelEnum.ERROR + + +def test_workflow_trace_no_message_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + trace_id=None, # Should fallback to workflow_run_id + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + ) + + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.id == "run-1" + assert trace_data.name == TraceTaskName.WORKFLOW_TRACE + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + metadata={}, # Missing app_id + workflow_app_log_id="log-1", + error="", + ) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace_basic(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = None + message_data.provider_response_latency = 0.5 + message_data.conversation_id = "conv-1" + message_data.total_price = 0.01 + message_data.model_id = "gpt-4" + message_data.answer = "hello" + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"query": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + ) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_generation.assert_called_once() + + gen_data = trace_instance.add_generation.call_args[0][0] + assert gen_data.name == "llm" + assert gen_data.usage.total == 30 + + +def test_message_trace_with_end_user(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.conversation_id = "conv-1" + message_data.status = MessageStatus.NORMAL + message_data.model_id = "gpt-4" + message_data.error = "" + message_data.answer = "hello" + message_data.total_price = 0.0 + message_data.provider_response_latency = 0.1 + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={}, + outputs={}, + message_tokens=0, + answer_tokens=0, + total_tokens=0, + start_time=_dt(), + end_time=_dt(), + metadata={}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + ) + + # Mock DB session for EndUser lookup + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.user_id == "session-id-123" + assert trace_data.metadata["user_id"] == "session-id-123" + + +def test_message_trace_none_data(trace_instance): + trace_info = SimpleNamespace(message_data=None, file_list=[], metadata={}) + trace_instance.add_trace = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_trace.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={"foo": "bar"}, + trace_id="trace-1", + query="hi", + ) + + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == TraceTaskName.MODERATION_TRACE + assert span_data.output["flagged"] is True + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = SuggestedQuestionTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_generation = MagicMock() + trace_instance.suggested_question_trace(trace_info) + + trace_instance.add_generation.assert_called_once() + gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"] + assert gen_data.name == TraceTaskName.SUGGESTED_QUESTION_TRACE + assert gen_data.usage.unit == UnitEnum.CHARACTERS + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == TraceTaskName.DATASET_RETRIEVAL_TRACE + assert span_data.output["documents"] == [{"id": "doc1"}] + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="msg-1", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result_string", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + tool_config={}, + tool_parameters={}, + error="some error", + ) + + trace_instance.add_span = MagicMock() + trace_instance.tool_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == "my_tool" + assert span_data.level == LevelEnum.ERROR + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + metadata={"m": 1}, + ) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + trace_instance.generate_name_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.name == TraceTaskName.GENERATE_NAME_TRACE + assert trace_data.user_id == "tenant-1" + + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.trace_id == "conv-1" + + +def test_add_trace_success(trace_instance): + data = LangfuseTrace(id="t1", name="trace") + trace_instance.add_trace(data) + trace_instance.langfuse_client.trace.assert_called_once() + + +def test_add_trace_error(trace_instance): + trace_instance.langfuse_client.trace.side_effect = Exception("error") + data = LangfuseTrace(id="t1", name="trace") + with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"): + trace_instance.add_trace(data) + + +def test_add_span_success(trace_instance): + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + trace_instance.add_span(data) + trace_instance.langfuse_client.span.assert_called_once() + + +def test_add_span_error(trace_instance): + trace_instance.langfuse_client.span.side_effect = Exception("error") + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + with pytest.raises(ValueError, match="LangFuse Failed to create span: error"): + trace_instance.add_span(data) + + +def test_update_span(trace_instance): + span = MagicMock() + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + trace_instance.update_span(span, data) + span.end.assert_called_once() + + +def test_add_generation_success(trace_instance): + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + trace_instance.add_generation(data) + trace_instance.langfuse_client.generation.assert_called_once() + + +def test_add_generation_error(trace_instance): + trace_instance.langfuse_client.generation.side_effect = Exception("error") + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"): + trace_instance.add_generation(data) + + +def test_update_generation(trace_instance): + gen = MagicMock() + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + trace_instance.update_generation(gen, data) + gen.end.assert_called_once() + + +def test_api_check_success(trace_instance): + trace_instance.langfuse_client.auth_check.return_value = True + assert trace_instance.api_check() is True + + +def test_api_check_error(trace_instance): + trace_instance.langfuse_client.auth_check.side_effect = Exception("fail") + with pytest.raises(ValueError, match="LangFuse API check failed: fail"): + trace_instance.api_check() + + +def test_get_project_key_success(trace_instance): + mock_data = MagicMock() + mock_data.id = "proj-1" + trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data]) + assert trace_instance.get_project_key() == "proj-1" + + +def test_get_project_key_error(trace_instance): + trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail") + with pytest.raises(ValueError, match="LangFuse get project key failed: fail"): + trace_instance.get_project_key() + + +def test_moderation_trace_none(trace_instance): + trace_info = ModerationTraceInfo( + message_id="m", + message_data=None, + inputs={}, + action="s", + flagged=False, + preset_response="", + query="", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_suggested_question_trace_none(trace_instance): + trace_info = SuggestedQuestionTraceInfo( + message_id="m", message_data=None, inputs={}, suggested_question=[], total_tokens=0, level="i", metadata={} + ) + trace_instance.add_generation = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_generation.assert_not_called() + + +def test_dataset_retrieval_trace_none(trace_instance): + trace_info = DatasetRetrievalTraceInfo(message_id="m", message_data=None, inputs={}, documents=[], metadata={}) + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_langfuse_trace_entity_with_list_dict_input(): + # To cover lines 29-31 in langfuse_trace_entity.py + # We need to mock replace_text_with_content or just check if it works + # Actually replace_text_with_content is imported from core.ops.utils + data = LangfuseTrace(id="t1", name="n", input=[{"text": "hello"}]) + assert isinstance(data.input, list) + assert data.input[0]["content"] == "hello" + + +def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog): + # Setup trace info to trigger LLM node usage extraction + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="t", + workflow_run_id="r", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="c", + start_time=_dt(), + end_time=_dt(), + metadata={"app_id": "app-1"}, + workflow_app_log_id="l", + error="", + ) + + node = MagicMock() + node.id = "n1" + node.title = "LLM Node" + node.node_type = NodeType.LLM + node.status = "succeeded" + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node.created_at = _dt() + node.elapsed_time = 0.1 + node.metadata = {} + node.outputs = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + trace_instance.add_generation.assert_called_once() diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py new file mode 100644 index 0000000000..98f9dd00cf --- /dev/null +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -0,0 +1,608 @@ +import collections +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import LangSmithConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from models import EndUser + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0) + + +@pytest.fixture +def langsmith_config(): + return LangSmithConfig(api_key="ls-123", project="default", endpoint="https://api.smith.langchain.com") + + +@pytest.fixture +def trace_instance(langsmith_config, monkeypatch): + # Mock LangSmith client + mock_client = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + + instance = LangSmithDataTrace(langsmith_config) + return instance + + +def test_init(langsmith_config, monkeypatch): + mock_client_class = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = LangSmithDataTrace(langsmith_config) + + mock_client_class.assert_called_once_with(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint) + assert instance.langsmith_key == langsmith_config.api_key + assert instance.project_name == langsmith_config.project + assert instance.file_base_url == "http://test.url" + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace(trace_instance, monkeypatch): + # Setup trace info + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=100, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + # Mock dependencies + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + + # Mock node executions + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 30} + + node_other = MagicMock() + node_other.id = "node-other" + node_other.title = "Tool Node" + node_other.node_type = NodeType.TOOL + node_other.status = "succeeded" + node_other.process_data = None + node_other.inputs = {"tool_input": "val"} + node_other.outputs = {"tool_output": "val"} + node_other.created_at = None # Trigger datetime.now() + node_other.elapsed_time = 0.2 + node_other.metadata = {} + + node_retrieval = MagicMock() + node_retrieval.id = "node-retrieval" + node_retrieval.title = "Retrieval Node" + node_retrieval.node_type = NodeType.KNOWLEDGE_RETRIEVAL + node_retrieval.status = "succeeded" + node_retrieval.process_data = None + node_retrieval.inputs = {"query": "val"} + node_retrieval.outputs = {"results": "val"} + node_retrieval.created_at = _dt() + node_retrieval.elapsed_time = 0.2 + node_retrieval.metadata = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + + trace_instance.workflow_trace(trace_info) + + # Verify add_run calls + # 1. message run (id="msg-1") + # 2. workflow run (id="run-1") + # 3. node llm run (id="node-llm") + # 4. node other run (id="node-other") + # 5. node retrieval run (id="node-retrieval") + assert trace_instance.add_run.call_count == 5 + + call_args = [call[0][0] for call in trace_instance.add_run.call_args_list] + + assert call_args[0].id == "msg-1" + assert call_args[0].name == TraceTaskName.MESSAGE_TRACE + + assert call_args[1].id == "run-1" + assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE + assert call_args[1].parent_run_id == "msg-1" + + assert call_args[2].id == "node-llm" + assert call_args[2].run_type == LangSmithRunType.llm + + assert call_args[3].id == "node-other" + assert call_args[3].run_type == LangSmithRunType.tool + + assert call_args[4].id == "node-retrieval" + assert call_args[4].run_type == LangSmithRunType.retriever + + +def test_workflow_trace_no_start_time(trace_instance, monkeypatch): + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=10, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=None, + end_time=None, + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + trace_instance.workflow_trace(trace_info) + assert trace_instance.add_run.called + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.trace_id = "trace-1" + trace_info.message_id = None + trace_info.workflow_run_id = "run-1" + trace_info.start_time = None + trace_info.workflow_data = MagicMock() + trace_info.workflow_data.created_at = _dt() + trace_info.metadata = {} # Empty metadata + trace_info.workflow_app_log_id = "log-1" + trace_info.file_list = [] + trace_info.total_tokens = 0 + trace_info.workflow_run_inputs = {} + trace_info.workflow_run_outputs = {} + trace_info.error = "" + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.answer = "hello answer" + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"input": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + message_file_data=MagicMock(url="file-url"), + ) + + # Mock EndUser lookup + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_run = MagicMock() + + trace_instance.message_trace(trace_info) + + # 1. message run + # 2. llm run + assert trace_instance.add_run.call_count == 2 + + call_args = [call[0][0] for call in trace_instance.add_run.call_args_list] + assert call_args[0].id == "msg-1" + assert call_args[0].extra["metadata"]["end_user_id"] == "session-id-123" + assert call_args[1].parent_run_id == "msg-1" + assert call_args[1].name == "llm" + + +def test_message_trace_no_data(trace_instance): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_data = None + trace_info.file_list = [] + trace_info.message_file_data = None + trace_info.metadata = {} + trace_instance.add_run = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_moderation_trace_no_data(trace_instance): + trace_info = MagicMock(spec=ModerationTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_suggested_question_trace_no_data(trace_instance): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_dataset_retrieval_trace_no_data(trace_instance): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + query="hi", + ) + + trace_instance.add_run = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.MODERATION_TRACE + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = SuggestedQuestionTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.SUGGESTED_QUESTION_TRACE + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.DATASET_RETRIEVAL_TRACE + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="msg-1", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + tool_config={}, + tool_parameters={}, + file_url="http://file", + ) + + trace_instance.add_run = MagicMock() + trace_instance.tool_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == "my_tool" + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="conv-1", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.generate_name_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.GENERATE_NAME_TRACE + + +def test_add_run_success(trace_instance): + run_data = LangSmithRunModel( + id="run-1", name="test", inputs={}, outputs={}, run_type=LangSmithRunType.tool, start_time=_dt() + ) + trace_instance.project_id = "proj-1" + trace_instance.add_run(run_data) + trace_instance.langsmith_client.create_run.assert_called_once() + args, kwargs = trace_instance.langsmith_client.create_run.call_args + assert kwargs["session_id"] == "proj-1" + + +def test_add_run_error(trace_instance): + run_data = LangSmithRunModel(id="run-1", name="test", run_type=LangSmithRunType.tool, start_time=_dt()) + trace_instance.langsmith_client.create_run.side_effect = Exception("failed") + with pytest.raises(ValueError, match="LangSmith Failed to create run: failed"): + trace_instance.add_run(run_data) + + +def test_update_run_success(trace_instance): + update_data = LangSmithRunUpdateModel(run_id="run-1", outputs={"out": "val"}) + trace_instance.update_run(update_data) + trace_instance.langsmith_client.update_run.assert_called_once() + + +def test_update_run_error(trace_instance): + update_data = LangSmithRunUpdateModel(run_id="run-1") + trace_instance.langsmith_client.update_run.side_effect = Exception("failed") + with pytest.raises(ValueError, match="LangSmith Failed to update run: failed"): + trace_instance.update_run(update_data) + + +def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog): + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=100, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node_llm.inputs = {} + node_llm.outputs = {} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + + import logging + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + + +def test_api_check_success(trace_instance): + assert trace_instance.api_check() is True + assert trace_instance.langsmith_client.create_project.called + assert trace_instance.langsmith_client.delete_project.called + + +def test_api_check_error(trace_instance): + trace_instance.langsmith_client.create_project.side_effect = Exception("error") + with pytest.raises(ValueError, match="LangSmith API check failed: error"): + trace_instance.api_check() + + +def test_get_project_url_success(trace_instance): + trace_instance.langsmith_client.get_run_url.return_value = "https://smith.langchain.com/o/org/p/proj/r/run" + url = trace_instance.get_project_url() + assert url == "https://smith.langchain.com/o/org/p/proj" + + +def test_get_project_url_error(trace_instance): + trace_instance.langsmith_client.get_run_url.side_effect = Exception("error") + with pytest.raises(ValueError, match="LangSmith get run url failed: error"): + trace_instance.get_project_url() diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py new file mode 100644 index 0000000000..0657acc1d9 --- /dev/null +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -0,0 +1,1019 @@ +"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" + +from __future__ import annotations + +import json +import os +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds +from dify_graph.enums import NodeType + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "wf-id", + "tenant_id": "tenant", + "workflow_run_id": "run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"key": "val"}, + "workflow_run_outputs": {"answer": "42"}, + "workflow_run_version": "v1", + "total_tokens": 10, + "file_list": [], + "query": "hello", + "metadata": {"user_id": "u1", "conversation_id": "c1"}, + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 10, + "total_tokens": 15, + "conversation_mode": "chat", + "metadata": {"conversation_id": "c1", "from_account_id": "a1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace( + model_provider="openai", + model_id="gpt-4", + total_price=0.01, + answer="response text", + ), + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "my_tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "output", + "tool_config": {"desc": "d"}, + "tool_parameters": {"p": "v"}, + "time_cost": 0.5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_moderation_trace_info(**overrides) -> ModerationTraceInfo: + defaults = { + "flagged": False, + "action": "allow", + "preset_response": "", + "query": "test", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults = { + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(), + "inputs": "query", + "documents": [{"content": "doc"}], + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(created_at=_dt(), updated_at=_dt()), + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def _make_generate_name_trace_info(**overrides) -> GenerateNameTraceInfo: + defaults = { + "tenant_id": "t1", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": 1}, + "outputs": {"name": "test"}, + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def _make_node(**overrides): + """Create a mock workflow node execution row.""" + defaults = { + "id": "node-1", + "tenant_id": "t1", + "app_id": "app-1", + "title": "Node Title", + "node_type": NodeType.CODE, + "status": "succeeded", + "inputs": '{"key": "value"}', + "outputs": '{"result": "ok"}', + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": None, + "execution_metadata": None, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_mlflow(): + with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + yield mock + + +@pytest.fixture +def mock_tracing(): + """Patch all MLflow tracing functions used by the module.""" + with ( + patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, + patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, + patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, + patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + ): + yield { + "start": mock_start, + "update": mock_update, + "set": mock_set, + "detach": mock_detach, + } + + +@pytest.fixture +def mock_db(): + with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + yield mock + + +@pytest.fixture +def trace_instance(mock_mlflow): + """Create an MLflowDataTrace using a basic MLflowConfig (no auth).""" + config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0") + return MLflowDataTrace(config) + + +# ── datetime_to_nanoseconds ───────────────────────────────────────────────── + + +class TestDatetimeToNanoseconds: + def test_none_returns_none(self): + assert datetime_to_nanoseconds(None) is None + + def test_converts_datetime(self): + dt = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + expected = int(dt.timestamp() * 1_000_000_000) + assert datetime_to_nanoseconds(dt) == expected + + +# ── __init__ / setup ───────────────────────────────────────────────────────── + + +class TestInit: + def test_mlflow_config_no_auth(self, mock_mlflow): + config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0") + trace = MLflowDataTrace(config) + mock_mlflow.set_tracking_uri.assert_called_with("http://localhost:5000") + mock_mlflow.set_experiment.assert_called_with(experiment_id="0") + assert trace.get_project_url() == "http://localhost:5000/#/experiments/0/traces" + assert os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] == "true" + + def test_mlflow_config_with_auth(self, mock_mlflow): + config = MLflowConfig( + tracking_uri="http://localhost:5000", + experiment_id="1", + username="user", + password="pass", + ) + MLflowDataTrace(config) + assert os.environ["MLFLOW_TRACKING_USERNAME"] == "user" + assert os.environ["MLFLOW_TRACKING_PASSWORD"] == "pass" + + def test_databricks_oauth(self, mock_mlflow): + config = DatabricksConfig( + host="https://db.com/", + experiment_id="42", + client_id="cid", + client_secret="csec", + ) + trace = MLflowDataTrace(config) + assert os.environ["DATABRICKS_HOST"] == "https://db.com/" + assert os.environ["DATABRICKS_CLIENT_ID"] == "cid" + assert os.environ["DATABRICKS_CLIENT_SECRET"] == "csec" + mock_mlflow.set_tracking_uri.assert_called_with("databricks") + # Trailing slash stripped + assert trace.get_project_url() == "https://db.com/ml/experiments/42/traces" + + def test_databricks_pat(self, mock_mlflow): + config = DatabricksConfig( + host="https://db.com", + experiment_id="1", + personal_access_token="pat", + ) + trace = MLflowDataTrace(config) + assert os.environ["DATABRICKS_TOKEN"] == "pat" + assert "db.com/ml/experiments/1/traces" in trace.get_project_url() + + def test_databricks_no_creds_raises(self, mock_mlflow): + config = DatabricksConfig(host="https://db.com", experiment_id="1") + with pytest.raises(ValueError, match="Either Databricks token"): + MLflowDataTrace(config) + + +# ── trace dispatcher ──────────────────────────────────────────────────────── + + +class TestTraceDispatcher: + def test_dispatches_workflow(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "workflow_trace") as mock_wt: + trace_instance.trace(_make_workflow_trace_info()) + mock_wt.assert_called_once() + + def test_dispatches_message(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "message_trace") as mock_mt: + trace_instance.trace(_make_message_trace_info()) + mock_mt.assert_called_once() + + def test_dispatches_tool(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "tool_trace") as mock_tt: + trace_instance.trace(_make_tool_trace_info()) + mock_tt.assert_called_once() + + def test_dispatches_moderation(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "moderation_trace") as mock_mod: + trace_instance.trace(_make_moderation_trace_info(message_data=SimpleNamespace(created_at=_dt()))) + mock_mod.assert_called_once() + + def test_dispatches_dataset_retrieval(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "dataset_retrieval_trace") as mock_dr: + trace_instance.trace(_make_dataset_retrieval_trace_info()) + mock_dr.assert_called_once() + + def test_dispatches_suggested_question(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "suggested_question_trace") as mock_sq: + trace_instance.trace(_make_suggested_question_trace_info()) + mock_sq.assert_called_once() + + def test_dispatches_generate_name(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "generate_name_trace") as mock_gn: + trace_instance.trace(_make_generate_name_trace_info()) + mock_gn.assert_called_once() + + def test_reraises_exception(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("boom")): + with pytest.raises(RuntimeError, match="boom"): + trace_instance.trace(_make_workflow_trace_info()) + + +# ── workflow_trace ─────────────────────────────────────────────────────────── + + +class TestWorkflowTrace: + def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(conversation_id="sess-1") + trace_instance.workflow_trace(trace_info) + + # Workflow span started and ended + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info( + workflow_run_inputs={"sys.app_id": "x", "user_input": "hi"}, + query="hello", + ) + trace_instance.workflow_trace(trace_info) + + call_kwargs = mock_tracing["start"].call_args + inputs = call_kwargs.kwargs["inputs"] + assert "sys.app_id" not in inputs + assert inputs["user_input"] == "hi" + assert inputs["query"] == "hello" + + def test_workflow_with_llm_node(self, trace_instance, mock_tracing, mock_db): + llm_node = _make_node( + node_type=NodeType.LLM, + process_data=json.dumps( + { + "prompts": [{"role": "user", "text": "hi"}], + "model_name": "gpt-4", + "model_provider": "openai", + "finish_reason": "stop", + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + ), + outputs='{"text": "hello world"}', + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + assert mock_tracing["start"].call_count == 2 + node_span.end.assert_called_once() + workflow_span.end.assert_called_once() + + def test_workflow_with_question_classifier_node(self, trace_instance, mock_tracing, mock_db): + qc_node = _make_node( + node_type=NodeType.QUESTION_CLASSIFIER, + process_data=json.dumps( + { + "prompts": "classify this", + "model_name": "gpt-4", + "model_provider": "openai", + } + ), + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + assert mock_tracing["start"].call_count == 2 + + def test_workflow_with_http_request_node(self, trace_instance, mock_tracing, mock_db): + http_node = _make_node( + node_type=NodeType.HTTP_REQUEST, + process_data='{"url": "https://api.com"}', + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + # HTTP_REQUEST uses process_data as inputs + node_start_call = mock_tracing["start"].call_args_list[1] + assert node_start_call.kwargs["inputs"] == '{"url": "https://api.com"}' + + def test_workflow_with_knowledge_retrieval_node(self, trace_instance, mock_tracing, mock_db): + kr_node = _make_node( + node_type=NodeType.KNOWLEDGE_RETRIEVAL, + outputs=json.dumps( + { + "result": [ + {"content": "doc1", "metadata": {"source": "s1"}}, + {"content": "doc2", "metadata": {}}, + ] + } + ), + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + # outputs should be parsed to Document objects + end_call = node_span.end.call_args + outputs = end_call.kwargs["outputs"] + assert len(outputs) == 2 + + def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): + failed_node = _make_node(status="failed") + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + node_span.set_status.assert_called_once() + node_span.add_event.assert_called_once() + + def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + workflow_span = MagicMock() + mock_tracing["start"].return_value = workflow_span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(error="workflow failed") + trace_instance.workflow_trace(trace_info) + workflow_span.set_status.assert_called_once() + workflow_span.add_event.assert_called_once() + # Still ends the span via finally + workflow_span.end.assert_called_once() + + def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): + node = _make_node(inputs=None, outputs=None) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + node_call = mock_tracing["start"].call_args_list[1] + assert node_call.kwargs["inputs"] == {} + end_call = node_span.end.call_args + assert end_call.kwargs["outputs"] == {} + + def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info( + metadata={}, + conversation_id=None, + ) + trace_instance.workflow_trace(trace_info) + # _set_trace_metadata still called with empty metadata + mock_tracing["update"].assert_called_once() + + def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): + """When query is empty string, it's falsy so no query key added.""" + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(query="") + trace_instance.workflow_trace(trace_info) + call_kwargs = mock_tracing["start"].call_args + inputs = call_kwargs.kwargs["inputs"] + assert "query" not in inputs + + +# ── _parse_llm_inputs_and_attributes ───────────────────────────────────────── + + +class TestParseLlmInputsAndAttributes: + def test_none_process_data(self, trace_instance): + node = _make_node(process_data=None) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == {} + assert attrs == {} + + def test_invalid_json(self, trace_instance): + node = _make_node(process_data="not json") + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == {} + assert attrs == {} + + def test_valid_process_data_with_usage(self, trace_instance): + node = _make_node( + process_data=json.dumps( + { + "prompts": [{"role": "user", "text": "hi"}], + "model_name": "gpt-4", + "model_provider": "openai", + "finish_reason": "stop", + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + ) + ) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert isinstance(inputs, list) + assert attrs["model_name"] == "gpt-4" + assert "usage" in attrs + + def test_valid_process_data_without_usage(self, trace_instance): + node = _make_node( + process_data=json.dumps( + { + "prompts": "simple prompt", + "model_name": "gpt-3.5", + } + ) + ) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == "simple prompt" + assert attrs["model_name"] == "gpt-3.5" + + +# ── _parse_knowledge_retrieval_outputs ─────────────────────────────────────── + + +class TestParseKnowledgeRetrievalOutputs: + def test_with_results(self, trace_instance): + outputs = {"result": [{"content": "c1", "metadata": {"s": "1"}}]} + docs = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert len(docs) == 1 + assert docs[0].page_content == "c1" + + def test_empty_result(self, trace_instance): + outputs = {"result": []} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + def test_no_result_key(self, trace_instance): + outputs = {"other": "data"} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + def test_result_not_list(self, trace_instance): + outputs = {"result": "not a list"} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + +# ── message_trace ──────────────────────────────────────────────────────────── + + +class TestMessageTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing, mock_db): + trace_info = _make_message_trace_info(message_data=None) + trace_instance.message_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_message_trace(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_instance.message_trace(_make_message_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_message_trace_with_error(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info(error="something broke") + trace_instance.message_trace(trace_info) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + + def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setenv("FILES_URL", "http://files.test") + + file_data = SimpleNamespace(url="path/to/file.png") + trace_info = _make_message_trace_info( + message_file_data=file_data, + file_list=["existing_file.txt"], + ) + trace_instance.message_trace(trace_info) + call_kwargs = mock_tracing["start"].call_args + attrs = call_kwargs.kwargs["attributes"] + assert "http://files.test/path/to/file.png" in attrs["file_list"] + assert "existing_file.txt" in attrs["file_list"] + + def test_message_trace_file_list_none(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info(file_list=None, message_file_data=None) + trace_instance.message_trace(trace_info) + mock_tracing["start"].assert_called_once() + + def test_message_trace_with_end_user(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + end_user = MagicMock() + end_user.session_id = "session-xyz" + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + + trace_info = _make_message_trace_info( + metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, + ) + trace_instance.message_trace(trace_info) + # update_current_trace called with user id from EndUser + mock_tracing["update"].assert_called_once() + + def test_message_trace_with_no_conversation_id(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info( + metadata={"from_account_id": "acc-1"}, + ) + trace_instance.message_trace(trace_info) + mock_tracing["update"].assert_called_once() + + +# ── _get_message_user_id ───────────────────────────────────────────────────── + + +class TestGetMessageUserId: + def test_returns_end_user_session_id(self, trace_instance, mock_db): + end_user = MagicMock() + end_user.session_id = "session-1" + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) + assert result == "session-1" + + def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): + mock_db.session.query.return_value.where.return_value.first.return_value = None + result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) + assert result == "acc-1" + + def test_returns_account_id_when_no_end_user_id(self, trace_instance, mock_db): + result = trace_instance._get_message_user_id({"from_account_id": "acc-1"}) + assert result == "acc-1" + + def test_returns_none_when_nothing(self, trace_instance, mock_db): + result = trace_instance._get_message_user_id({}) + assert result is None + + +# ── tool_trace ─────────────────────────────────────────────────────────────── + + +class TestToolTrace: + def test_basic_tool_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.tool_trace(_make_tool_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + span.set_status.assert_not_called() + + def test_tool_trace_with_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.tool_trace(_make_tool_trace_info(error="tool failed")) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + span.end.assert_called_once() + + +# ── moderation_trace ───────────────────────────────────────────────────────── + + +class TestModerationTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_moderation_trace_info(message_data=None) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_moderation_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_moderation_trace_info( + message_data=SimpleNamespace(created_at=_dt()), + start_time=_dt(), + end_time=_dt(), + ) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + end_kwargs = span.end.call_args.kwargs["outputs"] + assert end_kwargs["action"] == "allow" + assert end_kwargs["flagged"] is False + + def test_moderation_uses_message_data_created_at_if_no_start_time(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_moderation_trace_info( + message_data=SimpleNamespace(created_at=_dt()), + start_time=None, + end_time=_dt(), + ) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_called_once() + + +# ── dataset_retrieval_trace ────────────────────────────────────────────────── + + +class TestDatasetRetrievalTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.dataset_retrieval_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_dataset_retrieval_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── suggested_question_trace ───────────────────────────────────────────────── + + +class TestSuggestedQuestionTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_suggested_question_trace_info(message_data=None) + trace_instance.suggested_question_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_suggested_question_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.suggested_question_trace(_make_suggested_question_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_suggested_question_with_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_suggested_question_trace_info(error="failed") + trace_instance.suggested_question_trace(trace_info) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + + def test_uses_message_data_times_when_no_start_end(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_suggested_question_trace_info( + start_time=None, + end_time=None, + ) + trace_instance.suggested_question_trace(trace_info) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── generate_name_trace ────────────────────────────────────────────────────── + + +class TestGenerateNameTrace: + def test_basic_generate_name_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.generate_name_trace(_make_generate_name_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── _get_workflow_nodes ────────────────────────────────────────────────────── + + +class TestGetWorkflowNodes: + def test_queries_db(self, trace_instance, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + result = trace_instance._get_workflow_nodes("run-1") + assert result == ["n1", "n2"] + + +# ── _get_node_span_type ───────────────────────────────────────────────────── + + +class TestGetNodeSpanType: + @pytest.mark.parametrize( + ("node_type", "expected_contains"), + [ + (NodeType.LLM, "LLM"), + (NodeType.QUESTION_CLASSIFIER, "LLM"), + (NodeType.KNOWLEDGE_RETRIEVAL, "RETRIEVER"), + (NodeType.TOOL, "TOOL"), + (NodeType.CODE, "TOOL"), + (NodeType.HTTP_REQUEST, "TOOL"), + (NodeType.AGENT, "AGENT"), + ], + ) + def test_mapped_types(self, trace_instance, node_type, expected_contains): + result = trace_instance._get_node_span_type(node_type) + assert expected_contains in str(result) + + def test_unknown_type_returns_chain(self, trace_instance): + result = trace_instance._get_node_span_type("unknown_node") + assert result == "CHAIN" + + +# ── _set_trace_metadata ───────────────────────────────────────────────────── + + +class TestSetTraceMetadata: + def test_sets_and_detaches(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = "token" + + trace_instance._set_trace_metadata(span, {"key": "val"}) + mock_tracing["set"].assert_called_once_with(span) + mock_tracing["update"].assert_called_once_with(metadata={"key": "val"}) + mock_tracing["detach"].assert_called_once_with("token") + + def test_detaches_even_on_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = "token" + mock_tracing["update"].side_effect = RuntimeError("fail") + + with pytest.raises(RuntimeError): + trace_instance._set_trace_metadata(span, {}) + mock_tracing["detach"].assert_called_once_with("token") + + def test_no_detach_when_token_is_none(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = None + + trace_instance._set_trace_metadata(span, {}) + mock_tracing["detach"].assert_not_called() + + +# ── _parse_prompts ─────────────────────────────────────────────────────────── + + +class TestParsePrompts: + def test_string_input(self, trace_instance): + assert trace_instance._parse_prompts("hello") == "hello" + + def test_dict_input(self, trace_instance): + result = trace_instance._parse_prompts({"role": "user", "text": "hi"}) + assert result == {"role": "user", "content": "hi"} + + def test_list_input(self, trace_instance): + prompts = [ + {"role": "user", "text": "hi"}, + {"role": "assistant", "text": "hello"}, + ] + result = trace_instance._parse_prompts(prompts) + assert len(result) == 2 + assert result[0]["role"] == "user" + + def test_none_input(self, trace_instance): + assert trace_instance._parse_prompts(None) is None + + def test_int_passthrough(self, trace_instance): + assert trace_instance._parse_prompts(42) == 42 + + +# ── _parse_single_message ─────────────────────────────────────────────────── + + +class TestParseSingleMessage: + def test_basic_message(self, trace_instance): + result = trace_instance._parse_single_message({"role": "user", "text": "hello"}) + assert result == {"role": "user", "content": "hello"} + + def test_default_role(self, trace_instance): + result = trace_instance._parse_single_message({"text": "hello"}) + assert result["role"] == "user" + + def test_with_tool_calls(self, trace_instance): + item = { + "role": "assistant", + "text": "", + "tool_calls": [{"id": "tc1", "function": {"name": "fn"}}], + } + result = trace_instance._parse_single_message(item) + assert "tool_calls" in result + + def test_tool_role_ignores_tool_calls(self, trace_instance): + item = { + "role": "tool", + "text": "result", + "tool_calls": [{"id": "tc1"}], + } + result = trace_instance._parse_single_message(item) + assert "tool_calls" not in result + + def test_with_files(self, trace_instance): + item = {"role": "user", "text": "look", "files": ["f1.png"]} + result = trace_instance._parse_single_message(item) + assert result["files"] == ["f1.png"] + + def test_no_files(self, trace_instance): + result = trace_instance._parse_single_message({"role": "user", "text": "hi"}) + assert "files" not in result + + +# ── _resolve_tool_call_ids ─────────────────────────────────────────────────── + + +class TestResolveToolCallIds: + def test_resolves_tool_call_ids(self, trace_instance): + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "tc1"}, {"id": "tc2"}], + }, + {"role": "tool", "content": "result1"}, + {"role": "tool", "content": "result2"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert result[1]["tool_call_id"] == "tc1" + assert result[2]["tool_call_id"] == "tc2" + + def test_no_tool_calls(self, trace_instance): + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert "tool_call_id" not in result[0] + assert "tool_call_id" not in result[1] + + def test_tool_message_no_ids_available(self, trace_instance): + """Tool message with no preceding tool_calls should not crash.""" + messages = [ + {"role": "tool", "content": "result"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert "tool_call_id" not in result[0] + + +# ── api_check ──────────────────────────────────────────────────────────────── + + +class TestApiCheck: + def test_success(self, trace_instance, mock_mlflow): + mock_mlflow.search_experiments.return_value = [] + assert trace_instance.api_check() is True + + def test_failure(self, trace_instance, mock_mlflow): + mock_mlflow.search_experiments.side_effect = ConnectionError("refused") + with pytest.raises(ValueError, match="MLflow connection failed"): + trace_instance.api_check() + + +# ── get_project_url ────────────────────────────────────────────────────────── + + +class TestGetProjectUrl: + def test_returns_url(self, trace_instance): + assert "experiments" in trace_instance.get_project_url() diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py new file mode 100644 index 0000000000..80a0331c4b --- /dev/null +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -0,0 +1,678 @@ +import collections +import logging +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import OpikConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from models import EndUser +from models.enums import MessageStatus + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +@pytest.fixture +def opik_config(): + return OpikConfig( + project="test-project", workspace="test-workspace", url="https://cloud.opik.com/api/", api_key="api-key-123" + ) + + +@pytest.fixture +def trace_instance(opik_config, monkeypatch): + mock_client = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + + instance = OpikDataTrace(opik_config) + return instance + + +def test_wrap_dict(): + assert wrap_dict("input", {"a": 1}) == {"a": 1} + assert wrap_dict("input", "hello") == {"input": "hello"} + + +def test_wrap_metadata(): + assert wrap_metadata({"a": 1}, b=2) == {"a": 1, "b": 2, "created_from": "dify"} + + +def test_prepare_opik_uuid(): + # Test with valid datetime and uuid string + dt = datetime(2024, 1, 1) + uuid_str = "b3e8e918-472e-4b69-8051-12502c34fc07" + result = prepare_opik_uuid(dt, uuid_str) + assert result is not None + # We won't test the exact uuid7 value but just that it returns a string id + + # Test with None dt and uuid_str + result = prepare_opik_uuid(None, None) + assert result is not None + + +def test_init(opik_config, monkeypatch): + mock_opik = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = OpikDataTrace(opik_config) + + mock_opik.assert_called_once_with( + project_name=opik_config.project, + workspace=opik_config.workspace, + host=opik_config.url, + api_key=opik_config.api_key, + ) + assert instance.file_base_url == "http://test.url" + assert instance.project == opik_config.project + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace_with_message_id(trace_instance, monkeypatch): + # Define constants for better readability + WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce" + WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef" + MESSAGE_ID = "04ec3956-85f3-488a-8539-1017251dc8c6" + CONVERSATION_ID = "d3d01066-23ae-4830-9ce4-eb5640b42a7e" + TRACE_ID = "bf26d929-6f15-4c2f-9abc-761c217056f3" + WORKFLOW_APP_LOG_ID = "ca0e018e-edd4-43fb-a05a-ea001ca8ef4b" + LLM_NODE_ID = "80d7dfa8-08f4-4ab7-aa37-0ca7d27207e3" + CODE_NODE_ID = "b9cd9a7b-c534-4aa9-b5da-efd454140900" + + trace_info = WorkflowTraceInfo( + workflow_id=WORKFLOW_ID, + tenant_id="tenant-1", + workflow_run_id=WORKFLOW_RUN_ID, + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_version="1.0", + message_id=MESSAGE_ID, + conversation_id=CONVERSATION_ID, + total_tokens=100, + file_list=[], + query="hi", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id=TRACE_ID, + metadata={"app_id": "app-1", "user_id": "user-1"}, + workflow_app_log_id=WORKFLOW_APP_LOG_ID, + error="", + ) + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + + node_llm = MagicMock() + node_llm.id = LLM_NODE_ID + node_llm.title = "LLM Node" + node_llm.node_type = NodeType.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {"foo": "bar"} + + node_other = MagicMock() + node_other.id = CODE_NODE_ID + node_other.title = "Other Node" + node_other.node_type = NodeType.CODE + node_other.status = "failed" + node_other.process_data = None + node_other.inputs = {"code": "print"} + node_other.outputs = {"result": "ok"} + node_other.created_at = None + node_other.elapsed_time = 0.2 + node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1].get("opik_trace_data", trace_instance.add_trace.call_args[0][0]) + assert trace_data["name"] == TraceTaskName.MESSAGE_TRACE + assert "message" in trace_data["tags"] + assert "workflow" in trace_data["tags"] + + assert trace_instance.add_span.call_count >= 1 + + +def test_workflow_trace_no_message_id(trace_instance, monkeypatch): + # Define constants for better readability + WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3" + WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac" + CONVERSATION_ID = "88a17f2e-9436-4472-bab9-4b1601d5af3c" + WORKFLOW_APP_LOG_ID = "41780d0d-ffba-4220-bc0c-401e4c89cdfb" + + trace_info = WorkflowTraceInfo( + workflow_id=WORKFLOW_ID, + tenant_id="tenant-1", + workflow_run_id=WORKFLOW_RUN_ID, + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id=CONVERSATION_ID, + start_time=_dt(), + end_time=_dt(), + trace_id=None, + metadata={"app_id": "app-1"}, + workflow_app_log_id=WORKFLOW_APP_LOG_ID, + error="", + ) + + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a", + tenant_id="tenant-1", + workflow_run_id="46f53304-1659-464b-bee5-116585f0bec8", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="83f86b89-caef-4de8-a0f9-f164eddae1ea", + start_time=_dt(), + end_time=_dt(), + metadata={}, + workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", + error="", + ) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace_basic(trace_instance, monkeypatch): + # Define constants for better readability + MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab" + CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3" + MESSAGE_TRACE_ID = "710ace2f-bca8-41be-858c-54da42742a77" + OPIT_TRACE_ID = "f7dfd978-0d10-4549-8abf-00f2cbc49d2c" + + message_data = MagicMock() + message_data.id = MESSAGE_DATA_ID + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = None + message_data.provider_response_latency = 0.5 + message_data.conversation_id = CONVERSATION_ID + message_data.total_price = 0.01 + message_data.model_id = "gpt-4" + message_data.answer = "hello" + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = MessageTraceInfo( + message_id=MESSAGE_TRACE_ID, + message_data=message_data, + inputs={"query": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id=OPIT_TRACE_ID, + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + message_file_data=MagicMock(url="test.png"), + ) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_1")) + trace_instance.add_span = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + +def test_message_trace_with_end_user(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.conversation_id = "7d9f96d8-3be2-4e93-9c0e-922ff98dccc6" + message_data.status = MessageStatus.NORMAL + message_data.model_id = "gpt-4" + message_data.error = "" + message_data.answer = "hello" + message_data.total_price = 0.0 + message_data.provider_response_latency = 0.1 + + trace_info = MessageTraceInfo( + message_id="6bff35c7-33b7-4acb-ba21-44569a0327d0", + message_data=message_data, + inputs={}, + outputs={}, + message_tokens=0, + answer_tokens=0, + total_tokens=0, + start_time=_dt(), + end_time=_dt(), + metadata={}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=["url1"], + error=None, + ) + + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) + trace_instance.add_span = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_data = trace_instance.add_trace.call_args[0][0] + assert trace_data["metadata"]["user_id"] == "acc-1" + assert trace_data["metadata"]["end_user_id"] == "session-id-123" + + +def test_message_trace_none_data(trace_instance): + trace_info = SimpleNamespace(message_data=None, file_list=[], message_file_data=None, metadata={}) + trace_instance.add_trace = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_trace.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="489d0dfd-065c-4106-8f9c-daded296c92d", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={"foo": "bar"}, + trace_id="6f16cf18-9f4b-4955-8b6b-43cfa10978fc", + query="hi", + ) + + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.MODERATION_TRACE + assert span_data["output"]["flagged"] is True + + +def test_moderation_trace_none(trace_instance): + trace_info = ModerationTraceInfo( + message_id="cd732e4e-37f1-4c7e-8c64-820308bedcbf", + message_data=None, + inputs={}, + action="s", + flagged=False, + preset_response="", + query="", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = SuggestedQuestionTraceInfo( + message_id="7de55bda-a91d-477e-98ab-85c53c438469", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="a6687292-68c7-42ba-ae51-285579944d7b", + ) + + trace_instance.add_span = MagicMock() + trace_instance.suggested_question_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.SUGGESTED_QUESTION_TRACE + + +def test_suggested_question_trace_none(trace_instance): + trace_info = SuggestedQuestionTraceInfo( + message_id="23696fc5-7e7f-46ec-bce8-1adc3c7f297d", + message_data=None, + inputs={}, + suggested_question=[], + total_tokens=0, + level="i", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="3e1a819f-c391-4950-adfd-96f82e5419a1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="41361000-e9be-4d11-b5e4-ab27ce0817d6", + ) + + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.DATASET_RETRIEVAL_TRACE + + +def test_dataset_retrieval_trace_none(trace_instance): + trace_info = DatasetRetrievalTraceInfo( + message_id="35d6d44c-bccb-4e6e-8bd8-859257723ea8", message_data=None, inputs={}, documents=[], metadata={} + ) + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="99db92c4-2254-496a-b5cc-18153315ce35", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result_string", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="a15a5fcb-7ffd-4458-8330-208f4cb1f796", + tool_config={}, + tool_parameters={}, + error="some error", + ) + + trace_instance.add_span = MagicMock() + trace_instance.tool_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == "my_tool" + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="271fe28f-6b86-416b-8d6b-bbbbfa9db791", + start_time=_dt(), + end_time=_dt(), + metadata={"921f010e-6878-4831-ae6b-271bf68c56fb": 1}, + ) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_3")) + trace_instance.add_span = MagicMock() + + trace_instance.generate_name_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + trace_data = trace_instance.add_trace.call_args[0][0] + assert trace_data["name"] == TraceTaskName.GENERATE_NAME_TRACE + + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["trace_id"] == "trace_id_3" + + +def test_add_trace_success(trace_instance): + trace_data = {"id": "t1", "name": "trace"} + trace_instance.opik_client.trace.return_value = MagicMock(id="t1") + trace = trace_instance.add_trace(trace_data) + trace_instance.opik_client.trace.assert_called_once() + assert trace.id == "t1" + + +def test_add_trace_error(trace_instance): + trace_instance.opik_client.trace.side_effect = Exception("error") + trace_data = {"id": "t1", "name": "trace"} + with pytest.raises(ValueError, match="Opik Failed to create trace: error"): + trace_instance.add_trace(trace_data) + + +def test_add_span_success(trace_instance): + span_data = {"id": "s1", "name": "span", "trace_id": "t1"} + trace_instance.add_span(span_data) + trace_instance.opik_client.span.assert_called_once() + + +def test_add_span_error(trace_instance): + trace_instance.opik_client.span.side_effect = Exception("error") + span_data = {"id": "s1", "name": "span", "trace_id": "t1"} + with pytest.raises(ValueError, match="Opik Failed to create span: error"): + trace_instance.add_span(span_data) + + +def test_api_check_success(trace_instance): + trace_instance.opik_client.auth_check.return_value = True + assert trace_instance.api_check() is True + + +def test_api_check_error(trace_instance): + trace_instance.opik_client.auth_check.side_effect = Exception("fail") + with pytest.raises(ValueError, match="Opik API check failed: fail"): + trace_instance.api_check() + + +def test_get_project_url_success(trace_instance): + trace_instance.opik_client.get_project_url.return_value = "http://project.url" + assert trace_instance.get_project_url() == "http://project.url" + trace_instance.opik_client.get_project_url.assert_called_once_with(project_name=trace_instance.project) + + +def test_get_project_url_error(trace_instance): + trace_instance.opik_client.get_project_url.side_effect = Exception("fail") + with pytest.raises(ValueError, match="Opik get run url failed: fail"): + trace_instance.get_project_url() + + +def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog): + trace_info = WorkflowTraceInfo( + workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de", + tenant_id="66e8e918-472e-4b69-8051-12502c34fc07", + workflow_run_id="8403965c-3344-4d22-a8fe-d8d55cee64d9", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="7a02cb9d-6949-4c59-a89d-f25bbc881e0e", + start_time=_dt(), + end_time=_dt(), + metadata={"app_id": "77e8e918-472e-4b69-8051-12502c34fc07"}, + workflow_app_log_id="82268424-e193-476c-a6db-f473388ee5fe", + error="", + ) + + node = MagicMock() + node.id = "88e8e918-472e-4b69-8051-12502c34fc07" + node.title = "LLM Node" + node.node_type = NodeType.LLM + node.status = "succeeded" + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node.created_at = _dt() + node.elapsed_time = 0.1 + node.metadata = {} + node.outputs = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + assert trace_instance.add_span.call_count >= 1 + # Verify that at least one of the spans is for the LLM Node + span_names = [call.args[0]["name"] for call in trace_instance.add_span.call_args_list] + assert "LLM Node" in span_names diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/tests/unit_tests/core/ops/tencent_trace/test_client.py new file mode 100644 index 0000000000..870c18e53e --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_client.py @@ -0,0 +1,583 @@ +"""Tests for the TencentTraceClient helpers that drive tracing and metrics.""" + +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from opentelemetry.sdk.trace import Event +from opentelemetry.trace import Status, StatusCode + +from core.ops.tencent_trace import client as client_module +from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version +from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData + +metric_reader_instances: list[DummyMetricReader] = [] +meter_provider_instances: list[DummyMeterProvider] = [] + + +class DummyHistogram: + """Placeholder histogram type used by the stubbed metric stack.""" + + +class AggregationTemporality: + DELTA = "delta" + + +class DummyMeter: + def __init__(self) -> None: + self.created: list[tuple[dict[str, object], MagicMock]] = [] + + def create_histogram(self, **kwargs: object) -> MagicMock: + hist = MagicMock(name=f"hist-{kwargs.get('name')}") + self.created.append((kwargs, hist)) + return hist + + +class DummyMeterProvider: + def __init__(self, resource: object, metric_readers: list[object]) -> None: + self.resource = resource + self.metric_readers = metric_readers + self.meter = DummyMeter() + self.shutdown = MagicMock(name="meter_provider_shutdown") + meter_provider_instances.append(self) + + def get_meter(self, name: str, version: str) -> DummyMeter: + return self.meter + + +class DummyMetricReader: + def __init__(self, exporter: object, export_interval_millis: int) -> None: + self.exporter = exporter + self.export_interval_millis = export_interval_millis + self.shutdown = MagicMock(name="metric_reader_shutdown") + metric_reader_instances.append(self) + + +class DummyGrpcMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyHttpMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyJsonMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyJsonMetricExporterNoTemporality: + """Exporter that rejects preferred_temporality to exercise fallback.""" + + def __init__(self, **kwargs: object) -> None: + if "preferred_temporality" in kwargs: + raise RuntimeError("unsupported preferred_temporality") + self.kwargs = kwargs + + +def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: + """Drop fake metric modules into sys.modules so the client imports resolve.""" + + metrics_module = types.ModuleType("opentelemetry.sdk.metrics") + metrics_module.Histogram = DummyHistogram + metrics_module.MeterProvider = DummyMeterProvider + monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics", metrics_module) + + metrics_export_module = types.ModuleType("opentelemetry.sdk.metrics.export") + metrics_export_module.AggregationTemporality = AggregationTemporality + metrics_export_module.PeriodicExportingMetricReader = DummyMetricReader + monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics.export", metrics_export_module) + + grpc_module = types.ModuleType("opentelemetry.exporter.otlp.proto.grpc.metric_exporter") + grpc_module.OTLPMetricExporter = DummyGrpcMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.grpc.metric_exporter", grpc_module) + + http_module = types.ModuleType("opentelemetry.exporter.otlp.proto.http.metric_exporter") + http_module.OTLPMetricExporter = DummyHttpMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.http.metric_exporter", http_module) + + http_json_module = types.ModuleType("opentelemetry.exporter.otlp.http.json.metric_exporter") + http_json_module.OTLPMetricExporter = DummyJsonMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.http.json.metric_exporter", http_json_module) + + legacy_json_module = types.ModuleType("opentelemetry.exporter.otlp.json.metric_exporter") + legacy_json_module.OTLPMetricExporter = DummyJsonMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.json.metric_exporter", legacy_json_module) + + +@pytest.fixture(autouse=True) +def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None: + metric_reader_instances.clear() + meter_provider_instances.clear() + _add_stub_modules(monkeypatch) + + +@pytest.fixture(autouse=True) +def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: + span_exporter = MagicMock(name="span_exporter") + monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) + + span_processor = MagicMock(name="span_processor") + monkeypatch.setattr(client_module, "BatchSpanProcessor", MagicMock(return_value=span_processor)) + + tracer = MagicMock(name="tracer") + span = MagicMock(name="span") + tracer.start_span.return_value = span + + tracer_provider = MagicMock(name="tracer_provider") + tracer_provider.get_tracer.return_value = tracer + tracer_provider.shutdown = MagicMock(name="tracer_provider_shutdown") + monkeypatch.setattr(client_module, "TracerProvider", MagicMock(return_value=tracer_provider)) + + resource = MagicMock(name="resource") + monkeypatch.setattr(client_module, "Resource", MagicMock(return_value=resource)) + + logger_mock = MagicMock(name="tencent_logger") + monkeypatch.setattr(client_module, "logger", logger_mock) + + trace_api_stub = SimpleNamespace( + set_span_in_context=MagicMock(name="set_span_in_context", return_value="trace-context"), + NonRecordingSpan=MagicMock(name="non_recording_span", side_effect=lambda ctx: f"non-{ctx}"), + ) + monkeypatch.setattr(client_module, "trace_api", trace_api_stub) + + fake_config = SimpleNamespace( + project=SimpleNamespace(version="test"), + COMMIT_SHA="sha", + DEPLOY_ENV="dev", + EDITION="cloud", + ) + monkeypatch.setattr(client_module, "dify_config", fake_config) + + monkeypatch.setattr(client_module.socket, "gethostname", lambda: "fake-host") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "") + + return { + "span_exporter": span_exporter, + "span_processor": span_processor, + "tracer": tracer, + "span": span, + "tracer_provider": tracer_provider, + "logger": logger_mock, + "trace_api": trace_api_stub, + } + + +def _build_client() -> TencentTraceClient: + return TencentTraceClient( + service_name="service", + endpoint="https://trace.example.com:4317", + token="token", + ) + + +def test_get_opentelemetry_sdk_version_reads_install(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(client_module, "version", lambda pkg: "2.0.0") + assert _get_opentelemetry_sdk_version() == "2.0.0" + + +def test_get_opentelemetry_sdk_version_falls_back(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(client_module, "version", MagicMock(side_effect=RuntimeError("boom"))) + assert _get_opentelemetry_sdk_version() == "1.27.0" + + +@pytest.mark.parametrize( + ("endpoint", "expected"), + [ + ( + "https://example.com:9090", + ("example.com:9090", False, "example.com", 9090), + ), + ( + "http://localhost", + ("localhost:4317", True, "localhost", 4317), + ), + ( + "example.com:bad", + ("example.com:4317", False, "example.com", 4317), + ), + ], +) +def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[str, bool, str, int]) -> None: + assert TencentTraceClient._resolve_grpc_target(endpoint) == expected + + +def test_resolve_grpc_target_handles_errors() -> None: + assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) + + +@pytest.mark.parametrize( + ("method", "attr_name", "args"), + [ + ("record_llm_duration", "hist_llm_duration", (0.3, {"foo": object()})), + ("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")), + ("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")), + ("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")), + ("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})), + ], +) +def test_record_methods_call_histograms(method: str, attr_name: str, args: tuple[object, ...]) -> None: + client = _build_client() + hist_mock = MagicMock(name=attr_name) + setattr(client, attr_name, hist_mock) + + getattr(client, method)(*args) + hist_mock.record.assert_called_once() + + +def test_record_methods_skip_when_histogram_missing() -> None: + client = _build_client() + client.hist_llm_duration = None + client.record_llm_duration(0.1) + + client.hist_token_usage = None + client.record_token_usage(1, "go", "chat", "model", "model", "addr", "provider") + + client.hist_time_to_first_token = None + client.record_time_to_first_token(0.2, "prov", "model") + + client.hist_time_to_generate = None + client.record_time_to_generate(0.3, "prov", "model") + + client.hist_trace_duration = None + client.record_trace_duration(0.5) + + +def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: + client = _build_client() + client.hist_llm_duration = MagicMock(name="hist_llm_duration") + client.hist_llm_duration.record.side_effect = RuntimeError("boom") + + client.record_llm_duration(0.2) + logger = patch_core_components["logger"] + logger.debug.assert_called() + + +def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + + data = SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={"key": "value"}, + events=[Event(name="evt", attributes={"k": "v"}, timestamp=123)], + status=Status(StatusCode.OK), + start_time=10, + end_time=20, + ) + + client._create_and_export_span(data) + span.set_attributes.assert_called_once() + span.add_event.assert_called_once() + span.set_status.assert_called_once() + span.end.assert_called_once_with(end_time=20) + assert client.span_contexts[2] == "ctx" + + +def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: + client = _build_client() + client.span_contexts[10] = "existing" + span = patch_core_components["span"] + span.get_span_context.return_value = "child" + + data = SpanData( + trace_id=1, + parent_span_id=10, + span_id=11, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + + client._create_and_export_span(data) + trace_api = patch_core_components["trace_api"] + trace_api.NonRecordingSpan.assert_called_once_with("existing") + trace_api.set_span_in_context.assert_called_once() + + +def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + client.tracer.start_span.side_effect = RuntimeError("boom") + + client._create_and_export_span( + SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + ) + logger = patch_core_components["logger"] + logger.exception.assert_called_once() + + +def test_api_check_connects_successfully(monkeypatch: pytest.MonkeyPatch) -> None: + client = _build_client() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("host:123", False, "host", 123)), + ) + + socket_mock = MagicMock() + socket_instance = MagicMock() + socket_instance.connect_ex.return_value = 0 + socket_mock.return_value = socket_instance + monkeypatch.setattr(client_module.socket, "socket", socket_mock) + + assert client.api_check() + socket_instance.connect_ex.assert_called_once() + + +def test_api_check_returns_false_and_handles_local(monkeypatch: pytest.MonkeyPatch) -> None: + client = _build_client() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("host:123", False, "host", 123)), + ) + + socket_mock = MagicMock() + socket_instance = MagicMock() + socket_instance.connect_ex.return_value = 1 + socket_mock.return_value = socket_instance + monkeypatch.setattr(client_module.socket, "socket", socket_mock) + + assert not client.api_check() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("localhost:4317", True, "localhost", 4317)), + ) + socket_instance.connect_ex.return_value = 1 + assert client.api_check() + + +def test_api_check_handles_exceptions(monkeypatch: pytest.MonkeyPatch) -> None: + client = TencentTraceClient("svc", "https://localhost", "token") + + monkeypatch.setattr(client_module.socket, "socket", MagicMock(side_effect=RuntimeError("boom"))) + assert client.api_check() + + +def test_get_project_url() -> None: + client = _build_client() + assert client.get_project_url() == "https://console.cloud.tencent.com/apm" + + +def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span_processor = patch_core_components["span_processor"] + tracer_provider = patch_core_components["tracer_provider"] + + client.shutdown() + span_processor.force_flush.assert_called_once() + span_processor.shutdown.assert_called_once() + tracer_provider.shutdown.assert_called_once() + + meter_provider = meter_provider_instances[-1] + metric_reader = metric_reader_instances[-1] + meter_provider.shutdown.assert_called_once() + metric_reader.shutdown.assert_called_once() + + +def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: + client = _build_client() + meter_provider = meter_provider_instances[-1] + meter_provider.shutdown.side_effect = RuntimeError("boom") + client.metric_reader.shutdown.side_effect = RuntimeError("boom") + + client.shutdown() + logger = patch_core_components["logger"] + logger.debug.assert_any_call( + "[Tencent APM] Error shutting down meter provider", + exc_info=True, + ) + logger.debug.assert_any_call( + "[Tencent APM] Error shutting down metric reader", + exc_info=True, + ) + + +def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(DummyMeterProvider, "__init__", MagicMock(side_effect=RuntimeError("err"))) + client = _build_client() + + assert client.meter is None + assert client.meter_provider is None + assert client.hist_llm_duration is None + assert client.hist_token_usage is None + assert client.hist_time_to_first_token is None + assert client.hist_time_to_generate is None + assert client.hist_trace_duration is None + assert client.metric_reader is None + + +def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: + client = _build_client() + monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) + + client.add_span( + SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + ) + + logger = patch_core_components["logger"] + logger.exception.assert_called_once() + + +def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + + data = SpanData.model_construct( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={"num": 5, "flag": True, "pi": 3.14, "text": "value"}, + events=[], + links=[], + status=Status(StatusCode.OK), + start_time=0, + end_time=1, + ) + + client._create_and_export_span(data) + (attrs,) = span.set_attributes.call_args.args + assert attrs["num"] == 5 + assert attrs["flag"] is True + assert attrs["pi"] == 3.14 + assert attrs["text"] == "value" + + +def test_record_llm_duration_converts_attributes() -> None: + client = _build_client() + hist_mock = MagicMock(name="hist_llm_duration") + client.hist_llm_duration = hist_mock + + client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) + _, attrs = hist_mock.record.call_args.args + assert isinstance(attrs["foo"], str) + assert attrs["bar"] == 2 + + +def test_record_trace_duration_converts_attributes() -> None: + client = _build_client() + hist_mock = MagicMock(name="hist_trace_duration") + client.hist_trace_duration = hist_mock + + client.record_trace_duration(1.0, {"meta": object(), "ok": True}) + _, attrs = hist_mock.record.call_args.args + assert isinstance(attrs["meta"], str) + assert attrs["ok"] is True + + +@pytest.mark.parametrize( + ("method", "attr_name", "args"), + [ + ("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")), + ("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")), + ("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")), + ("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})), + ], +) +def test_record_methods_handle_exceptions( + method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] +) -> None: + client = _build_client() + hist_mock = MagicMock(name=attr_name) + hist_mock.record.side_effect = RuntimeError("boom") + setattr(client, attr_name, hist_mock) + + getattr(client, method)(*args) + logger = patch_core_components["logger"] + logger.debug.assert_called() + + +def test_metrics_initializes_grpc_metric_exporter() -> None: + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" + assert metric_reader.exporter.kwargs["insecure"] is False + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + + +def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + + +def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert "preferred_temporality" in metric_reader.exporter.kwargs + + +def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + exporter_module = sys.modules["opentelemetry.exporter.otlp.http.json.metric_exporter"] + monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) + _ = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) + assert "preferred_temporality" not in metric_reader.exporter.kwargs + + +def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + + def _fail_import(mod_path: str) -> types.ModuleType: + raise ModuleNotFoundError(mod_path) + + monkeypatch.setattr(client_module.importlib, "import_module", _fail_import) + + _ = _build_client() + metric_reader = metric_reader_instances[-1] + assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py new file mode 100644 index 0000000000..a0b6d52720 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -0,0 +1,359 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.tencent_trace.entities.semconv import ( + GEN_AI_IS_ENTRY, + GEN_AI_IS_STREAMING_REQUEST, + GEN_AI_MODEL_NAME, + GEN_AI_SPAN_KIND, + GEN_AI_USAGE_INPUT_TOKENS, + INPUT_VALUE, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) +from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from core.rag.models.document import Document +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus + + +class TestTencentSpanBuilder: + def test_get_time_nanoseconds(self): + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + mock_convert.return_value = 123456789 + dt = datetime.now() + result = TencentSpanBuilder._get_time_nanoseconds(dt) + assert result == 123456789 + mock_convert.assert_called_once_with(dt) + + def test_build_workflow_spans(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run_id" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.workflow_run_inputs = {"sys.query": "hello"} + trace_info.workflow_run_outputs = {"answer": "world"} + trace_info.metadata = {"conversation_id": "conv_id"} + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") + + assert len(spans) == 2 + assert spans[0].name == "message" + assert spans[0].span_id == 2 + assert spans[1].name == "workflow" + assert spans[1].span_id == 1 + assert spans[1].parent_span_id == 2 + + def test_build_workflow_spans_no_message(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run_id" + trace_info.error = "some error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.workflow_run_inputs = {} + trace_info.workflow_run_outputs = {} + trace_info.metadata = {} # No conversation_id + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 1 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") + + assert len(spans) == 1 + assert spans[0].name == "workflow" + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].status.description == "some error" + assert spans[0].attributes[GEN_AI_IS_ENTRY] == "true" + + def test_build_workflow_llm_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.process_data = { + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "time_to_first_token": 0.5}, + "prompts": ["hello"], + } + node_execution.outputs = {"text": "world"} + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 456 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) + + assert span.name == "GENERATION" + assert span.attributes[GEN_AI_MODEL_NAME] == "gpt-4" + assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true" + assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "10" + + def test_build_workflow_llm_span_usage_in_outputs(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.process_data = {} + node_execution.outputs = { + "text": "world", + "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, + } + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 456 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) + + assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "15" + assert GEN_AI_IS_STREAMING_REQUEST not in span.attributes + + def test_build_message_span_standalone(self): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg_id" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.inputs = {"q": "hi"} + trace_info.outputs = "hello" + trace_info.metadata = {"conversation_id": "conv_id"} + trace_info.is_streaming_request = True + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 789 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") + + assert span.name == "message" + assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true" + assert span.attributes[INPUT_VALUE] == str(trace_info.inputs) + + def test_build_message_span_standalone_with_error(self): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg_id" + trace_info.error = "some error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.inputs = None + trace_info.outputs = None + trace_info.metadata = {} + trace_info.is_streaming_request = False + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 789 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") + + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "some error" + assert span.attributes[INPUT_VALUE] == "" + + def test_build_tool_span(self): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg_id" + trace_info.tool_name = "search" + trace_info.error = "tool error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.tool_parameters = {"p": 1} + trace_info.tool_inputs = {"i": 2} + trace_info.tool_outputs = "result" + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 101 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) + + assert span.name == "search" + assert span.status.status_code == StatusCode.ERROR + assert span.attributes[TOOL_NAME] == "search" + + def test_build_retrieval_span(self): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg_id" + trace_info.inputs = "query" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + + doc = Document( + page_content="content", metadata={"dataset_id": "d1", "doc_id": "di1", "document_id": "du1", "score": 0.9} + ) + trace_info.documents = [doc] + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 202 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) + + assert span.name == "retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "query" + assert "content" in span.attributes[RETRIEVAL_DOCUMENT] + + def test_build_retrieval_span_with_error(self): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg_id" + trace_info.inputs = "" + trace_info.error = "retrieval failed" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.documents = [] + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 202 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) + + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "retrieval failed" + + def test_get_workflow_node_status(self): + node = MagicMock(spec=WorkflowNodeExecution) + + node.status = WorkflowNodeExecutionStatus.SUCCEEDED + assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.OK + + node.status = WorkflowNodeExecutionStatus.FAILED + node.error = "fail" + status = TencentSpanBuilder._get_workflow_node_status(node) + assert status.status_code == StatusCode.ERROR + assert status.description == "fail" + + node.status = WorkflowNodeExecutionStatus.EXCEPTION + node.error = "exc" + status = TencentSpanBuilder._get_workflow_node_status(node) + assert status.status_code == StatusCode.ERROR + assert status.description == "exc" + + node.status = WorkflowNodeExecutionStatus.RUNNING + assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.UNSET + + def test_build_workflow_retrieval_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my retrieval" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {"query": "q1"} + node_execution.outputs = {"result": [{"content": "c1"}]} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 303 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) + + assert span.name == "my retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "q1" + assert "c1" in span.attributes[RETRIEVAL_DOCUMENT] + + def test_build_workflow_retrieval_span_empty(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my retrieval" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {} + node_execution.outputs = {} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 303 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) + + assert span.attributes[RETRIEVAL_QUERY] == "" + assert span.attributes[RETRIEVAL_DOCUMENT] == "" + + def test_build_workflow_tool_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my tool" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"info": "some"}} + node_execution.inputs = {"param": "val"} + node_execution.outputs = {"res": "ok"} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 404 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) + + assert span.name == "my tool" + assert span.attributes[TOOL_NAME] == "my tool" + assert "some" in span.attributes[TOOL_DESCRIPTION] + + def test_build_workflow_tool_span_no_metadata(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my tool" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.metadata = None + node_execution.inputs = None + node_execution.outputs = {"res": "ok"} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 404 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) + + assert span.attributes[TOOL_DESCRIPTION] == "{}" + assert span.attributes[TOOL_PARAMETERS] == "{}" + + def test_build_workflow_task_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my task" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {"in": 1} + node_execution.outputs = {"out": 2} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 505 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) + + assert span.name == "my task" + assert span.attributes[GEN_AI_SPAN_KIND] == GenAISpanKind.TASK.value diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py new file mode 100644 index 0000000000..077a92d866 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -0,0 +1,647 @@ +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import TencentConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.tencent_trace.tencent_trace import TencentDataTrace +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import NodeType +from models import Account, App, TenantAccountJoin + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def tencent_config(): + return TencentConfig(service_name="test-service", endpoint="https://test-endpoint", token="test-token") + + +@pytest.fixture +def mock_trace_client(): + with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + yield mock + + +@pytest.fixture +def mock_span_builder(): + with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + yield mock + + +@pytest.fixture +def mock_trace_utils(): + with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + yield mock + + +@pytest.fixture +def tencent_data_trace(tencent_config, mock_trace_client): + return TencentDataTrace(tencent_config) + + +class TestTencentDataTrace: + def test_init(self, tencent_config, mock_trace_client): + trace = TencentDataTrace(tencent_config) + mock_trace_client.assert_called_once_with( + service_name=tencent_config.service_name, + endpoint=tencent_config.endpoint, + token=tencent_config.token, + metrics_export_interval_sec=5, + ) + assert trace.trace_client == mock_trace_client.return_value + + def test_trace_dispatch(self, tencent_data_trace): + methods = [ + ( + WorkflowTraceInfo( + workflow_id="wf", + tenant_id="t", + workflow_run_id="run", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="v", + total_tokens=0, + file_list=[], + query="", + metadata={}, + ), + "workflow_trace", + ), + ( + MessageTraceInfo( + message_id="msg", + message_data={}, + inputs={}, + outputs={}, + start_time=None, + end_time=None, + conversation_mode="chat", + conversation_model="gpt-3.5-turbo", + message_tokens=0, + answer_tokens=0, + total_tokens=0, + metadata={}, + ), + "message_trace", + ), + ( + ModerationTraceInfo( + flagged=False, action="a", preset_response="p", query="q", metadata={}, message_id="m" + ), + None, + ), # Pass + ( + SuggestedQuestionTraceInfo( + suggested_question=[], + level="l", + total_tokens=0, + metadata={}, + message_id="m", + message_data={}, + inputs={}, + start_time=None, + end_time=None, + ), + "suggested_question_trace", + ), + ( + DatasetRetrievalTraceInfo( + metadata={}, + message_id="m", + message_data={}, + inputs={}, + documents=[], + start_time=None, + end_time=None, + ), + "dataset_retrieval_trace", + ), + ( + ToolTraceInfo( + tool_name="t", + tool_inputs={}, + tool_outputs="", + tool_config={}, + tool_parameters={}, + time_cost=0, + metadata={}, + message_id="m", + inputs={}, + outputs={}, + start_time=None, + end_time=None, + ), + "tool_trace", + ), + ( + GenerateNameTraceInfo( + tenant_id="t", metadata={}, message_id="m", inputs={}, outputs={}, start_time=None, end_time=None + ), + None, + ), # Pass + ] + + for trace_info, method_name in methods: + if method_name: + with patch.object(tencent_data_trace, method_name) as mock_method: + tencent_data_trace.trace(trace_info) + mock_method.assert_called_once_with(trace_info) + else: + tencent_data_trace.trace(trace_info) + + def test_api_check(self, tencent_data_trace): + tencent_data_trace.trace_client.api_check.return_value = True + assert tencent_data_trace.api_check() is True + tencent_data_trace.trace_client.api_check.assert_called_once() + + def test_get_project_url(self, tencent_data_trace): + tencent_data_trace.trace_client.get_project_url.return_value = "http://url" + assert tencent_data_trace.get_project_url() == "http://url" + tencent_data_trace.trace_client.get_project_url.assert_called_once() + + def test_workflow_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + trace_info.trace_id = "parent-trace-id" + + mock_trace_utils.convert_to_trace_id.return_value = 123 + mock_trace_utils.create_link.return_value = "link" + + with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): + with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc: + with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur: + mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()] + + tencent_data_trace.workflow_trace(trace_info) + + mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_workflow_spans.assert_called_once() + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_proc.assert_called_once_with(trace_info, 123) + mock_dur.assert_called_once_with(trace_info) + + def test_workflow_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.workflow_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") + + def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg-id" + trace_info.trace_id = "parent-trace-id" + + mock_trace_utils.convert_to_trace_id.return_value = 123 + mock_trace_utils.create_link.return_value = "link" + + with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): + with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics: + with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur: + mock_span_builder.build_message_span.return_value = MagicMock() + + tencent_data_trace.message_trace(trace_info) + + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_message_span.assert_called_once() + tencent_data_trace.trace_client.add_span.assert_called_once() + mock_metrics.assert_called_once_with(trace_info) + mock_dur.assert_called_once_with(trace_info) + + def test_message_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.message_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") + + def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg-id" + + mock_trace_utils.convert_to_span_id.return_value = 456 + mock_trace_utils.convert_to_trace_id.return_value = 123 + + tencent_data_trace.tool_trace(trace_info) + + mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message") + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_span_builder.build_tool_span.assert_called_once_with(trace_info, 123, 456) + tencent_data_trace.trace_client.add_span.assert_called_once() + + def test_tool_trace_no_msg_id(self, tencent_data_trace): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = None + + tencent_data_trace.tool_trace(trace_info) + tencent_data_trace.trace_client.add_span.assert_not_called() + + def test_tool_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.tool_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") + + def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg-id" + + mock_trace_utils.convert_to_span_id.return_value = 456 + mock_trace_utils.convert_to_trace_id.return_value = 123 + + tencent_data_trace.dataset_retrieval_trace(trace_info) + + mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message") + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_span_builder.build_retrieval_span.assert_called_once_with(trace_info, 123, 456) + tencent_data_trace.trace_client.add_span.assert_called_once() + + def test_dataset_retrieval_trace_no_msg_id(self, tencent_data_trace): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = None + + tencent_data_trace.dataset_retrieval_trace(trace_info) + tencent_data_trace.trace_client.add_span.assert_not_called() + + def test_dataset_retrieval_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.dataset_retrieval_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") + + def test_suggested_question_trace(self, tencent_data_trace): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + tencent_data_trace.suggested_question_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") + + def test_suggested_question_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.suggested_question_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") + + def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + mock_trace_utils.convert_to_span_id.return_value = 111 + + node1 = MagicMock(spec=WorkflowNodeExecution) + node1.id = "n1" + node1.node_type = NodeType.LLM + node2 = MagicMock(spec=WorkflowNodeExecution) + node2.id = "n2" + node2.node_type = NodeType.TOOL + + with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]): + with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]): + with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_metrics.assert_called_once_with(node1) + + def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + mock_trace_utils.convert_to_span_id.return_value = 111 + + node = MagicMock(spec=WorkflowNodeExecution) + node.id = "n1" + + with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): + with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + # The exception should be caught by the outer handler since convert_to_span_id is called first + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + + def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + + def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + nodes = [ + (NodeType.LLM, mock_span_builder.build_workflow_llm_span), + (NodeType.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span), + (NodeType.TOOL, mock_span_builder.build_workflow_tool_span), + (NodeType.CODE, mock_span_builder.build_workflow_task_span), + ] + + for node_type, builder_method in nodes: + node = MagicMock(spec=WorkflowNodeExecution) + node.node_type = node_type + builder_method.return_value = "span" + + result = tencent_data_trace._build_workflow_node_span(node, 123, trace_info, 456) + + assert result == "span" + builder_method.assert_called_once_with(123, 456, trace_info, node) + + def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder): + node = MagicMock(spec=WorkflowNodeExecution) + node.node_type = NodeType.LLM + node.id = "n1" + mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) + assert result is None + mock_log.assert_called_once() + + def test_get_workflow_node_executions(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"app_id": "app-1"} + trace_info.workflow_run_id = "run-1" + + app = MagicMock(spec=App) + app.id = "app-1" + app.created_by = "user-1" + + account = MagicMock(spec=Account) + account.id = "user-1" + + tenant_join = MagicMock(spec=TenantAccountJoin) + tenant_join.tenant_id = "tenant-1" + + mock_executions = [MagicMock()] + + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.engine = "engine" + with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + session = mock_session_ctx.return_value.__enter__.return_value + session.scalar.side_effect = [app, account] + session.query.return_value.filter_by.return_value.first.return_value = tenant_join + + with patch( + "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" + ) as mock_repo: + mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + + results = tencent_data_trace._get_workflow_node_executions(trace_info) + + assert results == mock_executions + account.set_tenant_id.assert_called_once_with("tenant-1") + + def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + results = tencent_data_trace._get_workflow_node_executions(trace_info) + assert results == [] + mock_log.assert_called_once() + + def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"app_id": "app-1"} + + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.init_app = MagicMock() # Ensure init_app is mocked + mock_db.engine = "engine" + with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + session = mock_session_ctx.return_value.__enter__.return_value + session.scalar.return_value = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + results = tencent_data_trace._get_workflow_node_executions(trace_info) + assert results == [] + mock_log.assert_called_once() + + def test_get_user_id_workflow(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.tenant_id = "tenant-1" + trace_info.metadata = {"user_id": "user-1"} + + with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.init_app = MagicMock() + mock_db.engine = MagicMock() + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" + + def test_get_user_id_only_user_id(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {"user_id": "user-1"} + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "user-1" + + def test_get_user_id_anonymous(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {} + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "anonymous" + + def test_get_user_id_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.tenant_id = "t" + trace_info.metadata = {"user_id": "u"} + + with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" + mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") + + def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = { + "usage": { + "latency": 2.5, + "time_to_first_token": 0.5, + "time_to_generate": 2.0, + "prompt_tokens": 10, + "completion_tokens": 20, + }, + "model_provider": "openai", + "model_name": "gpt-4", + "model_mode": "chat", + } + node.outputs = {} + + tencent_data_trace._record_llm_metrics(node) + + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once() + tencent_data_trace.trace_client.record_time_to_generate.assert_called_once() + assert tencent_data_trace.trace_client.record_token_usage.call_count == 2 + + def test_record_llm_metrics_usage_in_outputs(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = {} + node.outputs = {"usage": {"latency": 1.0, "prompt_tokens": 5}} + + tencent_data_trace._record_llm_metrics(node) + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_token_usage.assert_called_once() + + def test_record_llm_metrics_exception(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = None + node.outputs = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_llm_metrics(node) + # Should not crash + + def test_record_message_llm_metrics(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {"ls_provider": "openai", "ls_model_name": "gpt-4"} + trace_info.message_data = {"provider_response_latency": 1.1} + trace_info.is_streaming_request = True + trace_info.gen_ai_server_time_to_first_token = 0.2 + trace_info.llm_streaming_time_to_generate = 0.9 + trace_info.message_tokens = 15 + trace_info.answer_tokens = 25 + + tencent_data_trace._record_message_llm_metrics(trace_info) + + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once() + tencent_data_trace.trace_client.record_time_to_generate.assert_called_once() + assert tencent_data_trace.trace_client.record_token_usage.call_count == 2 + + def test_record_message_llm_metrics_object_data(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {} + msg_data = MagicMock() + msg_data.provider_response_latency = 1.1 + msg_data.model_provider = "anthropic" + msg_data.model_id = "claude" + trace_info.message_data = msg_data + trace_info.is_streaming_request = False + + tencent_data_trace._record_message_llm_metrics(trace_info) + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + + def test_record_message_llm_metrics_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_message_llm_metrics(trace_info) + # Should not crash + + def test_record_workflow_trace_duration(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + from datetime import datetime, timedelta + + now = datetime.now() + trace_info.start_time = now + trace_info.end_time = now + timedelta(seconds=3) + trace_info.workflow_run_status = "succeeded" + trace_info.conversation_id = "conv-1" + + # Mock the record_trace_duration method to capture arguments + with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record: + tencent_data_trace._record_workflow_trace_duration(trace_info) + + # Assert the method was called once + mock_record.assert_called_once() + + # Extract arguments passed to the method + args, kwargs = mock_record.call_args + + # Validate the duration argument + assert args[0] == 3.0 + + # Validate the attributes dict in kwargs + attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} + assert attributes["conversation_mode"] == "workflow" + assert attributes["has_conversation"] == "true" + + def test_record_workflow_trace_duration_fallback(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.start_time = None + trace_info.workflow_run_elapsed_time = 4.5 + trace_info.workflow_run_status = "failed" + trace_info.conversation_id = None + + with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record: + tencent_data_trace._record_workflow_trace_duration(trace_info) + mock_record.assert_called_once() + args, kwargs = mock_record.call_args + assert args[0] == 4.5 + # Check attributes dict (either in kwargs or as second positional arg) + attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} + assert attributes["has_conversation"] == "false" + + def test_record_workflow_trace_duration_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_workflow_trace_duration(trace_info) + + def test_record_message_trace_duration(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + from datetime import datetime, timedelta + + now = datetime.now() + trace_info.start_time = now + trace_info.end_time = now + timedelta(seconds=2) + trace_info.conversation_mode = "chat" + trace_info.is_streaming_request = True + + tencent_data_trace._record_message_trace_duration(trace_info) + tencent_data_trace.trace_client.record_trace_duration.assert_called_once_with( + 2.0, {"conversation_mode": "chat", "stream": "true"} + ) + + def test_record_message_trace_duration_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.start_time = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_message_trace_duration(trace_info) + + def test_del(self, tencent_data_trace): + client = tencent_data_trace.trace_client + tencent_data_trace.__del__() + client.shutdown.assert_called_once() + + def test_del_exception(self, tencent_data_trace): + tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.__del__() + mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py new file mode 100644 index 0000000000..ef28d18e20 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py @@ -0,0 +1,106 @@ +"""Unit tests for Tencent APM tracing utilities.""" + +from __future__ import annotations + +import hashlib +import uuid +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest +from opentelemetry.trace import Link, TraceFlags + +from core.ops.tencent_trace.utils import TencentTraceUtils + + +def test_convert_to_trace_id_with_valid_uuid() -> None: + uuid_str = "12345678-1234-5678-1234-567812345678" + assert TencentTraceUtils.convert_to_trace_id(uuid_str) == uuid.UUID(uuid_str).int + + +def test_convert_to_trace_id_uses_uuid4_when_none() -> None: + expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int + uuid4_mock.assert_called_once() + + +def test_convert_to_trace_id_raises_value_error_for_invalid_uuid() -> None: + with pytest.raises(ValueError, match=r"^Invalid UUID input:"): + TencentTraceUtils.convert_to_trace_id("not-a-uuid") + + +def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: + uuid_str = "12345678-1234-5678-1234-567812345678" + span_type = "llm" + + uuid_obj = uuid.UUID(uuid_str) + combined_key = f"{uuid_obj.hex}-{span_type}" + hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest() + expected = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) + + assert TencentTraceUtils.convert_to_span_id(uuid_str, span_type) == expected + assert TencentTraceUtils.convert_to_span_id(uuid_str, "other") != expected + + +def test_convert_to_span_id_uses_uuid4_when_none() -> None: + expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") + assert isinstance(span_id, int) + uuid4_mock.assert_called_once() + + +def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: + with pytest.raises(ValueError, match=r"^Invalid UUID input:"): + TencentTraceUtils.convert_to_span_id("bad-uuid", "span") + + +def test_generate_span_id_skips_invalid_span_id() -> None: + with patch( + "core.ops.tencent_trace.utils.random.getrandbits", + side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], + ) as bits_mock: + assert TencentTraceUtils.generate_span_id() == 42 + assert bits_mock.call_count == 2 + + +def test_convert_datetime_to_nanoseconds_accepts_datetime() -> None: + start_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + expected = int(start_time.timestamp() * 1e9) + assert TencentTraceUtils.convert_datetime_to_nanoseconds(start_time) == expected + + +def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: + fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) + expected = int(fixed.timestamp() * 1e9) + + with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + datetime_mock.now.return_value = fixed + assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected + datetime_mock.now.assert_called_once() + + +@pytest.mark.parametrize( + ("trace_id_str", "expected_trace_id"), + [ + ("0" * 31 + "1", int("0" * 31 + "1", 16)), + (str(uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc")), uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc").int), + ], +) +def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: int) -> None: + link = TencentTraceUtils.create_link(trace_id_str) + assert isinstance(link, Link) + assert link.context.trace_id == expected_trace_id + assert link.context.span_id == TencentTraceUtils.INVALID_SPAN_ID + assert link.context.is_remote is False + assert link.context.trace_flags == TraceFlags(TraceFlags.SAMPLED) + + +@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) +def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: + fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] + assert link.context.trace_id == fallback_uuid.int + uuid4_mock.assert_called_once() diff --git a/api/tests/unit_tests/core/ops/test_base_trace_instance.py b/api/tests/unit_tests/core/ops/test_base_trace_instance.py new file mode 100644 index 0000000000..a8bee7dfa7 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_base_trace_instance.py @@ -0,0 +1,112 @@ +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.orm import Session + +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.entities.trace_entity import BaseTraceInfo +from models import Account, App, TenantAccountJoin + + +class ConcreteTraceInstance(BaseTraceInstance): + def __init__(self, trace_config: BaseTracingConfig): + super().__init__(trace_config) + + def trace(self, trace_info: BaseTraceInfo): + super().trace(trace_info) + + +@pytest.fixture +def mock_db_session(monkeypatch): + mock_session = MagicMock(spec=Session) + mock_session.__enter__.return_value = mock_session + mock_session.__exit__.return_value = None + + mock_session_class = MagicMock(return_value=mock_session) + + monkeypatch.setattr("core.ops.base_trace_instance.Session", mock_session_class) + monkeypatch.setattr("core.ops.base_trace_instance.db", MagicMock()) + return mock_session + + +def test_get_service_account_with_tenant_app_not_found(mock_db_session): + mock_db_session.scalar.return_value = None + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="App with id some_app_id not found"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_no_creator(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = None + mock_db_session.scalar.return_value = mock_app + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="App with id some_app_id has no creator"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_creator_not_found(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + # First call to scalar returns app, second returns None (for account) + mock_db_session.scalar.side_effect = [mock_app, None] + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="Creator account with id creator_id not found for app some_app_id"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_tenant_not_found(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + mock_account = MagicMock(spec=Account) + mock_account.id = "creator_id" + + mock_db_session.scalar.side_effect = [mock_app, mock_account] + + # session.query(TenantAccountJoin).filter_by(...).first() returns None + mock_db_session.query.return_value.filter_by.return_value.first.return_value = None + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="Current tenant not found for account creator_id"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_success(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + mock_account = MagicMock(spec=Account) + mock_account.id = "creator_id" + mock_account.set_tenant_id = MagicMock() + + mock_db_session.scalar.side_effect = [mock_app, mock_account] + + mock_tenant_join = MagicMock(spec=TenantAccountJoin) + mock_tenant_join.tenant_id = "tenant_id" + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + result = instance.get_service_account_with_tenant("some_app_id") + + assert result == mock_account + mock_account.set_tenant_id.assert_called_once_with("tenant_id") diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py new file mode 100644 index 0000000000..2d325ccb0e --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -0,0 +1,576 @@ +import contextlib +import json +import queue +from datetime import datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.ops_trace_manager import ( + OpsTraceManager, + TraceQueueManager, + TraceTask, + TraceTaskName, +) + + +class DummyConfig: + def __init__(self, **kwargs): + self._data = kwargs + + def model_dump(self): + return dict(self._data) + + +class DummyTraceInstance: + instances: list["DummyTraceInstance"] = [] + + def __init__(self, config): + self.config = config + DummyTraceInstance.instances.append(self) + + def api_check(self): + return True + + def get_project_key(self): + return "fake-key" + + def get_project_url(self): + return "https://project.fake" + + +FAKE_PROVIDER_ENTRY = { + "config_class": DummyConfig, + "secret_keys": ["secret_value"], + "other_keys": ["other_value"], + "trace_instance": DummyTraceInstance, +} + + +class FakeProviderMap: + def __init__(self, data): + self._data = data + + def __getitem__(self, key): + if key in self._data: + return self._data[key] + raise KeyError(f"Unsupported tracing provider: {key}") + + +class DummyTimer: + def __init__(self, interval, function): + self.interval = interval + self.function = function + self.name = "" + self.daemon = False + self.started = False + + def start(self): + self.started = True + + def is_alive(self): + return False + + +class FakeMessageFile: + def __init__(self): + self.url = "path/to/file" + self.id = "file-id" + self.type = "document" + self.created_by_role = "role" + self.created_by = "user" + + +def make_message_data(**overrides): + created_at = datetime(2025, 2, 20, 12, 0, 0) + base = { + "id": "msg-id", + "conversation_id": "conv-id", + "created_at": created_at, + "updated_at": created_at + timedelta(seconds=3), + "message": "hello", + "provider_response_latency": 1, + "message_tokens": 5, + "answer_tokens": 7, + "answer": "world", + "error": "", + "status": "complete", + "model_provider": "provider", + "model_id": "model", + "from_end_user_id": "end-user", + "from_account_id": "account", + "agent_based": False, + "workflow_run_id": "workflow-run", + "from_source": "source", + "message_metadata": json.dumps({"usage": {"time_to_first_token": 1, "time_to_generate": 2}}), + "agent_thoughts": [], + "query": "sample-query", + "inputs": "sample-input", + } + base.update(overrides) + + class MessageData: + def __init__(self, data): + self.__dict__.update(data) + + def to_dict(self): + return dict(self.__dict__) + + return MessageData(base) + + +def make_agent_thought(tool_name, created_at): + return SimpleNamespace( + tools=[tool_name], + created_at=created_at, + tool_meta={ + tool_name: { + "tool_config": {"foo": "bar"}, + "time_cost": 5, + "error": "", + "tool_parameters": {"x": 1}, + } + }, + ) + + +def make_workflow_run(): + return SimpleNamespace( + workflow_id="wf-1", + tenant_id="tenant", + id="run-id", + elapsed_time=10, + status="finished", + inputs_dict={"sys.file": ["f1"], "query": "search"}, + outputs_dict={"out": "value"}, + version="3", + error=None, + total_tokens=12, + workflow_run_id="run-id", + created_at=datetime(2025, 2, 20, 10, 0, 0), + finished_at=datetime(2025, 2, 20, 10, 0, 5), + triggered_from="user", + app_id="app-id", + to_dict=lambda self=None: {"run": "value"}, + ) + + +def configure_db_query(session, *, message_file=None, workflow_app_log=None): + def _side_effect(model): + query = MagicMock() + query.filter_by.return_value.first.return_value = None + if message_file and model.__name__ == "MessageFile": + query.filter_by.return_value.first.return_value = message_file + if workflow_app_log and model.__name__ == "WorkflowAppLog": + query.filter_by.return_value.first.return_value = workflow_app_log + return query + + session.query.side_effect = _side_effect + + +class DummySessionContext: + scalar_values = [] + + def __init__(self, engine): + self._values = list(self.scalar_values) + self._index = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def scalar(self, *args, **kwargs): + if self._index >= len(self._values): + return None + value = self._values[self._index] + self._index += 1 + return value + + +@pytest.fixture(autouse=True) +def patch_provider_map(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({"dummy": FAKE_PROVIDER_ENTRY}) + ) + OpsTraceManager.ops_trace_instances_cache.clear() + OpsTraceManager.decrypted_configs_cache.clear() + + +@pytest.fixture(autouse=True) +def patch_timer_and_current_app(monkeypatch): + monkeypatch.setattr("core.ops.ops_trace_manager.threading.Timer", DummyTimer) + monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_queue", queue.Queue()) + monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_timer", None) + + class FakeApp: + def app_context(self): + return contextlib.nullcontext() + + fake_current = MagicMock() + fake_current._get_current_object.return_value = FakeApp() + monkeypatch.setattr("core.ops.ops_trace_manager.current_app", fake_current) + + +@pytest.fixture(autouse=True) +def patch_sqlalchemy_session(monkeypatch): + monkeypatch.setattr("core.ops.ops_trace_manager.Session", DummySessionContext) + + +@pytest.fixture +def encryption_mocks(monkeypatch): + encrypt_mock = MagicMock(side_effect=lambda tenant, value: f"enc-{value}") + batch_decrypt_mock = MagicMock(side_effect=lambda tenant, values: [f"dec-{value}" for value in values]) + obfuscate_mock = MagicMock(side_effect=lambda value: f"ob-{value}") + monkeypatch.setattr("core.ops.ops_trace_manager.encrypt_token", encrypt_mock) + monkeypatch.setattr("core.ops.ops_trace_manager.batch_decrypt_token", batch_decrypt_mock) + monkeypatch.setattr("core.ops.ops_trace_manager.obfuscated_token", obfuscate_mock) + return encrypt_mock, batch_decrypt_mock, obfuscate_mock + + +@pytest.fixture +def mock_db(monkeypatch): + session = MagicMock() + session.scalars.return_value.all.return_value = ["chat"] + db_mock = MagicMock() + db_mock.session = session + db_mock.engine = MagicMock() + monkeypatch.setattr("core.ops.ops_trace_manager.db", db_mock) + return session + + +@pytest.fixture +def workflow_repo_fixture(monkeypatch): + repo = MagicMock() + repo.get_workflow_run_by_id_without_tenant.return_value = make_workflow_run() + monkeypatch.setattr(TraceTask, "_get_workflow_run_repo", classmethod(lambda cls: repo)) + return repo + + +@pytest.fixture +def trace_task_message(monkeypatch, mock_db): + message_data = make_message_data() + monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) + configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + return message_data + + +def test_encrypt_tracing_config_handles_star_and_encrypt(encryption_mocks): + encrypted = OpsTraceManager.encrypt_tracing_config( + "tenant", + "dummy", + {"secret_value": "value", "other_value": "info"}, + current_trace_config={"secret_value": "keep"}, + ) + assert encrypted["secret_value"] == "enc-value" + assert encrypted["other_value"] == "info" + + +def test_encrypt_tracing_config_preserves_star(encryption_mocks): + encrypted = OpsTraceManager.encrypt_tracing_config( + "tenant", + "dummy", + {"secret_value": "*", "other_value": "info"}, + current_trace_config={"secret_value": "keep"}, + ) + assert encrypted["secret_value"] == "keep" + + +def test_decrypt_tracing_config_caches(encryption_mocks): + _, decrypt_mock, _ = encryption_mocks + payload = {"secret_value": "enc", "other_value": "info"} + first = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload) + second = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload) + assert first == second + assert decrypt_mock.call_count == 1 + + +def test_obfuscated_decrypt_token(encryption_mocks): + _, _, obfuscate_mock = encryption_mocks + result = OpsTraceManager.obfuscated_decrypt_token("dummy", {"secret_value": "value", "other_value": "info"}) + assert "secret_value" in result + assert result["secret_value"] == "ob-value" + obfuscate_mock.assert_called_once() + + +def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): + trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + app = SimpleNamespace(id="app-id", tenant_id="tenant") + mock_db.scalar.return_value = app + + decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + assert decrypted["other_value"] == "info" + + +def test_get_decrypted_tracing_config_missing_trace_config(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None + + +def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): + trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + mock_db.scalar.return_value = None + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + + +def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): + trace_config_data = SimpleNamespace(tracing_config=None) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + with pytest.raises(ValueError, match="Tracing config cannot be None"): + OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + + +def test_get_ops_trace_instance_handles_none_app(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) + mock_db.query.return_value.where.return_value.first.return_value = app + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) + mock_db.query.return_value.where.return_value.first.return_value = app + monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_success(monkeypatch, mock_db): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) + mock_db.query.return_value.where.return_value.first.return_value = app + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", + classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), + ) + instance = OpsTraceManager.get_ops_trace_instance("app-id") + assert instance is not None + cached_instance = OpsTraceManager.get_ops_trace_instance("app-id") + assert instance is cached_instance + + +def test_get_app_config_through_message_id_returns_none(mock_db): + mock_db.scalar.return_value = None + assert OpsTraceManager.get_app_config_through_message_id("m") is None + + +def test_get_app_config_through_message_id_prefers_override(mock_db): + message = SimpleNamespace(conversation_id="conv") + conversation = SimpleNamespace(app_model_config_id=None, override_model_configs={"foo": "bar"}) + app_config = SimpleNamespace(id="config-id") + mock_db.scalar.side_effect = [message, conversation] + result = OpsTraceManager.get_app_config_through_message_id("m") + assert result == {"foo": "bar"} + + +def test_get_app_config_through_message_id_app_model_config(mock_db): + message = SimpleNamespace(conversation_id="conv") + conversation = SimpleNamespace(app_model_config_id="cfg", override_model_configs=None) + mock_db.scalar.side_effect = [message, conversation, SimpleNamespace(id="cfg")] + result = OpsTraceManager.get_app_config_through_message_id("m") + assert result.id == "cfg" + + +def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): + mock_db.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="Invalid tracing provider"): + OpsTraceManager.update_app_tracing_config("app", True, "bad") + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.update_app_tracing_config("app", True, None) + + +def test_update_app_tracing_config_success(mock_db): + app = SimpleNamespace(id="app-id", tracing="{}") + mock_db.query.return_value.where.return_value.first.return_value = app + OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") + assert app.tracing is not None + mock_db.commit.assert_called_once() + + +def test_get_app_tracing_config_errors_when_missing(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.get_app_tracing_config("app") + + +def test_get_app_tracing_config_returns_defaults(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} + + +def test_get_app_tracing_config_returns_payload(mock_db): + payload = {"enabled": True, "tracing_provider": "dummy"} + mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + assert OpsTraceManager.get_app_tracing_config("app-id") == payload + + +def test_check_and_project_helpers(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.provider_config_map", + FakeProviderMap( + { + "dummy": { + "config_class": DummyConfig, + "trace_instance": type( + "Trace", + (), + { + "__init__": lambda self, cfg: None, + "api_check": lambda self: True, + "get_project_key": lambda self: "key", + "get_project_url": lambda self: "url", + }, + ), + "secret_keys": [], + "other_keys": [], + } + } + ), + ) + assert OpsTraceManager.check_trace_config_is_effective({}, "dummy") + assert OpsTraceManager.get_trace_config_project_key({}, "dummy") == "key" + assert OpsTraceManager.get_trace_config_project_url({}, "dummy") == "url" + + +def test_trace_task_conversation_and_extract(monkeypatch): + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE, message_id="msg") + assert task.conversation_trace(foo="bar") == {"foo": "bar"} + assert task._extract_streaming_metrics(make_message_data(message_metadata="not json")) == {} + + +def test_trace_task_message_trace(trace_task_message, mock_db): + task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id") + result = task.message_trace("msg-id") + assert result.message_id == "msg-id" + + +def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): + DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] + execution = SimpleNamespace(id_="run-id") + task = TraceTask( + trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" + ) + result = task.workflow_trace(workflow_run_id="run-id", conversation_id="conv", user_id="user") + assert result.workflow_run_id == "run-id" + assert result.workflow_id == "wf-1" + + +def test_trace_task_moderation_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.MODERATION_TRACE, message_id="msg-id") + moderation_result = SimpleNamespace(action="block", preset_response="no", query="q", flagged=True) + timer = {"start": 1, "end": 2} + result = task.moderation_trace("msg-id", timer, moderation_result=moderation_result, inputs={"src": "payload"}) + assert result.flagged is True + assert result.message_id == "log-id" + + +def test_trace_task_suggested_question_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 2} + result = task.suggested_question_trace("msg-id", timer, suggested_question=["q1"]) + assert result.message_id == "log-id" + assert "suggested_question" in result.__dict__ + + +def test_trace_task_dataset_retrieval_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 2} + mock_doc = SimpleNamespace(model_dump=lambda: {"doc": "value"}) + result = task.dataset_retrieval_trace("msg-id", timer, documents=[mock_doc]) + assert result.documents == [{"doc": "value"}] + + +def test_trace_task_tool_trace(monkeypatch, mock_db): + custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) + monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) + configure_db_query(mock_db, message_file=FakeMessageFile()) + task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 5} + result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") + assert result.tool_name == "tool-a" + assert result.time_cost == 5 + + +def test_trace_task_generate_name_trace(): + task = TraceTask(trace_type=TraceTaskName.GENERATE_NAME_TRACE, conversation_id="conv-id") + timer = {"start": 1, "end": 2} + assert task.generate_name_trace("conv-id", timer, tenant_id=None) == {} + result = task.generate_name_trace( + "conv-id", timer, tenant_id="tenant", generate_conversation_name="name", inputs="q" + ) + assert result.outputs == "name" + assert result.tenant_id == "tenant" + + +def test_extract_streaming_metrics_invalid_json(): + task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id") + fake_message = make_message_data(message_metadata="invalid") + assert task._extract_streaming_metrics(fake_message) == {} + + +def test_trace_queue_manager_add_and_collect(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + manager = TraceQueueManager(app_id="app-id", user_id="user") + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE) + manager.add_trace_task(task) + tasks = manager.collect_tasks() + assert tasks == [task] + + +def test_trace_queue_manager_run_invokes_send(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + manager = TraceQueueManager(app_id="app-id", user_id="user") + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE) + called = {} + + def fake_collect(): + return [task] + + def fake_send(tasks): + called["tasks"] = tasks + + monkeypatch.setattr(TraceQueueManager, "collect_tasks", lambda self: fake_collect()) + monkeypatch.setattr(TraceQueueManager, "send_to_celery", lambda self, t: fake_send(t)) + manager.run() + assert called["tasks"] == [task] + + +def test_trace_queue_manager_send_to_celery(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + storage_save = MagicMock() + process_delay = MagicMock() + monkeypatch.setattr("core.ops.ops_trace_manager.storage.save", storage_save) + monkeypatch.setattr("core.ops.ops_trace_manager.process_trace_tasks.delay", process_delay) + monkeypatch.setattr("core.ops.ops_trace_manager.uuid4", MagicMock(return_value=SimpleNamespace(hex="file-123"))) + + manager = TraceQueueManager(app_id="app-id", user_id="user") + + class DummyTraceInfo: + def model_dump(self): + return {"trace": "info"} + + class DummyTask: + def __init__(self): + self.app_id = "app-id" + + def execute(self): + return DummyTraceInfo() + + task = DummyTask() + manager.send_to_celery([task]) + storage_save.assert_called_once() + process_delay.assert_called_once_with({"file_id": "file-123", "app_id": "app-id"}) diff --git a/api/tests/unit_tests/core/ops/test_utils.py b/api/tests/unit_tests/core/ops/test_utils.py index e1084001b7..8a89422782 100644 --- a/api/tests/unit_tests/core/ops/test_utils.py +++ b/api/tests/unit_tests/core/ops/test_utils.py @@ -1,9 +1,20 @@ import re from datetime import datetime +from unittest.mock import MagicMock, patch import pytest -from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import ( + filter_none_values, + generate_dotted_order, + get_message_data, + measure_time, + replace_text_with_content, + validate_integer_id, + validate_project_name, + validate_url, + validate_url_with_path, +) class TestValidateUrl: @@ -187,3 +198,92 @@ class TestGenerateDottedOrder: result = generate_dotted_order(run_id, start_time, None) assert "." not in result + + def test_dotted_order_with_string_start_time(self): + """Test dotted_order generation with string start_time.""" + start_time = "2025-12-23T04:19:55.111000" + run_id = "test-run-id" + result = generate_dotted_order(run_id, start_time) + + assert result == "20251223T041955111000Ztest-run-id" + + +class TestFilterNoneValues: + """Test cases for filter_none_values function""" + + def test_filter_none_values(self): + data = {"a": 1, "b": None, "c": "test", "d": datetime(2025, 1, 1, 12, 0, 0)} + result = filter_none_values(data) + assert result == {"a": 1, "c": "test", "d": "2025-01-01T12:00:00"} + + def test_filter_none_values_empty(self): + assert filter_none_values({}) == {} + + +class TestGetMessageData: + """Test cases for get_message_data function""" + + @patch("core.ops.utils.db") + @patch("core.ops.utils.Message") + @patch("core.ops.utils.select") + def test_get_message_data(self, mock_select, mock_message, mock_db): + mock_scalar = mock_db.session.scalar + mock_msg_instance = MagicMock() + mock_scalar.return_value = mock_msg_instance + + result = get_message_data("message-id") + + assert result == mock_msg_instance + mock_select.assert_called_once() + mock_scalar.assert_called_once() + + +class TestMeasureTime: + """Test cases for measure_time function""" + + def test_measure_time(self): + with measure_time() as timing_info: + assert "start" in timing_info + assert isinstance(timing_info["start"], datetime) + assert timing_info["end"] is None + + assert timing_info["end"] is not None + assert isinstance(timing_info["end"], datetime) + assert timing_info["end"] >= timing_info["start"] + + +class TestReplaceTextWithContent: + """Test cases for replace_text_with_content function""" + + def test_replace_text_with_content_dict(self): + data = {"text": "hello", "other": "world"} + assert replace_text_with_content(data) == {"content": "hello", "other": "world"} + + def test_replace_text_with_content_nested(self): + data = {"text": "v1", "nested": {"text": "v2", "list": [{"text": "v3"}]}} + expected = {"content": "v1", "nested": {"content": "v2", "list": [{"content": "v3"}]}} + assert replace_text_with_content(data) == expected + + def test_replace_text_with_content_list(self): + data = [{"text": "v1"}, "v2"] + assert replace_text_with_content(data) == [{"content": "v1"}, "v2"] + + def test_replace_text_with_content_primitive(self): + assert replace_text_with_content(123) == 123 + assert replace_text_with_content("text") == "text" + + +class TestValidateIntegerId: + """Test cases for validate_integer_id function""" + + def test_valid_integer_id(self): + assert validate_integer_id("123") == "123" + assert validate_integer_id(" 456 ") == "456" + + def test_invalid_integer_id_raises_error(self): + with pytest.raises(ValueError, match="ID must be a valid integer"): + validate_integer_id("abc") + + def test_empty_integer_id_raises_error(self): + with pytest.raises(ValueError, match="ID must be a valid integer"): + validate_integer_id("") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py new file mode 100644 index 0000000000..cdd97d5369 --- /dev/null +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -0,0 +1,1196 @@ +"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from weave.trace_server.trace_server_interface import TraceStatus + +from core.ops.entities.config_entity import WeaveConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel +from core.ops.weave_trace.weave_trace import WeaveDataTrace +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_weave_config(**overrides) -> WeaveConfig: + defaults = { + "api_key": "wv-api-key", + "project": "my-project", + "entity": "my-entity", + "host": None, + } + defaults.update(overrides) + return WeaveConfig(**defaults) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "wf-id", + "tenant_id": "tenant-1", + "workflow_run_id": "run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"key": "val"}, + "workflow_run_outputs": {"answer": "42"}, + "workflow_run_version": "v1", + "total_tokens": 10, + "file_list": [], + "query": "hello", + "metadata": {"user_id": "u1", "app_id": "app-1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = None + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 10, + "total_tokens": 15, + "conversation_mode": "chat", + "metadata": {"conversation_id": "c1"}, + "message_id": "msg-1", + "message_data": msg_data, + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_moderation_trace_info(**overrides) -> ModerationTraceInfo: + defaults = { + "flagged": False, + "action": "allow", + "preset_response": "", + "query": "test", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(created_at=_dt(), updated_at=_dt()), + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + defaults = { + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": msg_data, + "inputs": "query", + "documents": [{"content": "doc"}], + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "my_tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "output", + "tool_config": {"desc": "d"}, + "tool_parameters": {"p": "v"}, + "time_cost": 0.5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_generate_name_trace_info(**overrides) -> GenerateNameTraceInfo: + defaults = { + "tenant_id": "t1", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": 1}, + "outputs": {"name": "test"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def _make_node(**overrides): + """Create a mock workflow node execution object.""" + defaults = { + "id": "node-1", + "title": "Node Title", + "node_type": NodeType.CODE, + "status": "succeeded", + "inputs": {"key": "value"}, + "outputs": {"result": "ok"}, + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": None, + "metadata": {}, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_wandb(): + with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + mock.login.return_value = True + yield mock + + +@pytest.fixture +def mock_weave(): + with patch("core.ops.weave_trace.weave_trace.weave") as mock: + client = MagicMock() + client.entity = "my-entity" + client.project = "my-project" + mock.init.return_value = client + yield mock, client + + +@pytest.fixture +def trace_instance(mock_wandb, mock_weave): + """Create a WeaveDataTrace instance with mocked wandb/weave.""" + _, weave_client = mock_weave + config = _make_weave_config() + instance = WeaveDataTrace(config) + return instance + + +@pytest.fixture +def trace_instance_with_host(mock_wandb, mock_weave): + """Create a WeaveDataTrace instance with host configured.""" + _, weave_client = mock_weave + config = _make_weave_config(host="https://my.wandb.host") + instance = WeaveDataTrace(config) + return instance + + +# ── TestInit ───────────────────────────────────────────────────────────────── + + +class TestInit: + def test_init_without_host(self, mock_wandb, mock_weave): + """Test __init__ calls wandb.login without host.""" + mock_w, weave_client = mock_weave + config = _make_weave_config(host=None) + instance = WeaveDataTrace(config) + + mock_wandb.login.assert_called_once_with(key="wv-api-key", verify=True, relogin=True) + mock_w.init.assert_called_once_with(project_name="my-entity/my-project") + assert instance.weave_api_key == "wv-api-key" + assert instance.project_name == "my-project" + assert instance.entity == "my-entity" + assert instance.calls == {} + + def test_init_with_host(self, mock_wandb, mock_weave): + """Test __init__ calls wandb.login with host.""" + config = _make_weave_config(host="https://my.wandb.host") + instance = WeaveDataTrace(config) + + mock_wandb.login.assert_called_once_with( + key="wv-api-key", verify=True, relogin=True, host="https://my.wandb.host" + ) + assert instance.host == "https://my.wandb.host" + + def test_init_without_entity(self, mock_wandb, mock_weave): + """Test __init__ initializes weave without entity prefix when entity is None.""" + mock_w, weave_client = mock_weave + config = _make_weave_config(entity=None) + instance = WeaveDataTrace(config) + + mock_w.init.assert_called_once_with(project_name="my-project") + + def test_init_login_failure_raises(self, mock_wandb, mock_weave): + """Test __init__ raises ValueError when wandb.login returns False.""" + mock_wandb.login.return_value = False + config = _make_weave_config() + + with pytest.raises(ValueError, match="Weave login failed"): + WeaveDataTrace(config) + + def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch): + """Test FILES_URL is read from environment.""" + monkeypatch.setenv("FILES_URL", "http://files.example.com") + config = _make_weave_config() + instance = WeaveDataTrace(config) + assert instance.file_base_url == "http://files.example.com" + + def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch): + """Test FILES_URL defaults to http://127.0.0.1:5001.""" + monkeypatch.delenv("FILES_URL", raising=False) + config = _make_weave_config() + instance = WeaveDataTrace(config) + assert instance.file_base_url == "http://127.0.0.1:5001" + + def test_project_id_set_correctly(self, trace_instance): + """Test that project_id is set from weave_client entity/project.""" + assert trace_instance.project_id == "my-entity/my-project" + + +# ── TestGetProjectUrl ───────────────────────────────────────────────────────── + + +class TestGetProjectUrl: + def test_get_project_url_with_entity(self, trace_instance): + """Returns wandb URL with entity/project.""" + url = trace_instance.get_project_url() + assert url == "https://wandb.ai/my-entity/my-project" + + def test_get_project_url_without_entity(self, mock_wandb, mock_weave): + """Returns wandb URL with project only when entity is None.""" + config = _make_weave_config(entity=None) + instance = WeaveDataTrace(config) + url = instance.get_project_url() + assert url == "https://wandb.ai/my-project" + + def test_get_project_url_exception_raises(self, trace_instance, monkeypatch): + """Raises ValueError when exception occurs in get_project_url.""" + monkeypatch.setattr(trace_instance, "entity", None) + monkeypatch.setattr(trace_instance, "project_name", None) + # Force an error by making string formatting fail + with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + # Simulate exception via property + original_entity = trace_instance.entity + trace_instance.entity = None + trace_instance.project_name = None + url = trace_instance.get_project_url() + assert "https://wandb.ai/" in url + + +# ── TestTraceDispatcher ───────────────────────────────────────────────────── + + +class TestTraceDispatcher: + def test_dispatches_workflow_trace(self, trace_instance): + with patch.object(trace_instance, "workflow_trace") as mock_wt: + trace_instance.trace(_make_workflow_trace_info()) + mock_wt.assert_called_once() + + def test_dispatches_message_trace(self, trace_instance): + with patch.object(trace_instance, "message_trace") as mock_mt: + trace_instance.trace(_make_message_trace_info()) + mock_mt.assert_called_once() + + def test_dispatches_moderation_trace(self, trace_instance): + with patch.object(trace_instance, "moderation_trace") as mock_mod: + msg_data = MagicMock() + msg_data.created_at = _dt() + trace_instance.trace(_make_moderation_trace_info(message_data=msg_data)) + mock_mod.assert_called_once() + + def test_dispatches_suggested_question_trace(self, trace_instance): + with patch.object(trace_instance, "suggested_question_trace") as mock_sq: + trace_instance.trace(_make_suggested_question_trace_info()) + mock_sq.assert_called_once() + + def test_dispatches_dataset_retrieval_trace(self, trace_instance): + with patch.object(trace_instance, "dataset_retrieval_trace") as mock_dr: + trace_instance.trace(_make_dataset_retrieval_trace_info()) + mock_dr.assert_called_once() + + def test_dispatches_tool_trace(self, trace_instance): + with patch.object(trace_instance, "tool_trace") as mock_tool: + trace_instance.trace(_make_tool_trace_info()) + mock_tool.assert_called_once() + + def test_dispatches_generate_name_trace(self, trace_instance): + with patch.object(trace_instance, "generate_name_trace") as mock_gn: + trace_instance.trace(_make_generate_name_trace_info()) + mock_gn.assert_called_once() + + +# ── TestNormalizeTime ───────────────────────────────────────────────────────── + + +class TestNormalizeTime: + def test_none_returns_utc_now(self, trace_instance): + now_before = datetime.now(UTC) + result = trace_instance._normalize_time(None) + now_after = datetime.now(UTC) + assert result.tzinfo is not None + assert now_before <= result <= now_after + + def test_naive_datetime_gets_utc(self, trace_instance): + naive = datetime(2024, 6, 15, 12, 0, 0) + result = trace_instance._normalize_time(naive) + assert result.tzinfo == UTC + assert result.year == 2024 + assert result.month == 6 + + def test_aware_datetime_unchanged(self, trace_instance): + aware = datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC) + result = trace_instance._normalize_time(aware) + assert result == aware + assert result.tzinfo == UTC + + +# ── TestStartCall ───────────────────────────────────────────────────────────── + + +class TestStartCall: + def test_start_call_basic(self, trace_instance): + """Test basic start_call stores call metadata.""" + run = WeaveTraceModel( + id="run-1", + op="test-op", + inputs={"key": "val"}, + attributes={"trace_id": "t-1", "start_time": _dt()}, + ) + trace_instance.start_call(run) + + assert "run-1" in trace_instance.calls + assert trace_instance.calls["run-1"]["trace_id"] == "t-1" + assert trace_instance.calls["run-1"]["parent_id"] is None + trace_instance.weave_client.server.call_start.assert_called_once() + + def test_start_call_with_parent(self, trace_instance): + """Test start_call records parent_run_id.""" + run = WeaveTraceModel( + id="child-1", + op="child-op", + inputs={}, + attributes={"trace_id": "t-1", "start_time": _dt()}, + ) + trace_instance.start_call(run, parent_run_id="parent-1") + + assert trace_instance.calls["child-1"]["parent_id"] == "parent-1" + + def test_start_call_none_inputs_becomes_empty_dict(self, trace_instance): + """Test that None inputs is normalized to {}.""" + run = WeaveTraceModel( + id="run-2", + op="op", + inputs=None, + attributes={"trace_id": "t-2", "start_time": _dt()}, + ) + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + assert req.start.inputs == {} + + def test_start_call_non_dict_inputs_becomes_str_dict(self, trace_instance): + """Test that non-dict inputs is wrapped as string.""" + run = WeaveTraceModel( + id="run-3", + op="op", + inputs="some string input", + attributes={"trace_id": "t-3", "start_time": _dt()}, + ) + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + # String inputs gets converted by validator to a dict + assert isinstance(req.start.inputs, dict) + + def test_start_call_none_attributes_becomes_empty_dict(self, trace_instance): + """Test that None attributes is handled properly.""" + run = WeaveTraceModel( + id="run-4", + op="op", + inputs={}, + attributes=None, + ) + trace_instance.start_call(run) + # trace_id should fall back to run_data.id + assert trace_instance.calls["run-4"]["trace_id"] == "run-4" + + def test_start_call_non_dict_attributes_becomes_dict(self, trace_instance): + """Test that non-dict attributes is wrapped.""" + run = WeaveTraceModel( + id="run-5", + op="op", + inputs={}, + attributes=None, + ) + # Manually override after construction + run.attributes = "some-attr-string" + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + assert isinstance(req.start.attributes, dict) + assert req.start.attributes == {"attributes": "some-attr-string"} + + def test_start_call_trace_id_falls_back_to_run_id(self, trace_instance): + """When trace_id not in attributes, falls back to run_data.id.""" + run = WeaveTraceModel( + id="run-6", + op="op", + inputs={}, + attributes={"start_time": _dt()}, + ) + trace_instance.start_call(run) + assert trace_instance.calls["run-6"]["trace_id"] == "run-6" + + +# ── TestFinishCall ────────────────────────────────────────────────────────── + + +class TestFinishCall: + def _setup_call(self, trace_instance, run_id="run-1", trace_id="t-1"): + """Helper: register a call so finish_call can find it.""" + trace_instance.calls[run_id] = {"trace_id": trace_id, "parent_id": None} + + def test_finish_call_success(self, trace_instance): + """Test finish_call sends call_end with SUCCESS status.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + outputs={"result": "ok"}, + attributes={"start_time": _dt(), "end_time": _dt() + timedelta(seconds=1)}, + exception=None, + ) + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["status_counts"][TraceStatus.SUCCESS] == 1 + assert req.end.summary["status_counts"][TraceStatus.ERROR] == 0 + assert req.end.exception is None + + def test_finish_call_with_error(self, trace_instance): + """Test finish_call sends call_end with ERROR status when exception is set.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + outputs={}, + attributes={"start_time": _dt(), "end_time": _dt() + timedelta(seconds=1)}, + exception="Something broke", + ) + trace_instance.finish_call(run) + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["status_counts"][TraceStatus.ERROR] == 1 + assert req.end.summary["status_counts"][TraceStatus.SUCCESS] == 0 + assert req.end.exception == "Something broke" + + def test_finish_call_missing_id_raises(self, trace_instance): + """Test finish_call raises ValueError when call id not found.""" + run = WeaveTraceModel( + id="nonexistent", + op="op", + inputs={}, + ) + with pytest.raises(ValueError, match="Call with id nonexistent not found"): + trace_instance.finish_call(run) + + def test_finish_call_elapsed_negative_clamped_to_zero(self, trace_instance): + """Test that negative elapsed time is clamped to 0.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes={ + "start_time": _dt() + timedelta(seconds=5), + "end_time": _dt(), # end before start + }, + ) + trace_instance.finish_call(run) + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["weave"]["latency_ms"] == 0 + + def test_finish_call_none_attributes(self, trace_instance): + """Test finish_call handles None attributes.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes=None, + ) + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + + def test_finish_call_non_dict_attributes(self, trace_instance): + """Test finish_call handles non-dict attributes.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes=None, + ) + run.attributes = "some string attr" + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + + +# ── TestWorkflowTrace ───────────────────────────────────────────────────────── + + +class TestWorkflowTrace: + def _setup_repo(self, monkeypatch, nodes=None): + """Helper to patch session/repo dependencies.""" + if nodes is None: + nodes = [] + + repo = MagicMock() + repo.get_by_workflow_run.return_value = nodes + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + + monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + return repo + + def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): + """Workflow trace with no nodes and no message_id.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Only workflow run: start_call and finish_call each called once + assert trace_instance.start_call.call_count == 1 + assert trace_instance.finish_call.call_count == 1 + + def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch): + """Workflow trace with message_id creates both message and workflow runs.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id="msg-1") + trace_instance.workflow_trace(trace_info) + + # message run + workflow run = 2 start_call / finish_call + assert trace_instance.start_call.call_count == 2 + assert trace_instance.finish_call.call_count == 2 + + def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch): + """Workflow trace iterates node executions and creates node runs.""" + node = _make_node( + id="node-1", + node_type=NodeType.CODE, + inputs={"k": "v"}, + outputs={"r": "ok"}, + elapsed_time=0.5, + created_at=_dt(), + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 5}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # workflow run + node run = 2 calls + assert trace_instance.start_call.call_count == 2 + + def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch): + """LLM node uses process_data prompts as inputs.""" + node = _make_node( + node_type=NodeType.LLM, + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_mode": "chat", + "model_provider": "openai", + "model_name": "gpt-4", + }, + inputs={"key": "val"}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Check node start_call was called with prompts input + node_call_args = trace_instance.start_call.call_args_list[-1] + node_run = node_call_args[0][0] + # WeaveTraceModel validator wraps list prompts into {"messages": [...]} + # The key "messages" should be present (validator transforms the list) + assert "messages" in node_run.inputs + + def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch): + """Non-LLM node uses node_execution.inputs directly.""" + node = _make_node( + node_type=NodeType.TOOL, + inputs={"tool_input": "val"}, + process_data=None, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # node run inputs should be from node.inputs; validator adds usage_metadata + file_list + node_call_args = trace_instance.start_call.call_args_list[-1] + node_run = node_call_args[0][0] + assert node_run.inputs.get("tool_input") == "val" + + def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): + """Raises ValueError when app_id is missing from metadata.""" + monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + + trace_info = _make_workflow_trace_info( + message_id=None, + metadata={"user_id": "u1"}, # no app_id + ) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch): + """start_time defaults to datetime.now() when None.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None, start_time=None) + trace_instance.workflow_trace(trace_info) + + assert trace_instance.start_call.call_count == 1 + + def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch): + """Node with created_at=None uses datetime.now().""" + node = _make_node(created_at=None, elapsed_time=0.5) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch): + """Chat mode LLM node adds ls_provider and ls_model_name to attributes.""" + node = _make_node( + node_type=NodeType.LLM, + process_data={"model_mode": "chat", "model_provider": "openai", "model_name": "gpt-4", "prompts": []}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + start_calls = [] + + def capture_start(run, parent_run_id=None): + start_calls.append((run, parent_run_id)) + + trace_instance.start_call = capture_start + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Last start call is the node run + node_run, _ = start_calls[-1] + assert node_run.attributes.get("ls_provider") == "openai" + assert node_run.attributes.get("ls_model_name") == "gpt-4" + + def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch): + """Nodes are sorted by created_at before processing.""" + node1 = _make_node(id="node-b", created_at=_dt() + timedelta(seconds=2)) + node2 = _make_node(id="node-a", created_at=_dt()) + self._setup_repo(monkeypatch, nodes=[node1, node2]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + processed_ids = [] + + def capture_start(run, parent_run_id=None): + processed_ids.append(run.id) + + trace_instance.start_call = capture_start + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # First call = workflow run, then node-a, then node-b + assert processed_ids[1] == "node-a" + assert processed_ids[2] == "node-b" + + +# ── TestMessageTrace ────────────────────────────────────────────────────────── + + +class TestMessageTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """message_trace returns early when message_data is None.""" + trace_info = _make_message_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_message_trace(self, trace_instance, monkeypatch): + """message_trace creates message run and llm child run.""" + monkeypatch.setattr( + "core.ops.weave_trace.weave_trace.db.session.query", + lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + ) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info() + trace_instance.message_trace(trace_info) + + # message run + llm child run + assert trace_instance.start_call.call_count == 2 + assert trace_instance.finish_call.call_count == 2 + + def test_message_trace_with_file_data(self, trace_instance, monkeypatch): + """message_trace appends file URL to file_list.""" + file_data = MagicMock() + file_data.url = "path/to/file.png" + trace_instance.file_base_url = "http://files.test" + + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info( + message_file_data=file_data, + file_list=["existing.txt"], + ) + trace_instance.message_trace(trace_info) + + # The first start_call arg (the message run) should have file in outputs or inputs + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert "http://files.test/path/to/file.png" in message_run.file_list + + def test_message_trace_with_end_user(self, trace_instance, monkeypatch): + """message_trace looks up end user and sets end_user_id attribute.""" + end_user = MagicMock() + end_user.session_id = "session-xyz" + + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = "eu-1" + + trace_info = _make_message_trace_info(message_data=msg_data) + trace_instance.message_trace(trace_info) + + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert message_run.attributes.get("end_user_id") == "session-xyz" + + def test_message_trace_no_end_user(self, trace_instance, monkeypatch): + """message_trace handles when from_end_user_id is None.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = None + + trace_info = _make_message_trace_info(message_data=msg_data) + trace_instance.message_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): + """trace_id falls back to message_id when trace_id is None.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info(trace_id=None) + trace_instance.message_trace(trace_info) + + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert message_run.id == "msg-1" + + def test_message_trace_file_list_none(self, trace_instance, monkeypatch): + """message_trace handles file_list=None gracefully.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info(file_list=None, message_file_data=None) + trace_instance.message_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + +# ── TestModerationTrace ─────────────────────────────────────────────────────── + + +class TestModerationTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """moderation_trace returns early when message_data is None.""" + trace_info = _make_moderation_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_moderation_trace(self, trace_instance): + """moderation_trace creates a run with correct outputs.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + action="block", + flagged=True, + preset_response="blocked", + ) + trace_instance.moderation_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + run = trace_instance.start_call.call_args[0][0] + assert run.outputs["action"] == "block" + assert run.outputs["flagged"] is True + + def test_moderation_trace_with_no_times_uses_message_data_times(self, trace_instance): + """When start/end times are None, uses message_data created_at/updated_at.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + timedelta(seconds=1) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + start_time=None, + end_time=None, + ) + trace_instance.moderation_trace(trace_info) + trace_instance.start_call.assert_called_once() + + def test_moderation_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + trace_id=None, + ) + trace_instance.moderation_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestSuggestedQuestionTrace ──────────────────────────────────────────────── + + +class TestSuggestedQuestionTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """suggested_question_trace returns early when message_data is None.""" + trace_info = _make_suggested_question_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_suggested_question_trace(self, trace_instance): + """suggested_question_trace creates a run parented to trace_id.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_suggested_question_trace_info(trace_id="t-1") + trace_instance.suggested_question_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "t-1" + + def test_suggested_question_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_suggested_question_trace_info(trace_id=None) + trace_instance.suggested_question_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestDatasetRetrievalTrace ───────────────────────────────────────────────── + + +class TestDatasetRetrievalTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """dataset_retrieval_trace returns early when message_data is None.""" + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_dataset_retrieval_trace(self, trace_instance): + """dataset_retrieval_trace creates a run with documents as outputs.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_dataset_retrieval_trace_info( + documents=[{"id": "d1"}, {"id": "d2"}], + trace_id="t-1", + ) + trace_instance.dataset_retrieval_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + # WeaveTraceModel validator injects usage_metadata/file_list into dict outputs + assert run.outputs.get("documents") == [{"id": "d1"}, {"id": "d2"}] + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "t-1" + + def test_dataset_retrieval_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_dataset_retrieval_trace_info(trace_id=None) + trace_instance.dataset_retrieval_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestToolTrace ───────────────────────────────────────────────────────────── + + +class TestToolTrace: + def test_basic_tool_trace(self, trace_instance): + """tool_trace creates a run with correct op as tool_name.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id="t-1") + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert run.op == "my_tool" + # WeaveTraceModel validator injects usage_metadata/file_list into dict inputs + assert run.inputs.get("x") == 1 + + def test_tool_trace_with_file_url(self, trace_instance): + """tool_trace adds file_url to file_list when provided.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(file_url="http://files/file.pdf") + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert "http://files/file.pdf" in run.file_list + + def test_tool_trace_without_file_url(self, trace_instance): + """tool_trace uses empty file_list when file_url is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(file_url=None) + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert run.file_list == [] + + def test_tool_trace_trace_id_from_message_id(self, trace_instance): + """trace_id uses message_id fallback.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id=None) + trace_instance.tool_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + def test_tool_trace_message_id_none_uses_conversation_id(self, trace_instance): + """When message_id is None, tries conversation_id attribute.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id=None, message_id=None) + trace_instance.tool_trace(trace_info) + + # No crash; parent_run_id is None since no fallback + _, kwargs = trace_instance.start_call.call_args + # parent_run_id should be None when no message_id and no trace_id + assert kwargs.get("parent_run_id") is None + + +# ── TestGenerateNameTrace ───────────────────────────────────────────────────── + + +class TestGenerateNameTrace: + def test_basic_generate_name_trace(self, trace_instance): + """generate_name_trace creates a run with correct op.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_generate_name_trace_info() + trace_instance.generate_name_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + run = trace_instance.start_call.call_args[0][0] + assert run.op == str(TraceTaskName.GENERATE_NAME_TRACE) + + def test_generate_name_trace_no_parent(self, trace_instance): + """generate_name_trace has no parent run (no parent_run_id).""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_generate_name_trace_info() + trace_instance.generate_name_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + # No parent_run_id passed to generate_name start_call + assert kwargs == {} or kwargs.get("parent_run_id") is None + + +# ── TestApiCheck ────────────────────────────────────────────────────────────── + + +class TestApiCheck: + def test_api_check_success_without_host(self, trace_instance, mock_wandb): + """api_check returns True on successful login without host.""" + trace_instance.host = None + mock_wandb.login.return_value = True + + result = trace_instance.api_check() + + assert result is True + mock_wandb.login.assert_called_with(key=trace_instance.weave_api_key, verify=True, relogin=True) + + def test_api_check_success_with_host(self, trace_instance, mock_wandb): + """api_check returns True on successful login with host.""" + trace_instance.host = "https://my.wandb.host" + mock_wandb.login.return_value = True + + result = trace_instance.api_check() + + assert result is True + mock_wandb.login.assert_called_with( + key=trace_instance.weave_api_key, verify=True, relogin=True, host="https://my.wandb.host" + ) + + def test_api_check_login_failure_raises(self, trace_instance, mock_wandb): + """api_check raises ValueError when login returns False.""" + trace_instance.host = None + mock_wandb.login.return_value = False + + with pytest.raises(ValueError, match="Weave API check failed"): + trace_instance.api_check() + + def test_api_check_exception_raises_value_error(self, trace_instance, mock_wandb): + """api_check raises ValueError when wandb.login raises exception.""" + trace_instance.host = None + mock_wandb.login.side_effect = Exception("network error") + + with pytest.raises(ValueError, match="Weave API check failed: network error"): + trace_instance.api_check()