Deep Learning

Intro to JAX for Machine Learning

December 1, 2022
11 min read
EXX-Blog-JAX-ML-Pyton-Lib.jpg

JAX is up and coming in the Machine Learning space with ambitions to make machine learning simple yet efficient. JAX is still a Google and Deepmind research project and not yet an official Google product but has been used extensively internally and adopted by external ML researchers. We wanted to offer an introduction to JAX, how to install JAX, and its advantages and capabilities.

What is JAX for Machine Learning?

JAX is a Python library designed for high-performance numerical computing, especially machine learning research. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. JAX focuses on accelerating the machine learning process by using XLA to compile NumPy functions on GPUs and uses autograd to differentiate Python and NumPy functions as well as gradient-based optimization. JAX is able to differentiate through loops, branches, recursion, and closures, and take derivatives of derivatives of derivatives with ease using GPU acceleration. JAX also supports backpropagation and forward-mode differentiation.

JAX offers superior performance when using GPUs to run your code and a just-in-time (JIT) compilation option to easily speed up large projects, which we will delve into later in this article. 

Think of JAX as a Python Library that modifies NumPy and Python code with function transformations to enable accelerated machine learning. As a general rule, you should use JAX whenever planning to train with GPUs, compute gradients (autograd), or use JIT code compiling.

Why use JAX?

In addition to working with normal CPUs, JAX's main function is the capability to be fully functional with different processing units such as GPUs. This gives JAX a great advantage over similar packages because the use of GPU parallelization enables faster performance than CPUs when it comes to image and vector processing.

This is extremely important because when using the NumPy library users can build matrices of exceptional sizes allowing GPUs to be much more time-efficient when processing such data formats. 

This time difference enables the JAX library to exceed NumPy alone by over 100 times the speed and performance through a couple of key implementations:

  • Vectorization - processing multiple data as single instruction that provides great speedups for linear algebra computations and machine learning
  • Code Parallelization - the process of taking serial code that runs on a single processor and distributing it. GPUs are preferred here since they have numerous processors specialized for computations.
  • Automatic Differentiation - very simple and straightforward differentiation that can be chained multiple times to evaluate higher-order derivatives with ease.

How to Install JAX

To install the CPU-only version of JAX, which might be useful for doing local development on a laptop, you can run

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

On Linux, it is often necessary to first update pip to a version that supports manylinux2014 wheels.

pip installation: GPU (CUDA)

To install JAX with both CPU and NVIDIA GPU support, you must first install CUDA and CuDNN, if they haven’t already been installed. Unlike many other popular deep learning systems, JAX does not bundle CUDA or CuDNN as part of the pip package.

JAX provides pre-built CUDA-compatible wheels for Linux only, with CUDA 11.1 or newer, and CuDNN 8.0.5 or newer. Other combinations of the operating system, CUDA, and CuDNN are possible, but require building from the source.

  • CUDA 11.1 or newer is required
    • You may be able to use older CUDA versions if you build from the source, but there are known bugs in CUDA in all CUDA versions older than 11.1, so we do not ship prebuilt binaries for older CUDA versions.
  • The supported cuDNN versions for the prebuilt wheels are:
    • cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough since it supports additional functionality.
    • cuDNN 8.0.5 or newer.
  • You must use an NVIDIA driver version that is at least as new as your CUDA toolkit's corresponding driver version. For example, if you have CUDA 11.4 update 4 installed, you must use NVIDIA driver 470.82.01 or newer if on Linux. This is a strict requirement that exists because JAX relies on JIT-compiling code; older drivers may lead to failures.
    • If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels are only available on Linux.
pip install --upgrade "jax[cuda]" https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

The jaxlib version must correspond to the version of the existing CUDA installation you want to use. You can specify a particular CUDA and CuDNN version for jaxlib explicitly:

pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

You can find your CUDA version with the command:

nvcc --version

Some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-11.1). If CUDA is installed elsewhere on your system, you can either create a symlink:

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

Comparing JAX to NumPy

Since JAX is an augmented NumPy, their syntax is very similar, giving users the ability to use the two interchangeably in projects where NumPy or JAX isn’t performing. This is often with smaller projects where the amount of acceleration is negligible in time saved. However, as models get larger, the more you should consider JAX.

Multiplying two Matrices Using JAX vs NumPy

To clearly illustrate the speed difference between these two libraries, we will use both to multiply two matrices by each other and then check the performance differences between CPU only and GPU. We will also check the performance boost that is caused by the JIT compiler.

To follow along with this tutorial, install and import the JAX and NumPy libraries (from the previous step). You can test your code on sites such as Kaggle or Google Colab. As with any library, you should import JAX by writing the following lines at the beginning of your code:

import jax.numpy as jnp
from jax import random

You can also import the NumPy library in a similar manner:

import numpy as np

Next, we will compare the performance of both JAX and Numpy using the CPU and GPU by multiplying two matrices together in Python. For these benchmarks, lower is better.

NumPy on CPU

To begin, we will create a matrix of 5,000 by 5,000 using NumPy and test its performance speed-wise.

import numpy as np

size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

785 ms per loop

A single loop of the code running on NumPy took around 750 ms per loop to run.

JAX on CPU

Now let’s run the same code,  but this time using the JAX library.

import jax.numpy as jnp

size = 5000
x = jnp.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

1.43 sec per loop

As you can see, comparing JAX and NumPy CPU-only performance shows that NumPy is the faster option. While JAX may not provide the best performance with normal CPUs, it does provide much better performance with GPUs.

JAX with GPU

Now, let's try to create the same 5,000 by 5,000 matrix, this time using JAX with a GPU instead of the regular CPU:

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
size = 5000

x = random.normal(key, (size, size)).astype(jnp.float32)
%time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()

80.6 ms per loop

As clearly shown when running JAX on a GPU instead of a CPU, we achieve a much better time of around 80ms per loop (around 15 times the performance). This will be even easier to see when using larger matrices or time scales.

Just-in-Time Compilation (JIT)

Using the jit command, our code will be compiled using a specific XLA compiler, allowing our functions to be efficiently executed.

XLA, short for accelerated linear algebra, is used by libraries such as JAX and Tensorflow to compile and run code on the GPU with greater efficiency. So to sum it up, XLA is a specific linear algebra compiler that is capable of compiling code at a much higher speed.

We will test our code using the selu_np function, which stands for Scaled Exponential Linear Unit, and check the different time performances between NumPy on a normal CPU, and running JAX on a GPU with JIT.

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

NumPy on CPU 

To start with, we will create a vector of size 1,000,000 using the NumPy library.

import numpy as np

x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)

8.3  ms per loop

JAX on GPU with JIT

Now we will test our code while using JAX and JIT on a GPU.

import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit

key = random.PRNGKey(0)

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) 
%time selu_jax_jit(x_jax).block_until_ready() 
%timeit selu_jax_jit(x_jax).block_until_ready() 

153 µs per loop (0.153 milisecond per loop)

Lastly, when using the JIT compiler with a GPU, we get a much better performance than using a normal GPU. As you can clearly, see the difference is very apparent, a nearly 5000% speed increase or 50 times faster from NumPy to JAX with JIT!

Think of JAX as a modification to NumPy to enable accelerated machine learning with GPUs. Since NumPy can only be compiled CPU, JAX is faster than NumPy if you opt to execute code on GPUs. As a general rule, you should use JAX whenever planning to use NumPy with GPUs or use JIT code compiling.

Note: to check the original article from where the examples in this tutorial were used, check the following link: original code.

JAX Limitation: Pure Functions

JAX transformations and complications are designed for Python functions that are functionally pure. Pure functions cannot change the state of the program by accessing outside variables, and cannot have side effects on functions such as input/output streams like print().

Consecutive runs cause these side effects to not perform as intended. If you are not careful, untracked side effects could throw off the accuracy of your intended computations.

Using Google’s JAX 

In this article, we explained the capabilities of JAX and what advantages it brings to NumPy. We covered how to install the JAX library and its advantages for machine learning.

We then went on to import JAX and NumPy. Moreover, we compared JAX with NumPy ( which is the most well-known competitor library out there) and revealed the time and performance differences between these two using regular CPUs and GPUs alongside some JIT tests as well and saw drastic speed improvements.

If you are an advanced machine/deep learning practitioner, then adding a library such as JAX to your arsenal with its (GPU/TPU) accelerators and its efficient JIT compiler will definitely make life much easier.

Topics

EXX-Blog-JAX-ML-Pyton-Lib.jpg
Deep Learning

Intro to JAX for Machine Learning

December 1, 202211 min read

JAX is up and coming in the Machine Learning space with ambitions to make machine learning simple yet efficient. JAX is still a Google and Deepmind research project and not yet an official Google product but has been used extensively internally and adopted by external ML researchers. We wanted to offer an introduction to JAX, how to install JAX, and its advantages and capabilities.

What is JAX for Machine Learning?

JAX is a Python library designed for high-performance numerical computing, especially machine learning research. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. JAX focuses on accelerating the machine learning process by using XLA to compile NumPy functions on GPUs and uses autograd to differentiate Python and NumPy functions as well as gradient-based optimization. JAX is able to differentiate through loops, branches, recursion, and closures, and take derivatives of derivatives of derivatives with ease using GPU acceleration. JAX also supports backpropagation and forward-mode differentiation.

JAX offers superior performance when using GPUs to run your code and a just-in-time (JIT) compilation option to easily speed up large projects, which we will delve into later in this article. 

Think of JAX as a Python Library that modifies NumPy and Python code with function transformations to enable accelerated machine learning. As a general rule, you should use JAX whenever planning to train with GPUs, compute gradients (autograd), or use JIT code compiling.

Why use JAX?

In addition to working with normal CPUs, JAX's main function is the capability to be fully functional with different processing units such as GPUs. This gives JAX a great advantage over similar packages because the use of GPU parallelization enables faster performance than CPUs when it comes to image and vector processing.

This is extremely important because when using the NumPy library users can build matrices of exceptional sizes allowing GPUs to be much more time-efficient when processing such data formats. 

This time difference enables the JAX library to exceed NumPy alone by over 100 times the speed and performance through a couple of key implementations:

  • Vectorization - processing multiple data as single instruction that provides great speedups for linear algebra computations and machine learning
  • Code Parallelization - the process of taking serial code that runs on a single processor and distributing it. GPUs are preferred here since they have numerous processors specialized for computations.
  • Automatic Differentiation - very simple and straightforward differentiation that can be chained multiple times to evaluate higher-order derivatives with ease.

How to Install JAX

To install the CPU-only version of JAX, which might be useful for doing local development on a laptop, you can run

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

On Linux, it is often necessary to first update pip to a version that supports manylinux2014 wheels.

pip installation: GPU (CUDA)

To install JAX with both CPU and NVIDIA GPU support, you must first install CUDA and CuDNN, if they haven’t already been installed. Unlike many other popular deep learning systems, JAX does not bundle CUDA or CuDNN as part of the pip package.

JAX provides pre-built CUDA-compatible wheels for Linux only, with CUDA 11.1 or newer, and CuDNN 8.0.5 or newer. Other combinations of the operating system, CUDA, and CuDNN are possible, but require building from the source.

  • CUDA 11.1 or newer is required
    • You may be able to use older CUDA versions if you build from the source, but there are known bugs in CUDA in all CUDA versions older than 11.1, so we do not ship prebuilt binaries for older CUDA versions.
  • The supported cuDNN versions for the prebuilt wheels are:
    • cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough since it supports additional functionality.
    • cuDNN 8.0.5 or newer.
  • You must use an NVIDIA driver version that is at least as new as your CUDA toolkit's corresponding driver version. For example, if you have CUDA 11.4 update 4 installed, you must use NVIDIA driver 470.82.01 or newer if on Linux. This is a strict requirement that exists because JAX relies on JIT-compiling code; older drivers may lead to failures.
    • If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels are only available on Linux.
pip install --upgrade "jax[cuda]" https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

The jaxlib version must correspond to the version of the existing CUDA installation you want to use. You can specify a particular CUDA and CuDNN version for jaxlib explicitly:

pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

You can find your CUDA version with the command:

nvcc --version

Some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-11.1). If CUDA is installed elsewhere on your system, you can either create a symlink:

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

Comparing JAX to NumPy

Since JAX is an augmented NumPy, their syntax is very similar, giving users the ability to use the two interchangeably in projects where NumPy or JAX isn’t performing. This is often with smaller projects where the amount of acceleration is negligible in time saved. However, as models get larger, the more you should consider JAX.

Multiplying two Matrices Using JAX vs NumPy

To clearly illustrate the speed difference between these two libraries, we will use both to multiply two matrices by each other and then check the performance differences between CPU only and GPU. We will also check the performance boost that is caused by the JIT compiler.

To follow along with this tutorial, install and import the JAX and NumPy libraries (from the previous step). You can test your code on sites such as Kaggle or Google Colab. As with any library, you should import JAX by writing the following lines at the beginning of your code:

import jax.numpy as jnp
from jax import random

You can also import the NumPy library in a similar manner:

import numpy as np

Next, we will compare the performance of both JAX and Numpy using the CPU and GPU by multiplying two matrices together in Python. For these benchmarks, lower is better.

NumPy on CPU

To begin, we will create a matrix of 5,000 by 5,000 using NumPy and test its performance speed-wise.

import numpy as np

size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

785 ms per loop

A single loop of the code running on NumPy took around 750 ms per loop to run.

JAX on CPU

Now let’s run the same code,  but this time using the JAX library.

import jax.numpy as jnp

size = 5000
x = jnp.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

1.43 sec per loop

As you can see, comparing JAX and NumPy CPU-only performance shows that NumPy is the faster option. While JAX may not provide the best performance with normal CPUs, it does provide much better performance with GPUs.

JAX with GPU

Now, let's try to create the same 5,000 by 5,000 matrix, this time using JAX with a GPU instead of the regular CPU:

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
size = 5000

x = random.normal(key, (size, size)).astype(jnp.float32)
%time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()

80.6 ms per loop

As clearly shown when running JAX on a GPU instead of a CPU, we achieve a much better time of around 80ms per loop (around 15 times the performance). This will be even easier to see when using larger matrices or time scales.

Just-in-Time Compilation (JIT)

Using the jit command, our code will be compiled using a specific XLA compiler, allowing our functions to be efficiently executed.

XLA, short for accelerated linear algebra, is used by libraries such as JAX and Tensorflow to compile and run code on the GPU with greater efficiency. So to sum it up, XLA is a specific linear algebra compiler that is capable of compiling code at a much higher speed.

We will test our code using the selu_np function, which stands for Scaled Exponential Linear Unit, and check the different time performances between NumPy on a normal CPU, and running JAX on a GPU with JIT.

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

NumPy on CPU 

To start with, we will create a vector of size 1,000,000 using the NumPy library.

import numpy as np

x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)

8.3  ms per loop

JAX on GPU with JIT

Now we will test our code while using JAX and JIT on a GPU.

import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit

key = random.PRNGKey(0)

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) 
%time selu_jax_jit(x_jax).block_until_ready() 
%timeit selu_jax_jit(x_jax).block_until_ready() 

153 µs per loop (0.153 milisecond per loop)

Lastly, when using the JIT compiler with a GPU, we get a much better performance than using a normal GPU. As you can clearly, see the difference is very apparent, a nearly 5000% speed increase or 50 times faster from NumPy to JAX with JIT!

Think of JAX as a modification to NumPy to enable accelerated machine learning with GPUs. Since NumPy can only be compiled CPU, JAX is faster than NumPy if you opt to execute code on GPUs. As a general rule, you should use JAX whenever planning to use NumPy with GPUs or use JIT code compiling.

Note: to check the original article from where the examples in this tutorial were used, check the following link: original code.

JAX Limitation: Pure Functions

JAX transformations and complications are designed for Python functions that are functionally pure. Pure functions cannot change the state of the program by accessing outside variables, and cannot have side effects on functions such as input/output streams like print().

Consecutive runs cause these side effects to not perform as intended. If you are not careful, untracked side effects could throw off the accuracy of your intended computations.

Using Google’s JAX 

In this article, we explained the capabilities of JAX and what advantages it brings to NumPy. We covered how to install the JAX library and its advantages for machine learning.

We then went on to import JAX and NumPy. Moreover, we compared JAX with NumPy ( which is the most well-known competitor library out there) and revealed the time and performance differences between these two using regular CPUs and GPUs alongside some JIT tests as well and saw drastic speed improvements.

If you are an advanced machine/deep learning practitioner, then adding a library such as JAX to your arsenal with its (GPU/TPU) accelerators and its efficient JIT compiler will definitely make life much easier.

Topics