Skip to content

Commit b4fad1c

Browse files
authored
fix: Support cluster clients in saver (#56)
* Support cluster clients in saver * Standardize on _apply_ttl_to_keys function
1 parent 21a34fd commit b4fad1c

File tree

8 files changed

+612
-128
lines changed

8 files changed

+612
-128
lines changed

langgraph/checkpoint/redis/__init__.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import json
44
from contextlib import contextmanager
5-
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast
5+
import logging
6+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
67

78
from langchain_core.runnables import RunnableConfig
89
from langgraph.checkpoint.base import (
@@ -14,6 +15,7 @@
1415
)
1516
from langgraph.constants import TASKS
1617
from redis import Redis
18+
from redis.cluster import RedisCluster
1719
from redisvl.index import SearchIndex
1820
from redisvl.query import FilterQuery
1921
from redisvl.query.filter import Num, Tag
@@ -32,15 +34,21 @@
3234
)
3335
from langgraph.checkpoint.redis.version import __lib_name__, __version__
3436

37+
logger = logging.getLogger(__name__)
3538

36-
class RedisSaver(BaseRedisSaver[Redis, SearchIndex]):
39+
40+
class RedisSaver(BaseRedisSaver[Union[Redis, RedisCluster], SearchIndex]):
3741
"""Standard Redis implementation for checkpoint saving."""
3842

43+
_redis: Union[Redis, RedisCluster] # Support both standalone and cluster clients
44+
# Whether to assume the Redis server is a cluster; None triggers auto-detection
45+
cluster_mode: Optional[bool] = None
46+
3947
def __init__(
4048
self,
4149
redis_url: Optional[str] = None,
4250
*,
43-
redis_client: Optional[Redis] = None,
51+
redis_client: Optional[Union[Redis, RedisCluster]] = None,
4452
connection_args: Optional[Dict[str, Any]] = None,
4553
ttl: Optional[Dict[str, Any]] = None,
4654
) -> None:
@@ -54,7 +62,7 @@ def __init__(
5462
def configure_client(
5563
self,
5664
redis_url: Optional[str] = None,
57-
redis_client: Optional[Redis] = None,
65+
redis_client: Optional[Union[Redis, RedisCluster]] = None,
5866
connection_args: Optional[Dict[str, Any]] = None,
5967
) -> None:
6068
"""Configure the Redis client."""
@@ -74,6 +82,27 @@ def create_indexes(self) -> None:
7482
self.SCHEMAS[2], redis_client=self._redis
7583
)
7684

85+
def setup(self) -> None:
86+
"""Initialize the indices in Redis and detect cluster mode."""
87+
self._detect_cluster_mode()
88+
super().setup()
89+
90+
def _detect_cluster_mode(self) -> None:
91+
"""Detect if the Redis client is a cluster client by inspecting its class."""
92+
if self.cluster_mode is not None:
93+
logger.info(
94+
f"Redis cluster_mode explicitly set to {self.cluster_mode}, skipping detection."
95+
)
96+
return
97+
98+
# Determine cluster mode based on client class
99+
if isinstance(self._redis, RedisCluster):
100+
logger.info("Redis client is a cluster client")
101+
self.cluster_mode = True
102+
else:
103+
logger.info("Redis client is a standalone client")
104+
self.cluster_mode = False
105+
77106
def list(
78107
self,
79108
config: Optional[RunnableConfig],
@@ -458,7 +487,7 @@ def from_conn_string(
458487
cls,
459488
redis_url: Optional[str] = None,
460489
*,
461-
redis_client: Optional[Redis] = None,
490+
redis_client: Optional[Union[Redis, RedisCluster]] = None,
462491
connection_args: Optional[Dict[str, Any]] = None,
463492
ttl: Optional[Dict[str, Any]] = None,
464493
) -> Iterator[RedisSaver]:
@@ -592,8 +621,8 @@ def delete_thread(self, thread_id: str) -> None:
592621

593622
checkpoint_results = self.checkpoints_index.search(checkpoint_query)
594623

595-
# Delete all checkpoint-related keys
596-
pipeline = self._redis.pipeline()
624+
# Collect all keys to delete
625+
keys_to_delete = []
597626

598627
for doc in checkpoint_results.docs:
599628
checkpoint_ns = getattr(doc, "checkpoint_ns", "")
@@ -603,7 +632,7 @@ def delete_thread(self, thread_id: str) -> None:
603632
checkpoint_key = BaseRedisSaver._make_redis_checkpoint_key(
604633
storage_safe_thread_id, checkpoint_ns, checkpoint_id
605634
)
606-
pipeline.delete(checkpoint_key)
635+
keys_to_delete.append(checkpoint_key)
607636

608637
# Delete all blobs for this thread
609638
blob_query = FilterQuery(
@@ -622,7 +651,7 @@ def delete_thread(self, thread_id: str) -> None:
622651
blob_key = BaseRedisSaver._make_redis_checkpoint_blob_key(
623652
storage_safe_thread_id, checkpoint_ns, channel, version
624653
)
625-
pipeline.delete(blob_key)
654+
keys_to_delete.append(blob_key)
626655

627656
# Delete all writes for this thread
628657
writes_query = FilterQuery(
@@ -642,10 +671,19 @@ def delete_thread(self, thread_id: str) -> None:
642671
write_key = BaseRedisSaver._make_redis_checkpoint_writes_key(
643672
storage_safe_thread_id, checkpoint_ns, checkpoint_id, task_id, idx
644673
)
645-
pipeline.delete(write_key)
674+
keys_to_delete.append(write_key)
646675

647-
# Execute all deletions
648-
pipeline.execute()
676+
# Execute all deletions based on cluster mode
677+
if self.cluster_mode:
678+
# For cluster mode, delete keys individually
679+
for key in keys_to_delete:
680+
self._redis.delete(key)
681+
else:
682+
# For non-cluster mode, use pipeline for efficiency
683+
pipeline = self._redis.pipeline()
684+
for key in keys_to_delete:
685+
pipeline.delete(key)
686+
pipeline.execute()
649687

650688

651689
__all__ = [

0 commit comments

Comments
 (0)