Skip to content

Commit c82d2e0

Browse files
committed
Fix bug with pip-installed torch w/ nonstandard CUDA versions
1 parent 1e6dccb commit c82d2e0

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

gptorch/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch.nn import functional as F
1111
import numpy as np
1212

13-
torch_version = [int(s) for s in torch.__version__.split(".")]
14-
_potri = torch.cholesky_inverse if torch_version >= [1, 1, 0] else torch.potri
13+
_torch_majorminor = [int(s) for s in torch.__version__.split(".")[:2]]
14+
_potri = torch.cholesky_inverse if _torch_majorminor >= [1, 1] else torch.potri
1515

1616

1717
TRIANGULAR_SOLVE = "triangular_solve" in dir(torch)

test/test_functions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# File: test_functions.py
2+
# File Created: Saturday, 30th October 2021 11:06:35 am
3+
# Author: Steven Atkinson (steven@atkinson.mn)
4+
5+
def test_importable():
6+
from gptorch import functions
7+

0 commit comments

Comments
 (0)