Description
TL;DR - torch.linalg.slogdet
is over one order of magnitude slower in computing per-sample gradients in the latest nightly version of PyTorch/FuncTorch (1.13.0.dev20220721
/ 0.3.0a0+e8a68f4
) than a previous version of PyTorch/FuncTorch ( 1.12.0a0+git7c2103a
/ 0.2.0a0+9d6ee76
) compiled from source. This seems to be due to the lack of batching rules for aten::_linalg_solve_ex
, aten::linalg_solve
, aten::linalg_solve_ex
, and aten::_linalg_slogdet
.
Thanks! :)
Hi All,
I've recently noticed that my code significantly slowed down (by around an order of magnitude) when moving from PyTorch 1.12 to 1.13. I've made a minimal reproducible example to highlight this issue. For reference, this issue was starting from #979 with some more info there, although the issue has been solved and a new issue was open as per @vfdev-5 suggestion.
The MRE below computes per-sample gradients with respect to the parameters for the laplacian of a model w.r.t its inputs. The script will compute the per-sample gradients for N
inputs from 1 to 6 and show the walltime, then I decide to use torch.profile.profiler
to give a more clear benchmark for N=4
.
I've benchmarked two versions of PyTorch/FuncTorch. The first version was made from source (and can be found here). The only thing that is changed is the slogdet_backward
formula which you can find here. The full version for this "old-source" version is,
PyTorch version: 1.12.0a0+git7c2103a
CUDA version: 11.6
FuncTorch version: 0.2.0a0+9d6ee76
The other version is the latest nightly (hereafter referred to as "nightly"). The full version of this "nightly" version is,
PyTorch version: 1.13.0.dev20220721
CUDA version: 11.6
FuncTorch version: 0.3.0a0+e8a68f4
A comparison in walltime (measured in seconds) as N
increases from 1 to 6 is as follows
N | [old-source] | [nightly]
1 | 0.5719 | 2.4907 #first call is slow because ?
1 | 0.0133 | 2.0593
2 | 0.0870 | 2.4496
3 | 0.1153 | 2.9293
4 | 0.1129 | 3.3715
5 | 0.1576 | 3.8302
6 | 0.2059 | 4.2622
The torch.profile.profiler
case of N
= 4 for the "old-source" version is shown below and is stored by cuda_time_total
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::matmul 0.38% 417.000us 4.25% 4.694ms 82.351us 0.000us 0.00% 127.220ms 2.232ms 57
aten::mm 0.69% 765.000us 3.77% 4.169ms 64.138us 25.971ms 24.41% 117.336ms 1.805ms 65
aten::bmm 0.45% 493.000us 1.06% 1.168ms 43.259us 64.058ms 60.20% 87.447ms 3.239ms 27
autograd::engine::evaluate_function: MmBackward0 0.12% 131.000us 1.37% 1.513ms 168.111us 0.000us 0.00% 42.798ms 4.755ms 9
MmBackward0 0.04% 40.000us 1.22% 1.347ms 149.667us 0.000us 0.00% 42.630ms 4.737ms 9
volta_dgemm_64x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 41.116ms 38.64% 41.116ms 4.112ms 10
autograd::engine::evaluate_function: AddmmBackward0 0.09% 103.000us 1.71% 1.890ms 189.000us 0.000us 0.00% 21.334ms 2.133ms 10
volta_dgemm_128x64_nt 0.00% 0.000us 0.00% 0.000us 0.000us 20.883ms 19.62% 20.883ms 3.481ms 6
AddmmBackward0 0.04% 39.000us 1.13% 1.254ms 125.400us 0.000us 0.00% 19.551ms 1.955ms 10
volta_dgemm_64x64_tn 0.00% 0.000us 0.00% 0.000us 0.000us 14.590ms 13.71% 14.590ms 2.432ms 6
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 110.510ms
Self CUDA time total: 106.412ms
However, in the case of using the latest "nightly" version. The MRE significantly slows down and the torch.profile.profiler
is dominated by the following commands aten::_linalg_solve_ex
, aten::linalg_solve
, aten::linalg_solve_ex
, and aten::_linalg_slogdet
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::_linalg_solve_ex 14.79% 1.125s 141.49% 10.763s 212.108us 0.000us 0.00% 2.966s 58.452us 50741
aten::linalg_solve 0.09% 7.174ms 117.89% 8.967s 560.456ms 0.000us 0.00% 2.284s 142.775ms 16
aten::linalg_solve_ex 0.00% 37.000us 75.04% 5.708s 475.648ms 0.000us 0.00% 1.513s 126.102ms 12
autograd::engine::evaluate_function: LinalgSolveExBa... 0.00% 122.000us 62.86% 4.781s 683.040ms 0.000us 0.00% 1.275s 182.126ms 7
LinalgSolveExBackward0 0.00% 76.000us 62.85% 4.781s 683.008ms 0.000us 0.00% 1.275s 182.124ms 7
aten::linalg_lu_solve 9.34% 710.739ms 35.82% 2.724s 55.427us 643.114ms 33.69% 831.927ms 16.926us 49152
aten::linalg_lu_factor_ex 7.28% 553.883ms 20.95% 1.593s 27.784us 661.250ms 34.64% 721.665ms 12.585us 57344
aten::_linalg_slogdet 4.59% 349.122ms 72.77% 5.535s 658.539us 0.000us 0.00% 677.273ms 80.580us 8405
void getf2_cta_32x32<double, double>(int, int, int, ... 0.00% 0.000us 0.00% 0.000us 0.000us 540.579ms 28.32% 540.579ms 9.427us 57344
void trsm_batch_left_lower_kernel<double>(cublasTrsm... 0.00% 0.000us 0.00% 0.000us 0.000us 277.594ms 14.54% 277.594ms 5.648us 49152
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 7.607s
Self CUDA time total: 1.909s
functorch
also prompts me with a UserWarning
that batching rules do not exists for aten::_linalg_solve_ex
, aten::linalg_solve
, aten::linalg_solve_ex
, and aten::_linalg_slogdet
and it defaults to a for-loop which will affect performance.
~/pytorch_nightly/debug/per-sample-elocal.py:49: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_slogdet. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
sgn, logabs = torch.linalg.slogdet(mat)
~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::linalg_solve. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
~/anaconda3/envs/pytorch_nightly/lib/python3.9/site-packages/torch/autograd/__init__.py:294: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_linalg_solve_ex. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /tmp/pip-req-build-hjjdrhz_/functorch/csrc/BatchedFallback.cpp:83.)
The full script to reproduce this error can be found below.
import torch
import torch.nn as nn
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
import functorch
from functorch import jacrev, jacfwd, hessian, make_functional, vmap, grad
import time
_ = torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
#version info
print("PyTorch version: ", torch.__version__)
print("CUDA version: ", torch.version.cuda)
print("FuncTorch version: ", functorch.__version__)
#time with torch synchronization
def sync_time() -> float:
torch.cuda.synchronize()
return time.perf_counter()
class model(nn.Module):
def __init__(self, num_inputs, num_hidden):
super(model, self).__init__()
self.num_inputs=num_inputs
self.func = nn.Tanh()
self.fc1 = nn.Linear(2, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_inputs)
def forward(self, x):
"""
Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
"""
idx=len(x.shape) #creates args for repeat if vmap is used or not
rep=[1 for _ in range(idx)]
rep[-2] = self.num_inputs
g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
f = torch.cat((x,g), dim=-1)
h = self.func(self.fc1(f))
mat = self.fc2(h)
sgn, logabs = torch.linalg.slogdet(mat)
return sgn, logabs
#=================================================================================================#
#Profile code for N=1 to 6
#=================================================================================================#
B=4096 #batch
N=2 #input nodes
H=128 #number of hidden nodes
device=torch.device("cuda")
for N in [1,1,2,3,4,5,6]:
net = model(N, H)
net = net.to(device)
x = torch.randn(B,N,1,device=device) #input data
fnet, params = make_functional(net)
def logabs(params, x):
_, logabs = fnet(params, x)
return logabs
def kinetic_functorch(params, X):
#do once, and re-use via has_aux?
calc_jacobian = jacrev(logabs, argnums=1)
#can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1)
return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)
#per-sample gradients for local energy w.r.t params via FuncTorch
t1=sync_time()
elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
t2=sync_time()
print("N: %2i | Walltime: %6.4f (s)" % (N, t2-t1))
#=================================================================================================#
#Profile code for N=4
#=================================================================================================#
N=4
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
net = model(N, H)
net = net.to(device)
x = torch.randn(B,N,1,device=device) #input data
fnet, params = make_functional(net)
def logabs(params, x):
_, logabs = fnet(params, x)
return logabs
def kinetic_functorch(params, X):
#do once, and re-use via has_aux?
calc_jacobian = jacrev(logabs, argnums=1)
#can only use jacrev for back-compatibility in PyTorch-1.12 for torch.linalg.slogdet
calc_hessian = jacrev(jacrev(logabs, argnums=1), argnums=1)
return -0.5*torch.sum(calc_hessian(params, X).squeeze(-3).squeeze(-1).diagonal(0,-2,-1) + calc_jacobian(params, X).squeeze(-1).pow(2), dim=-1)
#per-sample gradients for local energy w.r.t params via FuncTorch
t1=sync_time()
elocal_grad_ft = vmap(grad(kinetic_functorch, argnums=0), in_dims=(None, 0))(params, x)
t2=sync_time()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Thanks in advance! :)