fix: allow config pubsub join timeout for lower post-run latency (#36438)

Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
Yunlu Wen
2026-05-20 16:45:51 +08:00
committed by GitHub
parent 848c15a265
commit e3b45a48eb
9 changed files with 155 additions and 18 deletions
+1
View File
@@ -767,6 +767,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while use redis as event bus.
# It's highly recommended to enable this for large deployments.
EVENT_BUS_REDIS_USE_CLUSTERS=false
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
+19
View File
@@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
from pydantic.types import NonNegativeInt
from pydantic_settings import BaseSettings
@@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings):
default=600,
)
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
description=(
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
"an SSE client and the response stream actually closing.\n\n"
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
"return promptly while the daemon listener thread cleans itself up on the next poll "
"boundary - safe because the listener holds no critical state and exits within one poll "
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
),
default=2000,
)
def _build_default_pubsub_url(self) -> str:
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
+4 -2
View File
@@ -457,14 +457,16 @@ def init_app(app: DifyApp):
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
return StreamsBroadcastChannel(
_pubsub_redis_client,
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
join_timeout_ms=join_timeout_ms,
)
return RedisBroadcastChannel(_pubsub_redis_client)
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
def redis_fallback[T](default_return: T | None = None): # type: ignore
@@ -26,6 +26,8 @@ class RedisSubscriptionBase(Subscription):
client: Redis | RedisCluster,
pubsub: PubSub,
topic: str,
*,
join_timeout_ms: int = 2000,
):
# The _pubsub is None only if the subscription is closed.
self._client = client
@@ -37,6 +39,11 @@ class RedisSubscriptionBase(Subscription):
self._listener_thread: threading.Thread | None = None
self._start_lock = threading.Lock()
self._started = False
# Max time close() will wait for the listener thread to finish before
# returning. Bounds SSE close tail latency. The listener is a daemon
# and exits on its own within one poll window (~1s), so a low value
# here just means close() returns sooner without breaking anything.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def _start_if_needed(self) -> None:
"""Start the subscription if not already started."""
@@ -205,7 +212,7 @@ class RedisSubscriptionBase(Subscription):
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
listener = self._listener_thread
if listener is not None:
listener.join(timeout=1.0)
listener.join(timeout=self._join_timeout_ms / 1000.0)
self._listener_thread = None
# Abstract methods to be implemented by subclasses
+15 -2
View File
@@ -22,18 +22,30 @@ class BroadcastChannel:
def __init__(
self,
redis_client: Redis | RedisCluster,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
# See `RedisSubscriptionBase._join_timeout_ms`: how long close()
# waits for the listener thread before returning.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def topic(self, topic: str) -> Topic:
return Topic(self._client, topic)
return Topic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
class Topic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
def __init__(
self,
redis_client: Redis | RedisCluster,
topic: str,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def as_producer(self) -> Producer:
return self
@@ -49,6 +61,7 @@ class Topic:
client=self._client,
pubsub=self._client.pubsub(),
topic=self._redis_topic,
join_timeout_ms=self._join_timeout_ms,
)
@@ -20,18 +20,28 @@ class ShardedRedisBroadcastChannel:
def __init__(
self,
redis_client: Redis | RedisCluster,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def topic(self, topic: str) -> ShardedTopic:
return ShardedTopic(self._client, topic)
return ShardedTopic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
class ShardedTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
def __init__(
self,
redis_client: Redis | RedisCluster,
topic: str,
*,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def as_producer(self) -> Producer:
return self
@@ -47,6 +57,7 @@ class ShardedTopic:
client=self._client,
pubsub=self._client.pubsub(),
topic=self._redis_topic,
join_timeout_ms=self._join_timeout_ms,
)
@@ -24,20 +24,42 @@ class StreamsBroadcastChannel:
- The stream key expires `retention_seconds` after the last event is published (to bound storage).
"""
def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600):
def __init__(
self,
redis_client: Redis | RedisCluster,
*,
retention_seconds: int = 600,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._retention_seconds = max(int(retention_seconds or 0), 0)
# Max time close() will wait for the listener thread to finish.
# See `_StreamsSubscription._join_timeout_ms` for the rationale.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
def topic(self, topic: str) -> StreamsTopic:
return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds)
return StreamsTopic(
self._client,
topic,
retention_seconds=self._retention_seconds,
join_timeout_ms=self._join_timeout_ms,
)
class StreamsTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
def __init__(
self,
redis_client: Redis | RedisCluster,
topic: str,
*,
retention_seconds: int = 600,
join_timeout_ms: int = 2000,
):
self._client = redis_client
self._topic = topic
self._key = serialize_redis_name(f"stream:{topic}")
self._retention_seconds = retention_seconds
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
self.max_length = 5000
def as_producer(self) -> Producer:
@@ -55,15 +77,23 @@ class StreamsTopic:
return self
def subscribe(self) -> Subscription:
return _StreamsSubscription(self._client, self._key)
return _StreamsSubscription(self._client, self._key, join_timeout_ms=self._join_timeout_ms)
class _StreamsSubscription(Subscription):
_SENTINEL = object()
def __init__(self, client: Redis | RedisCluster, key: str):
def __init__(self, client: Redis | RedisCluster, key: str, *, join_timeout_ms: int = 2000):
self._client = client
self._key = key
# Max time close() will wait for the listener thread to finish before
# returning. Bounds SSE close tail latency: the listener blocks on
# XREAD with BLOCK=1000ms, so close() naturally waits up to ~1s for
# the thread to notice _closed. Setting this lower lets close()
# return promptly while the daemon listener exits on its own within
# one BLOCK window - safe because the listener holds no critical
# state. ``0`` means close() does not wait at all.
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
self._queue: queue.Queue[object] = queue.Queue()
@@ -181,11 +211,13 @@ class _StreamsSubscription(Subscription):
# We close the listener outside of the with block to avoid holding the
# lock for a long time.
if listener is not None and listener.is_alive():
listener.join(timeout=2.0)
listener.join(timeout=self._join_timeout_ms / 1000.0)
if listener.is_alive():
logger.warning(
"Streams subscription listener for key %s did not stop within timeout; keeping reference.",
logger.debug(
"Streams subscription listener for key %s did not stop within %dms; "
"daemon thread will exit on its own within one poll window.",
self._key,
self._join_timeout_ms,
)
# Context manager helpers
@@ -176,6 +176,48 @@ class TestStreamsBroadcastChannel:
assert topic.as_producer() is topic
assert topic.as_subscriber() is topic
def test_join_timeout_ms_propagates_from_channel_to_subscription(self, fake_redis: FakeStreamsRedis):
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=150)
topic = channel.topic("join-timeout-prop")
assert topic._join_timeout_ms == 150
sub = topic.subscribe()
try:
assert sub._join_timeout_ms == 150
finally:
sub.close()
def test_join_timeout_ms_defaults_to_2000(self, fake_redis: FakeStreamsRedis):
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60)
topic = channel.topic("join-timeout-default")
assert topic._join_timeout_ms == 2000
def test_small_join_timeout_makes_close_return_promptly(self, fake_redis: FakeStreamsRedis):
"""close() should respect the configured join timeout.
Regression test for SSE close tail latency: when an idle listener is
blocked on its poll cycle, close() with a small join_timeout_ms must
not wait for the full poll window. The orphaned daemon listener
cleans itself up later.
"""
channel = StreamsBroadcastChannel(fake_redis, retention_seconds=60, join_timeout_ms=50)
topic = channel.topic("join-timeout-prompt-close")
sub = topic.subscribe()
# Drive listener startup so the thread is actually blocked in xread.
assert sub.receive(timeout=0.05) is None
time.sleep(0.05)
started = time.monotonic()
sub.close()
elapsed = time.monotonic() - started
# 50ms timeout + scheduling slack; pick a ceiling well under the
# default poll window (1000ms) to make the regression meaningful.
assert elapsed < 0.5, f"close() took {elapsed:.3f}s; expected prompt return"
def test_publish_logs_warning_when_expire_fails(self, caplog: pytest.LogCaptureFixture):
channel = StreamsBroadcastChannel(FailExpireRedis(), retention_seconds=60)
topic = channel.topic("expire-warning")
@@ -342,10 +384,17 @@ class TestStreamsSubscription:
assert next(iter(subscription)) == b"event"
def test_close_logs_warning_when_listener_does_not_stop_in_time(
def test_close_logs_debug_when_listener_does_not_stop_in_time(
self,
caplog: pytest.LogCaptureFixture,
):
"""When a low join_timeout elapses with the listener still alive,
close() should log at DEBUG (not WARNING) - with a deliberately small
timeout this is expected, not anomalous; the orphaned daemon thread
cleans itself up on the next poll boundary.
"""
import logging
blocking_redis = BlockingRedis()
subscription = _StreamsSubscription(blocking_redis, "stream:slow-close")
@@ -363,8 +412,10 @@ class TestStreamsSubscription:
listener.is_alive = lambda: True # type: ignore[method-assign]
try:
subscription.close()
assert "did not stop within timeout" in caplog.text
with caplog.at_level(logging.DEBUG, logger="libs.broadcast_channel.redis.streams_channel"):
subscription.close()
assert "did not stop within" in caplog.text
assert "daemon thread will exit on its own" in caplog.text
finally:
listener.join = original_join # type: ignore[method-assign]
listener.is_alive = original_is_alive # type: ignore[method-assign]
+1
View File
@@ -118,6 +118,7 @@ CELERY_TASK_ANNOTATIONS=null
EVENT_BUS_REDIS_URL=
EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
EVENT_BUS_REDIS_USE_CLUSTERS=false
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
# Web and app limits
WEB_API_CORS_ALLOW_ORIGINS=*