diff --git a/pymc_extras/model/marginal/graph_analysis.py b/pymc_extras/model/marginal/graph_analysis.py index 422177dd4..6a7a7f874 100644 --- a/pymc_extras/model/marginal/graph_analysis.py +++ b/pymc_extras/model/marginal/graph_analysis.py @@ -5,6 +5,7 @@ from pymc import SymbolicRandomVariable from pymc.model.fgraph import ModelVar +from pymc.variational.minibatch_rv import MinibatchRandomVariable from pytensor.graph import Variable, ancestors from pytensor.graph.basic import io_toposort from pytensor.tensor import TensorType, TensorVariable @@ -313,6 +314,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) var_dims[node.outputs[0]] = output_dims + elif isinstance(node.op, MinibatchRandomVariable): + var_dims[node.outputs[0]] = inputs_dims[0] + else: raise NotImplementedError(f"Marginalization through operation {node} not supported.") diff --git a/tests/model/marginal/test_graph_analysis.py b/tests/model/marginal/test_graph_analysis.py index 57affd0ee..e835a4d03 100644 --- a/tests/model/marginal/test_graph_analysis.py +++ b/tests/model/marginal/test_graph_analysis.py @@ -2,6 +2,7 @@ import pytest from pymc.distributions import CustomDist +from pymc.variational.minibatch_rv import create_minibatch_rv from pytensor.tensor.type_other import NoneTypeT from pymc_extras.model.marginal.graph_analysis import ( @@ -160,6 +161,13 @@ def test_random_variable(self): with pytest.raises(ValueError, match="Use of known dimensions"): subgraph_batch_dim_connection(inp, [invalid_out]) + def test_minibatched_random_variable(self): + inp = pt.tensor(shape=(4, 3, 2)) + out1 = pt.random.normal(loc=inp) + out2 = create_minibatch_rv(out1, total_size=(10, 10, 10)) + [dims1] = subgraph_batch_dim_connection(inp, [out2]) + assert dims1 == (0, 1, 2) + def test_symbolic_random_variable(self): inp = pt.tensor(shape=(4, 3, 2))