|
5 | 5 | import pytensor.tensor as pt
|
6 | 6 | from pymc import SymbolicRandomVariable
|
7 | 7 | from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
|
| 8 | +from pymc.distributions.transforms import Chain |
8 | 9 | from pymc.logprob.abstract import _get_measurable_outputs, _logprob
|
9 | 10 | from pymc.logprob.joint_logprob import factorized_joint_logprob
|
| 11 | +from pymc.logprob.transforms import IntervalTransform |
10 | 12 | from pymc.model import Model
|
11 | 13 | from pymc.pytensorf import constant_fold, inputvars
|
12 | 14 | from pytensor import Mode
|
@@ -171,11 +173,19 @@ def _marginalize(self, user_warnings=False):
|
171 | 173 | if old_rv in self.basic_RVs:
|
172 | 174 | self._transfer_rv_mappings(old_rv, new_rv)
|
173 | 175 | if user_warnings:
|
| 176 | + # Interval transforms for dependent variable won't work for non-constant bounds because |
| 177 | + # the RV inputs are now different and may depend on another RV that also depends on the |
| 178 | + # same marginalized RV |
174 | 179 | transform = self.rvs_to_transforms[new_rv]
|
175 |
| - if transform is not None: |
| 180 | + if isinstance(transform, IntervalTransform) or ( |
| 181 | + isinstance(transform, Chain) |
| 182 | + and any( |
| 183 | + isinstance(tr, IntervalTransform) for tr in transform.transform_list |
| 184 | + ) |
| 185 | + ): |
176 | 186 | warnings.warn(
|
177 |
| - "Transforms for variables that depend on marginalized RVs are currently not working, " |
178 |
| - f"rv={new_rv}, transform={transform}", |
| 187 | + f"The transform {transform} for the variable {old_rv}, which depends on the " |
| 188 | + f"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.", |
179 | 189 | UserWarning,
|
180 | 190 | )
|
181 | 191 | return self
|
|
0 commit comments