Skip to content

Commit b8ff197

Browse files
committed
Merge remote-tracking branch 'bdashore3/main'
2 parents 945ee7b + c3cc17e commit b8ff197

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

.github/workflows/build-wheels.yml

+20-11
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ on:
44
workflow_dispatch:
55
inputs:
66
version:
7-
description: 'Version tag of flash-attention to build: v2.3.4'
7+
description: 'Version tag of flash-attention to build: (format: v2.3.4/v2.3.4.post1)'
88
default: 'v2.3.4'
99
required: true
1010
type: string
1111
workflow_call:
1212
inputs:
1313
version:
14-
description: 'Version tag of flash-attention to build: v2.3.4'
14+
description: 'Version tag of flash-attention to build (format: v2.3.4/v2.3.4.post1)'
1515
default: 'v2.3.4'
1616
required: true
1717
type: string
@@ -22,12 +22,12 @@ permissions:
2222
jobs:
2323
build_wheels:
2424
name: Build wheels for Python ${{ matrix.pyver }}, CUDA ${{ matrix.cuda }}, and Torch ${{ matrix.torchver }}
25-
runs-on: windows-latest
25+
runs-on: windows-2022
2626
strategy:
2727
matrix:
2828
pyver: ["3.10", "3.11"]
2929
cuda: ["12.2.2"]
30-
torchver: ["2.1.2", "2.2.0"]
30+
torchver: ["2.2.2", "2.3.1"]
3131
defaults:
3232
run:
3333
shell: pwsh
@@ -44,6 +44,10 @@ jobs:
4444
with:
4545
python-version: ${{ matrix.pyver }}
4646

47+
- name: Install VS2022 BuildTools 17.9.7
48+
run: choco install -y visualstudio2022buildtools --version=117.9.7.0 --params "--add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 --installChannelUri https://aka.ms/vs/17/release/180911598_-255012421/channel"
49+
if: runner.os == 'Windows'
50+
4751
- name: Setup Mamba
4852
uses: conda-incubator/setup-miniconda@v2.2.0
4953
with:
@@ -57,20 +61,25 @@ jobs:
5761

5862
- name: Install Dependencies
5963
run: |
60-
$cudaVersion = $env:CUDAVER
61-
$cudaVersionPytorch = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
62-
$cudaChannels = ''
63-
$cudaNum = [int]$cudaVersion.substring($cudaVersion.LastIndexOf('.')+1)
64-
while ($cudaNum -ge 0) { $cudaChannels += '-c nvidia/label/cuda-' + $cudaVersion.Remove($cudaVersion.LastIndexOf('.')+1) + $cudaNum + ' '; $cudaNum-- }
65-
mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()
64+
$cudaVersion = '${{ matrix.cuda }}'
65+
$cudaVersionPytorch = '${{ matrix.cuda }}'.Remove('${{ matrix.cuda }}'.LastIndexOf('.')).Replace('.','')
66+
67+
mamba install -y -c nvidia/label/cuda-$cudaVersion cuda-toolkit cuda-runtime
6668
if (!(mamba list cuda)[-1].contains('cuda')) {sleep -s 10; mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()}
6769
if (!(mamba list cuda)[-1].contains('cuda')) {throw 'CUDA Toolkit failed to install!'}
6870
69-
python -m pip install --upgrade build setuptools wheel packaging ninja psutil torch==${{ matrix.torchver }} --extra-index-url "https://download.pytorch.org/whl/cu121"
71+
python -m pip install --upgrade build setuptools wheel packaging ninja torch==${{ matrix.torchver }} psutil --extra-index-url "https://download.pytorch.org/whl/cu121"
7072
7173
- name: Build Wheel
7274
id: build-wheel
7375
run: |
76+
# --- Spawn the VS shell
77+
if ($IsWindows) {
78+
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
79+
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools' -DevCmdArguments '-arch=x64 -host_arch=x64'
80+
$env:DISTUTILS_USE_SDK=1
81+
}
82+
7483
$cudaVersion = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
7584
$packageVersion = $env:PCKGVER.TrimStart('v')
7685

0 commit comments

Comments
 (0)