From e3b45a48ebac88b8e804060a0e87e4a2840a9934 Mon Sep 17 00:00:00 2001 From: Yunlu Wen Date: Wed, 20 May 2026 16:45:51 +0800 Subject: [PATCH] fix: allow config pubsub join timeout for lower post-run latency (#36438) Co-authored-by: QuantumGhost --- api/.env.example | 1 + .../middleware/cache/redis_pubsub_config.py | 19 +++++++ api/extensions/ext_redis.py | 6 +- .../broadcast_channel/redis/_subscription.py | 9 ++- api/libs/broadcast_channel/redis/channel.py | 17 +++++- .../redis/sharded_channel.py | 15 ++++- .../redis/streams_channel.py | 48 +++++++++++++--- .../redis/test_streams_channel_unit_tests.py | 57 ++++++++++++++++++- docker/.env.example | 1 + 9 files changed, 155 insertions(+), 18 deletions(-) diff --git a/api/.env.example b/api/.env.example index 34be400e87..833d83797d 100644 --- a/api/.env.example +++ b/api/.env.example @@ -767,6 +767,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub # Whether to use Redis cluster mode while use redis as event bus. # It's highly recommended to enable this for large deployments. EVENT_BUS_REDIS_USE_CLUSTERS=false +EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000 # Whether to Enable human input timeout check task ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py index 0a166818b3..d465f2e93c 100644 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast from urllib.parse import quote_plus, urlunparse from pydantic import AliasChoices, Field +from pydantic.types import NonNegativeInt from pydantic_settings import BaseSettings @@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings): default=600, ) + PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field( + validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"), + description=( + "Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to " + "finish before returning. Bounds the tail latency between a terminal event being delivered to " + "an SSE client and the response stream actually closing.\n\n" + "The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout " + "for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for " + "the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() " + "return promptly while the daemon listener thread cleans itself up on the next poll " + "boundary - safe because the listener holds no critical state and exits within one poll " + "window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up " + "and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n" + "Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS." + ), + default=2000, + ) + def _build_default_pubsub_url(self) -> str: defaults = _redis_defaults(self) if not defaults.REDIS_HOST or not defaults.REDIS_PORT: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 9f7f73765e..af0d77411b 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -457,14 +457,16 @@ def init_app(app: DifyApp): def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here." + join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": - return ShardedRedisBroadcastChannel(_pubsub_redis_client) + return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms) if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams": return StreamsBroadcastChannel( _pubsub_redis_client, retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS, + join_timeout_ms=join_timeout_ms, ) - return RedisBroadcastChannel(_pubsub_redis_client) + return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms) def redis_fallback[T](default_return: T | None = None): # type: ignore diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index 4db79a15a9..9fe50445e4 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -26,6 +26,8 @@ class RedisSubscriptionBase(Subscription): client: Redis | RedisCluster, pubsub: PubSub, topic: str, + *, + join_timeout_ms: int = 2000, ): # The _pubsub is None only if the subscription is closed. self._client = client @@ -37,6 +39,11 @@ class RedisSubscriptionBase(Subscription): self._listener_thread: threading.Thread | None = None self._start_lock = threading.Lock() self._started = False + # Max time close() will wait for the listener thread to finish before + # returning. Bounds SSE close tail latency. The listener is a daemon + # and exits on its own within one poll window (~1s), so a low value + # here just means close() returns sooner without breaking anything. + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) def _start_if_needed(self) -> None: """Start the subscription if not already started.""" @@ -205,7 +212,7 @@ class RedisSubscriptionBase(Subscription): # Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread. listener = self._listener_thread if listener is not None: - listener.join(timeout=1.0) + listener.join(timeout=self._join_timeout_ms / 1000.0) self._listener_thread = None # Abstract methods to be implemented by subclasses diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index b76a23eb3c..7f13ebaabc 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -22,18 +22,30 @@ class BroadcastChannel: def __init__( self, redis_client: Redis | RedisCluster, + *, + join_timeout_ms: int = 2000, ): self._client = redis_client + # See `RedisSubscriptionBase._join_timeout_ms`: how long close() + # waits for the listener thread before returning. + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) def topic(self, topic: str) -> Topic: - return Topic(self._client, topic) + return Topic(self._client, topic, join_timeout_ms=self._join_timeout_ms) class Topic: - def __init__(self, redis_client: Redis | RedisCluster, topic: str): + def __init__( + self, + redis_client: Redis | RedisCluster, + topic: str, + *, + join_timeout_ms: int = 2000, + ): self._client = redis_client self._topic = topic self._redis_topic = serialize_redis_name(topic) + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) def as_producer(self) -> Producer: return self @@ -49,6 +61,7 @@ class Topic: client=self._client, pubsub=self._client.pubsub(), topic=self._redis_topic, + join_timeout_ms=self._join_timeout_ms, ) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 919d8d622e..02dc987107 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -20,18 +20,28 @@ class ShardedRedisBroadcastChannel: def __init__( self, redis_client: Redis | RedisCluster, + *, + join_timeout_ms: int = 2000, ): self._client = redis_client + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) def topic(self, topic: str) -> ShardedTopic: - return ShardedTopic(self._client, topic) + return ShardedTopic(self._client, topic, join_timeout_ms=self._join_timeout_ms) class ShardedTopic: - def __init__(self, redis_client: Redis | RedisCluster, topic: str): + def __init__( + self, + redis_client: Redis | RedisCluster, + topic: str, + *, + join_timeout_ms: int = 2000, + ): self._client = redis_client self._topic = topic self._redis_topic = serialize_redis_name(topic) + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) def as_producer(self) -> Producer: return self @@ -47,6 +57,7 @@ class ShardedTopic: client=self._client, pubsub=self._client.pubsub(), topic=self._redis_topic, + join_timeout_ms=self._join_timeout_ms, ) diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 55ff6cd4f9..985b253c7c 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -24,20 +24,42 @@ class StreamsBroadcastChannel: - The stream key expires `retention_seconds` after the last event is published (to bound storage). """ - def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600): + def __init__( + self, + redis_client: Redis | RedisCluster, + *, + retention_seconds: int = 600, + join_timeout_ms: int = 2000, + ): self._client = redis_client self._retention_seconds = max(int(retention_seconds or 0), 0) + # Max time close() will wait for the listener thread to finish. + # See `_StreamsSubscription._join_timeout_ms` for the rationale. + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) def topic(self, topic: str) -> StreamsTopic: - return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds) + return StreamsTopic( + self._client, + topic, + retention_seconds=self._retention_seconds, + join_timeout_ms=self._join_timeout_ms, + ) class StreamsTopic: - def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): + def __init__( + self, + redis_client: Redis | RedisCluster, + topic: str, + *, + retention_seconds: int = 600, + join_timeout_ms: int = 2000, + ): self._client = redis_client self._topic = topic self._key = serialize_redis_name(f"stream:{topic}") self._retention_seconds = retention_seconds + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) self.max_length = 5000 def as_producer(self) -> Producer: @@ -55,15 +77,23 @@ class StreamsTopic: return self def subscribe(self) -> Subscription: - return _StreamsSubscription(self._client, self._key) + return _StreamsSubscription(self._client, self._key, join_timeout_ms=self._join_timeout_ms) class _StreamsSubscription(Subscription): _SENTINEL = object() - def __init__(self, client: Redis | RedisCluster, key: str): + def __init__(self, client: Redis | RedisCluster, key: str, *, join_timeout_ms: int = 2000): self._client = client self._key = key + # Max time close() will wait for the listener thread to finish before + # returning. Bounds SSE close tail latency: the listener blocks on + # XREAD with BLOCK=1000ms, so close() naturally waits up to ~1s for + # the thread to notice _closed. Setting this lower lets close() + # return promptly while the daemon listener exits on its own within + # one BLOCK window - safe because the listener holds no critical + # state. ``0`` means close() does not wait at all. + self._join_timeout_ms = max(int(join_timeout_ms or 0), 0) self._queue: queue.Queue[object] = queue.Queue() @@ -181,11 +211,13 @@ class _StreamsSubscription(Subscription): # We close the listener outside of the with block to avoid holding the # lock for a long time. if listener is not None and listener.is_alive(): - listener.join(timeout=2.0) + listener.join(timeout=self._join_timeout_ms / 1000.0) if listener.is_alive(): - logger.warning( - "Streams subscription listener for key %s did not stop within timeout; keeping reference.", + logger.debug( + "Streams subscription listener for key %s did not stop within %dms; " + "daemon thread will exit on its own within one poll window.", self._key, + self._join_timeout_ms, ) # Context manager helpers diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index c6f57c7e59..95085eaf67 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -176,6 +176,48 @@ class TestStreamsBroadcastChannel: assert topic.as_producer() is topic assert topic.as_subscriber() is topic + def test_join_timeout_ms_propagates_from_channel_to_subscription(self, fake_redis: FakeStreamsRedis): + channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=150) + topic = channel.topic("join-timeout-prop") + + assert topic._join_timeout_ms == 150 + + sub = topic.subscribe() + try: + assert sub._join_timeout_ms == 150 + finally: + sub.close() + + def test_join_timeout_ms_defaults_to_2000(self, fake_redis: FakeStreamsRedis): + channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60) + topic = channel.topic("join-timeout-default") + + assert topic._join_timeout_ms == 2000 + + def test_small_join_timeout_makes_close_return_promptly(self, fake_redis: FakeStreamsRedis): + """close() should respect the configured join timeout. + + Regression test for SSE close tail latency: when an idle listener is + blocked on its poll cycle, close() with a small join_timeout_ms must + not wait for the full poll window. The orphaned daemon listener + cleans itself up later. + """ + channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=50) + topic = channel.topic("join-timeout-prompt-close") + sub = topic.subscribe() + + # Drive listener startup so the thread is actually blocked in xread. + assert sub.receive(timeout=0.05) is None + time.sleep(0.05) + + started = time.monotonic() + sub.close() + elapsed = time.monotonic() - started + + # 50ms timeout + scheduling slack; pick a ceiling well under the + # default poll window (1000ms) to make the regression meaningful. + assert elapsed < 0.5, f"close() took {elapsed:.3f}s; expected prompt return" + def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture): channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60) topic = channel.topic("expire-warning") @@ -342,10 +384,17 @@ class TestStreamsSubscription: assert next(iter(subscription)) == b"event" - def test_close_logs_warning_when_listener_does_not_stop_in_time( + def test_close_logs_debug_when_listener_does_not_stop_in_time( self, caplog: pytest.LogCaptureFixture, ): + """When a low join_timeout elapses with the listener still alive, + close() should log at DEBUG (not WARNING) - with a deliberately small + timeout this is expected, not anomalous; the orphaned daemon thread + cleans itself up on the next poll boundary. + """ + import logging + blocking_redis = BlockingRedis() subscription = _StreamsSubscription(blocking_redis, "stream:slow-close") @@ -363,8 +412,10 @@ class TestStreamsSubscription: listener.is_alive = lambda: True # type: ignore[method-assign] try: - subscription.close() - assert "did not stop within timeout" in caplog.text + with caplog.at_level(logging.DEBUG, logger="libs.broadcast_channel.redis.streams_channel"): + subscription.close() + assert "did not stop within" in caplog.text + assert "daemon thread will exit on its own" in caplog.text finally: listener.join = original_join # type: ignore[method-assign] listener.is_alive = original_is_alive # type: ignore[method-assign] diff --git a/docker/.env.example b/docker/.env.example index c708a40c15..c723d25d9b 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -118,6 +118,7 @@ CELERY_TASK_ANNOTATIONS=null EVENT_BUS_REDIS_URL= EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub EVENT_BUS_REDIS_USE_CLUSTERS=false +EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000 # Web and app limits WEB_API_CORS_ALLOW_ORIGINS=*