diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0ed9ec352c..7519971ce7 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -39,6 +39,7 @@ class AdvancedPromptTransform(PromptTransform): prompt_messages = self._get_completion_model_prompt_messages( prompt_template_entity=prompt_template_entity, inputs=inputs, + query=query, files=files, context=context, memory=memory, @@ -60,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform): def _get_completion_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, + query: Optional[str], files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], @@ -86,6 +88,9 @@ class AdvancedPromptTransform(PromptTransform): model_config=model_config ) + if query: + prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + prompt = prompt_template.format( prompt_inputs ) @@ -147,21 +152,30 @@ class AdvancedPromptTransform(PromptTransform): else: prompt_messages.append(UserPromptMessage(content=query)) elif files: - # get last message - last_message = prompt_messages[-1] if prompt_messages else None - if last_message and last_message.role == PromptMessageRole.USER: - # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] - for file in files: - prompt_message_contents.append(file.prompt_message_content) + if not query: + # get last message + last_message = prompt_messages[-1] if prompt_messages else None + if last_message and last_message.role == PromptMessageRole.USER: + # get last user message content and add files + prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + for file in files: + prompt_message_contents.append(file.prompt_message_content) - last_message.content = prompt_message_contents + last_message.content = prompt_message_contents + else: + prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + for file in files: + prompt_message_contents.append(file.prompt_message_content) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: prompt_message_contents.append(file.prompt_message_content) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + elif query: + prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages @@ -210,4 +224,4 @@ class AdvancedPromptTransform(PromptTransform): else: prompt_inputs['#histories#'] = '' - return prompt_inputs + return prompt_inputs diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 65a160a8e5..95f1e30b44 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -50,6 +50,7 @@ def test__get_completion_model_prompt_messages(): prompt_messages = prompt_transform._get_completion_model_prompt_messages( prompt_template_entity=prompt_template_entity, inputs=inputs, + query=None, files=files, context=context, memory=memory,