fix annotation bugs

This commit is contained in:
takatost
2024-03-18 17:57:10 +08:00
parent 0ea233edbe
commit 09cfbe117e
4 changed files with 27 additions and 152 deletions

View File

@@ -136,9 +136,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._handle_annotation_reply(event)
if annotation:
self._task_state.answer = annotation.content
self._handle_annotation_reply(event)
elif isinstance(event, QueueWorkflowStartedEvent):
self._handle_workflow_start()
elif isinstance(event, QueueNodeStartedEvent):
@@ -148,7 +146,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
if workflow_run and workflow_run.status == WorkflowRunStatus.FAILED.value:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
# handle output moderation
@@ -249,21 +247,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
if workflow_run:
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event))
break
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event))
break
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
yield self._message_end_to_stream_response()
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
@@ -277,9 +281,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._handle_annotation_reply(event)
if annotation:
self._task_state.answer = annotation.content
self._handle_annotation_reply(event)
# elif isinstance(event, QueueMessageFileEvent):
# response = self._message_file_to_stream_response(event)
# if response:

View File

@@ -1,18 +1,13 @@
import logging
import time
from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.entities.app_invoke_entities import (
AppGenerateEntity,
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.moderation.input_moderation import InputModeration
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
@@ -50,15 +45,6 @@ class WorkflowAppRunner:
inputs = application_generate_entity.inputs
files = application_generate_entity.files
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs
):
return
db.session.close()
# RUN WORKFLOW
@@ -92,79 +78,3 @@ class WorkflowAppRunner:
# return workflow
return workflow
def handle_input_moderation(self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: WorkflowAppGenerateEntity,
inputs: dict) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:return:
"""
try:
# process sensitive_word_avoidance
moderation_feature = InputModeration()
_, inputs, query = moderation_feature.check(
app_id=app_record.id,
tenant_id=app_generate_entity.app_config.tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=''
)
except ModerationException as e:
if app_generate_entity.stream:
self._stream_output(
queue_manager=queue_manager,
text=str(e),
)
queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION),
PublishFrom.APPLICATION_MANAGER
)
return True
return False
def _stream_output(self, queue_manager: AppQueueManager,
text: str) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:return:
"""
index = 0
for token in text:
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
def moderation_for_inputs(self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: dict) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=''
)

View File

@@ -2,7 +2,7 @@ import logging
from collections.abc import Generator
from typing import Any, Union
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
@@ -114,11 +114,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
# handle output moderation
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
# save workflow app log
self._save_workflow_app_log(workflow_run)
@@ -186,10 +181,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
yield self._text_replace_to_stream_response(output_moderation_answer)
# save workflow app log
self._save_workflow_app_log(workflow_run)
@@ -202,11 +193,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if delta_text is None:
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
@@ -268,29 +254,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
return True
else:
self._output_moderation_handler.append_new_token(text)
return False

View File

@@ -425,8 +425,11 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
-> Optional[WorkflowRun]:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,