Skip to content

Commit b10bed7

Browse files
committed
feat (WMS): Improve caching performance of Limiter
1 parent cb84f65 commit b10bed7

File tree

1 file changed

+121
-13
lines changed
  • src/DIRAC/WorkloadManagementSystem/Client

1 file changed

+121
-13
lines changed

src/DIRAC/WorkloadManagementSystem/Client/Limiter.py

+121-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
33
Utilities and classes here are used by the Matcher
44
"""
5+
import threading
6+
from collections import defaultdict
7+
from collections.abc import Callable
8+
from concurrent.futures import ThreadPoolExecutor, wait, Future
9+
from functools import partial
10+
from typing import Any
11+
12+
from cachetools import TTLCache
13+
514
from DIRAC import S_OK, S_ERROR
615
from DIRAC import gLogger
716

@@ -12,10 +21,109 @@
1221
from DIRAC.WorkloadManagementSystem.Client import JobStatus
1322

1423

24+
class TwoLevelCache:
25+
"""A two-level caching system with soft and hard time-to-live (TTL) expiration.
26+
27+
This cache implements a two-tier caching mechanism to allow for background refresh
28+
of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback,
29+
which helps in reducing latency and maintaining data freshness.
30+
31+
Attributes:
32+
soft_cache (TTLCache): A cache with a shorter TTL for quick access.
33+
hard_cache (TTLCache): A cache with a longer TTL as a fallback.
34+
locks (defaultdict): Thread-safe locks for each cache key.
35+
futures (dict): Stores ongoing asynchronous population tasks.
36+
pool (ThreadPoolExecutor): Thread pool for executing cache population tasks.
37+
38+
Args:
39+
soft_ttl (int): Time-to-live in seconds for the soft cache.
40+
hard_ttl (int): Time-to-live in seconds for the hard cache.
41+
max_workers (int): Maximum number of workers in the thread pool.
42+
max_items (int): Maximum number of items in the cache.
43+
44+
Example:
45+
>>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300)
46+
>>> def populate_func():
47+
... return "cached_value"
48+
>>> value = cache.get("key", populate_func)
49+
50+
Note:
51+
The cache uses a ThreadPoolExecutor with a maximum of 10 workers to
52+
handle concurrent cache population requests.
53+
"""
54+
55+
def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000):
56+
"""Initialize the TwoLevelCache with specified TTLs."""
57+
self.soft_cache = TTLCache(max_items, soft_ttl)
58+
self.hard_cache = TTLCache(max_items, hard_ttl)
59+
self.locks = defaultdict(threading.Lock)
60+
self.futures: dict[str, Future] = {}
61+
self.pool = ThreadPoolExecutor(max_workers=max_workers)
62+
63+
def get(self, key: str, populate_func: Callable[[], Any]):
64+
"""Retrieve a value from the cache, populating it if necessary.
65+
66+
This method first checks the soft cache for the key. If not found,
67+
it checks the hard cache while initiating a background refresh.
68+
If the key is not in either cache, it waits for the populate_func
69+
to complete and stores the result in both caches.
70+
71+
Locks are used to ensure there is never more than one concurrent
72+
population task for a given key.
73+
74+
Args:
75+
key (str): The cache key to retrieve or populate.
76+
populate_func (Callable[[], Any]): A function to call to populate the cache
77+
if the key is not found.
78+
79+
Returns:
80+
Any: The cached value associated with the key.
81+
82+
Note:
83+
This method is thread-safe and handles concurrent requests for the same key.
84+
"""
85+
if result := self.soft_cache.get(key):
86+
return result
87+
with self.locks[key]:
88+
if key not in self.futures:
89+
self.futures[key] = self.pool.submit(self._work, key, populate_func)
90+
if result := self.hard_cache.get(key):
91+
self.soft_cache[key] = result
92+
return result
93+
# It is critical that ``future`` is waited for outside of the lock as
94+
# _work aquires the lock before filling the caches. This also means
95+
# we can gaurentee that the future has not yet been removed from the
96+
# futures dict.
97+
future = self.futures[key]
98+
wait([future])
99+
return self.hard_cache[key]
100+
101+
def _work(self, key: str, populate_func: Callable[[], Any]) -> None:
102+
"""Internal method to execute the populate_func and update caches.
103+
104+
This method is intended to be run in a separate thread. It calls the
105+
populate_func, stores the result in both caches, and cleans up the
106+
associated future.
107+
108+
Args:
109+
key (str): The cache key to populate.
110+
populate_func (Callable[[], Any]): The function to call to get the value.
111+
112+
Note:
113+
This method is not intended to be called directly by users of the class.
114+
"""
115+
result = populate_func()
116+
with self.locks[key]:
117+
self.futures.pop(key)
118+
self.hard_cache[key] = result
119+
self.soft_cache[key] = result
120+
121+
15122
class Limiter:
16123
# static variables shared between all instances of this class
17124
csDictCache = DictCache()
18125
condCache = DictCache()
126+
newCache = TwoLevelCache(10, 300)
19127
delayMem = {}
20128

21129
def __init__(self, jobDB=None, opsHelper=None, pilotRef=None):
@@ -177,19 +285,7 @@ def __getRunningCondition(self, siteName, gridCE=None):
177285
if attName not in self.jobDB.jobAttributeNames:
178286
self.log.error("Attribute does not exist", f"({attName}). Check the job limits")
179287
continue
180-
cK = f"Running:{siteName}:{attName}"
181-
data = self.condCache.get(cK)
182-
if not data:
183-
result = self.jobDB.getCounters(
184-
"Jobs",
185-
[attName],
186-
{"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]},
187-
)
188-
if not result["OK"]:
189-
return result
190-
data = result["Value"]
191-
data = {k[0][attName]: k[1] for k in data}
192-
self.condCache.add(cK, 10, data)
288+
data = self.newCache.get(f"Running:{siteName}:{attName}", partial(self._countsByJobType, siteName, attName))
193289
for attValue in limitsDict[attName]:
194290
limit = limitsDict[attName][attValue]
195291
running = data.get(attValue, 0)
@@ -249,3 +345,15 @@ def __getDelayCondition(self, siteName):
249345
negCond[attName] = []
250346
negCond[attName].append(attValue)
251347
return S_OK(negCond)
348+
349+
def _countsByJobType(self, siteName, attName):
350+
result = self.jobDB.getCounters(
351+
"Jobs",
352+
[attName],
353+
{"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]},
354+
)
355+
if not result["OK"]:
356+
return result
357+
data = result["Value"]
358+
data = {k[0][attName]: k[1] for k in data}
359+
return data

0 commit comments

Comments
 (0)