@@ -541,6 +541,99 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
541
541
return std::make_tuple (at::stack (results), 0 );
542
542
}
543
543
544
+ std::tuple<Tensor,optional<int64_t >> index_fill_int_scalar_batch_rule (
545
+ const Tensor& self, optional<int64_t > self_bdim,
546
+ int64_t dim,
547
+ const Tensor& index, optional<int64_t > index_bdim,
548
+ const Scalar& value) {
549
+
550
+ if (!index_bdim) {
551
+ // Handle scalar tensors... self, other can be scalar tensors
552
+ const auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
553
+ auto self_ = moveBatchDimToFront (self, self_bdim);
554
+ if (self_logical_rank == 0 ) {
555
+ self_ = self_.unsqueeze (-1 );
556
+ }
557
+ dim = maybe_wrap_dim (dim, self_logical_rank);
558
+
559
+ optional<int64_t > out_bdim = nullopt;
560
+ if (self_bdim) {
561
+ const auto batch_size = self.size (*self_bdim);
562
+ self_ = ensure_has_bdim (self_, self_bdim.has_value (), batch_size);
563
+ dim = dim + 1 ;
564
+ out_bdim = 0 ;
565
+ }
566
+
567
+ auto result = self_.index_fill (dim, index , value);
568
+ if (self_logical_rank == 0 ) {
569
+ result = result.squeeze (-1 );
570
+ }
571
+ return std::make_tuple (result, out_bdim);
572
+ }
573
+
574
+ // SAME AS FOR index_add
575
+ // Index is batched. For-loop and stack is the best thing I can come up with
576
+ // right now. We really want generalized index_fill kernel in PyTorch
577
+ auto batch_size = get_bdim_size2 (self, self_bdim, index , index_bdim);
578
+ std::vector<Tensor> results;
579
+ results.reserve (batch_size);
580
+ for (const auto i : c10::irange (0 , batch_size)) {
581
+ const auto & self_slice = self_bdim.has_value () ?
582
+ self.select (*self_bdim, i) : self;
583
+ const auto & index_slice = index_bdim.has_value () ?
584
+ index .select (*index_bdim, i) : index ;
585
+ results.push_back (at::index_fill (self_slice, dim, index_slice, value));
586
+ }
587
+ return std::make_tuple (at::stack (results), 0 );
588
+ }
589
+
590
+ std::tuple<Tensor,optional<int64_t >> index_fill_int_tensor_batch_rule (
591
+ const Tensor& self, optional<int64_t > self_bdim,
592
+ int64_t dim,
593
+ const Tensor& index, optional<int64_t > index_bdim,
594
+ const Tensor& value, optional<int64_t > value_bdim) {
595
+
596
+ if (!index_bdim && !value_bdim) {
597
+ // Handle scalar tensors... self, other can be scalar tensors
598
+ const auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
599
+ auto self_ = moveBatchDimToFront (self, self_bdim);
600
+ if (self_logical_rank == 0 ) {
601
+ self_ = self_.unsqueeze (-1 );
602
+ }
603
+ dim = maybe_wrap_dim (dim, self_logical_rank);
604
+
605
+ optional<int64_t > out_bdim = nullopt;
606
+ if (self_bdim) {
607
+ const auto batch_size = self.size (*self_bdim);
608
+ self_ = ensure_has_bdim (self_, self_bdim.has_value (), batch_size);
609
+ dim = dim + 1 ;
610
+ out_bdim = 0 ;
611
+ }
612
+ auto result = self_.index_fill (dim, index , value);
613
+ if (self_logical_rank == 0 ) {
614
+ result = result.squeeze (-1 );
615
+ }
616
+ return std::make_tuple (result, out_bdim);
617
+ }
618
+
619
+ // SAME AS FOR index_add
620
+ // Index is batched. For-loop and stack is the best thing I can come up with
621
+ // right now. We really want generalized index_fill kernel in PyTorch
622
+ auto batch_size = get_bdim_size3 (self, self_bdim, index , index_bdim, value, value_bdim);
623
+ std::vector<Tensor> results;
624
+ results.reserve (batch_size);
625
+ for (const auto i : c10::irange (0 , batch_size)) {
626
+ const auto & self_slice = self_bdim.has_value () ?
627
+ self.select (*self_bdim, i) : self;
628
+ const auto & index_slice = index_bdim.has_value () ?
629
+ index .select (*index_bdim, i) : index ;
630
+ const auto & value_slice = value_bdim.has_value () ?
631
+ value.select (*value_bdim, i) : value;
632
+ results.push_back (at::index_fill (self_slice, dim, index_slice, value_slice));
633
+ }
634
+ return std::make_tuple (at::stack (results), 0 );
635
+ }
636
+
544
637
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
545
638
m.impl (" index.Tensor" , index_plumbing);
546
639
m.impl (" index_put_" , index_put__plumbing);
@@ -550,6 +643,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
550
643
m.impl (" index_copy" , index_copy_decomp);
551
644
m.impl (" index_select" , index_select_decomp);
552
645
VMAP_SUPPORT (" index_add" , index_add_batch_rule);
646
+ VMAP_SUPPORT (" index_fill.int_Scalar" , index_fill_int_scalar_batch_rule);
647
+ VMAP_SUPPORT (" index_fill.int_Tensor" , index_fill_int_tensor_batch_rule);
553
648
VMAP_SUPPORT (" diagonal_scatter" , diagonal_scatter_batch_rule);
554
649
VMAP_SUPPORT (" gather" , gather_batch_rule);
555
650
VMAP_SUPPORT (" gather_backward" , gather_backward_batch_rule);
0 commit comments