Skip to content

Commit 6575208

Browse files
committed
Update warning for transformed variables that depend on marginalized variable
Warning is now only issued for IntervalTransforms
1 parent 61a63b2 commit 6575208

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

pymc_experimental/marginal_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import pytensor.tensor as pt
66
from pymc import SymbolicRandomVariable
77
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
8+
from pymc.distributions.transforms import Chain
89
from pymc.logprob.abstract import _get_measurable_outputs, _logprob
910
from pymc.logprob.joint_logprob import factorized_joint_logprob
11+
from pymc.logprob.transforms import IntervalTransform
1012
from pymc.model import Model
1113
from pymc.pytensorf import constant_fold, inputvars
1214
from pytensor import Mode
@@ -171,11 +173,19 @@ def _marginalize(self, user_warnings=False):
171173
if old_rv in self.basic_RVs:
172174
self._transfer_rv_mappings(old_rv, new_rv)
173175
if user_warnings:
176+
# Interval transforms for dependent variable won't work for non-constant bounds because
177+
# the RV inputs are now different and may depend on another RV that also depends on the
178+
# same marginalized RV
174179
transform = self.rvs_to_transforms[new_rv]
175-
if transform is not None:
180+
if isinstance(transform, IntervalTransform) or (
181+
isinstance(transform, Chain)
182+
and any(
183+
isinstance(tr, IntervalTransform) for tr in transform.transform_list
184+
)
185+
):
176186
warnings.warn(
177-
"Transforms for variables that depend on marginalized RVs are currently not working, "
178-
f"rv={new_rv}, transform={transform}",
187+
f"The transform {transform} for the variable {old_rv}, which depends on the "
188+
f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.",
179189
UserWarning,
180190
)
181191
return self

pymc_experimental/tests/test_marginal_model.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,19 @@ def test_not_supported_marginalized_deterministic_and_potential():
346346
"transform, expected_warning",
347347
(
348348
(None, does_not_warn()),
349-
pytest.param(
350-
UNSET,
349+
(UNSET, does_not_warn()),
350+
(transforms.log, does_not_warn()),
351+
(transforms.Chain([transforms.log, transforms.logodds]), does_not_warn()),
352+
(
353+
transforms.Interval(0, 1),
351354
pytest.warns(
352-
UserWarning, match="Transforms for variables that depend on marginalized RVs"
355+
UserWarning, match="which depends on the marginalized idx may no longer work"
353356
),
354357
),
355-
pytest.param(
356-
transforms.log,
358+
(
359+
transforms.Chain([transforms.log, transforms.Interval(0, 1)]),
357360
pytest.warns(
358-
UserWarning, match="Transforms for variables that depend on marginalized RVs"
361+
UserWarning, match="which depends on the marginalized idx may no longer work"
359362
),
360363
),
361364
),
@@ -398,5 +401,9 @@ def test_marginalized_transforms(transform, expected_warning):
398401

399402
ip = m.initial_point()
400403
if transform is not None:
401-
assert "sigma_log__" in ip
404+
if transform is UNSET:
405+
transform_name = "log"
406+
else:
407+
transform_name = transform.name
408+
assert f"sigma_{transform_name}__" in ip
402409
np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))

0 commit comments

Comments
 (0)