Compare commits

...

4 Commits

Author SHA1 Message Date
Yanli 盐粒
688b2cfc48 remove the agent notes 2026-01-27 01:09:33 +08:00
Yanli 盐粒
2b787bff72 add the agent-note for message_entities 2026-01-25 23:00:34 +08:00
Yanli 盐粒
289b59208a make the model_runtime support reading and parsing the opaque_body from plugin LLM call (and fix the tool call parsing in streaming mode) 2026-01-25 23:00:13 +08:00
Yanli 盐粒
5c51e0a9ae add message level and content level opaque_body 2026-01-25 22:35:22 +08:00
4 changed files with 163 additions and 2 deletions

View File

View File

@@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum, auto
from typing import Annotated, Any, Literal, Union
from pydantic import BaseModel, Field, field_serializer, field_validator
from pydantic import BaseModel, Field, JsonValue, field_serializer, field_validator
class PromptMessageRole(StrEnum):
@@ -69,6 +69,7 @@ class PromptMessageContent(ABC, BaseModel):
"""
type: PromptMessageContentType
opaque_body: JsonValue | None = None
class TextPromptMessageContent(PromptMessageContent):
@@ -244,6 +245,7 @@ class AssistantPromptMessage(PromptMessage):
role: PromptMessageRole = PromptMessageRole.ASSISTANT
tool_calls: list[ToolCall] = []
opaque_body: JsonValue | None = None
def is_empty(self) -> bool:
"""

View File

@@ -164,6 +164,7 @@ class LargeLanguageModel(AIModel):
usage = LLMUsage.empty_usage()
system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
assistant_opaque_body = None
for chunk in result:
if isinstance(chunk.delta.message.content, str):
@@ -172,6 +173,8 @@ class LargeLanguageModel(AIModel):
content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None:
assistant_opaque_body = chunk.delta.message.opaque_body
usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint
@@ -183,6 +186,7 @@ class LargeLanguageModel(AIModel):
message=AssistantPromptMessage(
content=content or content_list,
tool_calls=tools_calls,
opaque_body=assistant_opaque_body,
),
usage=usage,
system_fingerprint=system_fingerprint,
@@ -261,6 +265,8 @@ class LargeLanguageModel(AIModel):
usage = None
system_fingerprint = None
real_model = model
assistant_opaque_body = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
if not content:
@@ -294,6 +300,10 @@ class LargeLanguageModel(AIModel):
)
_update_message_content(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None:
assistant_opaque_body = chunk.delta.message.opaque_body
real_model = chunk.model
if chunk.delta.usage:
@@ -304,7 +314,11 @@ class LargeLanguageModel(AIModel):
except Exception as e:
raise self._transform_invoke_error(e)
assistant_message = AssistantPromptMessage(content=message_content)
assistant_message = AssistantPromptMessage(
content=message_content,
tool_calls=tools_calls,
opaque_body=assistant_opaque_body,
)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(

View File

@@ -0,0 +1,145 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from unittest.mock import patch
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
class _CaptureAfterInvokeCallback(Callback):
after_result: LLMResult | None
def __init__(self) -> None:
self.after_result = None
def on_before_invoke(self, **kwargs: Any) -> None: # noqa: ANN401
return None
def on_new_chunk(self, **kwargs: Any) -> None: # noqa: ANN401
return None
def on_after_invoke(self, result: LLMResult, **kwargs: Any) -> None: # noqa: ANN401
self.after_result = result
def on_invoke_error(self, **kwargs: Any) -> None: # noqa: ANN401
return None
def _build_llm_instance() -> LargeLanguageModel:
declaration = ProviderEntity(
provider="test",
label=I18nObject(en_US="test"),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
)
plugin_model_provider = PluginModelProviderEntity(
id="pmp_1",
created_at=datetime.now(),
updated_at=datetime.now(),
provider="test",
tenant_id="tenant_1",
plugin_unique_identifier="test/plugin",
plugin_id="test/plugin",
declaration=declaration,
)
return LargeLanguageModel(
tenant_id="tenant_1",
plugin_id="test/plugin",
provider_name="test",
plugin_model_provider=plugin_model_provider,
)
def test_invoke_non_stream_preserves_assistant_opaque_body() -> None:
llm = _build_llm_instance()
prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")]
chunk = LLMResultChunk(
model="gpt-test",
prompt_messages=[],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content="hello", opaque_body={"provider_message_id": "msg_123"}),
),
)
def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401
yield chunk
with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm):
result = llm.invoke(
model="gpt-test",
credentials={},
prompt_messages=prompt_messages,
model_parameters={},
stream=False,
)
assert isinstance(result, LLMResult)
assert result.message.opaque_body == {"provider_message_id": "msg_123"}
assert list(result.prompt_messages) == prompt_messages
def test_invoke_stream_preserves_assistant_opaque_body_in_after_callback() -> None:
llm = _build_llm_instance()
prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")]
callback = _CaptureAfterInvokeCallback()
tool_call_1 = AssistantPromptMessage.ToolCall(
id="1",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": '),
)
tool_call_2 = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'),
)
chunk1 = LLMResultChunk(
model="gpt-test",
prompt_messages=[],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content="h", tool_calls=[tool_call_1], opaque_body={"provider_message_id": "msg_123"}),
),
)
chunk2 = LLMResultChunk(
model="gpt-test",
prompt_messages=[],
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content="i", tool_calls=[tool_call_2]),
),
)
def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401
yield chunk1
yield chunk2
with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm):
gen = llm.invoke(
model="gpt-test",
credentials={},
prompt_messages=prompt_messages,
model_parameters={},
stream=True,
callbacks=[callback],
)
chunks = list(gen)
assert chunks[0].prompt_messages == prompt_messages
assert callback.after_result is not None
assert callback.after_result.message.opaque_body == {"provider_message_id": "msg_123"}
assert len(callback.after_result.message.tool_calls) == 1
assert callback.after_result.message.tool_calls[0].function.arguments == '{"arg1": "value"}'