diff --git a/glue/cirq/stimcirq/__init__.py b/glue/cirq/stimcirq/__init__.py index 1e5543696..daef84e16 100644 --- a/glue/cirq/stimcirq/__init__.py +++ b/glue/cirq/stimcirq/__init__.py @@ -1,6 +1,7 @@ __version__ = '1.13.dev0' from ._cirq_to_stim import cirq_circuit_to_stim_circuit from ._cx_swap_gate import CXSwapGate +from ._cz_swap_gate import CZSwapGate from ._det_annotation import DetAnnotation from ._obs_annotation import CumulativeObservableAnnotation from ._shift_coords_annotation import ShiftCoordsAnnotation @@ -20,5 +21,6 @@ "SweepPauli": SweepPauli, "TwoQubitAsymmetricDepolarizingChannel": TwoQubitAsymmetricDepolarizingChannel, "CXSwapGate": CXSwapGate, + "CZSwapGate": CZSwapGate, } JSON_RESOLVER = JSON_RESOLVERS_DICT.get diff --git a/glue/cirq/stimcirq/_cz_swap_gate.py b/glue/cirq/stimcirq/_cz_swap_gate.py new file mode 100644 index 000000000..6f9ded68c --- /dev/null +++ b/glue/cirq/stimcirq/_cz_swap_gate.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, List + +import cirq +import stim + + +@cirq.value_equality +class CZSwapGate(cirq.Gate): + """Handles explaining stim's CZSWAP gates to cirq.""" + + def _num_qubits_(self) -> int: + return 2 + + def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> List[str]: + return ['ZSWAP', 'ZSWAP'] + + def _value_equality_values_(self): + return () + + def _decompose_(self, qubits): + a, b = qubits + yield cirq.SWAP(a, b) + yield cirq.CZ(a, b) + + def _stim_conversion_(self, edit_circuit: stim.Circuit, targets: List[int], **kwargs): + edit_circuit.append_operation('CZSWAP', targets) + + def __pow__(self, power: int) -> 'CZSwapGate': + if power == +1: + return self + if power == -1: + return self + return NotImplemented + + def __str__(self) -> str: + return 'CZSWAP' + + def __repr__(self): + return f'stimcirq.CZSwapGate()' + + @staticmethod + def _json_namespace_() -> str: + return '' + + def _json_dict_(self) -> Dict[str, Any]: + return {} diff --git a/glue/cirq/stimcirq/_cz_swap_test.py b/glue/cirq/stimcirq/_cz_swap_test.py new file mode 100644 index 000000000..34c2af445 --- /dev/null +++ b/glue/cirq/stimcirq/_cz_swap_test.py @@ -0,0 +1,72 @@ +import cirq +import stim +import stimcirq + + +def test_stim_conversion(): + a, b, c = cirq.LineQubit.range(3) + + cirq_circuit = cirq.Circuit( + stimcirq.CZSwapGate().on(a, b), + stimcirq.CZSwapGate().on(b, c), + ) + stim_circuit = stim.Circuit( + """ + CZSWAP 0 1 + TICK + CZSWAP 1 2 + TICK + """ + ) + assert stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) == stim_circuit + assert stimcirq.stim_circuit_to_cirq_circuit(stim_circuit) == cirq_circuit + + +def test_diagram(): + a, b = cirq.LineQubit.range(2) + cirq.testing.assert_has_diagram( + cirq.Circuit( + stimcirq.CZSwapGate()(a, b), + stimcirq.CZSwapGate()(a, b), + ), + """ +0: ---ZSWAP---ZSWAP--- + | | +1: ---ZSWAP---ZSWAP--- + """, + use_unicode_characters=False, + ) + + +def test_inverse(): + a = stimcirq.CZSwapGate() + assert a**+1 == a + assert a**-1 == a + + +def test_repr(): + val = stimcirq.CZSwapGate() + assert eval(repr(val), {"stimcirq": stimcirq}) == val + + +def test_equality(): + eq = cirq.testing.EqualsTester() + eq.add_equality_group(stimcirq.CZSwapGate(), stimcirq.CZSwapGate()) + + +def test_json_serialization(): + a, b, d = cirq.LineQubit.range(3) + c = cirq.Circuit( + stimcirq.CZSwapGate()(a, b), + stimcirq.CZSwapGate()(b, d), + ) + json = cirq.to_json(c) + c2 = cirq.read_json(json_text=json, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) + assert c == c2 + + +def test_json_backwards_compat_exact(): + raw = stimcirq.CZSwapGate() + packed = '{\n "cirq_type": "CZSwapGate"\n}' + assert cirq.to_json(raw) == packed + assert cirq.read_json(json_text=packed, resolvers=[*cirq.DEFAULT_RESOLVERS, stimcirq.JSON_RESOLVER]) == raw diff --git a/glue/cirq/stimcirq/_stim_to_cirq.py b/glue/cirq/stimcirq/_stim_to_cirq.py index 43bf516d0..af8edeb8f 100644 --- a/glue/cirq/stimcirq/_stim_to_cirq.py +++ b/glue/cirq/stimcirq/_stim_to_cirq.py @@ -16,6 +16,7 @@ import stim from ._cx_swap_gate import CXSwapGate +from ._cz_swap_gate import CZSwapGate from ._det_annotation import DetAnnotation from ._measure_and_or_reset_gate import MeasureAndOrResetGate from ._obs_annotation import CumulativeObservableAnnotation @@ -424,6 +425,7 @@ def handler( measure=False, reset=True, basis='X', invert_measure=False, key='' ) ), + "CZSWAP": gate(CZSwapGate()), "CXSWAP": gate(CXSwapGate(inverted=False)), "SWAPCX": gate(CXSwapGate(inverted=True)), "RY": gate(