Skip to content

No Batching rules for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet causes significant slowdown for per-sample gradients with torch.linalg.slogdet #984

Open
@AlphaBetaGamma96

Description

@AlphaBetaGamma96

TL;DR - torch.linalg.slogdet is over one order of magnitude slower in computing per-sample gradients in the latest nightly version of PyTorch/FuncTorch (1.13.0.dev20220721 / 0.3.0a0+e8a68f4) than a previous version of PyTorch/FuncTorch ( 1.12.0a0+git7c2103a / 0.2.0a0+9d6ee76) compiled from source. This seems to be due to the lack of batching rules for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet.

Thanks! :)

Hi All,

I've recently noticed that my code significantly slowed down (by around an order of magnitude) when moving from PyTorch 1.12 to 1.13. I've made a minimal reproducible example to highlight this issue. For reference, this issue was starting from #979 with some more info there, although the issue has been solved and a new issue was open as per @vfdev-5 suggestion.

The MRE below computes per-sample gradients with respect to the parameters for the laplacian of a model w.r.t its inputs. The script will compute the per-sample gradients for N inputs from 1 to 6 and show the walltime, then I decide to use torch.profile.profiler to give a more clear benchmark for N=4.

I've benchmarked two versions of PyTorch/FuncTorch. The first version was made from source (and can be found here). The only thing that is changed is the slogdet_backward formula which you can find here. The full version for this "old-source" version is,

PyTorch version:    1.12.0a0+git7c2103a
CUDA version:       11.6
FuncTorch version:  0.2.0a0+9d6ee76

The other version is the latest nightly (hereafter referred to as "nightly"). The full version of this "nightly" version is,

PyTorch version:    1.13.0.dev20220721
CUDA version:       11.6
FuncTorch version:  0.3.0a0+e8a68f4

A comparison in walltime (measured in seconds) as N increases from 1 to 6 is as follows

N | [old-source] | [nightly]
1 |    0.5719    |   2.4907     #first call is slow because ?
1 |    0.0133    |   2.0593
2 |    0.0870    |   2.4496 
3 |    0.1153    |   2.9293 
4 |    0.1129    |   3.3715
5 |    0.1576    |   3.8302
6 |    0.2059    |   4.2622

The torch.profile.profiler case of N = 4 for the "old-source" version is shown below and is stored by cuda_time_total

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.38%     417.000us         4.25%       4.694ms      82.351us       0.000us         0.00%     127.220ms       2.232ms            57  
                                               aten::mm         0.69%     765.000us         3.77%       4.169ms      64.138us      25.971ms        24.41%     117.336ms       1.805ms            65  
                                              aten::bmm         0.45%     493.000us         1.06%       1.168ms      43.259us      64.058ms        60.20%      87.447ms       3.239ms            27  
       autograd::engine::evaluate_function: MmBackward0         0.12%     131.000us         1.37%       1.513ms     168.111us       0.000us         0.00%      42.798ms       4.755ms             9  
                                            MmBackward0         0.04%      40.000us         1.22%       1.347ms     149.667us       0.000us         0.00%      42.630ms       4.737ms             9  
                                   volta_dgemm_64x64_nt         0.00%       0.000us         0.00%       0.000us       0.000us      41.116ms        38.64%      41.116ms       4.112ms            10  
    autograd::engine::evaluate_function: AddmmBackward0         0.09%     103.000us         1.71%       1.890ms     189.000us       0.000us         0.00%      21.334ms       2.133ms            10  
                                  volta_dgemm_128x64_nt         0.00%       0.000us         0.00%       0.000us       0.000us      20.883ms        19.62%      20.883ms       3.481ms             6  
                                         AddmmBackward0         0.04%      39.000us         1.13%       1.254ms     125.400us       0.000us         0.00%      19.551ms       1.955ms            10  
                                   volta_dgemm_64x64_tn         0.00%       0.000us         0.00%       0.000us       0.000us      14.590ms        13.71%      14.590ms       2.432ms             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 110.510ms
Self CUDA time total: 106.412ms

However, in the case of using the latest "nightly" version. The MRE significantly slows down and the torch.profile.profiler is dominated by the following commands aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 aten::_linalg_solve_ex        14.79%        1.125s       141.49%       10.763s     212.108us       0.000us         0.00%        2.966s      58.452us         50741  
                                     aten::linalg_solve         0.09%       7.174ms       117.89%        8.967s     560.456ms       0.000us         0.00%        2.284s     142.775ms            16  
                                  aten::linalg_solve_ex         0.00%      37.000us        75.04%        5.708s     475.648ms       0.000us         0.00%        1.513s     126.102ms            12  
autograd::engine::evaluate_function: LinalgSolveExBa...         0.00%     122.000us        62.86%        4.781s     683.040ms       0.000us         0.00%        1.275s     182.126ms             7  
                                 LinalgSolveExBackward0         0.00%      76.000us        62.85%        4.781s     683.008ms       0.000us         0.00%        1.275s     182.124ms             7  
                                  aten::linalg_lu_solve         9.34%     710.739ms        35.82%        2.724s      55.427us     643.114ms        33.69%     831.927ms      16.926us         49152  
                              aten::linalg_lu_factor_ex         7.28%     553.883ms        20.95%        1.593s      27.784us     661.250ms        34.64%     721.665ms      12.585us         57344  
                                  aten::_linalg_slogdet         4.59%     349.122ms        72.77%        5.535s     658.539us       0.000us         0.00%     677.273ms      80.580us          8405  
void getf2_cta_32x32<double, double>(int, int, int, ...         0.00%       0.000us         0.00%       0.000us       0.000us     540.579ms        28.32%     540.579ms       9.427us         57344  
void trsm_batch_left_lower_kernel<double>(cublasTrsm...         0.00%       0.000us         0.00%       0.000us       0.000us     277.594ms        14.54%     277.594ms       5.648us         49152  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.607s
Self CUDA time total: 1.909s

functorch also prompts me with a UserWarning that batching rules do not exists for aten::_linalg_solve_ex, aten::linalg_solve, aten::linalg_solve_ex, and aten::_linalg_slogdet and it defaults to a for-loop which will affect performance.

~/pytorch_nightly/debug/per-sample-elocal.py:49: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_slogdet. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
  sgn, logabs = torch.linalg.slogdet(mat)
  
~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  
~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_solve_ex. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)

The full script to reproduce this error can be found below.

import torch
import torch.nn as nn
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity

import functorch
from functorch import jacrev, jacfwd, hessian, make_functional, vmap, grad

import time 

_ = torch.manual_seed(0)
torch.set_default_dtype(torch.float64)

#version info
print("PyTorch version:   ", torch.__version__)
print("CUDA version:      ", torch.version.cuda)
print("FuncTorch version: ", functorch.__version__)

#time with torch synchronization
def sync_time() -> float:
  torch.cuda.synchronize()
  return time.perf_counter()

class model(nn.Module):

  def __init__(self, num_inputs, num_hidden):
    super(model, self).__init__()
    
    self.num_inputs=num_inputs
    self.func = nn.Tanh()
    
    self.fc1 = nn.Linear(2, num_hidden)
    self.fc2 = nn.Linear(num_hidden, num_inputs)
  
  def forward(self, x):
    """
    Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
    """
    
    idx=len(x.shape)             #creates args for repeat if vmap is used or not
    rep=[1 for _ in range(idx)]
    rep[-2] = self.num_inputs
    g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
    f = torch.cat((x,g), dim=-1)

    h = self.func(self.fc1(f))
    
    mat = self.fc2(h)
    sgn, logabs = torch.linalg.slogdet(mat)
    return sgn, logabs

#=================================================================================================#
#Profile code for N=1 to 6
#=================================================================================================#


B=4096 #batch
N=2    #input nodes
H=128  #number of hidden nodes
device=torch.device("cuda")

for N in [1,1,2,3,4,5,6]:

  net = model(N, H)
  net = net.to(device)

  x = torch.randn(B,N,1,device=device) #input data
  fnet, params = make_functional(net)

  def logabs(params, x):
    _, logabs = fnet(params, x)
    return logabs

  def kinetic_functorch(params, X):
    #do once, and re-use via has_aux?
    calc_jacobian = jacrev(logabs, argnums=1) 
    #can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
    calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1) 

    return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)


  #per-sample gradients for local energy w.r.t params via FuncTorch
  t1=sync_time()
  elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
  t2=sync_time()

  print("N: %2i | Walltime: %6.4f (s)" % (N, t2-t1))

#=================================================================================================#
#Profile code for N=4
#=================================================================================================#

N=4
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
  net = model(N, H)
  net = net.to(device)

  x = torch.randn(B,N,1,device=device) #input data
  fnet, params = make_functional(net)

  def logabs(params, x):
    _, logabs = fnet(params, x)
    return logabs

  def kinetic_functorch(params, X):
    #do once, and re-use via has_aux?
    calc_jacobian = jacrev(logabs, argnums=1) 
    #can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
    calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1) 

    return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)


  #per-sample gradients for local energy w.r.t params via FuncTorch
  t1=sync_time()
  elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
  t2=sync_time()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Thanks in advance! :)

Metadata

Metadata

Assignees

Labels

actionableIt is clear what should be done for this issue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions