4
4
workflow_dispatch :
5
5
inputs :
6
6
version :
7
- description : ' Version tag of flash-attention to build: v2.3.4'
7
+ description : ' Version tag of flash-attention to build: (format: v2.3.4/v2.3.4.post1) '
8
8
default : ' v2.3.4'
9
9
required : true
10
10
type : string
11
11
workflow_call :
12
12
inputs :
13
13
version :
14
- description : ' Version tag of flash-attention to build: v2.3.4'
14
+ description : ' Version tag of flash-attention to build (format : v2.3.4/v2.3.4.post1) '
15
15
default : ' v2.3.4'
16
16
required : true
17
17
type : string
@@ -22,12 +22,12 @@ permissions:
22
22
jobs :
23
23
build_wheels :
24
24
name : Build wheels for Python ${{ matrix.pyver }}, CUDA ${{ matrix.cuda }}, and Torch ${{ matrix.torchver }}
25
- runs-on : windows-latest
25
+ runs-on : windows-2022
26
26
strategy :
27
27
matrix :
28
28
pyver : ["3.10", "3.11"]
29
29
cuda : ["12.2.2"]
30
- torchver : ["2.1 .2", "2.2.0 "]
30
+ torchver : ["2.2 .2", "2.3.1 "]
31
31
defaults :
32
32
run :
33
33
shell : pwsh
44
44
with :
45
45
python-version : ${{ matrix.pyver }}
46
46
47
+ - name : Install VS2022 BuildTools 17.9.7
48
+ run : choco install -y visualstudio2022buildtools --version=117.9.7.0 --params "--add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 --installChannelUri https://aka.ms/vs/17/release/180911598_-255012421/channel"
49
+ if : runner.os == 'Windows'
50
+
47
51
- name : Setup Mamba
48
52
uses : conda-incubator/setup-miniconda@v2.2.0
49
53
with :
@@ -57,20 +61,25 @@ jobs:
57
61
58
62
- name : Install Dependencies
59
63
run : |
60
- $cudaVersion = $env:CUDAVER
61
- $cudaVersionPytorch = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
62
- $cudaChannels = ''
63
- $cudaNum = [int]$cudaVersion.substring($cudaVersion.LastIndexOf('.')+1)
64
- while ($cudaNum -ge 0) { $cudaChannels += '-c nvidia/label/cuda-' + $cudaVersion.Remove($cudaVersion.LastIndexOf('.')+1) + $cudaNum + ' '; $cudaNum-- }
65
- mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()
64
+ $cudaVersion = '${{ matrix.cuda }}'
65
+ $cudaVersionPytorch = '${{ matrix.cuda }}'.Remove('${{ matrix.cuda }}'.LastIndexOf('.')).Replace('.','')
66
+
67
+ mamba install -y -c nvidia/label/cuda-$cudaVersion cuda-toolkit cuda-runtime
66
68
if (!(mamba list cuda)[-1].contains('cuda')) {sleep -s 10; mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()}
67
69
if (!(mamba list cuda)[-1].contains('cuda')) {throw 'CUDA Toolkit failed to install!'}
68
70
69
- python -m pip install --upgrade build setuptools wheel packaging ninja psutil torch==${{ matrix.torchver }} --extra-index-url "https://download.pytorch.org/whl/cu121"
71
+ python -m pip install --upgrade build setuptools wheel packaging ninja torch==${{ matrix.torchver }} psutil --extra-index-url "https://download.pytorch.org/whl/cu121"
70
72
71
73
- name : Build Wheel
72
74
id : build-wheel
73
75
run : |
76
+ # --- Spawn the VS shell
77
+ if ($IsWindows) {
78
+ Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
79
+ Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools' -DevCmdArguments '-arch=x64 -host_arch=x64'
80
+ $env:DISTUTILS_USE_SDK=1
81
+ }
82
+
74
83
$cudaVersion = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
75
84
$packageVersion = $env:PCKGVER.TrimStart('v')
76
85
0 commit comments