From 892d5c762277e42e3eb6bd9da40a48cdc02f1598 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 8 Jun 2025 02:52:52 +0200 Subject: [PATCH 1/2] Support marginalising through a MinibatchRandomVariable --- pymc_extras/model/marginal/graph_analysis.py | 4 ++++ tests/model/marginal/test_graph_analysis.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/pymc_extras/model/marginal/graph_analysis.py b/pymc_extras/model/marginal/graph_analysis.py index 422177dd4..e619159c3 100644 --- a/pymc_extras/model/marginal/graph_analysis.py +++ b/pymc_extras/model/marginal/graph_analysis.py @@ -5,6 +5,7 @@ from pymc import SymbolicRandomVariable from pymc.model.fgraph import ModelVar +from pymc.variational.minibatch_rv import MinibatchRandomVariable from pytensor.graph import Variable, ancestors from pytensor.graph.basic import io_toposort from pytensor.tensor import TensorType, TensorVariable @@ -140,6 +141,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) elif isinstance(node.op, ModelVar): var_dims[node.outputs[0]] = inputs_dims[0] + elif isinstance(node.op, MinibatchRandomVariable): + var_dims[node.outputs[0]] = inputs_dims[0] + elif isinstance(node.op, DimShuffle): [input_dims] = inputs_dims output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order) diff --git a/tests/model/marginal/test_graph_analysis.py b/tests/model/marginal/test_graph_analysis.py index 57affd0ee..e835a4d03 100644 --- a/tests/model/marginal/test_graph_analysis.py +++ b/tests/model/marginal/test_graph_analysis.py @@ -2,6 +2,7 @@ import pytest from pymc.distributions import CustomDist +from pymc.variational.minibatch_rv import create_minibatch_rv from pytensor.tensor.type_other import NoneTypeT from pymc_extras.model.marginal.graph_analysis import ( @@ -160,6 +161,13 @@ def test_random_variable(self): with pytest.raises(ValueError, match="Use of known dimensions"): subgraph_batch_dim_connection(inp, [invalid_out]) + def test_minibatched_random_variable(self): + inp = pt.tensor(shape=(4, 3, 2)) + out1 = pt.random.normal(loc=inp) + out2 = create_minibatch_rv(out1, total_size=(10, 10, 10)) + [dims1] = subgraph_batch_dim_connection(inp, [out2]) + assert dims1 == (0, 1, 2) + def test_symbolic_random_variable(self): inp = pt.tensor(shape=(4, 3, 2)) From e515214bed99a5550b35882ff368551a89a90e3a Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 9 Jun 2025 16:19:31 +0200 Subject: [PATCH 2/2] Moved to last case --- pymc_extras/model/marginal/graph_analysis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_extras/model/marginal/graph_analysis.py b/pymc_extras/model/marginal/graph_analysis.py index e619159c3..6a7a7f874 100644 --- a/pymc_extras/model/marginal/graph_analysis.py +++ b/pymc_extras/model/marginal/graph_analysis.py @@ -141,9 +141,6 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) elif isinstance(node.op, ModelVar): var_dims[node.outputs[0]] = inputs_dims[0] - elif isinstance(node.op, MinibatchRandomVariable): - var_dims[node.outputs[0]] = inputs_dims[0] - elif isinstance(node.op, DimShuffle): [input_dims] = inputs_dims output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order) @@ -317,6 +314,9 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) var_dims[node.outputs[0]] = output_dims + elif isinstance(node.op, MinibatchRandomVariable): + var_dims[node.outputs[0]] = inputs_dims[0] + else: raise NotImplementedError(f"Marginalization through operation {node} not supported.")