Skip to content

Commit 4bed9ae

Browse files
authored
Update requirements-jax-cuda.txt (keras-team#2252)
* Update requirements-jax-cuda.txt Upgrading jax version. For CI tests for keras-hub on jax GPU * Update build.sh For CI tests for JAX GPU * Update build.sh * Update build.sh * Update requirements-jax-cuda.txt * Update requirements-jax-cuda.txt * Update requirements-jax-cuda.txt * Update requirements-jax-cuda.txt * Update build.sh * Update build.sh * Update build.sh
1 parent a037e32 commit 4bed9ae

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@ fi
1616
set -x
1717
cd "${KOKORO_ROOT}/"
1818

19-
PYTHON_BINARY="/usr/bin/python3.9"
19+
PYTHON_BINARY="/usr/bin/python3.10"
2020

2121
"${PYTHON_BINARY}" -m venv venv
2222
source venv/bin/activate
2323
# Check the python version
2424
python --version
2525
python3 --version
2626

27-
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
27+
# setting the LD_LIBRARY_PATH manually is causing segmentation fault
28+
# export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
2829
# Check cuda
2930
nvidia-smi
3031
nvcc --version

requirements-jax-cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ torchvision>=0.16.0
1010
# Jax with cuda support.
1111
# Keep same version as Keras repo.
1212
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
13-
jax[cuda12_pip]==0.4.28
13+
jax[cuda12]==0.6.0
1414

1515
-r requirements-common.txt

0 commit comments

Comments
 (0)