mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
fix bugs
This commit is contained in:
@@ -39,6 +39,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_messages = self._get_completion_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
@@ -60,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
def _get_completion_model_prompt_messages(self,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileObj],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
@@ -86,6 +88,9 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
if query:
|
||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
@@ -147,21 +152,30 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
elif files:
|
||||
# get last message
|
||||
last_message = prompt_messages[-1] if prompt_messages else None
|
||||
if last_message and last_message.role == PromptMessageRole.USER:
|
||||
# get last user message content and add files
|
||||
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
if not query:
|
||||
# get last message
|
||||
last_message = prompt_messages[-1] if prompt_messages else None
|
||||
if last_message and last_message.role == PromptMessageRole.USER:
|
||||
# get last user message content and add files
|
||||
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
last_message.content = prompt_message_contents
|
||||
last_message.content = prompt_message_contents
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data='')] # not for query
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
elif query:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@@ -210,4 +224,4 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
else:
|
||||
prompt_inputs['#histories#'] = ''
|
||||
|
||||
return prompt_inputs
|
||||
return prompt_inputs
|
||||
|
||||
@@ -50,6 +50,7 @@ def test__get_completion_model_prompt_messages():
|
||||
prompt_messages = prompt_transform._get_completion_model_prompt_messages(
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
|
||||
Reference in New Issue
Block a user