-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mx_fp8_bf16 kernel #1637
Add mx_fp8_bf16 kernel #1637
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1637
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #1637, branch: drisspg/stack/31
3b57cd9
to
ae51147
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! if CI is green - looks good! I think this should have at least one numerical test though. Can be a follow-up PR if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
stack-info: PR: #1637, branch: drisspg/stack/31
ae51147
to
1e3d2dd
Compare
stack-info: PR: #1637, branch: drisspg/stack/31
1e3d2dd
to
8d90a66
Compare
stack-info: PR: #1637, branch: drisspg/stack/31
8d90a66
to
0646800
Compare
0646800
to
7473aca
Compare
7473aca
to
d410880
Compare
b26ae7b
to
7b9df4d
Compare
7b9df4d
to
1b1e6b5
Compare
1b1e6b5
to
2febf48
Compare
2febf48
to
eb5a573
Compare
eb5a573
to
b005006
Compare
stack-info: PR: #1637, branch: drisspg/stack/31
b005006
to
e18a020
Compare
stack-info: PR: #1661
Stacked PRs:
Add mx_fp8_bf16 kernel
Will flesh out more but this moves over the kernel from here: https://github.com/drisspg/driss_torch/blob/2813322f0b0f9a0f0fc8d382090ad0aaecf3468a/src/mx_fp8_bf16.cu#L162
This does fp8xfp8 w/ E8m0 scales and group_size hard coded to 32. The format for the scales is the same as that for cublasLT. I have created a pytorch function that converts the [n_rows, n_cols//32] scales into the expected format:
https://github.com/drisspg/transformer_nuggets/blob/382cb0f19a5f615827174289b8ef552419d51fea/transformer_nuggets/mx/to_blocked.py#L11
This was surprisingly hard fought and would not have been possible w/ @albanD 😊
This allows this PR: #1625 to not have any dependencies on PT core updates while we add the required dtypes and bindings to cublas: pytorch/pytorch#145562
Follow up
Config needs more tuning