fix(chat): close streaming LLM generator when stop response is triggered (#36227)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Carmen Fernández Ruiz
2026-05-25 23:23:26 -10:00
committed by GitHub
parent f5d664887b
commit 7d45335a32
2 changed files with 81 additions and 35 deletions
+41 -35
View File
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
AppGenerateEntity,
EasyUIBasedAppGenerateEntity,
@@ -292,46 +293,51 @@ class AppRunner:
prompt_messages: list[PromptMessage] = []
text = ""
usage = None
for result in invoke_result:
if not agent:
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
else:
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
try:
for result in invoke_result:
if not agent:
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
else:
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
message = result.delta.message
if isinstance(message.content, str):
text += message.content
elif isinstance(message.content, list):
for content in message.content:
if isinstance(content, str):
text += content
elif isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
message = result.delta.message
if isinstance(message.content, str):
text += message.content
elif isinstance(message.content, list):
for content in message.content:
if isinstance(content, str):
text += content
elif isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
else:
_logger.warning("Received multimodal output but missing required parameters")
else:
text += content.data if hasattr(content, "data") else str(content)
text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model
if not model:
model = result.model
if not prompt_messages:
prompt_messages = list(result.prompt_messages)
if not prompt_messages:
prompt_messages = list(result.prompt_messages)
if result.delta.usage:
usage = result.delta.usage
if result.delta.usage:
usage = result.delta.usage
except GenerateTaskStoppedError:
# Explicitly close provider stream to stop in-flight token generation ASAP.
invoke_result.close()
raise
if usage is None:
usage = LLMUsage.empty_usage()
@@ -12,6 +12,7 @@ from core.app.app_config.entities import (
PromptTemplateEntity,
)
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -41,6 +42,23 @@ class _QueueRecorder:
self.events.append(event)
class _ClosableStream:
def __init__(self, chunks: list[LLMResultChunk]) -> None:
self._chunks = chunks
self.closed = False
def __iter__(self):
return self
def __next__(self):
if not self._chunks:
raise StopIteration
return self._chunks.pop(0)
def close(self) -> None:
self.closed = True
class TestAppRunner:
def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch: pytest.MonkeyPatch):
runner = AppRunner()
@@ -331,6 +349,28 @@ class TestAppRunner:
assert queue.events[-1].llm_result.usage == usage
exception_logger.assert_called_once()
def test_handle_invoke_result_stream_closes_generator_when_stopped(self):
runner = AppRunner()
chunk = LLMResultChunk(
model="stream-model",
prompt_messages=[AssistantPromptMessage(content="prompt")],
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="a")),
)
stream = _ClosableStream([chunk])
queue_manager = SimpleNamespace(
publish=MagicMock(side_effect=GenerateTaskStoppedError("stopped")),
)
with pytest.raises(GenerateTaskStoppedError):
runner._handle_invoke_result_stream(
invoke_result=stream,
queue_manager=queue_manager,
agent=False,
)
assert stream.closed is True
def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch: pytest.MonkeyPatch):
runner = AppRunner()