From 7d45335a32c101d2b4b48498c7edc51fae208df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carmen=20Fern=C3=A1ndez=20Ruiz?= <279459669+zeus1959@users.noreply.github.com> Date: Mon, 25 May 2026 23:23:26 -1000 Subject: [PATCH] fix(chat): close streaming LLM generator when stop response is triggered (#36227) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/core/app/apps/base_app_runner.py | 76 ++++++++++--------- .../core/app/apps/test_base_app_runner.py | 40 ++++++++++ 2 files changed, 81 insertions(+), 35 deletions(-) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 1251b397e2..0ca682e87a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AppGenerateEntity, EasyUIBasedAppGenerateEntity, @@ -292,46 +293,51 @@ class AppRunner: prompt_messages: list[PromptMessage] = [] text = "" usage = None - for result in invoke_result: - if not agent: - queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) - else: - queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) + try: + for result in invoke_result: + if not agent: + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) + else: + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) - message = result.delta.message - if isinstance(message.content, str): - text += message.content - elif isinstance(message.content, list): - for content in message.content: - if isinstance(content, str): - text += content - elif isinstance(content, TextPromptMessageContent): - text += content.data - elif isinstance(content, ImagePromptMessageContent): - if message_id and user_id and tenant_id: - try: - self._handle_multimodal_image_content( - content=content, - message_id=message_id, - user_id=user_id, - tenant_id=tenant_id, - queue_manager=queue_manager, - ) - except Exception: - _logger.exception("Failed to handle multimodal image output") + message = result.delta.message + if isinstance(message.content, str): + text += message.content + elif isinstance(message.content, list): + for content in message.content: + if isinstance(content, str): + text += content + elif isinstance(content, TextPromptMessageContent): + text += content.data + elif isinstance(content, ImagePromptMessageContent): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") else: - _logger.warning("Received multimodal output but missing required parameters") - else: - text += content.data if hasattr(content, "data") else str(content) + text += content.data if hasattr(content, "data") else str(content) - if not model: - model = result.model + if not model: + model = result.model - if not prompt_messages: - prompt_messages = list(result.prompt_messages) + if not prompt_messages: + prompt_messages = list(result.prompt_messages) - if result.delta.usage: - usage = result.delta.usage + if result.delta.usage: + usage = result.delta.usage + except GenerateTaskStoppedError: + # Explicitly close provider stream to stop in-flight token generation ASAP. + invoke_result.close() + raise if usage is None: usage = LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index c6eedf7be7..cd1e5babf8 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -12,6 +12,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.app.apps.base_app_runner import AppRunner +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -41,6 +42,23 @@ class _QueueRecorder: self.events.append(event) +class _ClosableStream: + def __init__(self, chunks: list[LLMResultChunk]) -> None: + self._chunks = chunks + self.closed = False + + def __iter__(self): + return self + + def __next__(self): + if not self._chunks: + raise StopIteration + return self._chunks.pop(0) + + def close(self) -> None: + self.closed = True + + class TestAppRunner: def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch: pytest.MonkeyPatch): runner = AppRunner() @@ -331,6 +349,28 @@ class TestAppRunner: assert queue.events[-1].llm_result.usage == usage exception_logger.assert_called_once() + def test_handle_invoke_result_stream_closes_generator_when_stopped(self): + runner = AppRunner() + chunk = LLMResultChunk( + model="stream-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="a")), + ) + stream = _ClosableStream([chunk]) + + queue_manager = SimpleNamespace( + publish=MagicMock(side_effect=GenerateTaskStoppedError("stopped")), + ) + + with pytest.raises(GenerateTaskStoppedError): + runner._handle_invoke_result_stream( + invoke_result=stream, + queue_manager=queue_manager, + agent=False, + ) + + assert stream.closed is True + def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch: pytest.MonkeyPatch): runner = AppRunner()