|
1 |
| -from typing import Any, Dict, List, Optional, Sequence, Union |
| 1 | +from typing import Any, List, Mapping, Optional, Sequence, Union |
2 | 2 |
|
3 | 3 | from pymc import Model
|
4 | 4 | from pymc.logprob.transforms import RVTransform
|
|
26 | 26 | from pymc_experimental.utils.pytensorf import rvs_in_graph
|
27 | 27 |
|
28 | 28 |
|
29 |
| -def observe(model: Model, vars_to_observations: Dict[Union["str", TensorVariable], Any]) -> Model: |
| 29 | +def observe( |
| 30 | + model: Model, vars_to_observations: Mapping[Union["str", TensorVariable], Any] |
| 31 | +) -> Model: |
30 | 32 | """Convert free RVs or Deterministics to observed RVs.
|
31 | 33 |
|
32 | 34 | Parameters
|
@@ -122,7 +124,9 @@ def replacement_fn(var, inner_replacements):
|
122 | 124 |
|
123 | 125 |
|
124 | 126 | def do(
|
125 |
| - model: Model, vars_to_interventions: Dict[Union["str", TensorVariable], Any], prune_vars=False |
| 127 | + model: Model, |
| 128 | + vars_to_interventions: Mapping[Union["str", TensorVariable], Any], |
| 129 | + prune_vars=False, |
126 | 130 | ) -> Model:
|
127 | 131 | """Replace model variables by intervention variables.
|
128 | 132 |
|
@@ -217,7 +221,7 @@ def do(
|
217 | 221 |
|
218 | 222 | def change_value_transforms(
|
219 | 223 | model: Model,
|
220 |
| - vars_to_transforms: Dict[ModelVariable, Union[RVTransform, None]], |
| 224 | + vars_to_transforms: Mapping[ModelVariable, Union[RVTransform, None]], |
221 | 225 | ) -> Model:
|
222 | 226 | """Change the value variables transforms in the model
|
223 | 227 |
|
|
0 commit comments