Skip to content

Model marginalization doesn't work through Minibatch nodes #492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
zaxtax opened this issue May 30, 2025 · 1 comment · Fixed by #498
Closed

Model marginalization doesn't work through Minibatch nodes #492

zaxtax opened this issue May 30, 2025 · 1 comment · Fixed by #498

Comments

@zaxtax
Copy link
Contributor

zaxtax commented May 30, 2025

When I use Minibatch in the follow model:

import pymc as pm
import numpy as np
from pymc_extras.model.marginal.marginal_model import marginalize

data = np.random.normal(size=10_000)

with pm.Model() as model:
    d = pm.Data("data", data)
    batched_data = pm.Minibatch(d, batch_size=100)
    x = pm.Normal("x", 0., 1.)
    b = pm.Bernoulli("b", 0.5, shape=(100,))
    y = pm.Normal("y", b*x, total_size=len(data), observed=batched_data)

model2 = marginalize(model, [b])

I get the following error:

NotImplementedError                       Traceback (most recent call last)
File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:561, in replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, dependent_rvs, input_rvs)
    560 try:
--> 561     dependent_rvs_dim_connections = subgraph_batch_dim_connection(
    562         rv_to_marginalize, dependent_rvs
    563     )
    564 except (ValueError, NotImplementedError) as e:
    565     # For the perspective of the user this is a NotImplementedError

File ~/upstream/pymc-extras/pymc_extras/model/marginal/graph_analysis.py:365, in subgraph_batch_dim_connection(input_var, output_vars)
    364 var_dims = {input_var: tuple(range(input_var.type.ndim))}
--> 365 var_dims = _subgraph_batch_dim_connection(var_dims, [input_var], output_vars)
    366 ret = []

File ~/upstream/pymc-extras/pymc_extras/model/marginal/graph_analysis.py:317, in _subgraph_batch_dim_connection(var_dims, input_vars, output_vars)
    316     else:
--> 317         raise NotImplementedError(f"Marginalization through operation {node} not supported.")
    319 return var_dims

NotImplementedError: Marginalization through operation MinibatchRandomVariable(y, 10000) not supported.

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
Cell In[21], line 2
      1 from pymc_extras.model.marginal.marginal_model import marginalize
----> 2 model2 = marginalize(model, [b])

File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:244, in marginalize(model, rvs_to_marginalize)
    237     other_direct_rv_ancestors = [
    238         rv
    239         for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
    240         if rv is not rv_to_marginalize
    241     ]
    242     input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))
--> 244     replace_finite_discrete_marginal_subgraph(fg, rv_to_marginalize, dependent_rvs, input_rvs)
    246 return model_from_fgraph(fg, mutate_fgraph=True)

File ~/upstream/pymc-extras/pymc_extras/model/marginal/marginal_model.py:566, in replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, dependent_rvs, input_rvs)
    561     dependent_rvs_dim_connections = subgraph_batch_dim_connection(
    562         rv_to_marginalize, dependent_rvs
    563     )
    564 except (ValueError, NotImplementedError) as e:
    565     # For the perspective of the user this is a NotImplementedError
--> 566     raise NotImplementedError(
    567         "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
    568         "You can try splitting the marginalized RV into separate components and marginalizing them separately."
    569     ) from e
    571 output_rvs = [rv_to_marginalize, *dependent_rvs]
    572 rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)

NotImplementedError: The graph between the marginalized and dependent RVs cannot be marginalized efficiently. You can try splitting the marginalized RV into separate components and marginalizing them separately.
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 3, 2025

Yeah it will need a special rule in subgraph_batch_dim_connection to understand how batch dims propagate through minibatch RVs. It should be an identity progapation, the relevant logic was already done in the real RV input to the minibatch RV

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants