You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* update Gemma attention for TPU
* add default fallback for GPU and CPU
* add fallback option if not running with JAX and TPU
* address review comments
* check input signature
* remove checking q length
* code reformat
* handle case when soft cap support is not needed
* fix format
* add tests for FA calls
* fix test
* update tests
* fix code format
* address review comments
* Update requirements-jax-cuda.txt
* Update gemma_causal_lm_test.py
* Update requirements-jax-cuda.txt
0 commit comments