diff --git a/WORKSPACE b/WORKSPACE index 4a53785..0061802 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -65,3 +65,11 @@ http_archive( urls = ["https://github.com/bazelbuild/platforms/archive/refs/tags/0.0.6.zip"], strip_prefix = "platforms-0.0.6", ) + +http_archive( + name = "stim_py", + build_file = "//external:stim_py.BUILD", + sha256 = "95236006859d6754be99629d4fb44788e742e962ac8c59caad421ca088f7350e", + strip_prefix = "stim-1.15.0", + urls = ["https://github.com/quantumlib/Stim/releases/download/v1.15.0/stim-1.15.0.tar.gz"], +) diff --git a/external/stim_py.BUILD b/external/stim_py.BUILD new file mode 100644 index 0000000..7c14e00 --- /dev/null +++ b/external/stim_py.BUILD @@ -0,0 +1,64 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +SOURCE_FILES_NO_MAIN = glob( + [ + "src/**/*.cc", + "src/**/*.h", + "src/**/*.inl", + ], + exclude = glob([ + "src/**/*.test.cc", + "src/**/*.test.h", + "src/**/*.perf.cc", + "src/**/*.perf.h", + "src/**/*.pybind.cc", + "src/**/*.pybind.h", + "src/**/main.cc", + ]), +) + +PYBIND_MODULES = [ + "src/stim/py/march.pybind.cc", + "src/stim/py/stim.pybind.cc", +] + +PYBIND_FILES_WITHOUT_MODULES = glob( + [ + "src/**/*.pybind.cc", + "src/**/*.pybind.h", + ], + exclude=PYBIND_MODULES, +) + + + +pybind_library( + name = "stim_pybind_lib", + srcs = SOURCE_FILES_NO_MAIN + PYBIND_FILES_WITHOUT_MODULES, + copts = [ + "-O3", + "-std=c++20", + "-fvisibility=hidden", + "-march=native", + "-DVERSION_INFO=0.0.dev0", + ], + includes = ["src/"], + visibility = ["//visibility:public"], +) + +pybind_extension( + name = "stim", + srcs = PYBIND_MODULES, + copts = [ + "-O3", + "-std=c++20", + "-fvisibility=hidden", + "-march=native", + "-DSTIM_PYBIND11_MODULE_NAME=stim", + "-DVERSION_INFO=0.0.dev0", + ], + deps=[":stim_pybind_lib"], + includes = ["src/"], + visibility = ["//visibility:public"], +) diff --git a/src/BUILD b/src/BUILD index fea2db0..ecbc22a 100644 --- a/src/BUILD +++ b/src/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # load("@benchmark//:benchmark.bzl", "cc_benchmark") -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library") load("@rules_python//python:defs.bzl", "py_library") package(default_visibility = ["//visibility:public"]) @@ -64,22 +64,35 @@ cc_library( ) +pybind_library( + name = "tesseract_decoder_pybind", + srcs = [ + "common.pybind.h", + ], + deps = [ + ":libcommon", + "@stim_py//:stim_pybind_lib", + ], +) + pybind_extension( name = "tesseract_decoder", srcs = [ - "common.pybind.h", "tesseract.pybind.cc", ], deps = [ - ":libcommon", + ":tesseract_decoder_pybind", + "@stim_py//:stim", ], ) py_library( name="lib_tesseract_decoder", - data=[":tesseract_decoder"], imports=["src"], + deps=[ + ":tesseract_decoder", + ], ) diff --git a/src/common.pybind.h b/src/common.pybind.h index 852c823..791a9ed 100644 --- a/src/common.pybind.h +++ b/src/common.pybind.h @@ -1,41 +1,54 @@ #ifndef TESSERACT_COMMON_PY_H #define TESSERACT_COMMON_PY_H +#include + #include #include #include -#include +#include "src/stim/dem/dem_instruction.pybind.h" +#include "stim/dem/detector_error_model_target.pybind.h" #include "common.h" namespace py = pybind11; -void add_common_module(py::module &root) { - auto m = root.def_submodule("common", "classes commonly used by the decoder"); - - // TODO: add as_dem_instruction_targets - py::class_(m, "Symptom") - .def(py::init, common::ObservablesMask>(), - py::arg("detectors") = std::vector(), - py::arg("observables") = 0) - .def_readwrite("detectors", &common::Symptom::detectors) - .def_readwrite("observables", &common::Symptom::observables) - .def("__str__", &common::Symptom::str) - .def(py::self == py::self) - .def(py::self != py::self); - - // TODO: add constructor with stim::DemInstruction. - py::class_(m, "Error") - .def_readwrite("likelihood_cost", &common::Error::likelihood_cost) - .def_readwrite("probability", &common::Error::probability) - .def_readwrite("symptom", &common::Error::symptom) - .def("__str__", &common::Error::str) - .def(py::init<>()) - .def(py::init &, common::ObservablesMask, - std::vector &>()) - .def(py::init &, common::ObservablesMask, - std::vector &>()); +void add_common_module(py::module &root) +{ + auto m = root.def_submodule("common", "classes commonly used by the decoder"); + + py::class_(m, "Symptom") + .def(py::init, common::ObservablesMask>(), + py::arg("detectors") = std::vector(), + py::arg("observables") = 0) + .def_readwrite("detectors", &common::Symptom::detectors) + .def_readwrite("observables", &common::Symptom::observables) + .def("__str__", &common::Symptom::str) + .def(py::self == py::self) + .def(py::self != py::self) + .def("as_dem_instruction_targets", [](common::Symptom s) + { + std::vector ret; + for(auto & t : s.as_dem_instruction_targets()) ret.emplace_back(t); + return ret; }); + + py::class_(m, "Error") + .def_readwrite("likelihood_cost", &common::Error::likelihood_cost) + .def_readwrite("probability", &common::Error::probability) + .def_readwrite("symptom", &common::Error::symptom) + .def("__str__", &common::Error::str) + .def(py::init<>()) + .def(py::init &, common::ObservablesMask, + std::vector &>()) + .def(py::init &, common::ObservablesMask, + std::vector &>()) + .def(py::init([](stim_pybind::ExposedDemInstruction edi) + { return new common::Error(edi.as_dem_instruction()); })); + + m.def("merge_identical_errors", &common::merge_identical_errors); + m.def("remove_zero_probability_errors", &common::remove_zero_probability_errors); + m.def("dem_from_counts", &common::dem_from_counts); } #endif diff --git a/src/py/BUILD b/src/py/BUILD index 83b0f31..240ad9b 100644 --- a/src/py/BUILD +++ b/src/py/BUILD @@ -1,5 +1,16 @@ +load("@rules_python//python:py_test.bzl", "py_test") load("@rules_python//python:pip.bzl", "compile_pip_requirements") +py_test( + name = "common_test", + srcs = ["common_test.py"], + visibility = ["//:__subpackages__"], + deps = [ + "@pypi//pytest", + "//src:lib_tesseract_decoder", + ], +) + compile_pip_requirements( name = "requirements", src = "requirements.in", diff --git a/src/py/common_test.py b/src/py/common_test.py new file mode 100644 index 0000000..a39c3c2 --- /dev/null +++ b/src/py/common_test.py @@ -0,0 +1,51 @@ +import pytest +import stim + +from src import tesseract_decoder + + +def test_as_dem_instruction_targets(): + s = tesseract_decoder.common.Symptom([1, 2], 4324) + dits = s.as_dem_instruction_targets() + assert dits == [ + stim.DemTarget("D1"), + stim.DemTarget("D2"), + stim.DemTarget("L2"), + stim.DemTarget("L5"), + stim.DemTarget("L6"), + stim.DemTarget("L7"), + stim.DemTarget("L12"), + ] + + +def test_error_from_dem_instruction(): + di = stim.DemInstruction("error", [0.125], [stim.target_logical_observable_id(3)]) + error = tesseract_decoder.common.Error(di) + + assert str(error) == "Error{cost=1.945910, symptom=Symptom{}}" + + +def test_merge_identical_errors(): + dem = stim.DetectorErrorModel() + assert isinstance( + tesseract_decoder.common.merge_identical_errors(dem), stim.DetectorErrorModel + ) + + +def test_remove_zero_probability_errors(): + dem = stim.DetectorErrorModel() + assert isinstance( + tesseract_decoder.common.remove_zero_probability_errors(dem), + stim.DetectorErrorModel, + ) + + +def test_dem_from_counts(): + dem = stim.DetectorErrorModel() + assert isinstance( + tesseract_decoder.common.dem_from_counts(dem, [], 3), stim.DetectorErrorModel + ) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/src/py/requirements.in b/src/py/requirements.in index 0fa4ec2..031a56a 100644 --- a/src/py/requirements.in +++ b/src/py/requirements.in @@ -1 +1,2 @@ stim +pytest diff --git a/src/py/requirements_lock.txt b/src/py/requirements_lock.txt index 6278058..7f96c6e 100644 --- a/src/py/requirements_lock.txt +++ b/src/py/requirements_lock.txt @@ -4,6 +4,10 @@ # # bazel run //src/py:requirements.update # +iniconfig==2.1.0 \ + --hash=sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7 \ + --hash=sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760 + # via pytest numpy==2.2.6 \ --hash=sha256:038613e9fb8c72b0a41f025a7e4c3f0b7a1b5d768ece4796b674c8f3fe13efff \ --hash=sha256:0678000bb9ac1475cd454c6b8c799206af8107e310843532b04d49649c717a47 \ @@ -61,6 +65,22 @@ numpy==2.2.6 \ --hash=sha256:fe27749d33bb772c80dcd84ae7e8df2adc920ae8297400dabec45f0dedb3f6de \ --hash=sha256:fee4236c876c4e8369388054d02d0e9bb84821feb1a64dd59e137e6511a551f8 # via stim +packaging==25.0 \ + --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ + --hash=sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f + # via pytest +pluggy==1.6.0 \ + --hash=sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3 \ + --hash=sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746 + # via pytest +pygments==2.19.1 \ + --hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \ + --hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c + # via pytest +pytest==8.4.0 \ + --hash=sha256:14d920b48472ea0dbf68e45b96cd1ffda4705f33307dcc86c676c1b5104838a6 \ + --hash=sha256:f40f825768ad76c0977cbacdf1fd37c6f7a468e460ea6a0636078f8972d4517e + # via -r src/py/requirements.in stim==1.15.0 \ --hash=sha256:0bb3757c69c9b16fd24ff7400b5cddb22017c4cae84fc4b7b73f84373cb03c00 \ --hash=sha256:190c5a3c9cecdfae3302d02057d1ed6d9ce7910d2bcc2ff375807d8f8ec5494d \ diff --git a/src/tesseract.pybind.cc b/src/tesseract.pybind.cc index 33dcc62..e78f1c5 100644 --- a/src/tesseract.pybind.cc +++ b/src/tesseract.pybind.cc @@ -3,4 +3,8 @@ #include "common.pybind.h" #include "pybind11/detail/common.h" -PYBIND11_MODULE(tesseract_py, m) { add_common_module(m); } +PYBIND11_MODULE(tesseract_decoder, m) +{ + py::module::import("stim"); + add_common_module(m); +}