You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/07/28 07:58:03 UTC

[tvm] branch main updated: [CI] Bump Flax and Jaxlib versions to fix Jaxlib install error (#15421)

This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 64ac43a243 [CI] Bump Flax and Jaxlib versions to fix Jaxlib install error (#15421)
64ac43a243 is described below

commit 64ac43a2436b591ae009b1ced09c14363a717710
Author: Liam Sturge <50...@users.noreply.github.com>
AuthorDate: Fri Jul 28 08:57:57 2023 +0100

    [CI] Bump Flax and Jaxlib versions to fix Jaxlib install error (#15421)
    
    Bump Flax and Jax versions to fix install error
    
    The Flax dependency Orbax (v0.1.8) has deprecated being able to install
    Orbax as a standalone package. Flax v0.6.8 attempts to install Orbax as
    a standalone package and raises an error about doing so.
    
    Going forward, the package orbax-checkpoint should be installed instead.
    Flax v0.6.8 does not recognize this and attempts to install Orbax
    instead of orbax-checkpoint and the installation fails.
    
    In order to resolve Jax installation issues, bumping the version of Flax
    to be at least 0.6.9, which resolves the problem. Flax >= 0.6.9 does not
    pin the version of orbax-checkpoint that it installs and the latest
    version requires Jax >= 0.4.9 to be installed so the two must be updated
    together.
---
 docker/install/ubuntu_install_jax.sh | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/docker/install/ubuntu_install_jax.sh b/docker/install/ubuntu_install_jax.sh
index 87cb6f7dbe..1914990916 100644
--- a/docker/install/ubuntu_install_jax.sh
+++ b/docker/install/ubuntu_install_jax.sh
@@ -23,13 +23,13 @@ set -o pipefail
 # Install jax and jaxlib
 if [ "$1" == "cuda" ]; then
     pip3 install --upgrade \
-        jaxlib==0.4.7 \
-        "jax[cuda11_pip]==0.4.7" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+        jaxlib~=0.4.9 \
+        "jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
 else
     pip3 install --upgrade \
-        jaxlib==0.4.7 \
-        "jax[cpu]==0.4.7"
+        jaxlib~=0.4.9 \
+        "jax[cpu]~=0.4.9"
 fi
 
 # Install flax
-pip3 install flax==0.6.8
+pip3 install flax~=0.6.9