This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ef745c1cd [CI] Add JAX deps in Dockerfiles (#14550)
3ef745c1cd is described below

commit 3ef745c1cd12bebb3c68814e6d6fe1421791ae98
Author: Yong Wu <[email protected]>
AuthorDate: Wed Apr 12 07:18:42 2023 -0700

    [CI] Add JAX deps in Dockerfiles (#14550)
    
    * [CI] Add JAX deps in Dockerfiles
    
    * Specify jax/jaxlib/flax version for python3.7
---
 docker/Dockerfile.ci_cpu             |  4 ++++
 docker/Dockerfile.ci_gpu             |  3 +++
 docker/install/ubuntu_install_jax.sh | 35 +++++++++++++++++++++++++++++++++++
 3 files changed, 42 insertions(+)

diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu
index 788838dbe8..7498a95875 100644
--- a/docker/Dockerfile.ci_cpu
+++ b/docker/Dockerfile.ci_cpu
@@ -108,6 +108,10 @@ RUN bash /install/ubuntu_install_tensorflow.sh
 COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh
 RUN bash /install/ubuntu_install_tflite.sh
 
+# JAX deps
+COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh
+RUN bash /install/ubuntu_install_jax.sh "cpu"
+
 # Compute Library
 COPY install/ubuntu_download_arm_compute_lib_binaries.sh 
/install/ubuntu_download_arm_compute_lib_binaries.sh
 RUN bash /install/ubuntu_download_arm_compute_lib_binaries.sh
diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu
index 3efd81496a..91566a4b36 100644
--- a/docker/Dockerfile.ci_gpu
+++ b/docker/Dockerfile.ci_gpu
@@ -89,6 +89,9 @@ RUN bash /install/ubuntu_install_coreml.sh
 COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh
 RUN bash /install/ubuntu_install_tensorflow.sh
 
+COPY install/ubuntu_install_jax.sh /install/ubuntu_install_jax.sh
+RUN bash /install/ubuntu_install_jax.sh "cuda"
+
 COPY install/ubuntu_install_darknet.sh /install/ubuntu_install_darknet.sh
 RUN bash /install/ubuntu_install_darknet.sh
 
diff --git a/docker/install/ubuntu_install_jax.sh 
b/docker/install/ubuntu_install_jax.sh
new file mode 100644
index 0000000000..a39fa2187a
--- /dev/null
+++ b/docker/install/ubuntu_install_jax.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+set -u
+set -o pipefail
+
+# Install jax and jaxlib
+if [ "$1" == "cuda" ]; then
+    pip3 install --upgrade \
+        jaxlib==0.3.25 \
+        "jax[cuda11_pip]==0.3.25" -f 
https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+else
+    pip3 install --upgrade \
+        jaxlib==0.3.25 \
+        "jax[cpu]==0.3.25"
+fi
+
+# Install flax
+pip3 install flax==0.6.4

Reply via email to