Skip to content

Commit 009b5ac

Browse files
authored
Support marginalising through a MinibatchRandomVariable
1 parent 6b2aa67 commit 009b5ac

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

pymc_extras/model/marginal/graph_analysis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pymc import SymbolicRandomVariable
77
from pymc.model.fgraph import ModelVar
8+
from pymc.variational.minibatch_rv import MinibatchRandomVariable
89
from pytensor.graph import Variable, ancestors
910
from pytensor.graph.basic import io_toposort
1011
from pytensor.tensor import TensorType, TensorVariable
@@ -313,6 +314,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars)
313314

314315
var_dims[node.outputs[0]] = output_dims
315316

317+
elif isinstance(node.op, MinibatchRandomVariable):
318+
var_dims[node.outputs[0]] = inputs_dims[0]
319+
316320
else:
317321
raise NotImplementedError(f"Marginalization through operation {node} not supported.")
318322

tests/model/marginal/test_graph_analysis.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from pymc.distributions import CustomDist
5+
from pymc.variational.minibatch_rv import create_minibatch_rv
56
from pytensor.tensor.type_other import NoneTypeT
67

78
from pymc_extras.model.marginal.graph_analysis import (
@@ -160,6 +161,13 @@ def test_random_variable(self):
160161
with pytest.raises(ValueError, match="Use of known dimensions"):
161162
subgraph_batch_dim_connection(inp, [invalid_out])
162163

164+
def test_minibatched_random_variable(self):
165+
inp = pt.tensor(shape=(4, 3, 2))
166+
out1 = pt.random.normal(loc=inp)
167+
out2 = create_minibatch_rv(out1, total_size=(10, 10, 10))
168+
[dims1] = subgraph_batch_dim_connection(inp, [out2])
169+
assert dims1 == (0, 1, 2)
170+
163171
def test_symbolic_random_variable(self):
164172
inp = pt.tensor(shape=(4, 3, 2))
165173

0 commit comments

Comments
 (0)