15
15
import threading
16
16
import pprint
17
17
from collections import deque
18
- import queue
19
18
import hashlib
20
19
from dataclasses import dataclass
20
+ # import xxhash
21
21
# from line_profiler import profile
22
22
23
23
# TODO:
@@ -34,6 +34,17 @@ def _tensor_blake2b_checksum(tensor: torch.Tensor, prev_hash: bytes | None) -> b
34
34
hasher .update (tensor .numpy ().tobytes ())
35
35
return hasher .digest ()
36
36
37
+ # xxhasher = xxhash.xxh128()
38
+ # def _tensor_xxhash128_checksum(tensor: torch.Tensor, prev_hash: bytes | None) -> bytes:
39
+ # global xxhasher
40
+ # xxhasher.reset()
41
+ # if prev_hash is not None:
42
+ # xxhasher.update(prev_hash)
43
+ # xxhasher.update(tensor.numpy().tobytes())
44
+ # return xxhasher.digest()
45
+
46
+ _tensor_hash_checksum = _tensor_blake2b_checksum
47
+
37
48
_uniquehash = 0
38
49
def _randomhash ():
39
50
global _uniquehash
@@ -700,7 +711,7 @@ def ___validate_cache(self):
700
711
# if ids.shape[-1] > 0:
701
712
# assert page.prev_hash == prev_hash, "bad prev_hash " + str(job) + " -> " + str(page)
702
713
# if ids.shape[-1] == self.page_size:
703
- # phash = _tensor_blake2b_checksum (ids, prev_hash)
714
+ # phash = _tensor_hash_checksum (ids, prev_hash)
704
715
# assert page.phash == phash, "bad phash " + str(job) + " -> " + str(page)
705
716
# prev_hash = phash
706
717
# spos = spos2
@@ -1624,7 +1635,7 @@ def receive_sample(
1624
1635
# assert page.sequence.shape[-1] == self.generator.page_size
1625
1636
# assert torch.all(page_ids == page.sequence)
1626
1637
# assert page_ids.shape[-1] == self.generator.page_size
1627
- new_hash = _tensor_blake2b_checksum (page_ids , last_hash )
1638
+ new_hash = _tensor_hash_checksum (page_ids , last_hash )
1628
1639
1629
1640
# If another referenced page has the same hash, switch to referencing that instead
1630
1641
@@ -1883,7 +1894,7 @@ def prepare_for_queue(self, generator, serial_number: int):
1883
1894
for i in range (context_pages ):
1884
1895
page_ids = seq .sequence_ids .torch_slice (i * page_size , (i + 1 ) * page_size )
1885
1896
assert page_ids .shape [- 1 ] == self .generator .page_size
1886
- r_hash = _tensor_blake2b_checksum (page_ids , r_hash )
1897
+ r_hash = _tensor_hash_checksum (page_ids , r_hash )
1887
1898
seq .page_hashes .append (r_hash )
1888
1899
all_unique_hashes .add (r_hash )
1889
1900
0 commit comments