Skip to content

Commit 489b751

Browse files
lukmazThe Meridian Authors
authored and
The Meridian Authors
committed
test_rhat
PiperOrigin-RevId: 738878610
1 parent 82c00c6 commit 489b751

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

meridian/model/posterior_sampler.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import arviz as az
2121
from meridian import constants
2222
import numpy as np
23+
import pandas as pd
2324
import tensorflow as tf
2425
import tensorflow_probability as tfp
2526

@@ -463,7 +464,8 @@ def __call__(
463464

464465
states = []
465466
traces = []
466-
for n_chains_batch in n_chains_list:
467+
print("\n\n. NEW SEED VERSION\n\n")
468+
for num, n_chains_batch in enumerate(n_chains_list):
467469
try:
468470
mcmc = _xla_windowed_adaptive_nuts(
469471
n_draws=n_burnin + n_keep,
@@ -480,6 +482,7 @@ def __call__(
480482
seed=seed,
481483
**pins,
482484
)
485+
seed = [x + 1 for x in (seed or [0, 0])]
483486
except tf.errors.ResourceExhaustedError as error:
484487
raise MCMCOOMError(
485488
"ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
@@ -488,6 +491,37 @@ def __call__(
488491
) from error
489492
states.append(mcmc.all_states._asdict())
490493
traces.append(mcmc.trace)
494+
# TODO
495+
print(f"\n\n\n BATCH {num}")
496+
tmp_state = mcmc.all_states._asdict()
497+
tmp_mcmc_states = {
498+
k: tf.einsum(
499+
"ij...->ji...",
500+
tmp_state[k][n_burnin:, ...],
501+
)
502+
for k in tmp_state.keys()
503+
if k not in constants.UNSAVED_PARAMETERS
504+
}
505+
rhat = pd.DataFrame()
506+
vis_mcmc_states = {k: v.values for k, v in tmp_mcmc_states.items()}
507+
for k, v in tfp.mcmc.potential_scale_reduction(
508+
{k: tf.einsum("ij...->ji...", v) for k, v in vis_mcmc_states.items()}
509+
).items():
510+
rhat_temp = v.numpy().flatten()
511+
rhat = pd.concat([
512+
rhat,
513+
pd.DataFrame({
514+
constants.PARAMETER: np.repeat(k, len(rhat_temp)),
515+
constants.RHAT: rhat_temp,
516+
}),
517+
])
518+
519+
# If the MCMC sampling fails, the r-hat value calculated will be very large.
520+
if (rhat[constants.RHAT] > 1e10).any():
521+
max_rhat = max(rhat[constants.RHAT])
522+
print(f" MAX RHAT: {max_rhat}")
523+
print(f" ALL RHATS: {rhat[constants.RHAT]}")
524+
# TODO
491525

492526
mcmc_states = {
493527
k: tf.einsum(

0 commit comments

Comments
 (0)