4
4
import pickle
5
5
import threading
6
6
from datetime import date , datetime , timezone
7
+ from enum import Enum
7
8
from time import time
8
9
from typing import Any
9
10
32
33
# Debounce our JSON validation a bit in order to not cause too much additional
33
34
# load everywhere
34
35
_last_validation_log : float | None = None
36
+ Pipeline = Any
37
+ # TODO type Pipeline instead of using Any here
35
38
36
39
37
40
def _validate_json_roundtrip (value : dict [str , Any ], model : type [models .Model ]) -> None :
@@ -49,6 +52,13 @@ def _validate_json_roundtrip(value: dict[str, Any], model: type[models.Model]) -
49
52
logger .exception ("buffer.invalid_value" , extra = {"value" : value , "model" : model })
50
53
51
54
55
+ class RedisOperation (Enum ):
56
+ SET_ADD = "sadd"
57
+ SET_GET = "smembers"
58
+ HASH_ADD = "hset"
59
+ HASH_GET_ALL = "hgetall"
60
+
61
+
52
62
class PendingBuffer :
53
63
def __init__ (self , size : int ):
54
64
assert size > 0
@@ -208,6 +218,48 @@ def get(
208
218
col : (int (results [i ]) if results [i ] is not None else 0 ) for i , col in enumerate (columns )
209
219
}
210
220
221
+ def get_redis_connection (self , key : str ) -> Pipeline | None :
222
+ if is_instance_redis_cluster (self .cluster , self .is_redis_cluster ):
223
+ conn = self .cluster
224
+ elif is_instance_rb_cluster (self .cluster , self .is_redis_cluster ):
225
+ conn = self .cluster .get_local_client_for_key (key )
226
+ else :
227
+ raise AssertionError ("unreachable" )
228
+
229
+ pipe = conn .pipeline ()
230
+ return pipe
231
+
232
+ def _execute_redis_operation (self , key : str , operation : RedisOperation , * args : Any ) -> Any :
233
+ pending_key = self ._make_pending_key_from_key (key )
234
+ pipe = self .get_redis_connection (pending_key )
235
+ if pipe :
236
+ getattr (pipe , operation .value )(key , * args )
237
+ if args :
238
+ pipe .expire (key , self .key_expire )
239
+ return pipe .execute ()
240
+
241
+ def push_to_set (self , key : str , value : list [int ] | int ) -> None :
242
+ self ._execute_redis_operation (key , RedisOperation .SET_ADD , value )
243
+
244
+ def get_set (self , key : str ) -> list [set [int ]]:
245
+ return self ._execute_redis_operation (key , RedisOperation .SET_GET )
246
+
247
+ def push_to_hash (
248
+ self ,
249
+ model : type [models .Model ],
250
+ filters : dict [str , models .Model | str | int ],
251
+ field : str ,
252
+ value : int ,
253
+ ) -> None :
254
+ key = self ._make_key (model , filters )
255
+ self ._execute_redis_operation (key , RedisOperation .HASH_ADD , field , value )
256
+
257
+ def get_hash (
258
+ self , model : type [models .Model ], field : dict [str , models .Model | str | int ]
259
+ ) -> dict [str , str ]:
260
+ key = self ._make_key (model , field )
261
+ return self ._execute_redis_operation (key , RedisOperation .HASH_GET_ALL )
262
+
211
263
def incr (
212
264
self ,
213
265
model : type [models .Model ],
@@ -226,19 +278,13 @@ def incr(
226
278
- Perform a set on signal_only (only if True)
227
279
- Add hashmap key to pending flushes
228
280
"""
229
-
230
281
key = self ._make_key (model , filters )
231
282
pending_key = self ._make_pending_key_from_key (key )
232
283
# We can't use conn.map() due to wanting to support multiple pending
233
284
# keys (one per Redis partition)
234
- if is_instance_redis_cluster (self .cluster , self .is_redis_cluster ):
235
- conn = self .cluster
236
- elif is_instance_rb_cluster (self .cluster , self .is_redis_cluster ):
237
- conn = self .cluster .get_local_client_for_key (key )
238
- else :
239
- raise AssertionError ("unreachable" )
240
-
241
- pipe = conn .pipeline ()
285
+ pipe = self .get_redis_connection (key )
286
+ if not pipe :
287
+ return
242
288
pipe .hsetnx (key , "m" , f"{ model .__module__ } .{ model .__name__ } " )
243
289
_validate_json_roundtrip (filters , model )
244
290
0 commit comments