mirror of
https://github.com/langgenius/dify.git
synced 2026-02-15 05:04:02 +00:00
Compare commits
4 Commits
refactor/r
...
yanli/add-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
688b2cfc48 | ||
|
|
2b787bff72 | ||
|
|
289b59208a | ||
|
|
5c51e0a9ae |
@@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
from pydantic import BaseModel, Field, JsonValue, field_serializer, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(StrEnum):
|
||||
@@ -69,6 +69,7 @@ class PromptMessageContent(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
opaque_body: JsonValue | None = None
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
@@ -244,6 +245,7 @@ class AssistantPromptMessage(PromptMessage):
|
||||
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
opaque_body: JsonValue | None = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -164,6 +164,7 @@ class LargeLanguageModel(AIModel):
|
||||
usage = LLMUsage.empty_usage()
|
||||
system_fingerprint = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
assistant_opaque_body = None
|
||||
|
||||
for chunk in result:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
@@ -172,6 +173,8 @@ class LargeLanguageModel(AIModel):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None:
|
||||
assistant_opaque_body = chunk.delta.message.opaque_body
|
||||
|
||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
@@ -183,6 +186,7 @@ class LargeLanguageModel(AIModel):
|
||||
message=AssistantPromptMessage(
|
||||
content=content or content_list,
|
||||
tool_calls=tools_calls,
|
||||
opaque_body=assistant_opaque_body,
|
||||
),
|
||||
usage=usage,
|
||||
system_fingerprint=system_fingerprint,
|
||||
@@ -261,6 +265,8 @@ class LargeLanguageModel(AIModel):
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
assistant_opaque_body = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
|
||||
if not content:
|
||||
@@ -294,6 +300,10 @@ class LargeLanguageModel(AIModel):
|
||||
)
|
||||
|
||||
_update_message_content(chunk.delta.message.content)
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None:
|
||||
assistant_opaque_body = chunk.delta.message.opaque_body
|
||||
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
@@ -304,7 +314,11 @@ class LargeLanguageModel(AIModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=message_content)
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content=message_content,
|
||||
tool_calls=tools_calls,
|
||||
opaque_body=assistant_opaque_body,
|
||||
)
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
|
||||
|
||||
class _CaptureAfterInvokeCallback(Callback):
|
||||
after_result: LLMResult | None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.after_result = None
|
||||
|
||||
def on_before_invoke(self, **kwargs: Any) -> None: # noqa: ANN401
|
||||
return None
|
||||
|
||||
def on_new_chunk(self, **kwargs: Any) -> None: # noqa: ANN401
|
||||
return None
|
||||
|
||||
def on_after_invoke(self, result: LLMResult, **kwargs: Any) -> None: # noqa: ANN401
|
||||
self.after_result = result
|
||||
|
||||
def on_invoke_error(self, **kwargs: Any) -> None: # noqa: ANN401
|
||||
return None
|
||||
|
||||
|
||||
def _build_llm_instance() -> LargeLanguageModel:
|
||||
declaration = ProviderEntity(
|
||||
provider="test",
|
||||
label=I18nObject(en_US="test"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
|
||||
)
|
||||
plugin_model_provider = PluginModelProviderEntity(
|
||||
id="pmp_1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
provider="test",
|
||||
tenant_id="tenant_1",
|
||||
plugin_unique_identifier="test/plugin",
|
||||
plugin_id="test/plugin",
|
||||
declaration=declaration,
|
||||
)
|
||||
|
||||
return LargeLanguageModel(
|
||||
tenant_id="tenant_1",
|
||||
plugin_id="test/plugin",
|
||||
provider_name="test",
|
||||
plugin_model_provider=plugin_model_provider,
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_non_stream_preserves_assistant_opaque_body() -> None:
|
||||
llm = _build_llm_instance()
|
||||
prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")]
|
||||
|
||||
chunk = LLMResultChunk(
|
||||
model="gpt-test",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="hello", opaque_body={"provider_message_id": "msg_123"}),
|
||||
),
|
||||
)
|
||||
|
||||
def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401
|
||||
yield chunk
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm):
|
||||
result = llm.invoke(
|
||||
model="gpt-test",
|
||||
credentials={},
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.message.opaque_body == {"provider_message_id": "msg_123"}
|
||||
assert list(result.prompt_messages) == prompt_messages
|
||||
|
||||
|
||||
def test_invoke_stream_preserves_assistant_opaque_body_in_after_callback() -> None:
|
||||
llm = _build_llm_instance()
|
||||
prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")]
|
||||
callback = _CaptureAfterInvokeCallback()
|
||||
|
||||
tool_call_1 = AssistantPromptMessage.ToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": '),
|
||||
)
|
||||
tool_call_2 = AssistantPromptMessage.ToolCall(
|
||||
id="",
|
||||
type="",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'),
|
||||
)
|
||||
|
||||
chunk1 = LLMResultChunk(
|
||||
model="gpt-test",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="h", tool_calls=[tool_call_1], opaque_body={"provider_message_id": "msg_123"}),
|
||||
),
|
||||
)
|
||||
chunk2 = LLMResultChunk(
|
||||
model="gpt-test",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="i", tool_calls=[tool_call_2]),
|
||||
),
|
||||
)
|
||||
|
||||
def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401
|
||||
yield chunk1
|
||||
yield chunk2
|
||||
|
||||
with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm):
|
||||
gen = llm.invoke(
|
||||
model="gpt-test",
|
||||
credentials={},
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters={},
|
||||
stream=True,
|
||||
callbacks=[callback],
|
||||
)
|
||||
chunks = list(gen)
|
||||
|
||||
assert chunks[0].prompt_messages == prompt_messages
|
||||
assert callback.after_result is not None
|
||||
assert callback.after_result.message.opaque_body == {"provider_message_id": "msg_123"}
|
||||
assert len(callback.after_result.message.tool_calls) == 1
|
||||
assert callback.after_result.message.tool_calls[0].function.arguments == '{"arg1": "value"}'
|
||||
|
||||
Reference in New Issue
Block a user