From 33662d1dfd2d96514ff475dde7db7200028be30d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joshua=20N=C3=A4f?= Date: Tue, 23 Jan 2024 22:08:04 +0100 Subject: [PATCH 1/3] Pass linear operator to fantasy strategy constructor instead of tensor --- gpytorch/models/exact_prediction_strategies.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 2b716d73f..ed72ee0f9 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -211,8 +211,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ # now update the root and root inverse new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar) - new_root = new_lt.root_decomposition().root.to_dense() - new_covar_cache = new_lt.root_inv_decomposition().root.to_dense() + new_root = new_lt.root_decomposition().root + new_covar_cache = new_lt.root_inv_decomposition().root # Expand inputs accordingly if necessary (for fantasies at the same points) if full_inputs[0].dim() <= full_targets.dim(): @@ -238,7 +238,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ inv_root=new_covar_cache, ) add_to_cache(fant_strat, "mean_cache", fant_mean_cache) - add_to_cache(fant_strat, "covar_cache", new_covar_cache) + add_to_cache(fant_strat, "covar_cache", new_covar_cache.to_dense()) return fant_strat @property From e19a57ec6fb818e0a14f9054616ca8dcf1ebab0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joshua=20N=C3=A4f?= Date: Tue, 19 Mar 2024 11:52:51 +0100 Subject: [PATCH 2/3] Fix failing test DenseLinearOperator expects a torch.tensor so we convert the linear operator to dense. --- gpytorch/models/exact_prediction_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index ed72ee0f9..6ec872636 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -222,7 +222,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ full_inputs = [fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs] full_mean = full_mean.expand(fant_batch_shape + full_mean.shape) full_covar = BatchRepeatLinearOperator(full_covar, repeat_shape) - new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape) + new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root.to_dense()), repeat_shape) # no need to repeat the covar cache, broadcasting will do the right thing if isinstance(full_output, MultitaskMultivariateNormal): From 93d87cdae12e2cea1096a7040aa0298ec6011f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joshua=20N=C3=A4f?= Date: Tue, 19 Mar 2024 12:02:35 +0100 Subject: [PATCH 3/3] Remove conversion of new_root to dense operator in get_fantasy_strategy Directly construct a BatchRepeatLinearOperator from new_root instead of converting it to a dense operator first. --- gpytorch/models/exact_prediction_strategies.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 6ec872636..2e95e2162 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -10,7 +10,6 @@ AddedDiagLinearOperator, BatchRepeatLinearOperator, ConstantMulLinearOperator, - DenseLinearOperator, InterpolatedLinearOperator, LinearOperator, LowRankRootAddedDiagLinearOperator, @@ -222,7 +221,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ full_inputs = [fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs] full_mean = full_mean.expand(fant_batch_shape + full_mean.shape) full_covar = BatchRepeatLinearOperator(full_covar, repeat_shape) - new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root.to_dense()), repeat_shape) + new_root = BatchRepeatLinearOperator(new_root, repeat_shape) # no need to repeat the covar cache, broadcasting will do the right thing if isinstance(full_output, MultitaskMultivariateNormal):