diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index a0efaec0750..d59e3025dec 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -221,9 +221,13 @@ def make_runmeta_and_point_fn( (-1 if s is None else s) for s in (shape or []) ] + dt = np.dtype(dtype).name + # Object types will be pickled by the ChainRecordAdapter! + if dt == "object": + dt = "str" svar = mcb.Variable( name=sname, - dtype=np.dtype(dtype).name, + dtype=dt, shape=sshape, undefined_ndim=shape is None, ) diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index 2989030fa77..aa2e26ba01c 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -120,6 +120,19 @@ def test_make_runmeta_and_point_fn(simple_model): assert not vars["vector_interval__"].is_deterministic assert vars["matrix"].is_deterministic assert len(rmeta.sample_stats) == len(step.stats_dtypes[0]) + + with simple_model: + step = pm.NUTS() + rmeta, point_fn = make_runmeta_and_point_fn( + initial_point=simple_model.initial_point(), + step=step, + model=simple_model, + ) + assert isinstance(rmeta, mcb.RunMeta) + svars = {s.name: s for s in rmeta.sample_stats} + # Unbeknownst to McBackend, object stats are pickled to str + assert "sampler_0__warning" in svars + assert svars["sampler_0__warning"].dtype == "str" pass