From 9b787864acd905fa6b3d3190d242cc60cb867a75 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 14 Dec 2021 23:58:46 -0600 Subject: [PATCH] Add an xfail for newer JAX versions that change sampler size behavior --- tests/link/test_jax.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/link/test_jax.py b/tests/link/test_jax.py index e8228f6217..6360071f26 100644 --- a/tests/link/test_jax.py +++ b/tests/link/test_jax.py @@ -1207,6 +1207,10 @@ def test_extra_ops_omni(): compare_jax_and_py(fgraph, []) +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.26"), + reason="JAX samplers require concrete/static shape values?", +) @pytest.mark.parametrize( "at_dist, dist_params, rng, size", [