Skip to content

Commit 24dcf3b

Browse files
committed
Add Windows workflows
1 parent 2c3baba commit 24dcf3b

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Batch Build flash-attention Wheels for Windows
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
versions:
7+
description: 'Version tag of flash-attention to build: v2.3.4'
8+
default: 'v2.3.2,v2.3.3,v2.3.4'
9+
required: true
10+
type: string
11+
12+
permissions:
13+
contents: write
14+
15+
jobs:
16+
define_matrix:
17+
name: Define Workflow Matrix
18+
runs-on: ubuntu-latest
19+
outputs:
20+
matrix: ${{ steps.set-matrix.outputs.matrix }}
21+
defaults:
22+
run:
23+
shell: pwsh
24+
env:
25+
PCKGVERS: ${{ inputs.versions }}
26+
27+
steps:
28+
- uses: actions/checkout@v4
29+
30+
- name: Define Job Output
31+
id: set-matrix
32+
run: |
33+
$x = ConvertTo-Json $env:PCKGVERS.Split(',').Trim() -Compress
34+
Write-Output ('matrix=' + $x) >> $env:GITHUB_OUTPUT
35+
36+
run_workflows:
37+
name: Build ${{ matrix.version }} Wheels
38+
needs: define_matrix
39+
strategy:
40+
max-parallel: 1
41+
matrix:
42+
version: ${{ fromJSON(needs.define_matrix.outputs.matrix) }}
43+
uses: ./.github/workflows/build-wheels.yml
44+
with:
45+
version: ${{ matrix.version }}

.github/workflows/build-wheels.yml

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
name: Build flash-attention Wheels for Windows
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
version:
7+
description: 'Version tag of flash-attention to build: v2.3.4'
8+
default: 'v2.3.4'
9+
required: true
10+
type: string
11+
workflow_call:
12+
inputs:
13+
version:
14+
description: 'Version tag of flash-attention to build: v2.3.4'
15+
default: 'v2.3.4'
16+
required: true
17+
type: string
18+
19+
permissions:
20+
contents: write
21+
22+
jobs:
23+
build_wheels:
24+
name: Build wheels for Python ${{ matrix.pyver }} and CUDA ${{ matrix.cuda }}
25+
runs-on: windows-latest
26+
strategy:
27+
matrix:
28+
pyver: ["3.8", "3.9", "3.10", "3.11"]
29+
cuda: ["12.1.1"]
30+
defaults:
31+
run:
32+
shell: pwsh
33+
env:
34+
CUDAVER: ${{ matrix.cuda }}
35+
PCKGVER: ${{ inputs.version }}
36+
37+
steps:
38+
- uses: actions/checkout@v4
39+
with:
40+
repository: 'Dao-AILab/flash-attention'
41+
ref: ${{ inputs.version }}
42+
submodules: 'recursive'
43+
44+
- uses: actions/setup-python@v4
45+
with:
46+
python-version: ${{ matrix.pyver }}
47+
48+
- name: Setup Mamba
49+
uses: conda-incubator/setup-miniconda@v2.2.0
50+
with:
51+
activate-environment: "build"
52+
python-version: ${{ matrix.pyver }}
53+
miniforge-variant: Mambaforge
54+
miniforge-version: latest
55+
use-mamba: true
56+
add-pip-as-python-dependency: true
57+
auto-activate-base: false
58+
59+
- name: Install Dependencies
60+
run: |
61+
$cudaVersion = $env:CUDAVER
62+
$cudaVersionPytorch = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
63+
$cudaChannels = ''
64+
$cudaNum = [int]$cudaVersion.substring($cudaVersion.LastIndexOf('.')+1)
65+
while ($cudaNum -ge 0) { $cudaChannels += '-c nvidia/label/cuda-' + $cudaVersion.Remove($cudaVersion.LastIndexOf('.')+1) + $cudaNum + ' '; $cudaNum-- }
66+
mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()
67+
if (!(mamba list cuda)[-1].contains('cuda')) {sleep -s 10; mamba install -y 'cuda' $cudaChannels.TrimEnd().Split()}
68+
if (!(mamba list cuda)[-1].contains('cuda')) {throw 'CUDA Toolkit failed to install!'}
69+
70+
python -m pip install --upgrade build setuptools wheel packaging ninja torch==2.1.0 --extra-index-url "https://download.pytorch.org/whl/cu$cudaVersionPytorch"
71+
72+
- name: Build Wheel
73+
id: build-wheel
74+
run: |
75+
$cudaVersion = $env:CUDAVER.Remove($env:CUDAVER.LastIndexOf('.')).Replace('.','')
76+
$packageVersion = $env:PCKGVER.TrimStart('v')
77+
78+
$env:CUDA_PATH = $env:CONDA_PREFIX
79+
$env:CUDA_HOME = $env:CONDA_PREFIX
80+
81+
$env:MAX_JOBS = '1'
82+
$env:FLASH_ATTENTION_FORCE_BUILD = 'TRUE'
83+
$env:FLASH_ATTENTION_FORCE_SINGLE_THREAD = 'TRUE'
84+
85+
python -m build -n --wheel
86+
87+
$wheel = (gi '.\dist\*.whl')[0]
88+
$wheelname = $wheel.name.replace("flash_attn-$packageVersion-","flash_attn-$packageVersion+cu$cudaVersion"+"torch2.1cxx11abiFALSE-")
89+
Move-Item $wheel.fullname ".\dist\$wheelname"
90+
91+
- uses: actions/upload-artifact@v3
92+
with:
93+
name: 'windows-wheels'
94+
path: ./dist/*.whl
95+
96+
- name: Upload files to a GitHub release
97+
uses: svenstaro/upload-release-action@2.6.1
98+
continue-on-error: true
99+
with:
100+
file: ./dist/*.whl
101+
tag: ${{ inputs.version }}
102+
file_glob: true
103+
overwrite: true
104+
release_name: ${{ inputs.version }}

0 commit comments

Comments
 (0)