From 1701c05f3aacc828813087f9dd813eea04b87030 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 27 Dec 2021 17:26:08 +0000 Subject: [PATCH] [skip ci] WIP on index_fill batch rule --- functorch/csrc/BatchRulesScatterOps.cpp | 95 +++++++++++++++++++++++++ test/test_ops.py | 5 +- test/test_vmap.py | 1 - 3 files changed, 96 insertions(+), 5 deletions(-) diff --git a/functorch/csrc/BatchRulesScatterOps.cpp b/functorch/csrc/BatchRulesScatterOps.cpp index a3482a625..08ae4cde7 100644 --- a/functorch/csrc/BatchRulesScatterOps.cpp +++ b/functorch/csrc/BatchRulesScatterOps.cpp @@ -541,6 +541,99 @@ std::tuple> index_add_batch_rule( return std::make_tuple(at::stack(results), 0); } +std::tuple> index_fill_int_scalar_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Scalar& value) { + + if (!index_bdim) { + // Handle scalar tensors... self, other can be scalar tensors + const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); + if (self_logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + dim = maybe_wrap_dim(dim, self_logical_rank); + + optional out_bdim = nullopt; + if (self_bdim) { + const auto batch_size = self.size(*self_bdim); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + dim = dim + 1; + out_bdim = 0; + } + + auto result = self_.index_fill(dim, index, value); + if (self_logical_rank == 0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, out_bdim); + } + + // SAME AS FOR index_add + // Index is batched. For-loop and stack is the best thing I can come up with + // right now. We really want generalized index_fill kernel in PyTorch + auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim); + std::vector results; + results.reserve(batch_size); + for (const auto i : c10::irange(0, batch_size)) { + const auto& self_slice = self_bdim.has_value() ? + self.select(*self_bdim, i) : self; + const auto& index_slice = index_bdim.has_value() ? + index.select(*index_bdim, i) : index; + results.push_back(at::index_fill(self_slice, dim, index_slice, value)); + } + return std::make_tuple(at::stack(results), 0); +} + +std::tuple> index_fill_int_tensor_batch_rule( + const Tensor& self, optional self_bdim, + int64_t dim, + const Tensor& index, optional index_bdim, + const Tensor& value, optional value_bdim) { + + if (!index_bdim && !value_bdim) { + // Handle scalar tensors... self, other can be scalar tensors + const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim); + auto self_ = moveBatchDimToFront(self, self_bdim); + if (self_logical_rank == 0) { + self_ = self_.unsqueeze(-1); + } + dim = maybe_wrap_dim(dim, self_logical_rank); + + optional out_bdim = nullopt; + if (self_bdim) { + const auto batch_size = self.size(*self_bdim); + self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); + dim = dim + 1; + out_bdim = 0; + } + auto result = self_.index_fill(dim, index, value); + if (self_logical_rank == 0) { + result = result.squeeze(-1); + } + return std::make_tuple(result, out_bdim); + } + + // SAME AS FOR index_add + // Index is batched. For-loop and stack is the best thing I can come up with + // right now. We really want generalized index_fill kernel in PyTorch + auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim); + std::vector results; + results.reserve(batch_size); + for (const auto i : c10::irange(0, batch_size)) { + const auto& self_slice = self_bdim.has_value() ? + self.select(*self_bdim, i) : self; + const auto& index_slice = index_bdim.has_value() ? + index.select(*index_bdim, i) : index; + const auto& value_slice = value_bdim.has_value() ? + value.select(*value_bdim, i) : value; + results.push_back(at::index_fill(self_slice, dim, index_slice, value_slice)); + } + return std::make_tuple(at::stack(results), 0); +} + TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("index.Tensor", index_plumbing); m.impl("index_put_", index_put__plumbing); @@ -550,6 +643,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("index_copy", index_copy_decomp); m.impl("index_select", index_select_decomp); VMAP_SUPPORT("index_add", index_add_batch_rule); + VMAP_SUPPORT("index_fill.int_Scalar", index_fill_int_scalar_batch_rule); + VMAP_SUPPORT("index_fill.int_Tensor", index_fill_int_tensor_batch_rule); VMAP_SUPPORT("diagonal_scatter", diagonal_scatter_batch_rule); VMAP_SUPPORT("gather", gather_batch_rule); VMAP_SUPPORT("gather_backward", gather_backward_batch_rule); diff --git a/test/test_ops.py b/test/test_ops.py index dc0b44e6a..f87655d79 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -513,7 +513,6 @@ def vjp_of_vjp(*args_and_cotangents): xfail('fmax'), xfail('fmin'), xfail('index_copy'), - xfail('index_fill'), xfail('linalg.det', ''), xfail('linalg.eigh'), xfail('linalg.householder_product'), @@ -595,7 +594,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('block_diag'), # TODO: We expect this to fail in core, but it doesn't xfail('index_copy'), xfail('index_put'), - xfail('index_fill'), xfail('masked_fill'), xfail('masked_scatter'), @@ -701,7 +699,6 @@ def test_vmapjvp(self, device, dtype, op): xfail('max', 'binary'), xfail('nn.functional.gaussian_nll_loss'), xfail('min', 'binary'), - xfail('index_fill'), xfail('index_put'), xfail('std_mean'), xfail('double', 'channels_last'), @@ -760,7 +757,7 @@ def test_vmapjvpall(self, device, dtype, op): xfail('fmax'), xfail('fmin'), xfail('index_copy'), - xfail('index_fill'), + xfail('index_fill'), # RuntimeError: aten::_unique hit the vmap fallback which is currently disabled xfail('linalg.cholesky'), xfail('linalg.cholesky_ex'), xfail('linalg.det'), diff --git a/test/test_vmap.py b/test/test_vmap.py index 91cf226b9..0dfe69d5d 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3181,7 +3181,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('gradient'), xfail('histogram'), xfail('hsplit'), - xfail('index_fill'), xfail('index_put'), xfail('isin'), xfail('linalg.cholesky'),