|
2 | 2 |
|
3 | 3 | Utilities and classes here are used by the Matcher
|
4 | 4 | """
|
| 5 | +import threading |
| 6 | +from collections import defaultdict |
| 7 | +from collections.abc import Callable |
| 8 | +from concurrent.futures import ThreadPoolExecutor, wait, Future |
| 9 | +from functools import partial |
| 10 | +from typing import Any |
| 11 | + |
| 12 | +from cachetools import TTLCache |
| 13 | + |
5 | 14 | from DIRAC import S_OK, S_ERROR
|
6 | 15 | from DIRAC import gLogger
|
7 | 16 |
|
|
12 | 21 | from DIRAC.WorkloadManagementSystem.Client import JobStatus
|
13 | 22 |
|
14 | 23 |
|
| 24 | +class TwoLevelCache: |
| 25 | + """A two-level caching system with soft and hard time-to-live (TTL) expiration. |
| 26 | +
|
| 27 | + This cache implements a two-tier caching mechanism to allow for background refresh |
| 28 | + of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback, |
| 29 | + which helps in reducing latency and maintaining data freshness. |
| 30 | +
|
| 31 | + Attributes: |
| 32 | + soft_cache (TTLCache): A cache with a shorter TTL for quick access. |
| 33 | + hard_cache (TTLCache): A cache with a longer TTL as a fallback. |
| 34 | + locks (defaultdict): Thread-safe locks for each cache key. |
| 35 | + futures (dict): Stores ongoing asynchronous population tasks. |
| 36 | + pool (ThreadPoolExecutor): Thread pool for executing cache population tasks. |
| 37 | +
|
| 38 | + Args: |
| 39 | + soft_ttl (int): Time-to-live in seconds for the soft cache. |
| 40 | + hard_ttl (int): Time-to-live in seconds for the hard cache. |
| 41 | + max_workers (int): Maximum number of workers in the thread pool. |
| 42 | + max_items (int): Maximum number of items in the cache. |
| 43 | +
|
| 44 | + Example: |
| 45 | + >>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300) |
| 46 | + >>> def populate_func(): |
| 47 | + ... return "cached_value" |
| 48 | + >>> value = cache.get("key", populate_func) |
| 49 | +
|
| 50 | + Note: |
| 51 | + The cache uses a ThreadPoolExecutor with a maximum of 10 workers to |
| 52 | + handle concurrent cache population requests. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__(self, soft_ttl: int, hard_ttl: int, *, max_workers: int = 10, max_items: int = 1_000_000): |
| 56 | + """Initialize the TwoLevelCache with specified TTLs.""" |
| 57 | + self.soft_cache = TTLCache(max_items, soft_ttl) |
| 58 | + self.hard_cache = TTLCache(max_items, hard_ttl) |
| 59 | + self.locks = defaultdict(threading.Lock) |
| 60 | + self.futures: dict[str, Future] = {} |
| 61 | + self.pool = ThreadPoolExecutor(max_workers=max_workers) |
| 62 | + |
| 63 | + def get(self, key: str, populate_func: Callable[[], Any]): |
| 64 | + """Retrieve a value from the cache, populating it if necessary. |
| 65 | +
|
| 66 | + This method first checks the soft cache for the key. If not found, |
| 67 | + it checks the hard cache while initiating a background refresh. |
| 68 | + If the key is not in either cache, it waits for the populate_func |
| 69 | + to complete and stores the result in both caches. |
| 70 | +
|
| 71 | + Locks are used to ensure there is never more than one concurrent |
| 72 | + population task for a given key. |
| 73 | +
|
| 74 | + Args: |
| 75 | + key (str): The cache key to retrieve or populate. |
| 76 | + populate_func (Callable[[], Any]): A function to call to populate the cache |
| 77 | + if the key is not found. |
| 78 | +
|
| 79 | + Returns: |
| 80 | + Any: The cached value associated with the key. |
| 81 | +
|
| 82 | + Note: |
| 83 | + This method is thread-safe and handles concurrent requests for the same key. |
| 84 | + """ |
| 85 | + if result := self.soft_cache.get(key): |
| 86 | + return result |
| 87 | + with self.locks[key]: |
| 88 | + if key not in self.futures: |
| 89 | + self.futures[key] = self.pool.submit(self._work, key, populate_func) |
| 90 | + if result := self.hard_cache.get(key): |
| 91 | + self.soft_cache[key] = result |
| 92 | + return result |
| 93 | + # It is critical that ``future`` is waited for outside of the lock as |
| 94 | + # _work aquires the lock before filling the caches. This also means |
| 95 | + # we can gaurentee that the future has not yet been removed from the |
| 96 | + # futures dict. |
| 97 | + future = self.futures[key] |
| 98 | + wait([future]) |
| 99 | + return self.hard_cache[key] |
| 100 | + |
| 101 | + def _work(self, key: str, populate_func: Callable[[], Any]) -> None: |
| 102 | + """Internal method to execute the populate_func and update caches. |
| 103 | +
|
| 104 | + This method is intended to be run in a separate thread. It calls the |
| 105 | + populate_func, stores the result in both caches, and cleans up the |
| 106 | + associated future. |
| 107 | +
|
| 108 | + Args: |
| 109 | + key (str): The cache key to populate. |
| 110 | + populate_func (Callable[[], Any]): The function to call to get the value. |
| 111 | +
|
| 112 | + Note: |
| 113 | + This method is not intended to be called directly by users of the class. |
| 114 | + """ |
| 115 | + result = populate_func() |
| 116 | + with self.locks[key]: |
| 117 | + self.futures.pop(key) |
| 118 | + self.hard_cache[key] = result |
| 119 | + self.soft_cache[key] = result |
| 120 | + |
| 121 | + |
15 | 122 | class Limiter:
|
16 | 123 | # static variables shared between all instances of this class
|
17 | 124 | csDictCache = DictCache()
|
18 | 125 | condCache = DictCache()
|
| 126 | + newCache = TwoLevelCache(10, 300) |
19 | 127 | delayMem = {}
|
20 | 128 |
|
21 | 129 | def __init__(self, jobDB=None, opsHelper=None, pilotRef=None):
|
@@ -177,19 +285,7 @@ def __getRunningCondition(self, siteName, gridCE=None):
|
177 | 285 | if attName not in self.jobDB.jobAttributeNames:
|
178 | 286 | self.log.error("Attribute does not exist", f"({attName}). Check the job limits")
|
179 | 287 | continue
|
180 |
| - cK = f"Running:{siteName}:{attName}" |
181 |
| - data = self.condCache.get(cK) |
182 |
| - if not data: |
183 |
| - result = self.jobDB.getCounters( |
184 |
| - "Jobs", |
185 |
| - [attName], |
186 |
| - {"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]}, |
187 |
| - ) |
188 |
| - if not result["OK"]: |
189 |
| - return result |
190 |
| - data = result["Value"] |
191 |
| - data = {k[0][attName]: k[1] for k in data} |
192 |
| - self.condCache.add(cK, 10, data) |
| 288 | + data = self.newCache.get(f"Running:{siteName}:{attName}", partial(self._countsByJobType, siteName, attName)) |
193 | 289 | for attValue in limitsDict[attName]:
|
194 | 290 | limit = limitsDict[attName][attValue]
|
195 | 291 | running = data.get(attValue, 0)
|
@@ -249,3 +345,15 @@ def __getDelayCondition(self, siteName):
|
249 | 345 | negCond[attName] = []
|
250 | 346 | negCond[attName].append(attValue)
|
251 | 347 | return S_OK(negCond)
|
| 348 | + |
| 349 | + def _countsByJobType(self, siteName, attName): |
| 350 | + result = self.jobDB.getCounters( |
| 351 | + "Jobs", |
| 352 | + [attName], |
| 353 | + {"Site": siteName, "Status": [JobStatus.RUNNING, JobStatus.MATCHED, JobStatus.STALLED]}, |
| 354 | + ) |
| 355 | + if not result["OK"]: |
| 356 | + return result |
| 357 | + data = result["Value"] |
| 358 | + data = {k[0][attName]: k[1] for k in data} |
| 359 | + return data |
0 commit comments