From b29741d928281ea000174532ac1395bc03d0c553 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Mon, 11 Mar 2024 17:29:27 -0700 Subject: [PATCH 1/6] Add `stim.PauliString.pauli_indices` (#710) Fixes https://github.com/quantumlib/Stim/issues/699 --- doc/python_api_reference_vDev.md | 45 +++++++++ doc/stim.pyi | 37 +++++++ glue/python/src/stim/__init__.pyi | 37 +++++++ src/stim/stabilizers/pauli_string.pybind.cc | 97 +++++++++++++++++++ .../stabilizers/pauli_string_pybind_test.py | 29 ++++++ 5 files changed, 245 insertions(+) diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index 4a197b909..373736162 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -264,6 +264,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`stim.PauliString.from_numpy`](#stim.PauliString.from_numpy) - [`stim.PauliString.from_unitary_matrix`](#stim.PauliString.from_unitary_matrix) - [`stim.PauliString.iter_all`](#stim.PauliString.iter_all) + - [`stim.PauliString.pauli_indices`](#stim.PauliString.pauli_indices) - [`stim.PauliString.random`](#stim.PauliString.random) - [`stim.PauliString.sign`](#stim.PauliString.sign) - [`stim.PauliString.to_numpy`](#stim.PauliString.to_numpy) @@ -9068,6 +9069,50 @@ def iter_all( """ ``` + +```python +# stim.PauliString.pauli_indices + +# (in class stim.PauliString) +def pauli_indices( + self, + included_paulis: str = "XYZ", +) -> List[int]: + """Returns the indices of non-identity Paulis, or of specified Paulis. + + Args: + include: A string containing the Pauli types to include. + X type Pauli indices are included if "X" or "x" is in the string. + Y type Pauli indices are included if "Y" or "y" is in the string. + Z type Pauli indices are included if "Z" or "z" is in the string. + I type Pauli indices are included if "I" or "_" is in the string. + An exception is thrown if other characters are in the string. + + Returns: + A list containing the ascending indices of matching Pauli terms. + + Examples: + >>> import stim + >>> stim.PauliString("_____X___Y____Z___").pauli_indices() + [5, 9, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("XZ") + [5, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("X") + [5] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("Y") + [9] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("IY") + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17] + + >>> stim.PauliString("-X103*Y100").pauli_indices() + [100, 103] + """ +``` + ```python # stim.PauliString.random diff --git a/doc/stim.pyi b/doc/stim.pyi index 7e476efcc..aa2c065b5 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -6982,6 +6982,43 @@ class PauliString: +_ZX +_ZZ """ + def pauli_indices( + self, + included_paulis: str = "XYZ", + ) -> List[int]: + """Returns the indices of non-identity Paulis, or of specified Paulis. + + Args: + include: A string containing the Pauli types to include. + X type Pauli indices are included if "X" or "x" is in the string. + Y type Pauli indices are included if "Y" or "y" is in the string. + Z type Pauli indices are included if "Z" or "z" is in the string. + I type Pauli indices are included if "I" or "_" is in the string. + An exception is thrown if other characters are in the string. + + Returns: + A list containing the ascending indices of matching Pauli terms. + + Examples: + >>> import stim + >>> stim.PauliString("_____X___Y____Z___").pauli_indices() + [5, 9, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("XZ") + [5, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("X") + [5] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("Y") + [9] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("IY") + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17] + + >>> stim.PauliString("-X103*Y100").pauli_indices() + [100, 103] + """ @staticmethod def random( num_qubits: int, diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index 7e476efcc..aa2c065b5 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -6982,6 +6982,43 @@ class PauliString: +_ZX +_ZZ """ + def pauli_indices( + self, + included_paulis: str = "XYZ", + ) -> List[int]: + """Returns the indices of non-identity Paulis, or of specified Paulis. + + Args: + include: A string containing the Pauli types to include. + X type Pauli indices are included if "X" or "x" is in the string. + Y type Pauli indices are included if "Y" or "y" is in the string. + Z type Pauli indices are included if "Z" or "z" is in the string. + I type Pauli indices are included if "I" or "_" is in the string. + An exception is thrown if other characters are in the string. + + Returns: + A list containing the ascending indices of matching Pauli terms. + + Examples: + >>> import stim + >>> stim.PauliString("_____X___Y____Z___").pauli_indices() + [5, 9, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("XZ") + [5, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("X") + [5] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("Y") + [9] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("IY") + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17] + + >>> stim.PauliString("-X103*Y100").pauli_indices() + [100, 103] + """ @staticmethod def random( num_qubits: int, diff --git a/src/stim/stabilizers/pauli_string.pybind.cc b/src/stim/stabilizers/pauli_string.pybind.cc index 3c335cee0..22160685f 100644 --- a/src/stim/stabilizers/pauli_string.pybind.cc +++ b/src/stim/stabilizers/pauli_string.pybind.cc @@ -636,6 +636,103 @@ void stim_pybind::pybind_pauli_string_methods(pybind11::module &m, pybind11::cla )DOC") .data()); + c.def( + "pauli_indices", + [](const FlexPauliString &self, const std::string &include) { + std::vector result; + size_t n64 = self.value.xs.num_u64_padded(); + bool keep_i = false; + bool keep_x = false; + bool keep_y = false; + bool keep_z = false; + for (char c : include) { + switch (c) { + case '_': + case 'I': + keep_i = true; + break; + case 'x': + case 'X': + keep_x = true; + break; + case 'y': + case 'Y': + keep_y = true; + break; + case 'z': + case 'Z': + keep_z = true; + break; + default: + throw std::invalid_argument("Invalid character in include string: " + std::string(1, c)); + } + } + for (size_t k = 0; k < n64; k++) { + uint64_t x = self.value.xs.u64[k]; + uint64_t z = self.value.zs.u64[k]; + uint64_t u = 0; + if (keep_i) { + u |= ~x & ~z; + } + if (keep_x) { + u |= x & ~z; + } + if (keep_y) { + u |= x & z; + } + if (keep_z) { + u |= ~x & z; + } + while (u) { + uint8_t v = std::countr_zero(u); + uint64_t q = k * 64 + v; + if (q >= self.value.num_qubits) { + return result; + } + result.push_back(q); + u &= u - 1; + } + } + return result; + }, + pybind11::arg("included_paulis") = "XYZ", + clean_doc_string(R"DOC( + @signature def pauli_indices(self, included_paulis: str = "XYZ") -> List[int]: + Returns the indices of non-identity Paulis, or of specified Paulis. + + Args: + include: A string containing the Pauli types to include. + X type Pauli indices are included if "X" or "x" is in the string. + Y type Pauli indices are included if "Y" or "y" is in the string. + Z type Pauli indices are included if "Z" or "z" is in the string. + I type Pauli indices are included if "I" or "_" is in the string. + An exception is thrown if other characters are in the string. + + Returns: + A list containing the ascending indices of matching Pauli terms. + + Examples: + >>> import stim + >>> stim.PauliString("_____X___Y____Z___").pauli_indices() + [5, 9, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("XZ") + [5, 14] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("X") + [5] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("Y") + [9] + + >>> stim.PauliString("_____X___Y____Z___").pauli_indices("IY") + [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17] + + >>> stim.PauliString("-X103*Y100").pauli_indices() + [100, 103] + )DOC") + .data()); + c.def( "commutes", [](const FlexPauliString &self, const FlexPauliString &other) { diff --git a/src/stim/stabilizers/pauli_string_pybind_test.py b/src/stim/stabilizers/pauli_string_pybind_test.py index 4efff759e..5058e6a55 100644 --- a/src/stim/stabilizers/pauli_string_pybind_test.py +++ b/src/stim/stabilizers/pauli_string_pybind_test.py @@ -899,3 +899,32 @@ def test_backwards_compatibility_init(): assert stim.PauliString(text="XYZ") == stim.PauliString("+XYZ") # noinspection PyArgumentList assert stim.PauliString(other=stim.PauliString("XYZ")) == stim.PauliString("+XYZ") + + +def test_pauli_indices(): + assert stim.PauliString().pauli_indices() == [] + assert stim.PauliString().pauli_indices("X") == [] + assert stim.PauliString().pauli_indices("I") == [] + assert stim.PauliString(5).pauli_indices() == [] + assert stim.PauliString(5).pauli_indices("X") == [] + assert stim.PauliString(5).pauli_indices("I") == [0, 1, 2, 3, 4] + assert stim.PauliString("X1000").pauli_indices() == [1000] + assert stim.PauliString("Y1000").pauli_indices() == [1000] + assert stim.PauliString("Z1000").pauli_indices() == [1000] + assert stim.PauliString("X1000").pauli_indices("YZ") == [] + assert stim.PauliString("Y1000").pauli_indices("XZ") == [] + assert stim.PauliString("Z1000").pauli_indices("XY") == [] + assert stim.PauliString("X1000").pauli_indices("X") == [1000] + assert stim.PauliString("Y1000").pauli_indices("Y") == [1000] + assert stim.PauliString("Z1000").pauli_indices("Z") == [1000] + + assert stim.PauliString("_XYZ").pauli_indices("x") == [1] + assert stim.PauliString("_XYZ").pauli_indices("X") == [1] + assert stim.PauliString("_XYZ").pauli_indices("y") == [2] + assert stim.PauliString("_XYZ").pauli_indices("Y") == [2] + assert stim.PauliString("_XYZ").pauli_indices("z") == [3] + assert stim.PauliString("_XYZ").pauli_indices("Z") == [3] + assert stim.PauliString("_XYZ").pauli_indices("I") == [0] + assert stim.PauliString("_XYZ").pauli_indices("_") == [0] + with pytest.raises(ValueError, match="Invalid character"): + assert stim.PauliString("_XYZ").pauli_indices("k") From dc59667f3280f002907dc743551003270d9f1017 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Mon, 11 Mar 2024 19:05:53 -0700 Subject: [PATCH 2/6] Add `stim.Circuit.detecting_regions` (#711) Fixes https://github.com/quantumlib/Stim/issues/349 --- doc/python_api_reference_vDev.md | 136 ++++++++++++++ doc/stim.pyi | 128 +++++++++++++ file_lists/test_files | 1 + glue/python/src/stim/__init__.pyi | 128 +++++++++++++ src/stim/circuit/circuit.pybind.cc | 240 ++++++++++++++++++++++++ src/stim/circuit/circuit_pybind_test.py | 27 +++ src/stim/dem/dem_instruction.cc | 27 ++- src/stim/dem/dem_instruction.h | 2 + src/stim/dem/dem_instruction.test.cc | 31 +++ src/stim/stabilizers/conversions.cc | 43 +++++ src/stim/stabilizers/conversions.h | 8 + 11 files changed, 769 insertions(+), 2 deletions(-) create mode 100644 src/stim/dem/dem_instruction.test.cc diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index 373736162..d6a474d01 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -27,6 +27,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`stim.Circuit.compile_sampler`](#stim.Circuit.compile_sampler) - [`stim.Circuit.copy`](#stim.Circuit.copy) - [`stim.Circuit.count_determined_measurements`](#stim.Circuit.count_determined_measurements) + - [`stim.Circuit.detecting_regions`](#stim.Circuit.detecting_regions) - [`stim.Circuit.detector_error_model`](#stim.Circuit.detector_error_model) - [`stim.Circuit.diagram`](#stim.Circuit.diagram) - [`stim.Circuit.explain_detector_error_model_errors`](#stim.Circuit.explain_detector_error_model_errors) @@ -1272,6 +1273,141 @@ def count_determined_measurements( """ ``` + +```python +# stim.Circuit.detecting_regions + +# (in class stim.Circuit) +def detecting_regions( + self, + *, + targets: Optional[Iterable[stim.DemTarget | str | Iterable[float]]] = None, + ticks: Optional[Iterable[int]] = None, +) -> Dict[stim.DemTarget, Dict[int, stim.PauliString]]: + """Records where detectors and observables are sensitive to errors over time. + + The result of this method is a nested dictionary, mapping detectors/observables + and ticks to Pauli sensitivities for that detector/observable at that time. + + For example, if observable 2 has Z-type sensitivity on qubits 5 and 6 during + tick 3, then `result[stim.target_logical_observable_id(2)][3]` will be equal to + `stim.PauliString("Z5*Z6")`. + + If you want sensitivities from more places in the circuit, besides just at the + TICK instructions, you can work around this by making a version of the circuit + with more TICKs. + + Args: + targets: Defaults to everything (None). + + When specified, this should be an iterable of filters where items + matching any one filter are included. + + A variety of filters are supported: + stim.DemTarget: Includes the targeted detector or observable. + Iterable[float]: Coordinate prefix match. Includes detectors whose + coordinate data begins with the same floats. + "D": Includes all detectors. + "L": Includes all observables. + "D#" (e.g. "D5"): Includes the detector with the specified index. + "L#" (e.g. "L5"): Includes the observable with the specified index. + + ticks: Defaults to everything (None). + When specified, this should be a list of integers corresponding to + the tick indices to report sensitivities for. + + ignore_anticommutation_errors: Defaults to False. + When set to False, invalid detecting regions that anticommute with a + reset will cause the method to raise an exception. When set to True, + the offending component will simply be silently dropped. This can + result in broken detectors having apparently enormous detecting + regions. + + Returns: + Nested dictionaries keyed first by a `stim.DemTarget` identifying the + detector or observable, then by the index of the tick, leading to a + PauliString with that target's error sensitivity at that tick. + + Note you can use `stim.PauliString.pauli_indices` to quickly get to the + non-identity terms in the sensitivity. + + Examples: + >>> import stim + + >>> detecting_regions = stim.Circuit(''' + ... R 0 + ... TICK + ... H 0 + ... TICK + ... CX 0 1 + ... TICK + ... MX 0 1 + ... DETECTOR rec[-1] rec[-2] + ... ''').detecting_regions() + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D0 + tick 0 = +Z_ + tick 1 = +X_ + tick 2 = +XX + + >>> circuit = stim.Circuit.generated( + ... "surface_code:rotated_memory_x", + ... rounds=5, + ... distance=4, + ... ) + + >>> detecting_regions = circuit.detecting_regions( + ... targets=["L0", (2, 4), stim.DemTarget.relative_detector_id(5)], + ... ticks=range(5, 15), + ... ) + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D1 + tick 5 = +____________________X______________________ + tick 6 = +____________________Z______________________ + target D5 + tick 5 = +______X____________________________________ + tick 6 = +______Z____________________________________ + target D14 + tick 5 = +__________X_X______XXX_____________________ + tick 6 = +__________X_X______XZX_____________________ + tick 7 = +__________X_X______XZX_____________________ + tick 8 = +__________X_X______XXX_____________________ + tick 9 = +__________XXX_____XXX______________________ + tick 10 = +__________XXX_______X______________________ + tick 11 = +__________X_________X______________________ + tick 12 = +____________________X______________________ + tick 13 = +____________________Z______________________ + target D29 + tick 7 = +____________________Z______________________ + tick 8 = +____________________X______________________ + tick 9 = +____________________XX_____________________ + tick 10 = +___________________XXX_______X_____________ + tick 11 = +____________X______XXXX______X_____________ + tick 12 = +__________X_X______XXX_____________________ + tick 13 = +__________X_X______XZX_____________________ + tick 14 = +__________X_X______XZX_____________________ + target D44 + tick 14 = +____________________Z______________________ + target L0 + tick 5 = +_X________X________X________X______________ + tick 6 = +_X________X________X________X______________ + tick 7 = +_X________X________X________X______________ + tick 8 = +_X________X________X________X______________ + tick 9 = +_X________X_______XX________X______________ + tick 10 = +_X________X________X________X______________ + tick 11 = +_X________XX_______X________XX_____________ + tick 12 = +_X________X________X________X______________ + tick 13 = +_X________X________X________X______________ + tick 14 = +_X________X________X________X______________ + """ +``` + ```python # stim.Circuit.detector_error_model diff --git a/doc/stim.pyi b/doc/stim.pyi index aa2c065b5..fffa1d476 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -716,6 +716,134 @@ class Circuit: >>> circuit.num_detectors + circuit.num_observables 217 """ + def detecting_regions( + self, + *, + targets: Optional[Iterable[stim.DemTarget | str | Iterable[float]]] = None, + ticks: Optional[Iterable[int]] = None, + ) -> Dict[stim.DemTarget, Dict[int, stim.PauliString]]: + """Records where detectors and observables are sensitive to errors over time. + + The result of this method is a nested dictionary, mapping detectors/observables + and ticks to Pauli sensitivities for that detector/observable at that time. + + For example, if observable 2 has Z-type sensitivity on qubits 5 and 6 during + tick 3, then `result[stim.target_logical_observable_id(2)][3]` will be equal to + `stim.PauliString("Z5*Z6")`. + + If you want sensitivities from more places in the circuit, besides just at the + TICK instructions, you can work around this by making a version of the circuit + with more TICKs. + + Args: + targets: Defaults to everything (None). + + When specified, this should be an iterable of filters where items + matching any one filter are included. + + A variety of filters are supported: + stim.DemTarget: Includes the targeted detector or observable. + Iterable[float]: Coordinate prefix match. Includes detectors whose + coordinate data begins with the same floats. + "D": Includes all detectors. + "L": Includes all observables. + "D#" (e.g. "D5"): Includes the detector with the specified index. + "L#" (e.g. "L5"): Includes the observable with the specified index. + + ticks: Defaults to everything (None). + When specified, this should be a list of integers corresponding to + the tick indices to report sensitivities for. + + ignore_anticommutation_errors: Defaults to False. + When set to False, invalid detecting regions that anticommute with a + reset will cause the method to raise an exception. When set to True, + the offending component will simply be silently dropped. This can + result in broken detectors having apparently enormous detecting + regions. + + Returns: + Nested dictionaries keyed first by a `stim.DemTarget` identifying the + detector or observable, then by the index of the tick, leading to a + PauliString with that target's error sensitivity at that tick. + + Note you can use `stim.PauliString.pauli_indices` to quickly get to the + non-identity terms in the sensitivity. + + Examples: + >>> import stim + + >>> detecting_regions = stim.Circuit(''' + ... R 0 + ... TICK + ... H 0 + ... TICK + ... CX 0 1 + ... TICK + ... MX 0 1 + ... DETECTOR rec[-1] rec[-2] + ... ''').detecting_regions() + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D0 + tick 0 = +Z_ + tick 1 = +X_ + tick 2 = +XX + + >>> circuit = stim.Circuit.generated( + ... "surface_code:rotated_memory_x", + ... rounds=5, + ... distance=4, + ... ) + + >>> detecting_regions = circuit.detecting_regions( + ... targets=["L0", (2, 4), stim.DemTarget.relative_detector_id(5)], + ... ticks=range(5, 15), + ... ) + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D1 + tick 5 = +____________________X______________________ + tick 6 = +____________________Z______________________ + target D5 + tick 5 = +______X____________________________________ + tick 6 = +______Z____________________________________ + target D14 + tick 5 = +__________X_X______XXX_____________________ + tick 6 = +__________X_X______XZX_____________________ + tick 7 = +__________X_X______XZX_____________________ + tick 8 = +__________X_X______XXX_____________________ + tick 9 = +__________XXX_____XXX______________________ + tick 10 = +__________XXX_______X______________________ + tick 11 = +__________X_________X______________________ + tick 12 = +____________________X______________________ + tick 13 = +____________________Z______________________ + target D29 + tick 7 = +____________________Z______________________ + tick 8 = +____________________X______________________ + tick 9 = +____________________XX_____________________ + tick 10 = +___________________XXX_______X_____________ + tick 11 = +____________X______XXXX______X_____________ + tick 12 = +__________X_X______XXX_____________________ + tick 13 = +__________X_X______XZX_____________________ + tick 14 = +__________X_X______XZX_____________________ + target D44 + tick 14 = +____________________Z______________________ + target L0 + tick 5 = +_X________X________X________X______________ + tick 6 = +_X________X________X________X______________ + tick 7 = +_X________X________X________X______________ + tick 8 = +_X________X________X________X______________ + tick 9 = +_X________X_______XX________X______________ + tick 10 = +_X________X________X________X______________ + tick 11 = +_X________XX_______X________XX_____________ + tick 12 = +_X________X________X________X______________ + tick 13 = +_X________X________X________X______________ + tick 14 = +_X________X________X________X______________ + """ def detector_error_model( self, *, diff --git a/file_lists/test_files b/file_lists/test_files index b08819aff..3b0f6ea4d 100644 --- a/file_lists/test_files +++ b/file_lists/test_files @@ -13,6 +13,7 @@ src/stim/cmd/command_gen.test.cc src/stim/cmd/command_m2d.test.cc src/stim/cmd/command_sample.test.cc src/stim/cmd/command_sample_dem.test.cc +src/stim/dem/dem_instruction.test.cc src/stim/dem/detector_error_model.test.cc src/stim/diagram/ascii_diagram.test.cc src/stim/diagram/base64.test.cc diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index aa2c065b5..fffa1d476 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -716,6 +716,134 @@ class Circuit: >>> circuit.num_detectors + circuit.num_observables 217 """ + def detecting_regions( + self, + *, + targets: Optional[Iterable[stim.DemTarget | str | Iterable[float]]] = None, + ticks: Optional[Iterable[int]] = None, + ) -> Dict[stim.DemTarget, Dict[int, stim.PauliString]]: + """Records where detectors and observables are sensitive to errors over time. + + The result of this method is a nested dictionary, mapping detectors/observables + and ticks to Pauli sensitivities for that detector/observable at that time. + + For example, if observable 2 has Z-type sensitivity on qubits 5 and 6 during + tick 3, then `result[stim.target_logical_observable_id(2)][3]` will be equal to + `stim.PauliString("Z5*Z6")`. + + If you want sensitivities from more places in the circuit, besides just at the + TICK instructions, you can work around this by making a version of the circuit + with more TICKs. + + Args: + targets: Defaults to everything (None). + + When specified, this should be an iterable of filters where items + matching any one filter are included. + + A variety of filters are supported: + stim.DemTarget: Includes the targeted detector or observable. + Iterable[float]: Coordinate prefix match. Includes detectors whose + coordinate data begins with the same floats. + "D": Includes all detectors. + "L": Includes all observables. + "D#" (e.g. "D5"): Includes the detector with the specified index. + "L#" (e.g. "L5"): Includes the observable with the specified index. + + ticks: Defaults to everything (None). + When specified, this should be a list of integers corresponding to + the tick indices to report sensitivities for. + + ignore_anticommutation_errors: Defaults to False. + When set to False, invalid detecting regions that anticommute with a + reset will cause the method to raise an exception. When set to True, + the offending component will simply be silently dropped. This can + result in broken detectors having apparently enormous detecting + regions. + + Returns: + Nested dictionaries keyed first by a `stim.DemTarget` identifying the + detector or observable, then by the index of the tick, leading to a + PauliString with that target's error sensitivity at that tick. + + Note you can use `stim.PauliString.pauli_indices` to quickly get to the + non-identity terms in the sensitivity. + + Examples: + >>> import stim + + >>> detecting_regions = stim.Circuit(''' + ... R 0 + ... TICK + ... H 0 + ... TICK + ... CX 0 1 + ... TICK + ... MX 0 1 + ... DETECTOR rec[-1] rec[-2] + ... ''').detecting_regions() + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D0 + tick 0 = +Z_ + tick 1 = +X_ + tick 2 = +XX + + >>> circuit = stim.Circuit.generated( + ... "surface_code:rotated_memory_x", + ... rounds=5, + ... distance=4, + ... ) + + >>> detecting_regions = circuit.detecting_regions( + ... targets=["L0", (2, 4), stim.DemTarget.relative_detector_id(5)], + ... ticks=range(5, 15), + ... ) + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D1 + tick 5 = +____________________X______________________ + tick 6 = +____________________Z______________________ + target D5 + tick 5 = +______X____________________________________ + tick 6 = +______Z____________________________________ + target D14 + tick 5 = +__________X_X______XXX_____________________ + tick 6 = +__________X_X______XZX_____________________ + tick 7 = +__________X_X______XZX_____________________ + tick 8 = +__________X_X______XXX_____________________ + tick 9 = +__________XXX_____XXX______________________ + tick 10 = +__________XXX_______X______________________ + tick 11 = +__________X_________X______________________ + tick 12 = +____________________X______________________ + tick 13 = +____________________Z______________________ + target D29 + tick 7 = +____________________Z______________________ + tick 8 = +____________________X______________________ + tick 9 = +____________________XX_____________________ + tick 10 = +___________________XXX_______X_____________ + tick 11 = +____________X______XXXX______X_____________ + tick 12 = +__________X_X______XXX_____________________ + tick 13 = +__________X_X______XZX_____________________ + tick 14 = +__________X_X______XZX_____________________ + target D44 + tick 14 = +____________________Z______________________ + target L0 + tick 5 = +_X________X________X________X______________ + tick 6 = +_X________X________X________X______________ + tick 7 = +_X________X________X________X______________ + tick 8 = +_X________X________X________X______________ + tick 9 = +_X________X_______XX________X______________ + tick 10 = +_X________X________X________X______________ + tick 11 = +_X________XX_______X________XX_____________ + tick 12 = +_X________X________X________X______________ + tick 13 = +_X________X________X________X______________ + tick 14 = +_X________X________X________X______________ + """ def detector_error_model( self, *, diff --git a/src/stim/circuit/circuit.pybind.cc b/src/stim/circuit/circuit.pybind.cc index d0e0c6349..7eb911298 100644 --- a/src/stim/circuit/circuit.pybind.cc +++ b/src/stim/circuit/circuit.pybind.cc @@ -46,6 +46,89 @@ using namespace stim; using namespace stim_pybind; +std::set py_dem_filter_to_dem_target_set( + const Circuit &circuit, const CircuitStats &stats, const pybind11::object &included_targets_filter) { + std::set result; + auto add_all_dets = [&]() { + for (uint64_t k = 0; k < stats.num_detectors; k++) { + result.insert(DemTarget::relative_detector_id(k)); + } + }; + auto add_all_obs = [&]() { + for (uint64_t k = 0; k < stats.num_observables; k++) { + result.insert(DemTarget::observable_id(k)); + } + }; + + bool has_coords = false; + std::map> cached_coords; + auto get_coords_cached = [&]() -> const std::map> & { + std::set all_dets; + for (uint64_t k = 0; k < stats.num_detectors; k++) { + all_dets.insert(k); + } + if (!has_coords) { + cached_coords = circuit.get_detector_coordinates(all_dets); + has_coords = true; + } + return cached_coords; + }; + + if (included_targets_filter.is_none()) { + add_all_dets(); + add_all_obs(); + return result; + } + for (const auto &filter : included_targets_filter) { + bool fail = false; + if (pybind11::isinstance(filter)) { + result.insert(pybind11::cast(filter)); + } else if (pybind11::isinstance(filter)) { + std::string s = pybind11::cast(filter); + if (s == "D") { + add_all_dets(); + } else if (s == "L") { + add_all_obs(); + } else if (s.starts_with("D") || s.starts_with("L")) { + result.insert(DemTarget::from_text(s)); + } else { + fail = true; + } + } else { + std::vector prefix; + for (auto e : filter) { + if (pybind11::isinstance(e) || pybind11::isinstance(e)) { + prefix.push_back(pybind11::cast(e)); + } else { + fail = true; + break; + } + } + if (!fail) { + for (const auto &[target, coord] : get_coords_cached()) { + if (coord.size() >= prefix.size()) { + bool match = true; + for (size_t k = 0; k < prefix.size(); k++) { + match &= prefix[k] == coord[k]; + } + if (match) { + result.insert(DemTarget::relative_detector_id(target)); + } + } + } + } + } + if (fail) { + std::stringstream ss; + ss << "Don't know how to interpret '"; + ss << pybind11::cast(pybind11::repr(filter)); + ss << "' as a dem target filter."; + throw std::invalid_argument(ss.str()); + } + } + return result; +} + std::string circuit_repr(const Circuit &self) { if (self.operations.empty()) { return "stim.Circuit()"; @@ -2034,6 +2117,163 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_ std::map> { + auto stats = self.compute_stats(); + auto included_target_set = py_dem_filter_to_dem_target_set(self, stats, included_targets); + std::set included_tick_set; + + if (included_ticks.is_none()) { + for (uint64_t k = 0; k < stats.num_ticks; k++) { + included_tick_set.insert(k); + } + } else { + for (const auto &t : included_ticks) { + included_tick_set.insert(pybind11::cast(t)); + } + } + auto result = circuit_to_detecting_regions( + self, included_target_set, included_tick_set, ignore_anticommutation_errors); + std::map> exposed_result; + for (const auto &[k, v] : result) { + exposed_result.insert({ExposedDemTarget(k), std::move(v)}); + } + return exposed_result; + }, + pybind11::kw_only(), + pybind11::arg("targets") = pybind11::none(), + pybind11::arg("ticks") = pybind11::none(), + pybind11::arg("ignore_anticommutation_errors") = false, + clean_doc_string(R"DOC( + @signature def detecting_regions(self, *, targets: Optional[Iterable[stim.DemTarget | str | Iterable[float]]] = None, ticks: Optional[Iterable[int]] = None) -> Dict[stim.DemTarget, Dict[int, stim.PauliString]]: + Records where detectors and observables are sensitive to errors over time. + + The result of this method is a nested dictionary, mapping detectors/observables + and ticks to Pauli sensitivities for that detector/observable at that time. + + For example, if observable 2 has Z-type sensitivity on qubits 5 and 6 during + tick 3, then `result[stim.target_logical_observable_id(2)][3]` will be equal to + `stim.PauliString("Z5*Z6")`. + + If you want sensitivities from more places in the circuit, besides just at the + TICK instructions, you can work around this by making a version of the circuit + with more TICKs. + + Args: + targets: Defaults to everything (None). + + When specified, this should be an iterable of filters where items + matching any one filter are included. + + A variety of filters are supported: + stim.DemTarget: Includes the targeted detector or observable. + Iterable[float]: Coordinate prefix match. Includes detectors whose + coordinate data begins with the same floats. + "D": Includes all detectors. + "L": Includes all observables. + "D#" (e.g. "D5"): Includes the detector with the specified index. + "L#" (e.g. "L5"): Includes the observable with the specified index. + + ticks: Defaults to everything (None). + When specified, this should be a list of integers corresponding to + the tick indices to report sensitivities for. + + ignore_anticommutation_errors: Defaults to False. + When set to False, invalid detecting regions that anticommute with a + reset will cause the method to raise an exception. When set to True, + the offending component will simply be silently dropped. This can + result in broken detectors having apparently enormous detecting + regions. + + Returns: + Nested dictionaries keyed first by a `stim.DemTarget` identifying the + detector or observable, then by the index of the tick, leading to a + PauliString with that target's error sensitivity at that tick. + + Note you can use `stim.PauliString.pauli_indices` to quickly get to the + non-identity terms in the sensitivity. + + Examples: + >>> import stim + + >>> detecting_regions = stim.Circuit(''' + ... R 0 + ... TICK + ... H 0 + ... TICK + ... CX 0 1 + ... TICK + ... MX 0 1 + ... DETECTOR rec[-1] rec[-2] + ... ''').detecting_regions() + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D0 + tick 0 = +Z_ + tick 1 = +X_ + tick 2 = +XX + + >>> circuit = stim.Circuit.generated( + ... "surface_code:rotated_memory_x", + ... rounds=5, + ... distance=4, + ... ) + + >>> detecting_regions = circuit.detecting_regions( + ... targets=["L0", (2, 4), stim.DemTarget.relative_detector_id(5)], + ... ticks=range(5, 15), + ... ) + >>> for target, tick_regions in detecting_regions.items(): + ... print("target", target) + ... for tick, sensitivity in tick_regions.items(): + ... print(" tick", tick, "=", sensitivity) + target D1 + tick 5 = +____________________X______________________ + tick 6 = +____________________Z______________________ + target D5 + tick 5 = +______X____________________________________ + tick 6 = +______Z____________________________________ + target D14 + tick 5 = +__________X_X______XXX_____________________ + tick 6 = +__________X_X______XZX_____________________ + tick 7 = +__________X_X______XZX_____________________ + tick 8 = +__________X_X______XXX_____________________ + tick 9 = +__________XXX_____XXX______________________ + tick 10 = +__________XXX_______X______________________ + tick 11 = +__________X_________X______________________ + tick 12 = +____________________X______________________ + tick 13 = +____________________Z______________________ + target D29 + tick 7 = +____________________Z______________________ + tick 8 = +____________________X______________________ + tick 9 = +____________________XX_____________________ + tick 10 = +___________________XXX_______X_____________ + tick 11 = +____________X______XXXX______X_____________ + tick 12 = +__________X_X______XXX_____________________ + tick 13 = +__________X_X______XZX_____________________ + tick 14 = +__________X_X______XZX_____________________ + target D44 + tick 14 = +____________________Z______________________ + target L0 + tick 5 = +_X________X________X________X______________ + tick 6 = +_X________X________X________X______________ + tick 7 = +_X________X________X________X______________ + tick 8 = +_X________X________X________X______________ + tick 9 = +_X________X_______XX________X______________ + tick 10 = +_X________X________X________X______________ + tick 11 = +_X________XX_______X________XX_____________ + tick 12 = +_X________X________X________X______________ + tick 13 = +_X________X________X________X______________ + tick 14 = +_X________X________X________X______________ + )DOC") + .data()); + c.def( "without_noise", &Circuit::without_noise, diff --git a/src/stim/circuit/circuit_pybind_test.py b/src/stim/circuit/circuit_pybind_test.py index 4dca08503..82dd5f86b 100644 --- a/src/stim/circuit/circuit_pybind_test.py +++ b/src/stim/circuit/circuit_pybind_test.py @@ -1697,3 +1697,30 @@ def test_has_flow_shorthands(): assert c.has_flow(stim.Flow("-iX_ -> -iXX xor rec[1] xor rec[3]")) with pytest.raises(ValueError): stim.Flow("iX_ -> XX") + + +def test_detecting_regions(): + assert stim.Circuit(''' + R 0 + TICK + H 0 + TICK + CX 0 1 + TICK + MX 0 1 + DETECTOR rec[-1] rec[-2] + ''').detecting_regions() == {stim.DemTarget.relative_detector_id(0): { + 0: stim.PauliString("Z_"), + 1: stim.PauliString("X_"), + 2: stim.PauliString("XX"), + }} + + +def test_detecting_region_filters(): + c = stim.Circuit.generated("repetition_code:memory", distance=3, rounds=3) + assert len(c.detecting_regions(targets=["D"])) == c.num_detectors + assert len(c.detecting_regions(targets=["L"])) == c.num_observables + assert len(c.detecting_regions()) == c.num_observables + c.num_detectors + assert len(c.detecting_regions(targets=["D0"])) == 1 + assert len(c.detecting_regions(targets=["D0", "L0"])) == 2 + assert len(c.detecting_regions(targets=[stim.target_relative_detector_id(0), "D0"])) == 1 diff --git a/src/stim/dem/dem_instruction.cc b/src/stim/dem/dem_instruction.cc index b240ff14c..6ddff59d7 100644 --- a/src/stim/dem/dem_instruction.cc +++ b/src/stim/dem/dem_instruction.cc @@ -2,6 +2,7 @@ #include +#include "stim/arg_parse.h" #include "stim/dem/detector_error_model.h" #include "stim/simulators/error_analyzer.h" #include "stim/str_util.h" @@ -11,14 +12,17 @@ using namespace stim; constexpr uint64_t OBSERVABLE_BIT = uint64_t{1} << 63; constexpr uint64_t SEPARATOR_SYGIL = UINT64_MAX; +constexpr uint64_t MAX_OBS = 0xFFFFFFFF; +constexpr uint64_t MAX_DET = (uint64_t{1} << 62) - 1; + DemTarget DemTarget::observable_id(uint64_t id) { - if (id > 0xFFFFFFFF) { + if (id > MAX_OBS) { throw std::invalid_argument("id > 0xFFFFFFFF"); } return {OBSERVABLE_BIT | id}; } DemTarget DemTarget::relative_detector_id(uint64_t id) { - if (id >= (uint64_t{1} << 62)) { + if (id > MAX_DET) { throw std::invalid_argument("Relative detector id too large."); } return {id}; @@ -75,6 +79,25 @@ void DemTarget::shift_if_detector_id(int64_t offset) { data = (uint64_t)((int64_t)data + offset); } } +DemTarget DemTarget::from_text(std::string_view text) { + if (!text.empty()) { + bool is_det = text[0] == 'D'; + bool is_obs = text[0] == 'L'; + if (is_det || is_obs) { + int64_t parsed = 0; + if (parse_int64(text.substr(1), &parsed)) { + if (parsed >= 0) { + if (is_det && parsed <= (int64_t)MAX_DET) { + return DemTarget::relative_detector_id(parsed); + } else if (is_obs && parsed <= (int64_t)MAX_OBS) { + return DemTarget::observable_id(parsed); + } + } + } + } + } + throw std::invalid_argument("Failed to parse as a stim.DemTarget: '" + std::string(text) + "'"); +} bool DemInstruction::operator<(const DemInstruction &other) const { if (type != other.type) { diff --git a/src/stim/dem/dem_instruction.h b/src/stim/dem/dem_instruction.h index cba775605..d75c4f349 100644 --- a/src/stim/dem/dem_instruction.h +++ b/src/stim/dem/dem_instruction.h @@ -37,6 +37,8 @@ struct DemTarget { bool operator!=(const DemTarget &other) const; bool operator<(const DemTarget &other) const; std::string str() const; + + static DemTarget from_text(std::string_view text); }; struct DetectorErrorModel; diff --git a/src/stim/dem/dem_instruction.test.cc b/src/stim/dem/dem_instruction.test.cc new file mode 100644 index 000000000..53865e465 --- /dev/null +++ b/src/stim/dem/dem_instruction.test.cc @@ -0,0 +1,31 @@ +#include "stim/dem/dem_instruction.h" + +#include "gtest/gtest.h" + +using namespace stim; + +TEST(dem_instruction, from_str) { + ASSERT_EQ(DemTarget::from_text("D5"), DemTarget::relative_detector_id(5)); + ASSERT_EQ(DemTarget::from_text("D0"), DemTarget::relative_detector_id(0)); + ASSERT_EQ(DemTarget::from_text("D4611686018427387903"), DemTarget::relative_detector_id(4611686018427387903)); + + ASSERT_EQ(DemTarget::from_text("L5"), DemTarget::observable_id(5)); + ASSERT_EQ(DemTarget::from_text("L0"), DemTarget::observable_id(0)); + ASSERT_EQ(DemTarget::from_text("L4294967295"), DemTarget::observable_id(4294967295)); + + ASSERT_THROW({ DemTarget::from_text("D4611686018427387904"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("L4294967296"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("L-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("L-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("D-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("Da"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("Da "); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text(" Da"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("X"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text(""); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("-1"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("0"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text("'"); }, std::invalid_argument); + ASSERT_THROW({ DemTarget::from_text(" "); }, std::invalid_argument); +} diff --git a/src/stim/stabilizers/conversions.cc b/src/stim/stabilizers/conversions.cc index caf7eba86..44c18119f 100644 --- a/src/stim/stabilizers/conversions.cc +++ b/src/stim/stabilizers/conversions.cc @@ -1,5 +1,7 @@ #include "stim/stabilizers/conversions.h" +#include "stim/simulators/sparse_rev_frame_tracker.h" + using namespace stim; void stim::independent_to_disjoint_xyz_errors( @@ -137,3 +139,44 @@ double stim::independent_per_channel_probability_to_depolarize2_probability(doub q *= q; return 15.0 / 16.0 * (1.0 - q); } + +std::map> stim::circuit_to_detecting_regions( + const Circuit &circuit, + std::set included_targets, + std::set included_ticks, + bool ignore_anticommutation_errors) { + CircuitStats stats = circuit.compute_stats(); + uint64_t tick_index = stats.num_ticks; + SparseUnsignedRevFrameTracker tracker( + stats.num_qubits, stats.num_measurements, stats.num_detectors, !ignore_anticommutation_errors); + std::map> result; + circuit.for_each_operation_reverse([&](const CircuitInstruction &inst) { + if (inst.gate_type == GateType::TICK) { + tick_index -= 1; + if (included_ticks.contains(tick_index)) { + for (size_t q = 0; q < stats.num_qubits; q++) { + for (auto target : tracker.xs[q]) { + if (included_targets.contains(target)) { + auto &m = result[target]; + if (!m.contains(tick_index)) { + m.insert({tick_index, FlexPauliString(stats.num_qubits)}); + } + m.at(tick_index).value.xs[q] ^= 1; + } + } + for (auto target : tracker.zs[q]) { + if (included_targets.contains(target)) { + auto &m = result[target]; + if (!m.contains(tick_index)) { + m.insert({tick_index, FlexPauliString(stats.num_qubits)}); + } + m.at(tick_index).value.zs[q] ^= 1; + } + } + } + } + } + tracker.undo_gate(inst); + }); + return result; +} diff --git a/src/stim/stabilizers/conversions.h b/src/stim/stabilizers/conversions.h index a1fb9fc4d..4349693fd 100644 --- a/src/stim/stabilizers/conversions.h +++ b/src/stim/stabilizers/conversions.h @@ -18,6 +18,8 @@ #define _STIM_STABILIZERS_CONVERSIONS_H #include "stim/circuit/circuit.h" +#include "stim/dem/dem_instruction.h" +#include "stim/stabilizers/flex_pauli_string.h" #include "stim/stabilizers/tableau.h" namespace stim { @@ -179,6 +181,12 @@ double depolarize2_probability_to_independent_per_channel_probability(double p); double independent_per_channel_probability_to_depolarize1_probability(double p); double independent_per_channel_probability_to_depolarize2_probability(double p); +std::map> circuit_to_detecting_regions( + const Circuit &circuit, + std::set included_targets, + std::set included_ticks, + bool ignore_anticommutation_errors); + } // namespace stim #include "stim/stabilizers/conversions.inl" From 4040fd80606b56174e12b05e6595ec4c8092a4bf Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Mon, 11 Mar 2024 21:37:40 -0700 Subject: [PATCH 3/6] Add `stim.Circuit.decomposed` (#712) - And allow `MPAD` to take a noise argument --- doc/gates.md | 6 + doc/python_api_reference_vDev.md | 73 ++ doc/stim.pyi | 65 ++ glue/python/src/stim/__init__.pyi | 65 ++ src/stim/circuit/circuit.pybind.cc | 68 ++ src/stim/circuit/circuit_pybind_test.py | 34 +- src/stim/circuit/export_qasm.cc | 58 +- src/stim/circuit/export_qasm.test.cc | 132 +++- src/stim/circuit/export_qasm_pybind_test.py | 3 +- src/stim/circuit/gate_decomposition.cc | 677 ++++++++++++++++-- src/stim/circuit/gate_decomposition.h | 18 +- src/stim/circuit/gate_decomposition.test.cc | 143 +++- src/stim/circuit/gate_target.cc | 6 + src/stim/circuit/gate_target.h | 2 + src/stim/circuit/gate_target.test.cc | 28 + src/stim/cmd/command_help.cc | 42 +- src/stim/cmd/command_help.h | 6 + src/stim/diagram/circuit_timeline_helper.cc | 12 +- src/stim/gates/gate_data_annotations.cc | 10 +- src/stim/gates/gates.test.cc | 68 +- src/stim/simulators/error_analyzer.cc | 67 +- src/stim/simulators/frame_simulator.inl | 18 +- .../simulators/sparse_rev_frame_tracker.cc | 32 +- src/stim/simulators/tableau_simulator.inl | 14 +- src/stim/simulators/tableau_simulator.test.cc | 2 +- 25 files changed, 1351 insertions(+), 298 deletions(-) diff --git a/doc/gates.md b/doc/gates.md index de61175c3..486a4d0d8 100644 --- a/doc/gates.md +++ b/doc/gates.md @@ -3143,6 +3143,12 @@ This can be useful for ensuring measurements are aligned to word boundaries, or number of measurement bits produced per circuit layer is always the same even if the number of measured qubits varies. +Parens Arguments: + + If no parens argument is given, the padding bits are recorded perfectly. + If one parens argument is given, the padding bits are recorded noisily. + The argument is the probability of recording the wrong result. + Targets: Each target is a measurement result to add. diff --git a/doc/python_api_reference_vDev.md b/doc/python_api_reference_vDev.md index d6a474d01..46a03c877 100644 --- a/doc/python_api_reference_vDev.md +++ b/doc/python_api_reference_vDev.md @@ -27,6 +27,7 @@ API references for stable versions are kept on the [stim github wiki](https://gi - [`stim.Circuit.compile_sampler`](#stim.Circuit.compile_sampler) - [`stim.Circuit.copy`](#stim.Circuit.copy) - [`stim.Circuit.count_determined_measurements`](#stim.Circuit.count_determined_measurements) + - [`stim.Circuit.decomposed`](#stim.Circuit.decomposed) - [`stim.Circuit.detecting_regions`](#stim.Circuit.detecting_regions) - [`stim.Circuit.detector_error_model`](#stim.Circuit.detector_error_model) - [`stim.Circuit.diagram`](#stim.Circuit.diagram) @@ -1273,6 +1274,78 @@ def count_determined_measurements( """ ``` + +```python +# stim.Circuit.decomposed + +# (in class stim.Circuit) +def decomposed( + self, +) -> stim.Circuit: + """Recreates the circuit using (mostly) the {H,S,CX,M,R} gate set. + + The intent of this method is to simplify the circuit to use fewer gate types, + so it's easier for other tools to consume. Currently, this method performs the + following simplifications: + + - Single qubit cliffords are decomposed into {H,S}. + - Multi-qubit cliffords are decomposed into {H,S,CX}. + - Single qubit dissipative gates are decomposed into {H,S,M,R}. + - Multi-qubit dissipative gates are decomposed into {H,S,CX,M,R}. + + Currently, the following types of gate *aren't* simplified, but they may be + in the future: + + - Noise instructions (like X_ERROR, DEPOLARIZE2, and E). + - Annotations (like TICK, DETECTOR, and SHIFT_COORDS). + - The MPAD instruction. + - Repeat blocks are not flattened. + + Returns: + A `stim.Circuit` whose function is equivalent to the original circuit, + but with most gates decomposed into the {H,S,CX,M,R} gate set. + + Examples: + >>> import stim + + >>> stim.Circuit(''' + ... SWAP 0 1 + ... ''').decomposed() + stim.Circuit(''' + CX 0 1 1 0 0 1 + ''') + + >>> stim.Circuit(''' + ... ISWAP 0 1 2 1 + ... TICK + ... MPP !X1*Y2*Z3 + ... ''').decomposed() + stim.Circuit(''' + H 0 + CX 0 1 1 0 + H 1 + S 1 0 + H 2 + CX 2 1 1 2 + H 1 + S 1 2 + TICK + H 1 2 + S 2 + H 2 + S 2 2 + CX 2 1 3 1 + M !1 + CX 2 1 3 1 + H 2 + S 2 + H 2 + S 2 2 + H 1 + ''') + """ +``` + ```python # stim.Circuit.detecting_regions diff --git a/doc/stim.pyi b/doc/stim.pyi index fffa1d476..ce6e0c2fe 100644 --- a/doc/stim.pyi +++ b/doc/stim.pyi @@ -716,6 +716,71 @@ class Circuit: >>> circuit.num_detectors + circuit.num_observables 217 """ + def decomposed( + self, + ) -> stim.Circuit: + """Recreates the circuit using (mostly) the {H,S,CX,M,R} gate set. + + The intent of this method is to simplify the circuit to use fewer gate types, + so it's easier for other tools to consume. Currently, this method performs the + following simplifications: + + - Single qubit cliffords are decomposed into {H,S}. + - Multi-qubit cliffords are decomposed into {H,S,CX}. + - Single qubit dissipative gates are decomposed into {H,S,M,R}. + - Multi-qubit dissipative gates are decomposed into {H,S,CX,M,R}. + + Currently, the following types of gate *aren't* simplified, but they may be + in the future: + + - Noise instructions (like X_ERROR, DEPOLARIZE2, and E). + - Annotations (like TICK, DETECTOR, and SHIFT_COORDS). + - The MPAD instruction. + - Repeat blocks are not flattened. + + Returns: + A `stim.Circuit` whose function is equivalent to the original circuit, + but with most gates decomposed into the {H,S,CX,M,R} gate set. + + Examples: + >>> import stim + + >>> stim.Circuit(''' + ... SWAP 0 1 + ... ''').decomposed() + stim.Circuit(''' + CX 0 1 1 0 0 1 + ''') + + >>> stim.Circuit(''' + ... ISWAP 0 1 2 1 + ... TICK + ... MPP !X1*Y2*Z3 + ... ''').decomposed() + stim.Circuit(''' + H 0 + CX 0 1 1 0 + H 1 + S 1 0 + H 2 + CX 2 1 1 2 + H 1 + S 1 2 + TICK + H 1 2 + S 2 + H 2 + S 2 2 + CX 2 1 3 1 + M !1 + CX 2 1 3 1 + H 2 + S 2 + H 2 + S 2 2 + H 1 + ''') + """ def detecting_regions( self, *, diff --git a/glue/python/src/stim/__init__.pyi b/glue/python/src/stim/__init__.pyi index fffa1d476..ce6e0c2fe 100644 --- a/glue/python/src/stim/__init__.pyi +++ b/glue/python/src/stim/__init__.pyi @@ -716,6 +716,71 @@ class Circuit: >>> circuit.num_detectors + circuit.num_observables 217 """ + def decomposed( + self, + ) -> stim.Circuit: + """Recreates the circuit using (mostly) the {H,S,CX,M,R} gate set. + + The intent of this method is to simplify the circuit to use fewer gate types, + so it's easier for other tools to consume. Currently, this method performs the + following simplifications: + + - Single qubit cliffords are decomposed into {H,S}. + - Multi-qubit cliffords are decomposed into {H,S,CX}. + - Single qubit dissipative gates are decomposed into {H,S,M,R}. + - Multi-qubit dissipative gates are decomposed into {H,S,CX,M,R}. + + Currently, the following types of gate *aren't* simplified, but they may be + in the future: + + - Noise instructions (like X_ERROR, DEPOLARIZE2, and E). + - Annotations (like TICK, DETECTOR, and SHIFT_COORDS). + - The MPAD instruction. + - Repeat blocks are not flattened. + + Returns: + A `stim.Circuit` whose function is equivalent to the original circuit, + but with most gates decomposed into the {H,S,CX,M,R} gate set. + + Examples: + >>> import stim + + >>> stim.Circuit(''' + ... SWAP 0 1 + ... ''').decomposed() + stim.Circuit(''' + CX 0 1 1 0 0 1 + ''') + + >>> stim.Circuit(''' + ... ISWAP 0 1 2 1 + ... TICK + ... MPP !X1*Y2*Z3 + ... ''').decomposed() + stim.Circuit(''' + H 0 + CX 0 1 1 0 + H 1 + S 1 0 + H 2 + CX 2 1 1 2 + H 1 + S 1 2 + TICK + H 1 2 + S 2 + H 2 + S 2 2 + CX 2 1 3 1 + M !1 + CX 2 1 3 1 + H 2 + S 2 + H 2 + S 2 2 + H 1 + ''') + """ def detecting_regions( self, *, diff --git a/src/stim/circuit/circuit.pybind.cc b/src/stim/circuit/circuit.pybind.cc index 7eb911298..7864c0102 100644 --- a/src/stim/circuit/circuit.pybind.cc +++ b/src/stim/circuit/circuit.pybind.cc @@ -2345,6 +2345,74 @@ void stim_pybind::pybind_circuit_methods(pybind11::module &, pybind11::class_>> import stim + + >>> stim.Circuit(''' + ... SWAP 0 1 + ... ''').decomposed() + stim.Circuit(''' + CX 0 1 1 0 0 1 + ''') + + >>> stim.Circuit(''' + ... ISWAP 0 1 2 1 + ... TICK + ... MPP !X1*Y2*Z3 + ... ''').decomposed() + stim.Circuit(''' + H 0 + CX 0 1 1 0 + H 1 + S 1 0 + H 2 + CX 2 1 1 2 + H 1 + S 1 2 + TICK + H 1 2 + S 2 + H 2 + S 2 2 + CX 2 1 3 1 + M !1 + CX 2 1 3 1 + H 2 + S 2 + H 2 + S 2 2 + H 1 + ''') + )DOC") + .data()); + c.def( "with_inlined_feedback", &circuit_with_inlined_feedback, diff --git a/src/stim/circuit/circuit_pybind_test.py b/src/stim/circuit/circuit_pybind_test.py index 82dd5f86b..e6dbc3cec 100644 --- a/src/stim/circuit/circuit_pybind_test.py +++ b/src/stim/circuit/circuit_pybind_test.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pathlib -import re import tempfile from typing import cast @@ -1695,10 +1695,40 @@ def test_has_flow_shorthands(): assert c.has_flow(stim.Flow("iX_ -> iXX xor rec[1] xor rec[3]")) assert not c.has_flow(stim.Flow("-iX_ -> iXX xor rec[1] xor rec[3]")) assert c.has_flow(stim.Flow("-iX_ -> -iXX xor rec[1] xor rec[3]")) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Anti-Hermitian"): stim.Flow("iX_ -> XX") +def test_decomposed(): + assert stim.Circuit(""" + ISWAP 0 1 2 1 + TICK + MPP X1*Z2*Y3 + """).decomposed() == stim.Circuit(""" + H 0 + CX 0 1 1 0 + H 1 + S 1 0 + H 2 + CX 2 1 1 2 + H 1 + S 1 2 + TICK + H 1 3 + S 3 + H 3 + S 3 3 + CX 2 1 3 1 + M 1 + CX 2 1 3 1 + H 3 + S 3 + H 3 + S 3 3 + H 1 + """) + + def test_detecting_regions(): assert stim.Circuit(''' R 0 diff --git a/src/stim/circuit/export_qasm.cc b/src/stim/circuit/export_qasm.cc index d18a2a19d..551ed3dff 100644 --- a/src/stim/circuit/export_qasm.cc +++ b/src/stim/circuit/export_qasm.cc @@ -131,45 +131,11 @@ struct QasmExporter { } void output_decomposed_mpp_operation(const CircuitInstruction &inst) { - decompose_mpp_operation( - inst, - stats.num_qubits, - [&](const CircuitInstruction &h_xz, - const CircuitInstruction &h_yz, - const CircuitInstruction &cnot, - const CircuitInstruction &meas) { - for (const auto t : h_xz.targets) { - out << "h q[" << t.qubit_value() << "];"; - } - for (const auto t : h_yz.targets) { - out << "sx q[" << t.qubit_value() << "];"; - } - for (size_t k = 0; k < cnot.targets.size(); k += 2) { - auto t1 = cnot.targets[k]; - auto t2 = cnot.targets[k + 1]; - out << "cx q[" << t1.qubit_value() << "],q[" << t2.qubit_value() << "];"; - } - for (auto t : meas.targets) { - buf_q1.str(""); - buf_m.str(""); - buf_q1 << "q[" << t.qubit_value() << "]"; - buf_m << "rec[" << measurement_offset << "]"; - output_measurement(t.is_inverted_result_target(), buf_q1.str().c_str(), buf_m.str().c_str()); - measurement_offset++; - } - for (size_t k = 0; k < cnot.targets.size(); k += 2) { - auto t1 = cnot.targets[k]; - auto t2 = cnot.targets[k + 1]; - out << "cx q[" << t1.qubit_value() << "],q[" << t2.qubit_value() << "];"; - } - for (const auto t : h_yz.targets) { - out << "sxdg q[" << t.qubit_value() << "];"; - } - for (const auto t : h_xz.targets) { - out << "h q[" << t.qubit_value() << "];"; - } - out << " // decomposed MPP\n"; - }); + out << "// --- begin decomposed " << inst << "\n"; + decompose_mpp_operation(inst, stats.num_qubits, [&](const CircuitInstruction &inst) { + output_instruction(inst); + }); + out << "// --- end decomposed MPP\n"; } void output_decomposable_instruction(const CircuitInstruction &instruction, bool decompose_inline) { @@ -460,7 +426,19 @@ struct QasmExporter { return; case GateType::MPAD: - measurement_offset += instruction.count_measurement_results(); + for (const auto &t : instruction.targets) { + if (open_qasm_version == 3) { + out << "rec[" << measurement_offset << "] = " << t.qubit_value() << ";\n"; + } else { + if (t.qubit_value()) { + throw std::invalid_argument( + "The circuit contains a vacuous measurement with a non-zero result " + "(like MPAD 1 or MPP !X1*X1) but OPENQASM 2 doesn't support classical assignment.\n" + "Pass the argument `open_qasm_version=3` to fix this."); + } + } + measurement_offset++; + } return; case GateType::TICK: diff --git a/src/stim/circuit/export_qasm.test.cc b/src/stim/circuit/export_qasm.test.cc index d0f67e86e..dad655398 100644 --- a/src/stim/circuit/export_qasm.test.cc +++ b/src/stim/circuit/export_qasm.test.cc @@ -64,7 +64,13 @@ creg rec[4]; measure q[0] -> rec[0];rec[0] = rec[0] ^ 1; rec[1] = mx(q[0]) ^ 1; rec[2] = mxx(q[0], q[1]) ^ 1; -h q[0];cx q[1],q[0];measure q[0] -> rec[3];rec[3] = rec[3] ^ 1;cx q[1],q[0];h q[0]; // decomposed MPP +// --- begin decomposed MPP !X0*Z1 +h q[0]; +cx q[1], q[0]; +measure q[0] -> rec[3];rec[3] = rec[3] ^ 1; +cx q[1], q[0]; +h q[0]; +// --- end decomposed MPP )QASM"); } @@ -178,6 +184,49 @@ measure q[2] -> rec[3]; )QASM"); } +TEST(export_qasm, export_open_qasm_mpad) { + Circuit c(R"CIRCUIT( + H 0 + MPAD 0 1 0 + M 0 + )CIRCUIT"); + + std::stringstream out; + export_open_qasm(c, out, 3, false); + ASSERT_EQ(out.str(), R"QASM(OPENQASM 3.0; +include "stdgates.inc"; + +qreg q[2]; +creg rec[4]; + +h q[0]; +rec[0] = 0; +rec[1] = 1; +rec[2] = 0; +measure q[0] -> rec[3]; +)QASM"); + + out.str(""); + ASSERT_THROW({ export_open_qasm(c, out, 2, true); }, std::invalid_argument); + c = Circuit(R"CIRCUIT( + H 0 + MPAD 0 0 0 + M 0 + )CIRCUIT"); + + out.str(""); + export_open_qasm(c, out, 2, true); + ASSERT_EQ(out.str(), R"QASM(OPENQASM 2.0; +include "qelib1.inc"; + +qreg q[1]; +creg rec[4]; + +h q[0]; +measure q[0] -> rec[3]; +)QASM"); +} + TEST(export_qasm, export_qasm_decomposed_operations) { Circuit c(R"CIRCUIT( R 3 @@ -246,7 +295,7 @@ h q[4]; measure q[4] -> rec[4]; reset q[4]; h q[4]; // decomposed MRX )QASM"); } -TEST(export_qasm, export_qasm_all_operations) { +TEST(export_qasm, export_qasm_all_operations_v3) { Circuit c = generate_test_circuit_with_all_operations(); c = c.without_noise(); @@ -336,10 +385,24 @@ cy q[14], q[15]; cz q[16], q[17]; barrier q; +rec[0] = 0; +rec[1] = 0; barrier q; -h q[0];sx q[1];cx q[1],q[0];cx q[2],q[0];measure q[0] -> rec[2];cx q[1],q[0];cx q[2],q[0];sxdg q[1];h q[0]; // decomposed MPP -cx q[1],q[0];measure q[0] -> rec[3];cx q[1],q[0]; // decomposed MPP +// --- begin decomposed MPP X0*Y1*Z2 Z0*Z1 +h q[0]; +hyz q[1]; +cx q[1], q[0]; +cx q[2], q[0]; +measure q[0] -> rec[2]; +cx q[1], q[0]; +cx q[2], q[0]; +hyz q[1]; +h q[0]; +cx q[1], q[0]; +measure q[0] -> rec[3]; +cx q[1], q[0]; +// --- end decomposed MPP rec[4] = mrx(q[0]); rec[5] = mry(q[1]); rec[6] = mr(q[2]); @@ -379,13 +442,26 @@ rec[15] = mr(q[0]); rec[16] = mr(q[0]); dets[0] = rec[16] ^ 0; obs[0] = obs[0] ^ rec[16] ^ 0; +rec[17] = 0; +rec[18] = 1; +rec[19] = 0; barrier q; rec[20] = mrx(q[0]) ^ 1; rec[21] = my(q[1]) ^ 1; rec[22] = mzz(q[2], q[3]) ^ 1; rec[23] = myy(q[4], q[5]); -h q[6];sx q[7];cx q[7],q[6];cx q[8],q[6];measure q[6] -> rec[24];rec[24] = rec[24] ^ 1;cx q[7],q[6];cx q[8],q[6];sxdg q[7];h q[6]; // decomposed MPP +// --- begin decomposed MPP X6*!Y7*Z8 +h q[6]; +hyz q[7]; +cx q[7], q[6]; +cx q[8], q[6]; +measure q[6] -> rec[24];rec[24] = rec[24] ^ 1; +cx q[7], q[6]; +cx q[8], q[6]; +hyz q[7]; +h q[6]; +// --- end decomposed MPP barrier q; if (ms[24]) { @@ -398,13 +474,19 @@ if (ms[24]) { Z q[2]; } )QASM"); +} - out.str(""); +TEST(export_qasm, export_qasm_all_operations_v2) { + Circuit c = generate_test_circuit_with_all_operations(); + c = c.without_noise(); + + std::stringstream out; c = circuit_with_inlined_feedback(c); for (size_t k = 0; k < c.operations.size(); k++) { bool drop = false; for (auto t : c.operations[k].targets) { drop |= t.is_sweep_bit_target(); + drop |= c.operations[k].gate_type == GateType::MPAD && t.qubit_value() > 0; } if (drop) { c.operations.erase(c.operations.begin() + k); @@ -439,7 +521,7 @@ gate ycy q0, q1 { s q0; s q0; s q0; s q1; s q1; s q1; h q0; cx q0, q1; h q0; s q gate ycz q0, q1 { s q0; s q0; s q0; cx q1, q0; s q0; } qreg q[18]; -creg rec[25]; +creg rec[22]; id q[0]; x q[1]; @@ -485,8 +567,20 @@ barrier q; barrier q; -h q[0];sx q[1];cx q[1],q[0];cx q[2],q[0];measure q[0] -> rec[2];cx q[1],q[0];cx q[2],q[0];sxdg q[1];h q[0]; // decomposed MPP -cx q[1],q[0];measure q[0] -> rec[3];cx q[1],q[0]; // decomposed MPP +// --- begin decomposed MPP X0*Y1*Z2 Z0*Z1 +h q[0]; +hyz q[1]; +cx q[1], q[0]; +cx q[2], q[0]; +measure q[0] -> rec[2]; +cx q[1], q[0]; +cx q[2], q[0]; +hyz q[1]; +h q[0]; +cx q[1], q[0]; +measure q[0] -> rec[3]; +cx q[1], q[0]; +// --- end decomposed MPP h q[0]; measure q[0] -> rec[4]; reset q[0]; h q[0]; // decomposed MRX s q[1]; s q[1]; s q[1]; h q[1]; measure q[1] -> rec[5]; reset q[1]; h q[1]; s q[1]; // decomposed MRY measure q[2] -> rec[6]; reset q[2]; // decomposed MR @@ -526,11 +620,21 @@ measure q[0] -> rec[15]; reset q[0]; // decomposed MR measure q[0] -> rec[16]; reset q[0]; // decomposed MR barrier q; -h q[0]; x q[0];measure q[0] -> rec[20];x q[0]; reset q[0]; h q[0]; // decomposed MRX -s q[1]; s q[1]; s q[1]; h q[1]; x q[1];measure q[1] -> rec[21];x q[1]; h q[1]; s q[1]; // decomposed MY -cx q[2], q[3]; x q[3];measure q[3] -> rec[22];x q[3]; cx q[2], q[3]; // decomposed MZZ -s q[4]; s q[5]; cx q[4], q[5]; h q[4]; measure q[4] -> rec[23]; s q[5]; s q[5]; h q[4]; cx q[4], q[5]; s q[4]; s q[5]; // decomposed MYY -h q[6];sx q[7];cx q[7],q[6];cx q[8],q[6];x q[6];measure q[6] -> rec[24];x q[6];cx q[7],q[6];cx q[8],q[6];sxdg q[7];h q[6]; // decomposed MPP +h q[0]; x q[0];measure q[0] -> rec[17];x q[0]; reset q[0]; h q[0]; // decomposed MRX +s q[1]; s q[1]; s q[1]; h q[1]; x q[1];measure q[1] -> rec[18];x q[1]; h q[1]; s q[1]; // decomposed MY +cx q[2], q[3]; x q[3];measure q[3] -> rec[19];x q[3]; cx q[2], q[3]; // decomposed MZZ +s q[4]; s q[5]; cx q[4], q[5]; h q[4]; measure q[4] -> rec[20]; s q[5]; s q[5]; h q[4]; cx q[4], q[5]; s q[4]; s q[5]; // decomposed MYY +// --- begin decomposed MPP X6*!Y7*Z8 +h q[6]; +hyz q[7]; +cx q[7], q[6]; +cx q[8], q[6]; +x q[6];measure q[6] -> rec[21];x q[6]; +cx q[7], q[6]; +cx q[8], q[6]; +hyz q[7]; +h q[6]; +// --- end decomposed MPP barrier q; )QASM"); diff --git a/src/stim/circuit/export_qasm_pybind_test.py b/src/stim/circuit/export_qasm_pybind_test.py index f25308445..5bd0d9d2b 100644 --- a/src/stim/circuit/export_qasm_pybind_test.py +++ b/src/stim/circuit/export_qasm_pybind_test.py @@ -69,12 +69,13 @@ def test_to_qasm2_runs_in_qiskit(): stim_circuit = stim.Circuit(""" R 0 1 MZZ !0 1 + MPAD 0 0 """) qasm = stim_circuit.to_qasm(open_qasm_version=2) qiskit_circuit = qiskit.QuantumCircuit.from_qasm_str(qasm) counts = qiskit_aer.AerSimulator().run(qiskit_circuit, shots=8).result().get_counts(qiskit_circuit) - assert counts['1'] == 8 + assert counts['001'] == 8 def test_to_qasm3_parses_in_qiskit(): diff --git a/src/stim/circuit/gate_decomposition.cc b/src/stim/circuit/gate_decomposition.cc index bf7d12a4a..e65fa6971 100644 --- a/src/stim/circuit/gate_decomposition.cc +++ b/src/stim/circuit/gate_decomposition.cc @@ -16,87 +16,176 @@ #include "stim/circuit/gate_decomposition.h" +#include + +#include "stim/stabilizers/pauli_string.h" + using namespace stim; +struct ConjugateBySelfInverse { + CircuitInstruction inst; + const std::function &do_instruction_callback; + ConjugateBySelfInverse( + CircuitInstruction inst, const std::function &do_instruction_callback) + : inst(inst), do_instruction_callback(do_instruction_callback) { + if (!inst.targets.empty()) { + do_instruction_callback(inst); + } + } + ~ConjugateBySelfInverse() { + if (!inst.targets.empty()) { + do_instruction_callback(inst); + } + } +}; + +template +static void for_each_active_qubit_in(PauliStringRef<64> obs, CALLBACK callback) { + size_t n = obs.xs.num_u64_padded(); + for (size_t w = 0; w < n; w++) { + uint64_t v = 0; + if (use_x) { + v |= obs.xs.u64[w]; + } + if (use_z) { + v |= obs.zs.u64[w]; + } + while (v) { + size_t j = std::countr_zero(v); + v &= ~(uint64_t{1} << j); + bool b = false; + uint32_t q = (uint32_t)(w * 64 + j); + if (use_x) { + b |= obs.xs[q]; + } + if (use_z) { + b |= obs.zs[q]; + } + if (b) { + callback(w * 64 + j); + } + } + } +} + +bool stim::accumulate_next_obs_terms_to_pauli_string_helper( + CircuitInstruction instruction, + size_t *start, + PauliString<64> *obs, + std::vector *bits, + bool allow_imaginary) { + if (*start >= instruction.targets.size()) { + return false; + } + + if (bits != nullptr) { + bits->clear(); + } + obs->xs.clear(); + obs->zs.clear(); + obs->sign = false; + bool imag = false; + + // Find end of current product. + size_t end = *start + 1; + while (end < instruction.targets.size() && instruction.targets[end].is_combiner()) { + end += 2; + } + + // Accumulate terms. + for (size_t k = *start; k < end; k += 2) { + GateTarget t = instruction.targets[k]; + + if (t.is_pauli_target()) { + obs->left_mul_pauli(t, &imag); + } else if (t.is_classical_bit_target() && bits != nullptr) { + bits->push_back(t); + } else { + throw std::invalid_argument("Found an unsupported target `" + t.str() + "` in " + instruction.str()); + } + } + + if (imag && !allow_imaginary) { + throw std::invalid_argument( + "Acted on an anti-Hermitian operator (e.g. X0*Z0 instead of Y0) in " + instruction.str()); + } + + *start = end; + return true; +} + void stim::decompose_mpp_operation( const CircuitInstruction &mpp_op, size_t num_qubits, - const std::function &callback) { - simd_bits<64> used(num_qubits); - simd_bits<64> inner_used(num_qubits); + const std::function &do_instruction_callback) { + PauliString<64> current(num_qubits); + simd_bits<64> merged(num_qubits); std::vector h_xz; std::vector h_yz; std::vector cnot; std::vector meas; + auto flush = [&]() { + if (meas.empty()) { + return; + } + { + ConjugateBySelfInverse c1(CircuitInstruction{GateType::H, {}, h_xz}, do_instruction_callback); + ConjugateBySelfInverse c2(CircuitInstruction{GateType::H_YZ, {}, h_yz}, do_instruction_callback); + ConjugateBySelfInverse c3(CircuitInstruction{GateType::CX, {}, cnot}, do_instruction_callback); + do_instruction_callback(CircuitInstruction{GateType::M, mpp_op.args, meas}); + } + h_xz.clear(); + h_yz.clear(); + cnot.clear(); + meas.clear(); + merged.clear(); + }; + size_t start = 0; - while (start < mpp_op.targets.size()) { - size_t end = start + 1; - while (end < mpp_op.targets.size() && mpp_op.targets[end].is_combiner()) { - end += 2; - } - - // Determine which qubits are being touched by the next group. - inner_used.clear(); - for (size_t i = start; i < end; i += 2) { - auto t = mpp_op.targets[i]; - if (inner_used[t.qubit_value()]) { - throw std::invalid_argument( - "A pauli product specified the same qubit twice.\n" - "The operation: " + - mpp_op.str()); - } - inner_used[t.qubit_value()] = true; - } - - // If there's overlap with previous groups, the previous groups have to be flushed first. - if (inner_used.intersects(used)) { - callback( - CircuitInstruction{GateType::H, {}, h_xz}, - CircuitInstruction{GateType::H_YZ, {}, h_yz}, - CircuitInstruction{GateType::CX, {}, cnot}, - CircuitInstruction{GateType::M, mpp_op.args, meas}); - h_xz.clear(); - h_yz.clear(); - cnot.clear(); - meas.clear(); - used.clear(); - } - used |= inner_used; - - // Append operations that are equivalent to the desired measurement. - for (size_t i = start; i < end; i += 2) { - auto t = mpp_op.targets[i]; - auto q = t.qubit_value(); - if (t.data & TARGET_PAULI_X_BIT) { - if (t.data & TARGET_PAULI_Z_BIT) { + while (accumulate_next_obs_terms_to_pauli_string_helper(mpp_op, &start, ¤t, nullptr)) { + // Products equal to +-I become MPAD instructions. + if (current.ref().has_no_pauli_terms()) { + flush(); + GateTarget t = GateTarget::qubit((uint32_t)current.sign); + do_instruction_callback(CircuitInstruction{GateType::MPAD, mpp_op.args, &t}); + continue; + } + + // If there's overlap with previous groups, the previous groups need to be flushed. + if (current.xs.intersects(merged) || current.zs.intersects(merged)) { + flush(); + } + merged |= current.xs; + merged |= current.zs; + + // Buffer operations to perform the desired measurement. + bool first = true; + for_each_active_qubit_in(current, [&](uint32_t q) { + bool x = current.xs[q]; + bool z = current.zs[q]; + // Include single qubit gates transforming the Pauli into a Z. + if (x) { + if (z) { h_yz.push_back({q}); } else { h_xz.push_back({q}); } } - if (i == start) { - meas.push_back({q}); + // Include CNOT gates folding onto a single measured qubit. + if (first) { + meas.push_back(GateTarget::qubit(q, current.sign)); + first = false; } else { cnot.push_back({q}); cnot.push_back({meas.back().qubit_value()}); } - meas.back().data ^= t.data & TARGET_INVERTED_BIT; - } - - start = end; + }); + assert(!first); } // Flush remaining groups. - callback( - CircuitInstruction{GateType::H, {}, h_xz}, - CircuitInstruction{GateType::H_YZ, {}, h_yz}, - CircuitInstruction{GateType::CX, {}, cnot}, - CircuitInstruction{GateType::M, mpp_op.args, meas}); + flush(); } void stim::decompose_pair_instruction_into_segments_with_single_use_controls( @@ -125,3 +214,475 @@ void stim::decompose_pair_instruction_into_segments_with_single_use_controls( k += 2; } } + +struct Simplifier { + size_t num_qubits; + std::function yield; + simd_bits<64> used; + std::vector qs1_buf; + std::vector qs2_buf; + std::vector qs_buf; + + Simplifier(size_t num_qubits, std::function init_yield) + : num_qubits(num_qubits), yield(init_yield), used(num_qubits) { + } + + void do_xcz(SpanRef targets) { + if (targets.empty()) { + return; + } + + qs_buf.clear(); + for (size_t k = 0; k < targets.size(); k += 2) { + qs_buf.push_back(targets[k + 1]); + qs_buf.push_back(targets[k]); + } + yield(CircuitInstruction{GateType::CX, {}, qs_buf}); + } + + void simplify_potentially_overlapping_1q_instruction(const CircuitInstruction &inst) { + used.clear(); + + size_t start = 0; + for (size_t k = 0; k < inst.targets.size(); k++) { + auto t = inst.targets[k]; + if (t.has_qubit_value() && used[t.qubit_value()]) { + CircuitInstruction disjoint = CircuitInstruction{inst.gate_type, inst.args, inst.targets.sub(start, k)}; + simplify_disjoint_1q_instruction(disjoint); + used.clear(); + start = k; + } + if (t.has_qubit_value()) { + used[t.qubit_value()] = true; + } + } + simplify_disjoint_1q_instruction( + CircuitInstruction{inst.gate_type, inst.args, inst.targets.sub(start, inst.targets.size())}); + } + + void simplify_potentially_overlapping_2q_instruction(const CircuitInstruction &inst) { + used.clear(); + + size_t start = 0; + for (size_t k = 0; k < inst.targets.size(); k += 2) { + auto a = inst.targets[k]; + auto b = inst.targets[k + 1]; + if ((a.has_qubit_value() && used[a.qubit_value()]) || (b.has_qubit_value() && used[b.qubit_value()])) { + CircuitInstruction disjoint = CircuitInstruction{inst.gate_type, inst.args, inst.targets.sub(start, k)}; + simplify_disjoint_2q_instruction(disjoint); + used.clear(); + start = k; + } + if (a.has_qubit_value()) { + used[a.qubit_value()] = true; + } + if (b.has_qubit_value()) { + used[b.qubit_value()] = true; + } + } + simplify_disjoint_2q_instruction( + CircuitInstruction{inst.gate_type, inst.args, inst.targets.sub(start, inst.targets.size())}); + } + + void simplify_disjoint_1q_instruction(const CircuitInstruction &inst) { + const auto &ts = inst.targets; + + switch (inst.gate_type) { + case GateType::I: + // Do nothing. + break; + case GateType::X: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::Y: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::Z: + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::C_XYZ: + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::C_ZYX: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::H: + yield({GateType::H, {}, ts}); + break; + case GateType::H_XY: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::H_YZ: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::S: + yield({GateType::S, {}, ts}); + break; + case GateType::SQRT_X: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::SQRT_X_DAG: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::SQRT_Y: + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::SQRT_Y_DAG: + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::S_DAG: + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + break; + + case GateType::MX: + yield({GateType::H, {}, ts}); + yield({GateType::M, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::MY: + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::M, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::M: + yield({GateType::M, {}, ts}); + break; + case GateType::MRX: + yield({GateType::H, {}, ts}); + yield({GateType::M, {}, ts}); + yield({GateType::R, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::MRY: + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::S, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::M, {}, ts}); + yield({GateType::R, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::MR: + yield({GateType::M, {}, ts}); + yield({GateType::R, {}, ts}); + break; + case GateType::RX: + yield({GateType::R, {}, ts}); + yield({GateType::H, {}, ts}); + break; + case GateType::RY: + yield({GateType::R, {}, ts}); + yield({GateType::H, {}, ts}); + yield({GateType::S, {}, ts}); + break; + case GateType::R: + yield({GateType::R, {}, ts}); + break; + + default: + throw std::invalid_argument("Unhandled in Simplifier::simplify_disjoint_1q_instruction: " + inst.str()); + } + } + + void simplify_disjoint_2q_instruction(const CircuitInstruction &inst) { + const auto &ts = inst.targets; + qs_buf.clear(); + qs1_buf.clear(); + qs2_buf.clear(); + for (size_t k = 0; k < inst.targets.size(); k += 2) { + auto a = inst.targets[k]; + auto b = inst.targets[k + 1]; + if (a.has_qubit_value()) { + auto t = GateTarget::qubit(a.qubit_value()); + qs1_buf.push_back(t); + qs_buf.push_back(t); + } + if (b.has_qubit_value()) { + auto t = GateTarget::qubit(b.qubit_value()); + qs2_buf.push_back(t); + qs_buf.push_back(t); + } + } + + switch (inst.gate_type) { + case GateType::CX: + yield({GateType::CX, {}, ts}); + break; + case GateType::XCZ: + do_xcz(ts); + break; + case GateType::XCX: + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs1_buf}); + break; + case GateType::XCY: + yield({GateType::S, {}, qs2_buf}); + yield({GateType::S, {}, qs2_buf}); + yield({GateType::S, {}, qs2_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::S, {}, qs2_buf}); + break; + case GateType::YCX: + yield({GateType::S, {}, qs1_buf}); + yield({GateType::S, {}, qs1_buf}); + yield({GateType::S, {}, qs1_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::S, {}, qs1_buf}); + break; + case GateType::YCY: + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::S, {}, qs_buf}); + break; + case GateType::YCZ: + yield({GateType::S, {}, qs1_buf}); + yield({GateType::S, {}, qs1_buf}); + yield({GateType::S, {}, qs1_buf}); + do_xcz(ts); + yield({GateType::S, {}, qs1_buf}); + break; + case GateType::CY: + yield({GateType::S, {}, qs2_buf}); + yield({GateType::S, {}, qs2_buf}); + yield({GateType::S, {}, qs2_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::S, {}, qs2_buf}); + break; + case GateType::CZ: + yield({GateType::H, {}, qs2_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs2_buf}); + break; + case GateType::SQRT_XX: + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::H, {}, qs_buf}); + break; + case GateType::SQRT_XX_DAG: + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::H, {}, qs_buf}); + break; + case GateType::SQRT_YY: + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::H, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + break; + case GateType::SQRT_YY_DAG: + yield({GateType::S, {}, qs1_buf}); + yield({GateType::S, {}, qs1_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::H, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs2_buf}); + yield({GateType::S, {}, qs2_buf}); + break; + case GateType::SQRT_ZZ: + yield({GateType::H, {}, qs2_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + break; + case GateType::SQRT_ZZ_DAG: + yield({GateType::H, {}, qs2_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + break; + case GateType::SWAP: + yield({GateType::CX, {}, ts}); + do_xcz(ts); + yield({GateType::CX, {}, ts}); + break; + case GateType::ISWAP: + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + do_xcz(ts); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + break; + case GateType::ISWAP_DAG: + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + do_xcz(ts); + yield({GateType::H, {}, qs2_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + yield({GateType::S, {}, qs_buf}); + break; + case GateType::CXSWAP: + do_xcz(ts); + yield({GateType::CX, {}, ts}); + break; + case GateType::SWAPCX: + yield({GateType::CX, {}, ts}); + do_xcz(ts); + break; + case GateType::CZSWAP: + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + do_xcz(ts); + yield({GateType::H, {}, qs2_buf}); + break; + + case GateType::MXX: + yield({GateType::CX, {}, ts}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::M, {}, qs1_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + break; + case GateType::MYY: + yield({GateType::S, {}, qs_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::S, {}, qs2_buf}); + yield({GateType::S, {}, qs2_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::M, {}, qs1_buf}); + yield({GateType::H, {}, qs1_buf}); + yield({GateType::CX, {}, ts}); + yield({GateType::S, {}, qs_buf}); + break; + case GateType::MZZ: + yield({GateType::CX, {}, ts}); + yield({GateType::M, {}, qs2_buf}); + yield({GateType::CX, {}, ts}); + break; + + default: + throw std::invalid_argument("Unhandled in Simplifier::simplify_instruction: " + inst.str()); + } + } + + void simplify_instruction(const CircuitInstruction &inst) { + const Gate &g = GATE_DATA[inst.gate_type]; + + switch (inst.gate_type) { + case GateType::MPP: + decompose_mpp_operation(inst, num_qubits, [&](const CircuitInstruction sub) { + simplify_instruction(sub); + }); + break; + + case GateType::MPAD: + // Can't be easily simplified into M. + yield(inst); + break; + + case GateType::DETECTOR: + case GateType::OBSERVABLE_INCLUDE: + case GateType::TICK: + case GateType::QUBIT_COORDS: + case GateType::SHIFT_COORDS: + // Annotations can't be simplified. + yield(inst); + break; + + case GateType::DEPOLARIZE1: + case GateType::DEPOLARIZE2: + case GateType::X_ERROR: + case GateType::Y_ERROR: + case GateType::Z_ERROR: + case GateType::PAULI_CHANNEL_1: + case GateType::PAULI_CHANNEL_2: + case GateType::E: + case GateType::ELSE_CORRELATED_ERROR: + case GateType::HERALDED_ERASE: + case GateType::HERALDED_PAULI_CHANNEL_1: + // Noise isn't simplified. + yield(inst); + break; + default: { + if (g.flags & GATE_IS_SINGLE_QUBIT_GATE) { + simplify_potentially_overlapping_1q_instruction(inst); + } else if (g.flags & GATE_TARGETS_PAIRS) { + simplify_potentially_overlapping_2q_instruction(inst); + } else { + throw std::invalid_argument( + "Unhandled in simplify_potentially_overlapping_instruction: " + inst.str()); + } + } + } + } +}; + +Circuit stim::simplified_circuit(const Circuit &circuit) { + Circuit output; + Simplifier simplifier(circuit.count_qubits(), [&](const CircuitInstruction &inst) { + output.safe_append(inst); + }); + for (auto inst : circuit.operations) { + if (inst.gate_type == GateType::REPEAT) { + output.append_repeat_block( + inst.repeat_block_rep_count(), simplified_circuit(inst.repeat_block_body(circuit))); + } else { + simplifier.simplify_instruction(inst); + } + } + return output; +} diff --git a/src/stim/circuit/gate_decomposition.h b/src/stim/circuit/gate_decomposition.h index 5cbcd505c..09f0fef5e 100644 --- a/src/stim/circuit/gate_decomposition.h +++ b/src/stim/circuit/gate_decomposition.h @@ -53,15 +53,12 @@ namespace stim { /// Args: /// mpp_op: The operation to decompose. /// num_qubits: The number of qubits in the system. All targets must be less than this. -/// callback: The method told each chunk of the decomposition. +/// callback: How to execute decomposed instructions. void decompose_mpp_operation( const CircuitInstruction &mpp_op, size_t num_qubits, - const std::function &callback); + const std::function &do_instruction_callback); + /// Finds contiguous segments where the first target of each pair is used once. /// @@ -96,6 +93,15 @@ void decompose_mpp_operation( void decompose_pair_instruction_into_segments_with_single_use_controls( const CircuitInstruction &inst, size_t num_qubits, const std::function &callback); +bool accumulate_next_obs_terms_to_pauli_string_helper( + CircuitInstruction instruction, + size_t *start, + PauliString<64> *obs, + std::vector *bits, + bool allow_imaginary = false); + +Circuit simplified_circuit(const Circuit &circuit); + } // namespace stim #endif diff --git a/src/stim/circuit/gate_decomposition.test.cc b/src/stim/circuit/gate_decomposition.test.cc index 73c9beaaa..95b5150fd 100644 --- a/src/stim/circuit/gate_decomposition.test.cc +++ b/src/stim/circuit/gate_decomposition.test.cc @@ -17,52 +17,52 @@ #include "gtest/gtest.h" #include "stim/circuit/circuit.h" +#include "stim/cmd/command_help.h" +#include "stim/simulators/tableau_simulator.h" +#include "stim/test_util.test.h" using namespace stim; TEST(gate_decomposition, decompose_mpp_operation) { Circuit out; - auto append_into_circuit = [&](const CircuitInstruction &h_xz, - const CircuitInstruction &h_yz, - const CircuitInstruction &cnot, - const CircuitInstruction &meas) { - out.safe_append(h_xz); - out.safe_append(h_yz); - out.safe_append(cnot); - out.safe_append(meas); - out.safe_append(cnot); - out.safe_append(h_yz); - out.safe_append(h_xz); + auto append_into_circuit = [&](const CircuitInstruction &inst) { + out.safe_append(inst); out.append_from_text("TICK"); }; decompose_mpp_operation( Circuit("MPP(0.125) X0*X1*X2 Z3*Z4*Z5 X2*Y4 Z3 Z3 Z4*Z5").operations[0], 10, append_into_circuit); ASSERT_EQ(out, Circuit(R"CIRCUIT( H 0 1 2 - H_YZ + TICK CX 1 0 2 0 4 3 5 3 + TICK M(0.125) 0 3 + TICK CX 1 0 2 0 4 3 5 3 - H_YZ + TICK H 0 1 2 TICK H 2 + TICK H_YZ 4 + TICK CX 4 2 + TICK M(0.125) 2 3 + TICK CX 4 2 + TICK H_YZ 4 + TICK H 2 TICK - H - H_YZ CX 5 4 + TICK M(0.125) 3 4 + TICK CX 5 4 - H_YZ - H TICK )CIRCUIT")); @@ -70,25 +70,60 @@ TEST(gate_decomposition, decompose_mpp_operation) { decompose_mpp_operation(Circuit("MPP X0*Z1*Y2 X3*X4 Y0*Y1*Y2*Y3*Y4").operations[0], 10, append_into_circuit); ASSERT_EQ(out, Circuit(R"CIRCUIT( H 0 3 4 + TICK H_YZ 2 + TICK CX 1 0 2 0 4 3 + TICK M 0 3 + TICK CX 1 0 2 0 4 3 + TICK H_YZ 2 + TICK H 0 3 4 TICK - H H_YZ 0 1 2 3 4 + TICK CX 1 0 2 0 3 0 4 0 + TICK M 0 + TICK CX 1 0 2 0 3 0 4 0 + TICK H_YZ 0 1 2 3 4 - H TICK )CIRCUIT")); } +TEST(gate_decomposition, decompose_mpp_to_mpad) { + Circuit out; + auto append_into_circuit = [&](const CircuitInstruction &inst) { + out.safe_append(inst); + out.append_from_text("TICK"); + }; + decompose_mpp_operation( + Circuit(R"CIRCUIT( + MPP(0.125) X0*X0 X0*!X0 X0*Y0*Z0*X1*Y1*Z1 + )CIRCUIT") + .operations[0], + 10, + append_into_circuit); + ASSERT_EQ(out, Circuit(R"CIRCUIT( + MPAD(0.125) 0 + TICK + MPAD(0.125) 1 + TICK + MPAD(0.125) 1 + TICK + )CIRCUIT")); + + ASSERT_THROW( + { decompose_mpp_operation(Circuit("MPP(0.125) X0*Y0*Z0").operations[0], 10, append_into_circuit); }, + std::invalid_argument); +} + TEST(gate_decomposition, decompose_pair_instruction_into_segments_with_single_use_controls) { Circuit out; auto append_into_circuit = [&](const CircuitInstruction &segment) { @@ -120,3 +155,73 @@ TEST(gate_decomposition, decompose_pair_instruction_into_segments_with_single_us TICK )CIRCUIT")); } + +static std::pair>, std::vector>> circuit_output_eq_val( + const Circuit &circuit) { + // CAUTION: this is not 100% reliable when measurement count is larger than 1. + TableauSimulator<64> sim1(INDEPENDENT_TEST_RNG(), circuit.count_qubits(), -1); + TableauSimulator<64> sim2(INDEPENDENT_TEST_RNG(), circuit.count_qubits(), +1); + sim1.safe_do_circuit(circuit); + sim2.safe_do_circuit(circuit); + return {sim1.canonical_stabilizers(), sim2.canonical_stabilizers()}; +} + +bool is_simplification_correct(const Gate &gate) { + std::vector args; + while (args.size() < gate.arg_count && gate.arg_count != ARG_COUNT_SYGIL_ANY && + gate.arg_count != ARG_COUNT_SYGIL_ZERO_OR_ONE) { + args.push_back(args.empty() ? 1 : 0); + } + + Circuit original; + original.safe_append(gate.id, gate_decomposition_help_targets_for_gate_type(gate.id), args); + Circuit simplified = simplified_circuit(original); + + if (gate.h_s_cx_m_r_decomposition == nullptr) { + return simplified == original; + } + + uint32_t n = original.count_qubits(); + + Circuit epr; + for (uint32_t q = 0; q < n; q++) { + epr.safe_append_u("H", {q}); + } + for (uint32_t q = 0; q < n; q++) { + epr.safe_append_u("CNOT", {q, q + n}); + } + + Circuit circuit1 = epr + original; + Circuit circuit2 = epr + simplified; + + // Reset gates make the ancillary qubits irrelevant because the final value is unrelated to the initial value. + // So, for reset gates, discard the ancillary qubits. + // CAUTION: this could give false positives if "partial reset" gates are added in the future. + // (E.g. a two qubit gate that resets only one of the qubits.) + if ((gate.flags & GATE_IS_RESET) && !(gate.flags & GATE_PRODUCES_RESULTS)) { + for (uint32_t q = 0; q < n; q++) { + circuit1.safe_append_u("R", {q + n}); + circuit2.safe_append_u("R", {q + n}); + } + } + + // Verify decomposed all the way to base gate set, if the gate has a decomposition. + for (const auto &op : circuit2.operations) { + if (op.gate_type != GateType::CX && op.gate_type != GateType::H && op.gate_type != GateType::S && + op.gate_type != GateType::M && op.gate_type != GateType::R) { + return false; + } + } + + auto v1 = circuit_output_eq_val(circuit1); + auto v2 = circuit_output_eq_val(circuit2); + return v1 == v2; +} + +TEST(gate_decomposition, simplifications_are_correct) { + for (const auto &g : GATE_DATA.items) { + if (g.id != GateType::NOT_A_GATE && g.id != GateType::REPEAT) { + EXPECT_TRUE(is_simplification_correct(g)) << g.name; + } + } +} diff --git a/src/stim/circuit/gate_target.cc b/src/stim/circuit/gate_target.cc index c2dc17dd4..e5454864c 100644 --- a/src/stim/circuit/gate_target.cc +++ b/src/stim/circuit/gate_target.cc @@ -92,6 +92,9 @@ bool GateTarget::is_inverted_result_target() const { bool GateTarget::is_measurement_record_target() const { return data & TARGET_RECORD_BIT; } +bool GateTarget::is_pauli_target() const { + return data & (TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT); +} bool GateTarget::has_qubit_value() const { return !(data & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT | TARGET_COMBINER)); } @@ -104,6 +107,9 @@ bool GateTarget::is_combiner() const { bool GateTarget::is_sweep_bit_target() const { return data & TARGET_SWEEP_BIT; } +bool GateTarget::is_classical_bit_target() const { + return data & (TARGET_SWEEP_BIT | TARGET_RECORD_BIT); +} bool GateTarget::operator==(const GateTarget &other) const { return data == other.data; } diff --git a/src/stim/circuit/gate_target.h b/src/stim/circuit/gate_target.h index 7a27330fb..fd0fea165 100644 --- a/src/stim/circuit/gate_target.h +++ b/src/stim/circuit/gate_target.h @@ -57,6 +57,8 @@ struct GateTarget { bool is_measurement_record_target() const; bool is_qubit_target() const; bool is_sweep_bit_target() const; + bool is_classical_bit_target() const; + bool is_pauli_target() const; uint32_t qubit_value() const; bool operator==(const GateTarget &other) const; bool operator!=(const GateTarget &other) const; diff --git a/src/stim/circuit/gate_target.test.cc b/src/stim/circuit/gate_target.test.cc index a26f6e344..492ad704f 100644 --- a/src/stim/circuit/gate_target.test.cc +++ b/src/stim/circuit/gate_target.test.cc @@ -246,3 +246,31 @@ TEST(gate_target, target_str_round_trip) { ASSERT_EQ(GateTarget::from_target_str(t.target_str().c_str()), t) << t; } } + +TEST(gate_target, is_pauli_target) { + ASSERT_FALSE(GateTarget::qubit(2).is_pauli_target()); + ASSERT_FALSE(GateTarget::qubit(3, true).is_pauli_target()); + ASSERT_FALSE(GateTarget::sweep_bit(5).is_pauli_target()); + ASSERT_FALSE(GateTarget::rec(-7).is_pauli_target()); + ASSERT_TRUE(GateTarget::x(11).is_pauli_target()); + ASSERT_TRUE(GateTarget::x(13, true).is_pauli_target()); + ASSERT_TRUE(GateTarget::y(17).is_pauli_target()); + ASSERT_TRUE(GateTarget::y(19, true).is_pauli_target()); + ASSERT_TRUE(GateTarget::z(23).is_pauli_target()); + ASSERT_TRUE(GateTarget::z(29, true).is_pauli_target()); + ASSERT_FALSE(GateTarget::combiner().is_pauli_target()); +} + +TEST(gate_target, is_classical_bit_target) { + ASSERT_TRUE(GateTarget::sweep_bit(5).is_classical_bit_target()); + ASSERT_TRUE(GateTarget::rec(-7).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::qubit(2).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::qubit(3, true).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::x(11).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::x(13, true).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::y(17).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::y(19, true).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::z(23).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::z(29, true).is_classical_bit_target()); + ASSERT_FALSE(GateTarget::combiner().is_classical_bit_target()); +} diff --git a/src/stim/cmd/command_help.cc b/src/stim/cmd/command_help.cc index 0bbee7027..06945aa14 100644 --- a/src/stim/cmd/command_help.cc +++ b/src/stim/cmd/command_help.cc @@ -204,18 +204,37 @@ void print_example(Acc &out, const char *name, const Gate &gate) { out.change_indent(-4); } +std::vector stim::gate_decomposition_help_targets_for_gate_type(GateType g) { + if (g == GateType::MPP) { + return { + GateTarget::x(0), + GateTarget::combiner(), + GateTarget::y(1), + GateTarget::combiner(), + GateTarget::z(2), + GateTarget::x(3), + GateTarget::combiner(), + GateTarget::x(4), + }; + } else if (g == GateType::DETECTOR || g == GateType::OBSERVABLE_INCLUDE) { + return {GateTarget::rec(-1)}; + } else if (g == GateType::TICK || g == GateType::SHIFT_COORDS) { + return {}; + } else if (g == GateType::E || g == GateType::ELSE_CORRELATED_ERROR) { + return {GateTarget::x(0)}; + } else if (GATE_DATA[g].flags & GATE_TARGETS_PAIRS) { + return {GateTarget::qubit(0), GateTarget::qubit(1)}; + } else { + return {GateTarget::qubit(0)}; + } +} + void print_decomposition(Acc &out, const Gate &gate) { const char *decomposition = gate.h_s_cx_m_r_decomposition; if (decomposition != nullptr) { std::stringstream undecomposed; - if (gate.id == GateType::MPP) { - undecomposed << "MPP X0*Y1*Z2 X3*X4"; - } else { - undecomposed << gate.name << " 0"; - if (gate.flags & GATE_TARGETS_PAIRS) { - undecomposed << " 1"; - } - } + auto decomp_targets = gate_decomposition_help_targets_for_gate_type(gate.id); + undecomposed << CircuitInstruction{gate.id, {}, decomp_targets}; out << "Decomposition (into H, S, CX, M, R):\n"; out.change_indent(+4); @@ -234,8 +253,11 @@ void print_stabilizer_generators(Acc &out, const Gate &gate) { if (flows.empty()) { return; } - if (gate.id == GateType::MPP) { - out << "Stabilizer Generators (for `MPP X0*Y1*Z2 X3*X4`):\n"; + auto decomp_targets = gate_decomposition_help_targets_for_gate_type(gate.id); + if (decomp_targets.size() > 2) { + out << "Stabilizer Generators (for `"; + out << CircuitInstruction{gate.id, {}, decomp_targets}; + out << "`):\n"; } else { out << "Stabilizer Generators:\n"; } diff --git a/src/stim/cmd/command_help.h b/src/stim/cmd/command_help.h index fd6d7d7cd..08eab0a98 100644 --- a/src/stim/cmd/command_help.h +++ b/src/stim/cmd/command_help.h @@ -19,6 +19,10 @@ #include #include +#include + +#include "stim/circuit/gate_target.h" +#include "stim/gates/gates.h" namespace stim { @@ -26,6 +30,8 @@ int command_help(int argc, const char **argv); std::string help_for(std::string help_key); std::string clean_doc_string(const char *c, bool allow_too_long = false); +std::vector gate_decomposition_help_targets_for_gate_type(stim::GateType g); + } // namespace stim #endif \ No newline at end of file diff --git a/src/stim/diagram/circuit_timeline_helper.cc b/src/stim/diagram/circuit_timeline_helper.cc index 471cd8854..b756a5955 100644 --- a/src/stim/diagram/circuit_timeline_helper.cc +++ b/src/stim/diagram/circuit_timeline_helper.cc @@ -45,12 +45,20 @@ void CircuitTimelineHelper::do_atomic_operation( } void CircuitTimelineHelper::do_operation_with_target_combiners(const CircuitInstruction &op) { + bool paired = GATE_DATA[op.gate_type].flags & GATE_TARGETS_PAIRS; size_t start = 0; while (start < op.targets.size()) { size_t end = start + 1; while (end < op.targets.size() && op.targets[end].is_combiner()) { end += 2; } + if (paired) { + end++; + while (end < op.targets.size() && op.targets[end].is_combiner()) { + end += 2; + } + } + if (GATE_DATA[op.gate_type].flags & GATE_PRODUCES_RESULTS) { do_record_measure_result(op.targets[start].qubit_value()); } @@ -183,8 +191,6 @@ void CircuitTimelineHelper::do_record_measure_result(uint32_t target_qubit) { void CircuitTimelineHelper::do_next_operation(const Circuit &circuit, const CircuitInstruction &op) { if (op.gate_type == GateType::REPEAT) { do_repeat_block(circuit, op); - } else if (op.gate_type == GateType::MPP) { - do_operation_with_target_combiners(op); } else if (op.gate_type == GateType::DETECTOR) { do_detector(op); } else if (op.gate_type == GateType::OBSERVABLE_INCLUDE) { @@ -198,6 +204,8 @@ void CircuitTimelineHelper::do_next_operation(const Circuit &circuit, const Circ } else if (op.gate_type == GateType::TICK) { do_atomic_operation(op.gate_type, {}, {}); num_ticks_seen += 1; + } else if (GATE_DATA[op.gate_type].flags & GATE_TARGETS_COMBINERS) { + do_operation_with_target_combiners(op); } else if (GATE_DATA[op.gate_type].flags & GATE_TARGETS_PAIRS) { do_two_qubit_gate(op); } else { diff --git a/src/stim/gates/gate_data_annotations.cc b/src/stim/gates/gate_data_annotations.cc index ec533521e..b71301ccc 100644 --- a/src/stim/gates/gate_data_annotations.cc +++ b/src/stim/gates/gate_data_annotations.cc @@ -319,8 +319,8 @@ Parens Arguments: .name = "MPAD", .id = GateType::MPAD, .best_candidate_inverse_id = GateType::MPAD, - .arg_count = 0, - .flags = GATE_PRODUCES_RESULTS, + .arg_count = ARG_COUNT_SYGIL_ZERO_OR_ONE, + .flags = (GateFlags)(GATE_PRODUCES_RESULTS | GATE_ARGS_ARE_DISJOINT_PROBABILITIES), .category = "Z_Annotations", .help = R"MARKDOWN( Pads the measurement record with the listed measurement results. @@ -329,6 +329,12 @@ This can be useful for ensuring measurements are aligned to word boundaries, or number of measurement bits produced per circuit layer is always the same even if the number of measured qubits varies. +Parens Arguments: + + If no parens argument is given, the padding bits are recorded perfectly. + If one parens argument is given, the padding bits are recorded noisily. + The argument is the probability of recording the wrong result. + Targets: Each target is a measurement result to add. diff --git a/src/stim/gates/gates.test.cc b/src/stim/gates/gates.test.cc index 19adc114d..c43d02b35 100644 --- a/src/stim/gates/gates.test.cc +++ b/src/stim/gates/gates.test.cc @@ -17,6 +17,7 @@ #include "gtest/gtest.h" #include "stim/circuit/circuit.h" +#include "stim/cmd/command_help.h" #include "stim/mem/simd_word.test.h" #include "stim/simulators/tableau_simulator.h" #include "stim/stabilizers/flow.h" @@ -73,10 +74,9 @@ TEST(gate_data, hash_matches_storage_location) { } template -std::pair>, std::vector>> circuit_output_eq_val(const Circuit &circuit) { - if (circuit.count_measurements() > 1) { - throw std::invalid_argument("count_measurements > 1"); - } +static std::pair>, std::vector>> circuit_output_eq_val( + const Circuit &circuit) { + // CAUTION: this is not 100% reliable when measurement count is larger than 1. TableauSimulator sim1(INDEPENDENT_TEST_RNG(), circuit.count_qubits(), -1); TableauSimulator sim2(INDEPENDENT_TEST_RNG(), circuit.count_qubits(), +1); sim1.safe_do_circuit(circuit); @@ -91,19 +91,19 @@ bool is_decomposition_correct(const Gate &gate) { return false; } - std::vector qs{0}; - if (gate.flags & GATE_TARGETS_PAIRS) { - qs.push_back(1); - } + Circuit original; + original.safe_append(gate.id, gate_decomposition_help_targets_for_gate_type(gate.id), {}); + uint32_t n = original.count_qubits(); Circuit epr; - epr.safe_append_u("H", qs); - for (auto q : qs) { - epr.safe_append_u("CNOT", {q, q + 2}); + for (uint32_t q = 0; q < n; q++) { + epr.safe_append_u("H", {q}); + } + for (uint32_t q = 0; q < n; q++) { + epr.safe_append_u("CNOT", {q, q + n}); } - Circuit circuit1 = epr; - circuit1.safe_append_u(gate.name, qs); + Circuit circuit1 = epr + original; Circuit circuit2 = epr + Circuit(decomposition); // Reset gates make the ancillary qubits irrelevant because the final value is unrelated to the initial value. @@ -111,9 +111,9 @@ bool is_decomposition_correct(const Gate &gate) { // CAUTION: this could give false positives if "partial reset" gates are added in the future. // (E.g. a two qubit gate that resets only one of the qubits.) if ((gate.flags & GATE_IS_RESET) && !(gate.flags & GATE_PRODUCES_RESULTS)) { - for (auto q : qs) { - circuit1.safe_append_u("R", {q + 2}); - circuit2.safe_append_u("R", {q + 2}); + for (uint32_t q = 0; q < n; q++) { + circuit1.safe_append_u("R", {q + n}); + circuit2.safe_append_u("R", {q + n}); } } @@ -134,7 +134,7 @@ TEST_EACH_WORD_SIZE_W(gate_data, decompositions_are_correct, { if (g.flags & GATE_IS_UNITARY) { EXPECT_TRUE(g.h_s_cx_m_r_decomposition != nullptr) << g.name; } - if (g.h_s_cx_m_r_decomposition != nullptr && g.id != GateType::MPP) { + if (g.h_s_cx_m_r_decomposition != nullptr) { EXPECT_TRUE(is_decomposition_correct(g)) << g.name; } } @@ -156,22 +156,7 @@ TEST_EACH_WORD_SIZE_W(gate_data, stabilizer_flows_are_correct, { if (flows.empty()) { continue; } - std::vector targets; - if (g.id == GateType::MPP) { - targets.push_back(GateTarget::x(0)); - targets.push_back(GateTarget::combiner()); - targets.push_back(GateTarget::y(1)); - targets.push_back(GateTarget::combiner()); - targets.push_back(GateTarget::z(2)); - targets.push_back(GateTarget::x(3)); - targets.push_back(GateTarget::combiner()); - targets.push_back(GateTarget::x(4)); - } else { - targets.push_back(GateTarget::qubit(0)); - if (g.flags & GATE_TARGETS_PAIRS) { - targets.push_back(GateTarget::qubit(1)); - } - } + std::vector targets = gate_decomposition_help_targets_for_gate_type(g.id); Circuit c; c.safe_append(g.id, targets, {}); @@ -190,22 +175,7 @@ TEST_EACH_WORD_SIZE_W(gate_data, stabilizer_flows_are_also_correct_for_decompose if (flows.empty()) { continue; } - std::vector targets; - if (g.id == GateType::MPP) { - targets.push_back(GateTarget::x(0)); - targets.push_back(GateTarget::combiner()); - targets.push_back(GateTarget::y(1)); - targets.push_back(GateTarget::combiner()); - targets.push_back(GateTarget::z(2)); - targets.push_back(GateTarget::x(3)); - targets.push_back(GateTarget::combiner()); - targets.push_back(GateTarget::x(4)); - } else { - targets.push_back(GateTarget::qubit(0)); - if (g.flags & GATE_TARGETS_PAIRS) { - targets.push_back(GateTarget::qubit(1)); - } - } + std::vector targets = gate_decomposition_help_targets_for_gate_type(g.id); Circuit c(g.h_s_cx_m_r_decomposition); auto r = sample_if_circuit_has_stabilizer_flows(256, rng, c, flows); diff --git a/src/stim/simulators/error_analyzer.cc b/src/stim/simulators/error_analyzer.cc index 085074c7b..e1085516c 100644 --- a/src/stim/simulators/error_analyzer.cc +++ b/src/stim/simulators/error_analyzer.cc @@ -119,15 +119,6 @@ void ErrorAnalyzer::undo_gate(const CircuitInstruction &inst) { case GateType::CZ: undo_ZCZ(inst); break; - case GateType::H: - undo_H_XZ(inst); - break; - case GateType::H_XY: - undo_H_XY(inst); - break; - case GateType::H_YZ: - undo_H_YZ(inst); - break; case GateType::DEPOLARIZE1: undo_DEPOLARIZE1(inst); break; @@ -156,14 +147,8 @@ void ErrorAnalyzer::undo_gate(const CircuitInstruction &inst) { undo_ELSE_CORRELATED_ERROR(inst); break; case GateType::I: - undo_I(inst); - break; case GateType::X: - undo_I(inst); - break; case GateType::Y: - undo_I(inst); - break; case GateType::Z: undo_I(inst); break; @@ -173,39 +158,30 @@ void ErrorAnalyzer::undo_gate(const CircuitInstruction &inst) { case GateType::C_ZYX: undo_C_ZYX(inst); break; + case GateType::H_YZ: case GateType::SQRT_X: - undo_H_YZ(inst); - break; case GateType::SQRT_X_DAG: undo_H_YZ(inst); break; case GateType::SQRT_Y: - undo_H_XZ(inst); - break; case GateType::SQRT_Y_DAG: + case GateType::H: undo_H_XZ(inst); break; case GateType::S: - undo_H_XY(inst); - break; case GateType::S_DAG: + case GateType::H_XY: undo_H_XY(inst); break; case GateType::SQRT_XX: - undo_SQRT_XX(inst); - break; case GateType::SQRT_XX_DAG: undo_SQRT_XX(inst); break; case GateType::SQRT_YY: - undo_SQRT_YY(inst); - break; case GateType::SQRT_YY_DAG: undo_SQRT_YY(inst); break; case GateType::SQRT_ZZ: - undo_SQRT_ZZ(inst); - break; case GateType::SQRT_ZZ_DAG: undo_SQRT_ZZ(inst); break; @@ -213,8 +189,6 @@ void ErrorAnalyzer::undo_gate(const CircuitInstruction &inst) { undo_SWAP(inst); break; case GateType::ISWAP: - undo_ISWAP(inst); - break; case GateType::ISWAP_DAG: undo_ISWAP(inst); break; @@ -374,6 +348,9 @@ void ErrorAnalyzer::undo_HERALDED_PAULI_CHANNEL_1(const CircuitInstruction &dat) void ErrorAnalyzer::undo_MPAD(const CircuitInstruction &inst) { for (size_t k = inst.targets.size(); k-- > 0;) { tracker.num_measurements_in_past--; + + SparseXorVec &d = tracker.rec_bits[tracker.num_measurements_in_past]; + xor_sorted_measurement_error(d.range(), inst); tracker.rec_bits.erase(tracker.num_measurements_in_past); } } @@ -506,10 +483,10 @@ PauliString ErrorAnalyzer::current_error_sensitivity_for(DemT return result; } -void ErrorAnalyzer::xor_sorted_measurement_error(SpanRef targets, const CircuitInstruction &dat) { +void ErrorAnalyzer::xor_sorted_measurement_error(SpanRef targets, const CircuitInstruction &inst) { // Measurement error. - if (!dat.args.empty() && dat.args[0] > 0) { - add_error(dat.args[0], targets); + if (!inst.args.empty() && inst.args[0] > 0) { + add_error(inst.args[0], targets); } } @@ -1612,22 +1589,18 @@ void ErrorAnalyzer::undo_MPP(const CircuitInstruction &target_data) { decompose_mpp_operation( CircuitInstruction{GateType::MPP, target_data.args, reversed_targets}, tracker.xs.size(), - [&](const CircuitInstruction &h_xz, - const CircuitInstruction &h_yz, - const CircuitInstruction &cnot, - const CircuitInstruction &meas) { - undo_H_XZ(h_xz); - undo_H_YZ(h_yz); - undo_ZCX(cnot); - reversed_measure_targets.clear(); - for (size_t k = meas.targets.size(); k--;) { - reversed_measure_targets.push_back(meas.targets[k]); + [&](const CircuitInstruction &inst) { + if (inst.gate_type == GateType::M) { + reversed_measure_targets.clear(); + for (size_t k = inst.targets.size(); k--;) { + reversed_measure_targets.push_back(inst.targets[k]); + } + undo_MZ_with_context( + CircuitInstruction{GateType::M, inst.args, reversed_measure_targets}, + "a Pauli product measurement (MPP)"); + } else { + undo_gate(inst); } - undo_MZ_with_context( - {GateType::M, meas.args, reversed_measure_targets}, "a Pauli product measurement (MPP)"); - undo_ZCX(cnot); - undo_H_YZ(h_yz); - undo_H_XZ(h_xz); }); } diff --git a/src/stim/simulators/frame_simulator.inl b/src/stim/simulators/frame_simulator.inl index b2ed4a14c..100560d55 100644 --- a/src/stim/simulators/frame_simulator.inl +++ b/src/stim/simulators/frame_simulator.inl @@ -688,17 +688,8 @@ void FrameSimulator::do_MPP(const CircuitInstruction &target_data) { decompose_mpp_operation( target_data, num_qubits, - [&](const CircuitInstruction &h_xz, - const CircuitInstruction &h_yz, - const CircuitInstruction &cnot, - const CircuitInstruction &meas) { - do_H_XZ(h_xz); - do_H_YZ(h_yz); - do_ZCX(cnot); - do_MZ(meas); - do_ZCX(cnot); - do_H_YZ(h_yz); - do_H_XZ(h_xz); + [&](const CircuitInstruction &inst) { + safe_do_instruction(inst); }); } @@ -893,10 +884,11 @@ void FrameSimulator::do_MZZ(const CircuitInstruction &inst) { template void FrameSimulator::do_MPAD(const CircuitInstruction &inst) { + m_record.reserve_noisy_space_for_results(inst, rng); simd_bits empty(batch_size); - assert(inst.args.empty()); for (size_t k = 0; k < inst.targets.size(); k++) { - m_record.record_result(empty); + // 0-vs-1 is ignored because it's accounted for in the reference sample. + m_record.xor_record_reserved_result(empty); } } diff --git a/src/stim/simulators/sparse_rev_frame_tracker.cc b/src/stim/simulators/sparse_rev_frame_tracker.cc index ef5a1925b..b7d077f8f 100644 --- a/src/stim/simulators/sparse_rev_frame_tracker.cc +++ b/src/stim/simulators/sparse_rev_frame_tracker.cc @@ -323,30 +323,16 @@ void SparseUnsignedRevFrameTracker::undo_MPP(const CircuitInstruction &target_da decompose_mpp_operation( CircuitInstruction{target_data.gate_type, target_data.args, reversed_targets}, xs.size(), - [&](const CircuitInstruction &h_xz, - const CircuitInstruction &h_yz, - const CircuitInstruction &cnot, - const CircuitInstruction &meas) { - undo_H_XZ(h_xz); - undo_H_YZ(h_yz); - undo_ZCX(cnot); - try { - handle_x_gauges(meas); - } catch (const std::invalid_argument &ex) { - undo_ZCX(cnot); - undo_H_YZ(h_yz); - undo_H_XZ(h_xz); - throw ex; + [&](const CircuitInstruction &inst) { + if (inst.gate_type == GateType::M) { + reversed_measure_targets.clear(); + for (size_t k = inst.targets.size(); k--;) { + reversed_measure_targets.push_back(inst.targets[k]); + } + undo_MZ({GateType::M, inst.args, reversed_measure_targets}); + } else { + undo_gate(inst); } - - reversed_measure_targets.clear(); - for (size_t k = meas.targets.size(); k--;) { - reversed_measure_targets.push_back(meas.targets[k]); - } - undo_MZ({GateType::M, meas.args, reversed_measure_targets}); - undo_ZCX(cnot); - undo_H_YZ(h_yz); - undo_H_XZ(h_xz); }); } diff --git a/src/stim/simulators/tableau_simulator.inl b/src/stim/simulators/tableau_simulator.inl index e3a7f182e..c2a6858e2 100644 --- a/src/stim/simulators/tableau_simulator.inl +++ b/src/stim/simulators/tableau_simulator.inl @@ -60,17 +60,8 @@ void TableauSimulator::do_MPP(const CircuitInstruction &target_data) { decompose_mpp_operation( target_data, inv_state.num_qubits, - [&](const CircuitInstruction &h_xz, - const CircuitInstruction &h_yz, - const CircuitInstruction &cnot, - const CircuitInstruction &meas) { - do_H_XZ(h_xz); - do_H_YZ(h_yz); - do_ZCX(cnot); - do_MZ(meas); - do_ZCX(cnot); - do_H_YZ(h_yz); - do_H_XZ(h_xz); + [&](const CircuitInstruction &inst) { + do_gate(inst); }); } @@ -326,6 +317,7 @@ void TableauSimulator::do_MPAD(const CircuitInstruction &inst) { for (const auto &t : inst.targets) { measurement_record.record_result(t.qubit_value() != 0); } + noisify_new_measurements(inst); } template diff --git a/src/stim/simulators/tableau_simulator.test.cc b/src/stim/simulators/tableau_simulator.test.cc index ffc487997..69cf56fec 100644 --- a/src/stim/simulators/tableau_simulator.test.cc +++ b/src/stim/simulators/tableau_simulator.test.cc @@ -1599,7 +1599,7 @@ TEST_EACH_WORD_SIZE_W(TableauSimulator, noisy_measure_reset_z, { TEST_EACH_WORD_SIZE_W(TableauSimulator, measure_pauli_product_bad, { TableauSimulator t(INDEPENDENT_TEST_RNG()); - ASSERT_THROW({ t.safe_do_circuit("MPP X0*X0"); }, std::invalid_argument); + t.safe_do_circuit("MPP X0*X0"); ASSERT_THROW({ t.safe_do_circuit("MPP X0*Z0"); }, std::invalid_argument); }) From 4f1d217a37356e0215ab3768f1285ffa4d0c8dae Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Thu, 14 Mar 2024 08:48:43 -0700 Subject: [PATCH 4/6] Make `stim.Tableau.from_stabilizers` faster (#713) - Store the growing reduction as a circuit instead of as a tableau - Measured 20x faster (140ms -> 6ms) on a 144 qubit case - Measured 100x faster (15s -> 0.13s) on a 432 qubit case --- src/stim/stabilizers/conversions.h | 4 +- src/stim/stabilizers/conversions.inl | 89 +++++++++++++----------- src/stim/stabilizers/conversions.perf.cc | 48 +++++++++++++ 3 files changed, 100 insertions(+), 41 deletions(-) diff --git a/src/stim/stabilizers/conversions.h b/src/stim/stabilizers/conversions.h index 4349693fd..37166d70d 100644 --- a/src/stim/stabilizers/conversions.h +++ b/src/stim/stabilizers/conversions.h @@ -83,11 +83,13 @@ Circuit stabilizer_state_vector_to_circuit( /// ignore_noise: If the circuit contains noise channels, ignore them instead of raising an exception. /// ignore_measurement: If the circuit contains measurements, ignore them instead of raising an exception. /// ignore_reset: If the circuit contains resets, ignore them instead of raising an exception. +/// inverse: The last step of the implementation is to invert the tableau. Setting this argument +/// to true will skip this inversion, saving time but returning the inverse tableau. /// /// Returns: /// A tableau encoding the given circuit's Clifford operation. template -Tableau circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset); +Tableau circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset, bool inverse = false); /// Simulates the given circuit and outputs a state vector. /// diff --git a/src/stim/stabilizers/conversions.inl b/src/stim/stabilizers/conversions.inl index b6a8edec3..6ab5f2b17 100644 --- a/src/stim/stabilizers/conversions.inl +++ b/src/stim/stabilizers/conversions.inl @@ -149,7 +149,7 @@ std::vector>> tableau_to_unitary(const Tableau -Tableau circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset) { +Tableau circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset, bool inverse) { Tableau result(circuit.count_qubits()); TableauSimulator sim(std::mt19937_64(0), circuit.count_qubits()); @@ -185,7 +185,10 @@ Tableau circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ig } }); - return sim.inv_state.inverse(); + if (!inverse) { + return sim.inv_state.inverse(); + } + return sim.inv_state; } template @@ -556,7 +559,7 @@ Tableau stabilizers_to_tableau( } for (size_t k1 = 0; k1 < stabilizers.size(); k1++) { - for (size_t k2 = 0; k2 < stabilizers.size(); k2++) { + for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) { if (!stabilizers[k1].ref().commutes(stabilizers[k2])) { std::stringstream ss; ss << "Some of the given stabilizers anticommute.\n"; @@ -568,44 +571,39 @@ Tableau stabilizers_to_tableau( } } } - Tableau inverted(num_qubits); - - PauliString cur(num_qubits); - std::vector targets; - while (targets.size() < num_qubits) { - targets.push_back(targets.size()); - } - auto overwrite_cur_apply_recorded = [&](const PauliString &e) { - PauliStringRef cur_ref = cur.ref(); - cur.xs.clear(); - cur.zs.clear(); - cur.xs.word_range_ref(0, e.xs.num_simd_words) = e.xs; - cur.zs.word_range_ref(0, e.xs.num_simd_words) = e.zs; - cur.sign = e.sign; - inverted.apply_within(cur_ref, targets); - }; + Circuit elimination_instructions; + PauliString buf(num_qubits); size_t used = 0; for (const auto &e : stabilizers) { - overwrite_cur_apply_recorded(e); + if (e.num_qubits == num_qubits) { + buf = e; + } else { + buf.xs.clear(); + buf.zs.clear(); + memcpy(buf.xs.u8, e.xs.u8, e.xs.num_u8_padded()); + memcpy(buf.zs.u8, e.zs.u8, e.zs.num_u8_padded()); + buf.sign = e.sign; + } + buf.ref().do_circuit(elimination_instructions); // Find a non-identity term in the Pauli string past the region used by other stabilizers. size_t pivot; for (pivot = used; pivot < num_qubits; pivot++) { - if (cur.xs[pivot] || cur.zs[pivot]) { + if (buf.xs[pivot] || buf.zs[pivot]) { break; } } // Check for incompatible / redundant stabilizers. if (pivot == num_qubits) { - if (cur.xs.not_zero()) { + if (buf.xs.not_zero()) { throw std::invalid_argument("Some of the given stabilizers anticommute."); } - if (cur.sign) { + if (buf.sign) { throw std::invalid_argument("Some of the given stabilizers contradict each other."); } - if (!allow_redundant && cur.zs.not_zero()) { + if (!allow_redundant && buf.zs.not_zero()) { throw std::invalid_argument( "Didn't specify allow_redundant=True but one of the given stabilizers is a product of the others. " "To allow redundant stabilizers, pass the argument allow_redundant=True."); @@ -614,32 +612,36 @@ Tableau stabilizers_to_tableau( } // Change pivot basis to the Z axis. - if (cur.xs[pivot]) { - std::string name = cur.zs[pivot] ? "H_YZ" : "H_XZ"; - inverted.inplace_scatter_append(GATE_DATA.at(name).tableau(), {pivot}); + if (buf.xs[pivot]) { + GateType g = buf.zs[pivot] ? GateType::H_YZ : GateType::H; + GateTarget t = GateTarget::qubit(pivot); + CircuitInstruction instruction{g, {}, &t}; + elimination_instructions.safe_append(instruction); + buf.ref().do_instruction(instruction); } // Cancel other terms in Pauli string. for (size_t q = 0; q < num_qubits; q++) { - int p = cur.xs[q] + cur.zs[q] * 2; + int p = buf.xs[q] + buf.zs[q] * 2; if (p && q != pivot) { - inverted.inplace_scatter_append( - GATE_DATA.at(p == 1 ? "XCX" - : p == 2 ? "XCZ" - : "XCY") - .tableau(), - {pivot, q}); + std::array targets{GateTarget::qubit(pivot), GateTarget::qubit(q)}; + CircuitInstruction instruction{p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY, {}, targets}; + elimination_instructions.safe_append(instruction); + buf.ref().do_instruction(instruction); } } // Move pivot to diagonal. if (pivot != used) { - inverted.inplace_scatter_append(GATE_DATA.at("SWAP").tableau(), {pivot, used}); + std::array targets{GateTarget::qubit(pivot), GateTarget::qubit(used)}; + CircuitInstruction instruction{GateType::SWAP, {}, targets}; + elimination_instructions.safe_append(instruction); } // Fix sign. - overwrite_cur_apply_recorded(e); - if (cur.sign) { - inverted.inplace_scatter_append(GATE_DATA.at("X").tableau(), {used}); + if (buf.sign) { + GateTarget t = GateTarget::qubit(used); + CircuitInstruction instruction{GateType::X, {}, &t}; + elimination_instructions.safe_append(instruction); } used++; @@ -653,10 +655,17 @@ Tableau stabilizers_to_tableau( } } + if (num_qubits > 0) { + // Force size of resulting tableau to be correct. + GateTarget t = GateTarget::qubit(num_qubits - 1); + elimination_instructions.safe_append(CircuitInstruction{GateType::X, {}, &t}); + elimination_instructions.safe_append(CircuitInstruction{GateType::X, {}, &t}); + } + if (invert) { - return inverted; + return circuit_to_tableau(elimination_instructions.inverse(), false, false, false, true); } - return inverted.inverse(); + return circuit_to_tableau(elimination_instructions, false, false, false, true); } } // namespace stim diff --git a/src/stim/stabilizers/conversions.perf.cc b/src/stim/stabilizers/conversions.perf.cc index 0461355a7..50ad12c50 100644 --- a/src/stim/stabilizers/conversions.perf.cc +++ b/src/stim/stabilizers/conversions.perf.cc @@ -67,3 +67,51 @@ BENCHMARK(independent_to_disjoint_xyz_errors) { std::cout << "data dependence"; } } + +BENCHMARK(stabilizers_to_tableau) { + std::vector> offsets{ + {1, 0}, + {-1, 0}, + {0, 1}, + {0, -1}, + {3, 6}, + {-6, 3}, + }; + size_t w = 24; + size_t h = 12; + + auto normalize = [&](std::complex c) -> std::complex { + return {fmodf(c.real() + w*10, w), fmodf(c.imag() + h*10, h)}; + }; + auto q2i = [&](std::complex c) -> size_t { + c = normalize(c); + return (int)c.real() / 2 + c.imag() * (w / 2); + }; + + std::vector> stabilizers; + for (size_t x = 0; x < w; x++) { + for (size_t y = x % 2; y < h; y += 2) { + std::complex s{x % 2 ? -1.0f : +1.0f, 0.0f}; + std::complex c{(float)x, (float)y}; + stim::PauliString<64> ps(w * h / 2); + for (const auto &offset : offsets) { + size_t i = q2i(c + offset * s); + if (x % 2 == 0) { + ps.xs[i] = 1; + } else { + ps.zs[i] = 1; + } + } + stabilizers.push_back(ps); + } + } + + size_t dep = 0; + benchmark_go([&]() { + Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false); + dep += t.xs[0].zs[0]; + }).goal_millis(5); + if (dep == 99999999) { + std::cout << "data dependence"; + } +} From 1cdf45f884442d01806d6a8997b867a352e937c9 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Thu, 14 Mar 2024 23:10:42 -0700 Subject: [PATCH 5/6] Speed up `stim.Tableau.from_stabilizers` another 10x (#714) - 144 qubit case went from 5000 micros to 500 micros - 576 qubit case went from 2500 millis to 200 millis --- src/stim/stabilizers/conversions.inl | 120 ++++++++++++++++++----- src/stim/stabilizers/conversions.perf.cc | 53 +++++++++- 2 files changed, 146 insertions(+), 27 deletions(-) diff --git a/src/stim/stabilizers/conversions.inl b/src/stim/stabilizers/conversions.inl index 6ab5f2b17..24f5b3637 100644 --- a/src/stim/stabilizers/conversions.inl +++ b/src/stim/stabilizers/conversions.inl @@ -558,6 +558,18 @@ Tableau stabilizers_to_tableau( num_qubits = std::max(num_qubits, e.num_qubits); } + simd_bit_table buf_xs(num_qubits, stabilizers.size()); + simd_bit_table buf_zs(num_qubits, stabilizers.size()); + simd_bits buf_signs(stabilizers.size()); + simd_bits buf_workspace(stabilizers.size()); + for (size_t k = 0; k < stabilizers.size(); k++) { + memcpy(buf_xs[k].u8, stabilizers[k].xs.u8, stabilizers[k].xs.num_u8_padded()); + memcpy(buf_zs[k].u8, stabilizers[k].zs.u8, stabilizers[k].zs.num_u8_padded()); + buf_signs[k] = stabilizers[k].sign; + } + buf_xs = buf_xs.transposed(); + buf_zs = buf_zs.transposed(); + for (size_t k1 = 0; k1 < stabilizers.size(); k1++) { for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) { if (!stabilizers[k1].ref().commutes(stabilizers[k2])) { @@ -572,38 +584,28 @@ Tableau stabilizers_to_tableau( } } Circuit elimination_instructions; - PauliString buf(num_qubits); size_t used = 0; - for (const auto &e : stabilizers) { - if (e.num_qubits == num_qubits) { - buf = e; - } else { - buf.xs.clear(); - buf.zs.clear(); - memcpy(buf.xs.u8, e.xs.u8, e.xs.num_u8_padded()); - memcpy(buf.zs.u8, e.zs.u8, e.zs.num_u8_padded()); - buf.sign = e.sign; - } - buf.ref().do_circuit(elimination_instructions); - + for (size_t k = 0; k < stabilizers.size(); k++) { // Find a non-identity term in the Pauli string past the region used by other stabilizers. size_t pivot; for (pivot = used; pivot < num_qubits; pivot++) { - if (buf.xs[pivot] || buf.zs[pivot]) { + if (buf_xs[pivot][k] || buf_zs[pivot][k]) { break; } } // Check for incompatible / redundant stabilizers. if (pivot == num_qubits) { - if (buf.xs.not_zero()) { - throw std::invalid_argument("Some of the given stabilizers anticommute."); + for (size_t q = 0; q < num_qubits; q++) { + if (buf_xs[q][k]) { + throw std::invalid_argument("Some of the given stabilizers anticommute."); + } } - if (buf.sign) { + if (buf_signs[k]) { throw std::invalid_argument("Some of the given stabilizers contradict each other."); } - if (!allow_redundant && buf.zs.not_zero()) { + if (!allow_redundant) { throw std::invalid_argument( "Didn't specify allow_redundant=True but one of the given stabilizers is a product of the others. " "To allow redundant stabilizers, pass the argument allow_redundant=True."); @@ -612,21 +614,86 @@ Tableau stabilizers_to_tableau( } // Change pivot basis to the Z axis. - if (buf.xs[pivot]) { - GateType g = buf.zs[pivot] ? GateType::H_YZ : GateType::H; + if (buf_xs[pivot][k]) { + GateType g = buf_zs[pivot][k] ? GateType::H_YZ : GateType::H; GateTarget t = GateTarget::qubit(pivot); CircuitInstruction instruction{g, {}, &t}; elimination_instructions.safe_append(instruction); - buf.ref().do_instruction(instruction); + size_t q = pivot; + switch (g) { + case GateType::H_YZ: + buf_xs[q] ^= buf_zs[q]; + buf_workspace = buf_zs[q]; + buf_workspace.invert_bits(); + buf_workspace &= buf_xs[q]; + buf_signs ^= buf_workspace; + break; + case GateType::H: + buf_xs[q].swap_with(buf_zs[q]); + buf_workspace = buf_zs[q]; + buf_workspace &= buf_xs[q]; + buf_signs ^= buf_workspace; + break; + default: + throw std::invalid_argument("Unrecognized gate type."); + } } + // Cancel other terms in Pauli string. for (size_t q = 0; q < num_qubits; q++) { - int p = buf.xs[q] + buf.zs[q] * 2; + int p = buf_xs[q][k] + buf_zs[q][k] * 2; if (p && q != pivot) { std::array targets{GateTarget::qubit(pivot), GateTarget::qubit(q)}; - CircuitInstruction instruction{p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY, {}, targets}; + GateType g = p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY; + CircuitInstruction instruction{g, {}, targets}; elimination_instructions.safe_append(instruction); - buf.ref().do_instruction(instruction); + size_t q1 = targets[0].qubit_value(); + size_t q2 = targets[1].qubit_value(); + simd_bits_range_ref x1 = buf_xs[q1]; + simd_bits_range_ref z1 = buf_zs[q1]; + simd_bits_range_ref x2 = buf_xs[q2]; + simd_bits_range_ref z2 = buf_zs[q2]; + switch (g) { + case GateType::XCX: + buf_workspace = x1; + buf_workspace ^= x2; + buf_workspace &= z1; + buf_workspace &= z2; + buf_signs ^= buf_workspace; + x1 ^= z2; + x2 ^= z1; + break; + case GateType::XCY: + x1 ^= x2; + x1 ^= z2; + x2 ^= z1; + z2 ^= z1; + buf_workspace = x1; + buf_workspace |= x2; + buf_workspace.invert_bits(); + buf_workspace &= z1; + buf_workspace &= z2; + buf_signs ^= buf_workspace; + buf_workspace = z2; + buf_workspace.invert_bits(); + buf_workspace &= z1; + buf_workspace &= x1; + buf_workspace &= x2; + buf_signs ^= buf_workspace; + break; + case GateType::XCZ: + z2 ^= z1; + x1 ^= x2; + buf_workspace = z2; + buf_workspace ^= x1; + buf_workspace.invert_bits(); + buf_workspace &= x2; + buf_workspace &= z1; + buf_signs ^= buf_workspace; + break; + default: + throw std::invalid_argument("Unrecognized gate type."); + } } } @@ -635,13 +702,16 @@ Tableau stabilizers_to_tableau( std::array targets{GateTarget::qubit(pivot), GateTarget::qubit(used)}; CircuitInstruction instruction{GateType::SWAP, {}, targets}; elimination_instructions.safe_append(instruction); + buf_xs[pivot].swap_with(buf_xs[used]); + buf_zs[pivot].swap_with(buf_zs[used]); } // Fix sign. - if (buf.sign) { + if (buf_signs[k]) { GateTarget t = GateTarget::qubit(used); CircuitInstruction instruction{GateType::X, {}, &t}; elimination_instructions.safe_append(instruction); + buf_signs ^= buf_zs[used]; } used++; diff --git a/src/stim/stabilizers/conversions.perf.cc b/src/stim/stabilizers/conversions.perf.cc index 50ad12c50..bf6ca1b9f 100644 --- a/src/stim/stabilizers/conversions.perf.cc +++ b/src/stim/stabilizers/conversions.perf.cc @@ -68,7 +68,7 @@ BENCHMARK(independent_to_disjoint_xyz_errors) { } } -BENCHMARK(stabilizers_to_tableau) { +BENCHMARK(stabilizers_to_tableau_144) { std::vector> offsets{ {1, 0}, {-1, 0}, @@ -110,7 +110,56 @@ BENCHMARK(stabilizers_to_tableau) { benchmark_go([&]() { Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false); dep += t.xs[0].zs[0]; - }).goal_millis(5); + }).goal_micros(500); + if (dep == 99999999) { + std::cout << "data dependence"; + } +} + + +BENCHMARK(stabilizers_to_tableau_576) { + std::vector> offsets{ + {1, 0}, + {-1, 0}, + {0, 1}, + {0, -1}, + {3, 6}, + {-6, 3}, + }; + size_t w = 24*4; + size_t h = 12*4; + + auto normalize = [&](std::complex c) -> std::complex { + return {fmodf(c.real() + w*10, w), fmodf(c.imag() + h*10, h)}; + }; + auto q2i = [&](std::complex c) -> size_t { + c = normalize(c); + return (int)c.real() / 2 + c.imag() * (w / 2); + }; + + std::vector> stabilizers; + for (size_t x = 0; x < w; x++) { + for (size_t y = x % 2; y < h; y += 2) { + std::complex s{x % 2 ? -1.0f : +1.0f, 0.0f}; + std::complex c{(float)x, (float)y}; + stim::PauliString<64> ps(w * h / 2); + for (const auto &offset : offsets) { + size_t i = q2i(c + offset * s); + if (x % 2 == 0) { + ps.xs[i] = 1; + } else { + ps.zs[i] = 1; + } + } + stabilizers.push_back(ps); + } + } + + size_t dep = 0; + benchmark_go([&]() { + Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false); + dep += t.xs[0].zs[0]; + }).goal_millis(200); if (dep == 99999999) { std::cout << "data dependence"; } From 2161fc9f5151286f923925b429358f53c1844a8f Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Fri, 15 Mar 2024 16:54:00 -0700 Subject: [PATCH 6/6] Clean up `stim.Tableau.from_stabilizers` a bit (#716) - Switch to streaming computations of the gates - Improve the anticommutation check from O(n^2) additional work to O(n) additional work - Fix a sizing error and test that it's fixed - Cut costs on slowest test --- src/stim/mem/simd_word.test.cc | 4 +- src/stim/simulators/tableau_simulator.test.cc | 26 ++-- src/stim/stabilizers/conversions.inl | 120 ++++++++---------- src/stim/stabilizers/conversions.test.cc | 15 ++- src/stim/stabilizers/tableau.test.cc | 20 +-- src/stim/stabilizers/tableau_iter.test.cc | 2 +- 6 files changed, 95 insertions(+), 92 deletions(-) diff --git a/src/stim/mem/simd_word.test.cc b/src/stim/mem/simd_word.test.cc index 97ac430fb..0935bbf80 100644 --- a/src/stim/mem/simd_word.test.cc +++ b/src/stim/mem/simd_word.test.cc @@ -33,14 +33,14 @@ union WordOr64 { TEST_EACH_WORD_SIZE_W(simd_word_pick, popcount, { WordOr64 v; - auto n = sizeof(simd_word) * 8; + auto n = sizeof(simd_word) * 2; for (size_t expected = 0; expected <= n; expected++) { std::vector bits{}; for (size_t i = 0; i < n; i++) { bits.push_back(i < expected); } - for (size_t reps = 0; reps < 100; reps++) { + for (size_t reps = 0; reps < 10; reps++) { std::shuffle(bits.begin(), bits.end(), INDEPENDENT_TEST_RNG()); for (size_t i = 0; i < n; i++) { v.p[i >> 6] = 0; diff --git a/src/stim/simulators/tableau_simulator.test.cc b/src/stim/simulators/tableau_simulator.test.cc index 69cf56fec..34650fa1b 100644 --- a/src/stim/simulators/tableau_simulator.test.cc +++ b/src/stim/simulators/tableau_simulator.test.cc @@ -488,24 +488,24 @@ bool vec_sim_corroborates_measurement_process( TEST_EACH_WORD_SIZE_W(TableauSimulator, measurement_vs_vector_sim, { auto rng = INDEPENDENT_TEST_RNG(); - for (size_t k = 0; k < 10; k++) { + for (size_t k = 0; k < 5; k++) { auto state = Tableau::random(2, rng); ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0})); ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {1})); ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1})); } - for (size_t k = 0; k < 10; k++) { + for (size_t k = 0; k < 5; k++) { auto state = Tableau::random(4, rng); ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1})); ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {2, 1})); ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1, 2, 3})); } { - auto state = Tableau::random(12, rng); + auto state = Tableau::random(8, rng); ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1, 2, 3})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 10, 11})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {11, 5, 7})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 6, 7})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {7, 3, 4})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1, 2, 3, 4, 5, 6, 7})); } }) @@ -613,7 +613,7 @@ TEST_EACH_WORD_SIZE_W(TableauSimulator, correlated_error, { expected); int hits[3]{}; - size_t n = 10000; + size_t n = 5000; for (size_t k = 0; k < n; k++) { auto sample = TableauSimulator::sample_circuit( Circuit(R"circuit( @@ -628,8 +628,8 @@ TEST_EACH_WORD_SIZE_W(TableauSimulator, correlated_error, { hits[2] += sample[2]; } ASSERT_TRUE(0.45 * n < hits[0] && hits[0] < 0.55 * n); - ASSERT_TRUE((0.125 - 0.05) * n < hits[1] && hits[1] < (0.125 + 0.05) * n); - ASSERT_TRUE((0.28125 - 0.05) * n < hits[2] && hits[2] < (0.28125 + 0.05) * n); + ASSERT_TRUE((0.125 - 0.08) * n < hits[1] && hits[1] < (0.125 + 0.08) * n); + ASSERT_TRUE((0.28125 - 0.08) * n < hits[2] && hits[2] < (0.28125 + 0.08) * n); }) TEST_EACH_WORD_SIZE_W(TableauSimulator, quantum_cannot_control_classical, { @@ -865,12 +865,12 @@ TEST_EACH_WORD_SIZE_W(TableauSimulator, peek_bloch, { TEST_EACH_WORD_SIZE_W(TableauSimulator, paulis, { auto rng = INDEPENDENT_TEST_RNG(); - TableauSimulator sim1(INDEPENDENT_TEST_RNG(), 500); - TableauSimulator sim2(INDEPENDENT_TEST_RNG(), 500); - sim1.inv_state = Tableau::random(500, rng); + TableauSimulator sim1(INDEPENDENT_TEST_RNG(), 300); + TableauSimulator sim2(INDEPENDENT_TEST_RNG(), 300); + sim1.inv_state = Tableau::random(300, rng); sim2.inv_state = sim1.inv_state; - sim1.paulis(PauliString(500)); + sim1.paulis(PauliString(300)); ASSERT_EQ(sim1.inv_state, sim2.inv_state); sim1.paulis(PauliString(5)); ASSERT_EQ(sim1.inv_state, sim2.inv_state); diff --git a/src/stim/stabilizers/conversions.inl b/src/stim/stabilizers/conversions.inl index 24f5b3637..44e5ddd1a 100644 --- a/src/stim/stabilizers/conversions.inl +++ b/src/stim/stabilizers/conversions.inl @@ -558,10 +558,9 @@ Tableau stabilizers_to_tableau( num_qubits = std::max(num_qubits, e.num_qubits); } - simd_bit_table buf_xs(num_qubits, stabilizers.size()); - simd_bit_table buf_zs(num_qubits, stabilizers.size()); + simd_bit_table buf_xs(stabilizers.size(), num_qubits); + simd_bit_table buf_zs(stabilizers.size(), num_qubits); simd_bits buf_signs(stabilizers.size()); - simd_bits buf_workspace(stabilizers.size()); for (size_t k = 0; k < stabilizers.size(); k++) { memcpy(buf_xs[k].u8, stabilizers[k].xs.u8, stabilizers[k].xs.num_u8_padded()); memcpy(buf_zs[k].u8, stabilizers[k].zs.u8, stabilizers[k].zs.num_u8_padded()); @@ -570,19 +569,6 @@ Tableau stabilizers_to_tableau( buf_xs = buf_xs.transposed(); buf_zs = buf_zs.transposed(); - for (size_t k1 = 0; k1 < stabilizers.size(); k1++) { - for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) { - if (!stabilizers[k1].ref().commutes(stabilizers[k2])) { - std::stringstream ss; - ss << "Some of the given stabilizers anticommute.\n"; - ss << "For example:\n "; - ss << stabilizers[k1]; - ss << "\nanticommutes with\n"; - ss << stabilizers[k2] << "\n"; - throw std::invalid_argument(ss.str()); - } - } - } Circuit elimination_instructions; size_t used = 0; @@ -597,11 +583,6 @@ Tableau stabilizers_to_tableau( // Check for incompatible / redundant stabilizers. if (pivot == num_qubits) { - for (size_t q = 0; q < num_qubits; q++) { - if (buf_xs[q][k]) { - throw std::invalid_argument("Some of the given stabilizers anticommute."); - } - } if (buf_signs[k]) { throw std::invalid_argument("Some of the given stabilizers contradict each other."); } @@ -620,19 +601,21 @@ Tableau stabilizers_to_tableau( CircuitInstruction instruction{g, {}, &t}; elimination_instructions.safe_append(instruction); size_t q = pivot; + simd_bits_range_ref xs1 = buf_xs[q]; + simd_bits_range_ref zs1 = buf_zs[q]; + simd_bits_range_ref ss = buf_signs; switch (g) { case GateType::H_YZ: - buf_xs[q] ^= buf_zs[q]; - buf_workspace = buf_zs[q]; - buf_workspace.invert_bits(); - buf_workspace &= buf_xs[q]; - buf_signs ^= buf_workspace; + ss.for_each_word(xs1, zs1, [](auto &s, auto &x, auto &z) { + x ^= z; + s ^= z.andnot(x); + }); break; case GateType::H: - buf_xs[q].swap_with(buf_zs[q]); - buf_workspace = buf_zs[q]; - buf_workspace &= buf_xs[q]; - buf_signs ^= buf_workspace; + ss.for_each_word(xs1, zs1, [](auto &s, auto &x, auto &z) { + std::swap(x, z); + s ^= x & z; + }); break; default: throw std::invalid_argument("Unrecognized gate type."); @@ -649,47 +632,34 @@ Tableau stabilizers_to_tableau( elimination_instructions.safe_append(instruction); size_t q1 = targets[0].qubit_value(); size_t q2 = targets[1].qubit_value(); - simd_bits_range_ref x1 = buf_xs[q1]; - simd_bits_range_ref z1 = buf_zs[q1]; - simd_bits_range_ref x2 = buf_xs[q2]; - simd_bits_range_ref z2 = buf_zs[q2]; + simd_bits_range_ref ss = buf_signs; + simd_bits_range_ref xs1 = buf_xs[q1]; + simd_bits_range_ref zs1 = buf_zs[q1]; + simd_bits_range_ref xs2 = buf_xs[q2]; + simd_bits_range_ref zs2 = buf_zs[q2]; switch (g) { case GateType::XCX: - buf_workspace = x1; - buf_workspace ^= x2; - buf_workspace &= z1; - buf_workspace &= z2; - buf_signs ^= buf_workspace; - x1 ^= z2; - x2 ^= z1; + ss.for_each_word(xs1, zs1, xs2, zs2, [](auto &s, auto &x1, auto &z1, auto &x2, auto &z2) { + s ^= (x1 ^ x2) & z1 & z2; + x1 ^= z2; + x2 ^= z1; + }); break; case GateType::XCY: - x1 ^= x2; - x1 ^= z2; - x2 ^= z1; - z2 ^= z1; - buf_workspace = x1; - buf_workspace |= x2; - buf_workspace.invert_bits(); - buf_workspace &= z1; - buf_workspace &= z2; - buf_signs ^= buf_workspace; - buf_workspace = z2; - buf_workspace.invert_bits(); - buf_workspace &= z1; - buf_workspace &= x1; - buf_workspace &= x2; - buf_signs ^= buf_workspace; + ss.for_each_word(xs1, zs1, xs2, zs2, [](auto &s, auto &x1, auto &z1, auto &x2, auto &z2) { + x1 ^= x2 ^ z2; + x2 ^= z1; + z2 ^= z1; + s ^= x1.andnot(z1) & x2.andnot(z2); + s ^= x1 & z1 & z2.andnot(x2); + }); break; case GateType::XCZ: - z2 ^= z1; - x1 ^= x2; - buf_workspace = z2; - buf_workspace ^= x1; - buf_workspace.invert_bits(); - buf_workspace &= x2; - buf_workspace &= z1; - buf_signs ^= buf_workspace; + ss.for_each_word(xs1, zs1, xs2, zs2, [](auto &s, auto &x1, auto &z1, auto &x2, auto &z2) { + z2 ^= z1; + x1 ^= x2; + s ^= (z2 ^ x1).andnot(z1 & x2); + }); break; default: throw std::invalid_argument("Unrecognized gate type."); @@ -717,6 +687,26 @@ Tableau stabilizers_to_tableau( used++; } + // All stabilizers will have been mapped into Z products, if they commuted. + for (size_t q = 0; q < num_qubits; q++) { + if (buf_xs[q].not_zero()) { + for (size_t k1 = 0; k1 < stabilizers.size(); k1++) { + for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) { + if (!stabilizers[k1].ref().commutes(stabilizers[k2])) { + std::stringstream ss; + ss << "Some of the given stabilizers anticommute.\n"; + ss << "For example:\n "; + ss << stabilizers[k1]; + ss << "\nanticommutes with\n"; + ss << stabilizers[k2] << "\n"; + throw std::invalid_argument(ss.str()); + } + } + } + throw std::invalid_argument("The given stabilizers commute but the solver failed in a way that suggests they anticommute. Please report this as a bug."); + } + } + if (used < num_qubits) { if (!allow_underconstrained) { throw std::invalid_argument( diff --git a/src/stim/stabilizers/conversions.test.cc b/src/stim/stabilizers/conversions.test.cc index 7f213e585..633d23adf 100644 --- a/src/stim/stabilizers/conversions.test.cc +++ b/src/stim/stabilizers/conversions.test.cc @@ -217,7 +217,7 @@ TEST_EACH_WORD_SIZE_W(conversions, stabilizer_state_vector_to_circuit_basic, { TEST_EACH_WORD_SIZE_W(conversions, stabilizer_state_vector_to_circuit_fuzz_round_trip, { auto rng = INDEPENDENT_TEST_RNG(); for (const auto &little_endian : std::vector{false, true}) { - for (size_t n = 0; n < 10; n++) { + for (size_t n = 0; n < 5; n++) { // Pick a random stabilizer state. TableauSimulator sim(INDEPENDENT_TEST_RNG(), n); sim.inv_state = Tableau::random(n, rng); @@ -663,6 +663,19 @@ TEST_EACH_WORD_SIZE_W(conversions, stabilizer_to_tableau_detect_anticommutation, ASSERT_THROW({ stabilizers_to_tableau(input_stabilizers, false, false, false); }, std::invalid_argument); }) +TEST_EACH_WORD_SIZE_W(conversions, stabilizer_to_tableau_size_affecting_redundancy, { + std::vector> input_stabilizers; + input_stabilizers.push_back(PauliString::from_str("X_")); + input_stabilizers.push_back(PauliString::from_str("_X")); + for (size_t k = 0; k < 150; k++) { + input_stabilizers.push_back(PauliString::from_str("__")); + } + auto t = stabilizers_to_tableau(input_stabilizers, true, true, false); + ASSERT_EQ(t.num_qubits, 2); + ASSERT_EQ(t.zs[0], PauliString::from_str("X_")); + ASSERT_EQ(t.zs[1], PauliString::from_str("_X")); +}) + TEST(conversions, independent_to_disjoint_xyz_errors) { double out_x; double out_y; diff --git a/src/stim/stabilizers/tableau.test.cc b/src/stim/stabilizers/tableau.test.cc index 416cc47fd..b9193ba32 100644 --- a/src/stim/stabilizers/tableau.test.cc +++ b/src/stim/stabilizers/tableau.test.cc @@ -722,15 +722,15 @@ TEST_EACH_WORD_SIZE_W(tableau, expand_pad_equals, { TEST_EACH_WORD_SIZE_W(tableau, transposed_access, { auto rng = INDEPENDENT_TEST_RNG(); - size_t n = 1000; + size_t n = W > 256 ? 1000 : 400; Tableau t(n); auto m = t.xs.xt.data.num_bits_padded(); t.xs.xt.data.randomize(m, rng); t.xs.zt.data.randomize(m, rng); t.zs.xt.data.randomize(m, rng); t.zs.zt.data.randomize(m, rng); - for (size_t inp_qubit = 0; inp_qubit < 1000; inp_qubit += 99) { - for (size_t out_qubit = 0; out_qubit < 1000; out_qubit += 99) { + for (size_t inp_qubit = 0; inp_qubit < n; inp_qubit += 99) { + for (size_t out_qubit = 0; out_qubit < n; out_qubit += 99) { bool bxx = t.xs.xt[inp_qubit][out_qubit]; bool bxz = t.xs.zt[inp_qubit][out_qubit]; bool bzx = t.zs.xt[inp_qubit][out_qubit]; @@ -885,26 +885,26 @@ TEST_EACH_WORD_SIZE_W(tableau, transposed_xz_input, { TEST_EACH_WORD_SIZE_W(tableau, direct_sum, { auto rng = INDEPENDENT_TEST_RNG(); - auto t1 = Tableau::random(260, rng); - auto t2 = Tableau::random(270, rng); + auto t1 = Tableau::random(160, rng); + auto t2 = Tableau::random(170, rng); auto t3 = t1; t3 += t2; ASSERT_EQ(t3, t1 + t2); PauliString p1 = t1.xs[5]; - p1.ensure_num_qubits(260 + 270, 1.0); + p1.ensure_num_qubits(160 + 170, 1.0); ASSERT_EQ(t3.xs[5], p1); std::string p2 = t2.xs[6].str(); - std::string p3 = t3.xs[266].str(); + std::string p3 = t3.xs[166].str(); ASSERT_EQ(p2[0], p3[0]); p2 = p2.substr(1); p3 = p3.substr(1); - for (size_t k = 0; k < 260; k++) { + for (size_t k = 0; k < 160; k++) { ASSERT_EQ(p3[k], '_'); } - for (size_t k = 0; k < 270; k++) { - ASSERT_EQ(p3[260 + k], p2[k]); + for (size_t k = 0; k < 170; k++) { + ASSERT_EQ(p3[160 + k], p2[k]); } }) diff --git a/src/stim/stabilizers/tableau_iter.test.cc b/src/stim/stabilizers/tableau_iter.test.cc index 60f583250..731d2c00d 100644 --- a/src/stim/stabilizers/tableau_iter.test.cc +++ b/src/stim/stabilizers/tableau_iter.test.cc @@ -135,7 +135,7 @@ TEST_EACH_WORD_SIZE_W(tableau_iter, iter_tableau, { ASSERT_EQ(n1, 6); ASSERT_EQ(s1, 24); ASSERT_EQ(n2, 720); - ASSERT_EQ(n3, 1451520); + // ASSERT_EQ(n3, 1451520); // Note: disabled because it takes 2-3 seconds. }) TEST_EACH_WORD_SIZE_W(tableau_iter, iter_tableau_distinct, {