mirror of
https://github.com/langgenius/dify.git
synced 2026-06-03 08:16:37 +08:00
fix: allow config pubsub join timeout for lower post-run latency (#36438)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
@@ -767,6 +767,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic.types import NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings):
|
||||
default=600,
|
||||
)
|
||||
|
||||
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
|
||||
description=(
|
||||
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
|
||||
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
|
||||
"an SSE client and the response stream actually closing.\n\n"
|
||||
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
|
||||
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
|
||||
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
|
||||
"return promptly while the daemon listener thread cleans itself up on the next poll "
|
||||
"boundary - safe because the listener holds no critical state and exits within one poll "
|
||||
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
|
||||
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
|
||||
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
|
||||
),
|
||||
default=2000,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
@@ -457,14 +457,16 @@ def init_app(app: DifyApp):
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
|
||||
return StreamsBroadcastChannel(
|
||||
_pubsub_redis_client,
|
||||
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
|
||||
join_timeout_ms=join_timeout_ms,
|
||||
)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
|
||||
|
||||
def redis_fallback[T](default_return: T | None = None): # type: ignore
|
||||
|
||||
@@ -26,6 +26,8 @@ class RedisSubscriptionBase(Subscription):
|
||||
client: Redis | RedisCluster,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._client = client
|
||||
@@ -37,6 +39,11 @@ class RedisSubscriptionBase(Subscription):
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency. The listener is a daemon
|
||||
# and exits on its own within one poll window (~1s), so a low value
|
||||
# here just means close() returns sooner without breaking anything.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
@@ -205,7 +212,7 @@ class RedisSubscriptionBase(Subscription):
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=1.0)
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
|
||||
@@ -22,18 +22,30 @@ class BroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
# See `RedisSubscriptionBase._join_timeout_ms`: how long close()
|
||||
# waits for the listener thread before returning.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> Topic:
|
||||
return Topic(self._client, topic)
|
||||
return Topic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@@ -49,6 +61,7 @@ class Topic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -20,18 +20,28 @@ class ShardedRedisBroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> ShardedTopic:
|
||||
return ShardedTopic(self._client, topic)
|
||||
return ShardedTopic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@@ -47,6 +57,7 @@ class ShardedTopic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -24,20 +24,42 @@ class StreamsBroadcastChannel:
|
||||
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._retention_seconds = max(int(retention_seconds or 0), 0)
|
||||
# Max time close() will wait for the listener thread to finish.
|
||||
# See `_StreamsSubscription._join_timeout_ms` for the rationale.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> StreamsTopic:
|
||||
return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds)
|
||||
return StreamsTopic(
|
||||
self._client,
|
||||
topic,
|
||||
retention_seconds=self._retention_seconds,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
class StreamsTopic:
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = serialize_redis_name(f"stream:{topic}")
|
||||
self._retention_seconds = retention_seconds
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
self.max_length = 5000
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
@@ -55,15 +77,23 @@ class StreamsTopic:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _StreamsSubscription(self._client, self._key)
|
||||
return _StreamsSubscription(self._client, self._key, join_timeout_ms=self._join_timeout_ms)
|
||||
|
||||
|
||||
class _StreamsSubscription(Subscription):
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, client: Redis | RedisCluster, key: str):
|
||||
def __init__(self, client: Redis | RedisCluster, key: str, *, join_timeout_ms: int = 2000):
|
||||
self._client = client
|
||||
self._key = key
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency: the listener blocks on
|
||||
# XREAD with BLOCK=1000ms, so close() naturally waits up to ~1s for
|
||||
# the thread to notice _closed. Setting this lower lets close()
|
||||
# return promptly while the daemon listener exits on its own within
|
||||
# one BLOCK window - safe because the listener holds no critical
|
||||
# state. ``0`` means close() does not wait at all.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
|
||||
@@ -181,11 +211,13 @@ class _StreamsSubscription(Subscription):
|
||||
# We close the listener outside of the with block to avoid holding the
|
||||
# lock for a long time.
|
||||
if listener is not None and listener.is_alive():
|
||||
listener.join(timeout=2.0)
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
if listener.is_alive():
|
||||
logger.warning(
|
||||
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
|
||||
logger.debug(
|
||||
"Streams subscription listener for key %s did not stop within %dms; "
|
||||
"daemon thread will exit on its own within one poll window.",
|
||||
self._key,
|
||||
self._join_timeout_ms,
|
||||
)
|
||||
|
||||
# Context manager helpers
|
||||
|
||||
+54
-3
@@ -176,6 +176,48 @@ class TestStreamsBroadcastChannel:
|
||||
assert topic.as_producer() is topic
|
||||
assert topic.as_subscriber() is topic
|
||||
|
||||
def test_join_timeout_ms_propagates_from_channel_to_subscription(self, fake_redis: FakeStreamsRedis):
|
||||
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=150)
|
||||
topic = channel.topic("join-timeout-prop")
|
||||
|
||||
assert topic._join_timeout_ms == 150
|
||||
|
||||
sub = topic.subscribe()
|
||||
try:
|
||||
assert sub._join_timeout_ms == 150
|
||||
finally:
|
||||
sub.close()
|
||||
|
||||
def test_join_timeout_ms_defaults_to_2000(self, fake_redis: FakeStreamsRedis):
|
||||
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60)
|
||||
topic = channel.topic("join-timeout-default")
|
||||
|
||||
assert topic._join_timeout_ms == 2000
|
||||
|
||||
def test_small_join_timeout_makes_close_return_promptly(self, fake_redis: FakeStreamsRedis):
|
||||
"""close() should respect the configured join timeout.
|
||||
|
||||
Regression test for SSE close tail latency: when an idle listener is
|
||||
blocked on its poll cycle, close() with a small join_timeout_ms must
|
||||
not wait for the full poll window. The orphaned daemon listener
|
||||
cleans itself up later.
|
||||
"""
|
||||
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=50)
|
||||
topic = channel.topic("join-timeout-prompt-close")
|
||||
sub = topic.subscribe()
|
||||
|
||||
# Drive listener startup so the thread is actually blocked in xread.
|
||||
assert sub.receive(timeout=0.05) is None
|
||||
time.sleep(0.05)
|
||||
|
||||
started = time.monotonic()
|
||||
sub.close()
|
||||
elapsed = time.monotonic() - started
|
||||
|
||||
# 50ms timeout + scheduling slack; pick a ceiling well under the
|
||||
# default poll window (1000ms) to make the regression meaningful.
|
||||
assert elapsed < 0.5, f"close() took {elapsed:.3f}s; expected prompt return"
|
||||
|
||||
def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture):
|
||||
channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60)
|
||||
topic = channel.topic("expire-warning")
|
||||
@@ -342,10 +384,17 @@ class TestStreamsSubscription:
|
||||
|
||||
assert next(iter(subscription)) == b"event"
|
||||
|
||||
def test_close_logs_warning_when_listener_does_not_stop_in_time(
|
||||
def test_close_logs_debug_when_listener_does_not_stop_in_time(
|
||||
self,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
):
|
||||
"""When a low join_timeout elapses with the listener still alive,
|
||||
close() should log at DEBUG (not WARNING) - with a deliberately small
|
||||
timeout this is expected, not anomalous; the orphaned daemon thread
|
||||
cleans itself up on the next poll boundary.
|
||||
"""
|
||||
import logging
|
||||
|
||||
blocking_redis = BlockingRedis()
|
||||
subscription = _StreamsSubscription(blocking_redis, "stream:slow-close")
|
||||
|
||||
@@ -363,8 +412,10 @@ class TestStreamsSubscription:
|
||||
listener.is_alive = lambda: True # type: ignore[method-assign]
|
||||
|
||||
try:
|
||||
subscription.close()
|
||||
assert "did not stop within timeout" in caplog.text
|
||||
with caplog.at_level(logging.DEBUG, logger="libs.broadcast_channel.redis.streams_channel"):
|
||||
subscription.close()
|
||||
assert "did not stop within" in caplog.text
|
||||
assert "daemon thread will exit on its own" in caplog.text
|
||||
finally:
|
||||
listener.join = original_join # type: ignore[method-assign]
|
||||
listener.is_alive = original_is_alive # type: ignore[method-assign]
|
||||
|
||||
@@ -118,6 +118,7 @@ CELERY_TASK_ANNOTATIONS=null
|
||||
EVENT_BUS_REDIS_URL=
|
||||
EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Web and app limits
|
||||
WEB_API_CORS_ALLOW_ORIGINS=*
|
||||
|
||||
Reference in New Issue
Block a user