Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytype fixes #740

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions glue/cirq/stimcirq/_cirq_to_stim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def test_unitary_gate_conversions(gate: cirq.Gate):

def test_more_unitary_gate_conversions():
for p in [1, 1j, -1, -1j]:
assert_unitary_gate_converts_correctly(p * cirq.DensePauliString("IXYZ"))
assert_unitary_gate_converts_correctly((p * cirq.DensePauliString("IXYZ")).controlled(1))
gate: cirq.Gate = p * cirq.DensePauliString("IXYZ")
assert_unitary_gate_converts_correctly(gate)
assert_unitary_gate_converts_correctly(gate.controlled(1))

a, b = cirq.LineQubit.range(2)
c, _ = cirq_circuit_to_stim_data(
Expand Down Expand Up @@ -377,7 +378,7 @@ def test_on_tagged_loop():
repetitions=3,
).with_tags('my_tag')
)

stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(c)
assert stim.CircuitRepeatBlock in {type(instr) for instr in stim_circuit}

Expand Down
7 changes: 4 additions & 3 deletions glue/cirq/stimcirq/_stim_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
from typing import Dict, List, Sequence
from typing import DefaultDict, Dict, List, Sequence
import typing

import cirq

Expand Down Expand Up @@ -35,8 +36,8 @@ def run_sweep(

# Find number of qubits (length), number of instances, and indices for each measurement key.
lengths: Dict[str, int] = {}
instances: Dict[str, int] = collections.Counter()
indices: Dict[str, int] = collections.defaultdict(list)
instances: typing.Counter[str] = collections.Counter()
indices: DefaultDict[str, List[int]] = collections.defaultdict(list)
k = 0
for key, length in key_ranges:
prev_length = lengths.get(key)
Expand Down
4 changes: 2 additions & 2 deletions glue/cirq/stimcirq/_stim_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _proper_transform_circuit_qubits(circuit: cirq.AbstractCircuit, remap: Dict[
class CircuitTranslationTracker:
def __init__(self, flatten: bool):
self.qubit_coords: Dict[int, cirq.Qid] = {}
self.origin: DefaultDict[float] = collections.defaultdict(float)
self.origin: DefaultDict[int, float] = collections.defaultdict(float)
self.num_measurements_seen = 0
self.full_circuit = cirq.Circuit()
self.tick_circuit = cirq.Circuit()
Expand Down Expand Up @@ -422,7 +422,7 @@ def get_handler_table() -> Dict[
noise = CircuitTranslationTracker.OneToOneNoisyGateHandler
sweep_gate = CircuitTranslationTracker.SweepableGateHandler

def not_impl(message) -> Callable[[Any], None]:
def not_impl(message) -> Callable[[CircuitTranslationTracker, stim.Circuit], None]:
def handler(
tracker: CircuitTranslationTracker, instruction: stim.CircuitInstruction
) -> None:
Expand Down
25 changes: 12 additions & 13 deletions glue/sample/src/sinter/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import dataclasses
import pathlib
from typing import Any
from typing import Callable, Iterator, Optional, Union, Iterable, List, TYPE_CHECKING, Tuple, Dict
from typing import Callable, Iterator, Optional, Union, Iterable, List, Tuple, Dict

import math
import numpy as np
Expand All @@ -11,13 +11,12 @@
from sinter._collection_options import CollectionOptions
from sinter._csv_out import CSV_HEADER
from sinter._collection_work_manager import CollectionWorkManager
from sinter._decoding_decoder_class import Decoder
from sinter._existing_data import ExistingData
from sinter._printer import ThrottledProgressPrinter
from sinter._task import Task
from sinter._task_stats import TaskStats

if TYPE_CHECKING:
import sinter


@dataclasses.dataclass(frozen=True)
class Progress:
Expand All @@ -39,8 +38,8 @@ class Progress:

def iter_collect(*,
num_workers: int,
tasks: Union[Iterator['sinter.Task'],
Iterable['sinter.Task']],
tasks: Union[Iterator[Task],
Iterable[Task]],
hint_num_tasks: Optional[int] = None,
additional_existing_data: Optional[ExistingData] = None,
max_shots: Optional[int] = None,
Expand All @@ -51,10 +50,10 @@ def iter_collect(*,
start_batch_size: Optional[int] = None,
count_observable_error_combos: bool = False,
count_detection_events: bool = False,
custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None,
custom_decoders: Optional[Dict[str, Decoder]] = None,
custom_error_count_key: Optional[str] = None,
allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
) -> Iterator['sinter.Progress']:
) -> Iterator[Progress]:
"""Iterates error correction statistics collected from worker processes.

It is important to iterate until the sequence ends, or worker processes will
Expand Down Expand Up @@ -161,7 +160,7 @@ def iter_collect(*,

if hint_num_tasks is None:
try:
# noinspection PyTypeChecker
tasks = list(tasks)
hint_num_tasks = len(tasks)
except TypeError:
pass
Expand Down Expand Up @@ -220,10 +219,10 @@ def iter_collect(*,

def collect(*,
num_workers: int,
tasks: Union[Iterator['sinter.Task'], Iterable['sinter.Task']],
tasks: Union[Iterator[Task], Iterable[Task]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary to change? Because these methods are in the external API, I prefer to name the types how users would use them ("sinter.Task" instead of just "Task").

existing_data_filepaths: Iterable[Union[str, pathlib.Path]] = (),
save_resume_filepath: Union[None, str, pathlib.Path] = None,
progress_callback: Optional[Callable[['sinter.Progress'], None]] = None,
progress_callback: Optional[Callable[[Progress], None]] = None,
max_shots: Optional[int] = None,
max_errors: Optional[int] = None,
count_observable_error_combos: bool = False,
Expand All @@ -234,10 +233,10 @@ def collect(*,
start_batch_size: Optional[int] = None,
print_progress: bool = False,
hint_num_tasks: Optional[int] = None,
custom_decoders: Optional[Dict[str, 'sinter.Decoder']] = None,
custom_decoders: Optional[Dict[str, Decoder]] = None,
custom_error_count_key: Optional[str] = None,
allowed_cpu_affinity_ids: Optional[Iterable[int]] = None,
) -> List['sinter.TaskStats']:
) -> List[TaskStats]:
"""Collects statistics from the given tasks, using multiprocessing.

Args:
Expand Down
7 changes: 5 additions & 2 deletions glue/sample/src/sinter/_collection_work_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
self.shut_down_workers()
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
self.exit_stack = None
if self.exit_stack:
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
self.exit_stack = None
self.tmp_dir = None

def shut_down_workers(self) -> None:
Expand All @@ -112,6 +113,7 @@ def shut_down_workers(self) -> None:
w.join()

def fill_work_queue(self) -> bool:
assert self.queue_to_workers is not None, 'start_workers not called'
while len(self.deployed_jobs) < len(self.workers):
work = self.provide_more_work()
if work is None:
Expand All @@ -126,6 +128,7 @@ def wait_for_next_sample(self,
*,
timeout: Optional[float] = None,
) -> TaskStats:
assert self.queue_from_workers is not None, 'start_workers not called'
result = self.queue_from_workers.get(timeout=timeout)
assert isinstance(result, WorkOut)
if result.msg_error is not None:
Expand Down
2 changes: 2 additions & 0 deletions glue/sample/src/sinter/_decoding_decoder_class.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
from typing import Tuple, Union

import pathlib

import numpy as np
Expand Down
14 changes: 13 additions & 1 deletion glue/sample/src/sinter/_decoding_pymatching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import pathlib
from typing import TYPE_CHECKING

from sinter._decoding_decoder_class import Decoder, CompiledDecoder


if TYPE_CHECKING:
import numpy as np
import pymatching
import stim


class PyMatchingCompiledDecoder(CompiledDecoder):
def __init__(self, matcher: 'pymatching.Matching'):
self.matcher = matcher
Expand All @@ -10,12 +19,15 @@ def decode_shots_bit_packed(
*,
bit_packed_detection_event_data: 'np.ndarray',
) -> 'np.ndarray':
return self.matcher.decode_batch(
result = self.matcher.decode_batch(
shots=bit_packed_detection_event_data,
bit_packed_shots=True,
bit_packed_predictions=True,
return_weights=False,
)
if isinstance(result, tuple):
return result[0] # pymatching returned predictions and weights
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this ever actually executed, it would be a bug. Raise an exception, or do a cast.

return result


class PyMatchingDecoder(Decoder):
Expand Down
2 changes: 1 addition & 1 deletion glue/sample/src/sinter/_decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_decode_fails_correctly(decoder: str, force_streaming: Optional[bool]):
with open(d / 'bad_dets.b8', 'wb') as f:
f.write(b'!')

if 'vacuous' not in decoder:
if 'vacuous' not in decoder and decoder_obj is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should assert it's not none instead of disabling the testing of the decoder that wasn't.

with pytest.raises(Exception):
decoder_obj.decode_via_files(
num_shots=1,
Expand Down
16 changes: 10 additions & 6 deletions glue/sample/src/sinter/_decoding_vacuous.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import pathlib
import typing

from sinter._decoding_decoder_class import Decoder, CompiledDecoder

if typing.TYPE_CHECKING:
import stim

class VacuousDecoder(Decoder):
"""An example decoder that always predicts the observables aren't flipped.
Expand All @@ -15,10 +19,10 @@ def decode_via_files(self,
num_shots: int,
num_dets: int,
num_obs: int,
dem_path: 'pathlib.Path',
dets_b8_in_path: 'pathlib.Path',
obs_predictions_b8_out_path: 'pathlib.Path',
tmp_dir: 'pathlib.Path',
dem_path: pathlib.Path,
dets_b8_in_path: pathlib.Path,
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
) -> None:
with open(obs_predictions_b8_out_path, 'wb') as f:
f.write(b'\0' * (num_obs * num_shots))
Expand All @@ -33,6 +37,6 @@ def __init__(self, shape: int):
def decode_shots_bit_packed(
self,
*,
bit_packed_detection_event_data: 'np.ndarray',
) -> 'np.ndarray':
bit_packed_detection_event_data: np.ndarray,
) -> np.ndarray:
return np.zeros(shape=(bit_packed_detection_event_data.shape[0], self.shape), dtype=np.uint8)
2 changes: 1 addition & 1 deletion glue/sample/src/sinter/_main_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _set_axis_scale_label_ticks(
set_ticks(major_ticks)
set_ticks(minor_ticks, minor=True)
else:
raise NotImplemented(f'{scale_name=}')
raise NotImplementedError(f'{scale_name=}')
return scale_name


Expand Down
Loading
Loading