mirror of
https://github.com/langgenius/dify.git
synced 2026-06-06 16:10:07 +08:00
fix(chat): close streaming LLM generator when stop response is triggered (#36227)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f5d664887b
commit
7d45335a32
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
@@ -292,46 +293,51 @@ class AppRunner:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if not prompt_messages:
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
except GenerateTaskStoppedError:
|
||||
# Explicitly close provider stream to stop in-flight token generation ASAP.
|
||||
invoke_result.close()
|
||||
raise
|
||||
|
||||
if usage is None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.app.app_config.entities import (
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -41,6 +42,23 @@ class _QueueRecorder:
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
class _ClosableStream:
|
||||
def __init__(self, chunks: list[LLMResultChunk]) -> None:
|
||||
self._chunks = chunks
|
||||
self.closed = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if not self._chunks:
|
||||
raise StopIteration
|
||||
return self._chunks.pop(0)
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
class TestAppRunner:
|
||||
def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch: pytest.MonkeyPatch):
|
||||
runner = AppRunner()
|
||||
@@ -331,6 +349,28 @@ class TestAppRunner:
|
||||
assert queue.events[-1].llm_result.usage == usage
|
||||
exception_logger.assert_called_once()
|
||||
|
||||
def test_handle_invoke_result_stream_closes_generator_when_stopped(self):
|
||||
runner = AppRunner()
|
||||
chunk = LLMResultChunk(
|
||||
model="stream-model",
|
||||
prompt_messages=[AssistantPromptMessage(content="prompt")],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="a")),
|
||||
)
|
||||
stream = _ClosableStream([chunk])
|
||||
|
||||
queue_manager = SimpleNamespace(
|
||||
publish=MagicMock(side_effect=GenerateTaskStoppedError("stopped")),
|
||||
)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
runner._handle_invoke_result_stream(
|
||||
invoke_result=stream,
|
||||
queue_manager=queue_manager,
|
||||
agent=False,
|
||||
)
|
||||
|
||||
assert stream.closed is True
|
||||
|
||||
def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch: pytest.MonkeyPatch):
|
||||
runner = AppRunner()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user