Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Strilanc committed Jan 31, 2025
1 parent 51a1e42 commit f29dff1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
1 change: 1 addition & 0 deletions dev/util_gen_stub_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
}
skip = {
"__firstlineno__",
"__replace__",
"__builtins__",
"__cached__",
"__getstate__",
Expand Down
13 changes: 7 additions & 6 deletions src/stim/simulators/frame_simulator.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,23 @@ static void generate_biased_samples_bit_packed_contiguous(uint8_t *out, size_t n
uintptr_t aligned64_start = start & ~0b111ULL;
uintptr_t aligned64_end = end & ~0b111ULL;
if (aligned64_start != start) {
aligned64_start += 1;
aligned64_start += 8;
}
biased_randomize_bits(p, (uint64_t *)aligned64_start, (uint64_t *)aligned64_end, rng);
if (start != aligned64_start) {
if (start < aligned64_start) {
uint64_t pad;
biased_randomize_bits(p, &pad, &pad + 1, rng);
for (size_t k = 0; k < aligned64_start - start; k++) {
out[k] = (uint8_t)(pad & 0xFF);
pad >>= 8;
}
}
if (end != aligned64_end) {
if (aligned64_end < end) {
uint64_t pad;
biased_randomize_bits(p, &pad, &pad + 1, rng);
for (size_t k = 0; k < end - aligned64_end; k++) {
((uint8_t *)aligned64_end)[k] = (uint8_t)(pad & 0xFF);
while (aligned64_end < end) {
*(uint8_t *)aligned64_end = (uint8_t)(pad & 0xFF);
aligned64_end++;
pad >>= 8;
}
}
Expand Down Expand Up @@ -119,7 +120,7 @@ pybind11::object generate_bernoulli_samples(FrameSimulator<W> &self, size_t num_
auto stride = buf.strides(0);
void *start_of_data = (void *)buf.mutable_data();

if (stride == 1 && num_bytes > 0) {
if (stride == 1) {
generate_biased_samples_bit_packed_contiguous(
(uint8_t *)start_of_data,
num_bytes,
Expand Down
12 changes: 11 additions & 1 deletion src/stim/simulators/frame_simulator_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,17 @@ def test_generate_bernoulli_samples():
assert np.sum(np.unpackbits(v, count=1001, bitorder='little')) == 1001
assert np.sum(np.unpackbits(v, count=1008, bitorder='little')) == 1001

v = sim.generate_bernoulli_samples(256, p=0, bit_packed=True)
assert np.all(v == 0)

sim.generate_bernoulli_samples(256 - 101, p=1, bit_packed=True, out=v[1:-11])
for k in v:
print(k)
assert np.all(v[1:-12] == 0xFF)
assert v[-12] == 7
assert np.all(v[-11:] == 0)
assert np.all(v[:1] == 0)

v = sim.generate_bernoulli_samples(2**16, p=0.25, bit_packed=True)
assert abs(np.sum(np.unpackbits(v, count=2**16)) - 2**16*0.25) < 2**12

Expand All @@ -591,7 +602,6 @@ def test_generate_bernoulli_samples():
assert np.all(v[:-1] == 0xFF)
assert v[-1] == 0x7F


v[:] = 0
sim.generate_bernoulli_samples(2**15, p=1, bit_packed=True, out=v[::2])
assert np.all(v[0::2] == 0xFF)
Expand Down

0 comments on commit f29dff1

Please sign in to comment.