Skip to content

Commit 1701c05

Browse files
committed
[skip ci] WIP on index_fill batch rule
1 parent 1af1ae2 commit 1701c05

File tree

3 files changed

+96
-5
lines changed

3 files changed

+96
-5
lines changed

functorch/csrc/BatchRulesScatterOps.cpp

+95
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,99 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
541541
return std::make_tuple(at::stack(results), 0);
542542
}
543543

544+
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
545+
const Tensor& self, optional<int64_t> self_bdim,
546+
int64_t dim,
547+
const Tensor& index, optional<int64_t> index_bdim,
548+
const Scalar& value) {
549+
550+
if (!index_bdim) {
551+
// Handle scalar tensors... self, other can be scalar tensors
552+
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
553+
auto self_ = moveBatchDimToFront(self, self_bdim);
554+
if (self_logical_rank == 0) {
555+
self_ = self_.unsqueeze(-1);
556+
}
557+
dim = maybe_wrap_dim(dim, self_logical_rank);
558+
559+
optional<int64_t> out_bdim = nullopt;
560+
if (self_bdim) {
561+
const auto batch_size = self.size(*self_bdim);
562+
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
563+
dim = dim + 1;
564+
out_bdim = 0;
565+
}
566+
567+
auto result = self_.index_fill(dim, index, value);
568+
if (self_logical_rank == 0) {
569+
result = result.squeeze(-1);
570+
}
571+
return std::make_tuple(result, out_bdim);
572+
}
573+
574+
// SAME AS FOR index_add
575+
// Index is batched. For-loop and stack is the best thing I can come up with
576+
// right now. We really want generalized index_fill kernel in PyTorch
577+
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
578+
std::vector<Tensor> results;
579+
results.reserve(batch_size);
580+
for (const auto i : c10::irange(0, batch_size)) {
581+
const auto& self_slice = self_bdim.has_value() ?
582+
self.select(*self_bdim, i) : self;
583+
const auto& index_slice = index_bdim.has_value() ?
584+
index.select(*index_bdim, i) : index;
585+
results.push_back(at::index_fill(self_slice, dim, index_slice, value));
586+
}
587+
return std::make_tuple(at::stack(results), 0);
588+
}
589+
590+
std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
591+
const Tensor& self, optional<int64_t> self_bdim,
592+
int64_t dim,
593+
const Tensor& index, optional<int64_t> index_bdim,
594+
const Tensor& value, optional<int64_t> value_bdim) {
595+
596+
if (!index_bdim && !value_bdim) {
597+
// Handle scalar tensors... self, other can be scalar tensors
598+
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
599+
auto self_ = moveBatchDimToFront(self, self_bdim);
600+
if (self_logical_rank == 0) {
601+
self_ = self_.unsqueeze(-1);
602+
}
603+
dim = maybe_wrap_dim(dim, self_logical_rank);
604+
605+
optional<int64_t> out_bdim = nullopt;
606+
if (self_bdim) {
607+
const auto batch_size = self.size(*self_bdim);
608+
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
609+
dim = dim + 1;
610+
out_bdim = 0;
611+
}
612+
auto result = self_.index_fill(dim, index, value);
613+
if (self_logical_rank == 0) {
614+
result = result.squeeze(-1);
615+
}
616+
return std::make_tuple(result, out_bdim);
617+
}
618+
619+
// SAME AS FOR index_add
620+
// Index is batched. For-loop and stack is the best thing I can come up with
621+
// right now. We really want generalized index_fill kernel in PyTorch
622+
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
623+
std::vector<Tensor> results;
624+
results.reserve(batch_size);
625+
for (const auto i : c10::irange(0, batch_size)) {
626+
const auto& self_slice = self_bdim.has_value() ?
627+
self.select(*self_bdim, i) : self;
628+
const auto& index_slice = index_bdim.has_value() ?
629+
index.select(*index_bdim, i) : index;
630+
const auto& value_slice = value_bdim.has_value() ?
631+
value.select(*value_bdim, i) : value;
632+
results.push_back(at::index_fill(self_slice, dim, index_slice, value_slice));
633+
}
634+
return std::make_tuple(at::stack(results), 0);
635+
}
636+
544637
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
545638
m.impl("index.Tensor", index_plumbing);
546639
m.impl("index_put_", index_put__plumbing);
@@ -550,6 +643,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
550643
m.impl("index_copy", index_copy_decomp);
551644
m.impl("index_select", index_select_decomp);
552645
VMAP_SUPPORT("index_add", index_add_batch_rule);
646+
VMAP_SUPPORT("index_fill.int_Scalar", index_fill_int_scalar_batch_rule);
647+
VMAP_SUPPORT("index_fill.int_Tensor", index_fill_int_tensor_batch_rule);
553648
VMAP_SUPPORT("diagonal_scatter", diagonal_scatter_batch_rule);
554649
VMAP_SUPPORT("gather", gather_batch_rule);
555650
VMAP_SUPPORT("gather_backward", gather_backward_batch_rule);

test/test_ops.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,6 @@ def vjp_of_vjp(*args_and_cotangents):
513513
xfail('fmax'),
514514
xfail('fmin'),
515515
xfail('index_copy'),
516-
xfail('index_fill'),
517516
xfail('linalg.det', ''),
518517
xfail('linalg.eigh'),
519518
xfail('linalg.householder_product'),
@@ -595,7 +594,6 @@ def test_vmapvjp(self, device, dtype, op):
595594
xfail('block_diag'), # TODO: We expect this to fail in core, but it doesn't
596595
xfail('index_copy'),
597596
xfail('index_put'),
598-
xfail('index_fill'),
599597
xfail('masked_fill'),
600598
xfail('masked_scatter'),
601599
@@ -701,7 +699,6 @@ def test_vmapjvp(self, device, dtype, op):
701699
xfail('max', 'binary'),
702700
xfail('nn.functional.gaussian_nll_loss'),
703701
xfail('min', 'binary'),
704-
xfail('index_fill'),
705702
xfail('index_put'),
706703
xfail('std_mean'),
707704
xfail('double', 'channels_last'),
@@ -760,7 +757,7 @@ def test_vmapjvpall(self, device, dtype, op):
760757
xfail('fmax'),
761758
xfail('fmin'),
762759
xfail('index_copy'),
763-
xfail('index_fill'),
760+
xfail('index_fill'), # RuntimeError: aten::_unique hit the vmap fallback which is currently disabled
764761
xfail('linalg.cholesky'),
765762
xfail('linalg.cholesky_ex'),
766763
xfail('linalg.det'),

test/test_vmap.py

-1
Original file line numberDiff line numberDiff line change
@@ -3181,7 +3181,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31813181
xfail('gradient'),
31823182
xfail('histogram'),
31833183
xfail('hsplit'),
3184-
xfail('index_fill'),
31853184
xfail('index_put'),
31863185
xfail('isin'),
31873186
xfail('linalg.cholesky'),

0 commit comments

Comments
 (0)