20
20
import arviz as az
21
21
from meridian import constants
22
22
import numpy as np
23
+ import pandas as pd
23
24
import tensorflow as tf
24
25
import tensorflow_probability as tfp
25
26
@@ -463,7 +464,8 @@ def __call__(
463
464
464
465
states = []
465
466
traces = []
466
- for n_chains_batch in n_chains_list :
467
+ print ("\n \n . NEW SEED VERSION\n \n " )
468
+ for num , n_chains_batch in enumerate (n_chains_list ):
467
469
try :
468
470
mcmc = _xla_windowed_adaptive_nuts (
469
471
n_draws = n_burnin + n_keep ,
@@ -480,6 +482,7 @@ def __call__(
480
482
seed = seed ,
481
483
** pins ,
482
484
)
485
+ seed = [x + 1 for x in (seed or [0 , 0 ])]
483
486
except tf .errors .ResourceExhaustedError as error :
484
487
raise MCMCOOMError (
485
488
"ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
@@ -488,6 +491,37 @@ def __call__(
488
491
) from error
489
492
states .append (mcmc .all_states ._asdict ())
490
493
traces .append (mcmc .trace )
494
+ # TODO
495
+ print (f"\n \n \n BATCH { num } " )
496
+ tmp_state = mcmc .all_states ._asdict ()
497
+ tmp_mcmc_states = {
498
+ k : tf .einsum (
499
+ "ij...->ji..." ,
500
+ tmp_state [k ][n_burnin :, ...],
501
+ )
502
+ for k in tmp_state .keys ()
503
+ if k not in constants .UNSAVED_PARAMETERS
504
+ }
505
+ rhat = pd .DataFrame ()
506
+ vis_mcmc_states = {k : v .values for k , v in tmp_mcmc_states .items ()}
507
+ for k , v in tfp .mcmc .potential_scale_reduction (
508
+ {k : tf .einsum ("ij...->ji..." , v ) for k , v in vis_mcmc_states .items ()}
509
+ ).items ():
510
+ rhat_temp = v .numpy ().flatten ()
511
+ rhat = pd .concat ([
512
+ rhat ,
513
+ pd .DataFrame ({
514
+ constants .PARAMETER : np .repeat (k , len (rhat_temp )),
515
+ constants .RHAT : rhat_temp ,
516
+ }),
517
+ ])
518
+
519
+ # If the MCMC sampling fails, the r-hat value calculated will be very large.
520
+ if (rhat [constants .RHAT ] > 1e10 ).any ():
521
+ max_rhat = max (rhat [constants .RHAT ])
522
+ print (f" MAX RHAT: { max_rhat } " )
523
+ print (f" ALL RHATS: { rhat [constants .RHAT ]} " )
524
+ # TODO
491
525
492
526
mcmc_states = {
493
527
k : tf .einsum (
0 commit comments