Skip to content

Commit

Permalink
Merge pull request #2494 from naefjo/feature/online-learning-improvem…
Browse files Browse the repository at this point in the history
…ents

Bug: Exploit Structure in get_fantasy_strategy
  • Loading branch information
jacobrgardner authored Jun 20, 2024
2 parents 9551eba + e09674d commit 2e7959d
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions gpytorch/models/exact_prediction_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
AddedDiagLinearOperator,
BatchRepeatLinearOperator,
ConstantMulLinearOperator,
DenseLinearOperator,
InterpolatedLinearOperator,
LinearOperator,
LowRankRootAddedDiagLinearOperator,
Expand Down Expand Up @@ -211,8 +210,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_

# now update the root and root inverse
new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar)
new_root = new_lt.root_decomposition().root.to_dense()
new_covar_cache = new_lt.root_inv_decomposition().root.to_dense()
new_root = new_lt.root_decomposition().root
new_covar_cache = new_lt.root_inv_decomposition().root

# Expand inputs accordingly if necessary (for fantasies at the same points)
if full_inputs[0].dim() <= full_targets.dim():
Expand All @@ -222,7 +221,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
full_inputs = [fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs]
full_mean = full_mean.expand(fant_batch_shape + full_mean.shape)
full_covar = BatchRepeatLinearOperator(full_covar, repeat_shape)
new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape)
new_root = BatchRepeatLinearOperator(new_root, repeat_shape)
# no need to repeat the covar cache, broadcasting will do the right thing

if isinstance(full_output, MultitaskMultivariateNormal):
Expand All @@ -238,7 +237,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
inv_root=new_covar_cache,
)
add_to_cache(fant_strat, "mean_cache", fant_mean_cache)
add_to_cache(fant_strat, "covar_cache", new_covar_cache)
add_to_cache(fant_strat, "covar_cache", new_covar_cache.to_dense())
return fant_strat

@property
Expand Down

0 comments on commit 2e7959d

Please sign in to comment.