2
2
3
3
import json
4
4
from contextlib import contextmanager
5
- from typing import Any , Dict , Iterator , List , Optional , Tuple , cast
5
+ import logging
6
+ from typing import Any , Dict , Iterator , List , Optional , Tuple , Union , cast
6
7
7
8
from langchain_core .runnables import RunnableConfig
8
9
from langgraph .checkpoint .base import (
14
15
)
15
16
from langgraph .constants import TASKS
16
17
from redis import Redis
18
+ from redis .cluster import RedisCluster
17
19
from redisvl .index import SearchIndex
18
20
from redisvl .query import FilterQuery
19
21
from redisvl .query .filter import Num , Tag
32
34
)
33
35
from langgraph .checkpoint .redis .version import __lib_name__ , __version__
34
36
37
+ logger = logging .getLogger (__name__ )
35
38
36
- class RedisSaver (BaseRedisSaver [Redis , SearchIndex ]):
39
+
40
+ class RedisSaver (BaseRedisSaver [Union [Redis , RedisCluster ], SearchIndex ]):
37
41
"""Standard Redis implementation for checkpoint saving."""
38
42
43
+ _redis : Union [Redis , RedisCluster ] # Support both standalone and cluster clients
44
+ # Whether to assume the Redis server is a cluster; None triggers auto-detection
45
+ cluster_mode : Optional [bool ] = None
46
+
39
47
def __init__ (
40
48
self ,
41
49
redis_url : Optional [str ] = None ,
42
50
* ,
43
- redis_client : Optional [Redis ] = None ,
51
+ redis_client : Optional [Union [ Redis , RedisCluster ] ] = None ,
44
52
connection_args : Optional [Dict [str , Any ]] = None ,
45
53
ttl : Optional [Dict [str , Any ]] = None ,
46
54
) -> None :
@@ -54,7 +62,7 @@ def __init__(
54
62
def configure_client (
55
63
self ,
56
64
redis_url : Optional [str ] = None ,
57
- redis_client : Optional [Redis ] = None ,
65
+ redis_client : Optional [Union [ Redis , RedisCluster ] ] = None ,
58
66
connection_args : Optional [Dict [str , Any ]] = None ,
59
67
) -> None :
60
68
"""Configure the Redis client."""
@@ -74,6 +82,27 @@ def create_indexes(self) -> None:
74
82
self .SCHEMAS [2 ], redis_client = self ._redis
75
83
)
76
84
85
+ def setup (self ) -> None :
86
+ """Initialize the indices in Redis and detect cluster mode."""
87
+ self ._detect_cluster_mode ()
88
+ super ().setup ()
89
+
90
+ def _detect_cluster_mode (self ) -> None :
91
+ """Detect if the Redis client is a cluster client by inspecting its class."""
92
+ if self .cluster_mode is not None :
93
+ logger .info (
94
+ f"Redis cluster_mode explicitly set to { self .cluster_mode } , skipping detection."
95
+ )
96
+ return
97
+
98
+ # Determine cluster mode based on client class
99
+ if isinstance (self ._redis , RedisCluster ):
100
+ logger .info ("Redis client is a cluster client" )
101
+ self .cluster_mode = True
102
+ else :
103
+ logger .info ("Redis client is a standalone client" )
104
+ self .cluster_mode = False
105
+
77
106
def list (
78
107
self ,
79
108
config : Optional [RunnableConfig ],
@@ -458,7 +487,7 @@ def from_conn_string(
458
487
cls ,
459
488
redis_url : Optional [str ] = None ,
460
489
* ,
461
- redis_client : Optional [Redis ] = None ,
490
+ redis_client : Optional [Union [ Redis , RedisCluster ] ] = None ,
462
491
connection_args : Optional [Dict [str , Any ]] = None ,
463
492
ttl : Optional [Dict [str , Any ]] = None ,
464
493
) -> Iterator [RedisSaver ]:
@@ -592,8 +621,8 @@ def delete_thread(self, thread_id: str) -> None:
592
621
593
622
checkpoint_results = self .checkpoints_index .search (checkpoint_query )
594
623
595
- # Delete all checkpoint-related keys
596
- pipeline = self . _redis . pipeline ()
624
+ # Collect all keys to delete
625
+ keys_to_delete = []
597
626
598
627
for doc in checkpoint_results .docs :
599
628
checkpoint_ns = getattr (doc , "checkpoint_ns" , "" )
@@ -603,7 +632,7 @@ def delete_thread(self, thread_id: str) -> None:
603
632
checkpoint_key = BaseRedisSaver ._make_redis_checkpoint_key (
604
633
storage_safe_thread_id , checkpoint_ns , checkpoint_id
605
634
)
606
- pipeline . delete (checkpoint_key )
635
+ keys_to_delete . append (checkpoint_key )
607
636
608
637
# Delete all blobs for this thread
609
638
blob_query = FilterQuery (
@@ -622,7 +651,7 @@ def delete_thread(self, thread_id: str) -> None:
622
651
blob_key = BaseRedisSaver ._make_redis_checkpoint_blob_key (
623
652
storage_safe_thread_id , checkpoint_ns , channel , version
624
653
)
625
- pipeline . delete (blob_key )
654
+ keys_to_delete . append (blob_key )
626
655
627
656
# Delete all writes for this thread
628
657
writes_query = FilterQuery (
@@ -642,10 +671,19 @@ def delete_thread(self, thread_id: str) -> None:
642
671
write_key = BaseRedisSaver ._make_redis_checkpoint_writes_key (
643
672
storage_safe_thread_id , checkpoint_ns , checkpoint_id , task_id , idx
644
673
)
645
- pipeline . delete (write_key )
674
+ keys_to_delete . append (write_key )
646
675
647
- # Execute all deletions
648
- pipeline .execute ()
676
+ # Execute all deletions based on cluster mode
677
+ if self .cluster_mode :
678
+ # For cluster mode, delete keys individually
679
+ for key in keys_to_delete :
680
+ self ._redis .delete (key )
681
+ else :
682
+ # For non-cluster mode, use pipeline for efficiency
683
+ pipeline = self ._redis .pipeline ()
684
+ for key in keys_to_delete :
685
+ pipeline .delete (key )
686
+ pipeline .execute ()
649
687
650
688
651
689
__all__ = [
0 commit comments