diff --git a/tests/link/test_jax.py b/tests/link/test_jax.py index e8228f6217..6360071f26 100644 --- a/tests/link/test_jax.py +++ b/tests/link/test_jax.py @@ -1207,6 +1207,10 @@ def test_extra_ops_omni(): compare_jax_and_py(fgraph, []) +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.26"), + reason="JAX samplers require concrete/static shape values?", +) @pytest.mark.parametrize( "at_dist, dist_params, rng, size", [