# How to add a progress bar to JAX scans and loops

JAX allows you to write optimisers and samplers which are really fast if you use the `scan`

or `fori_loop`

functions. However if you write them in this way it’s not obvious how to add progress bar for your algorithm. This post explains how to make a progress bar using Python’s `print`

function as well as using tqdm. After briefly setting up the sampler, we first go over how to create a basic version using Python’s `print`

function, and then show how to create a nicer version using tqdm. You can find the code for the basic version here and the code for the tqdm version here.

*Update January 2023: this is now available in a pip-installable package: JAX-tqdm*

# Setup: sampling a Gaussian

We’ll use an Unadjusted Langevin Algorithm (ULA) to sample from a Gaussian to illustrate how to write the progress bar. Let’s start by defining the log-posterior of a d-dimensional Gaussian and we’ll use JAX to get it’s gradient:

```
@jit
def log_posterior(x):
return -0.5*jnp.dot(x,x)
grad_log_post = jit(grad(log_posterior))
```

We now define ULA using the `scan`

function (see this post for an explanation of the `scan`

function).

```
@partial(jit, static_argnums=(2,))
def ula_kernel(key, param, grad_log_post, dt):
key, subkey = random.split(key)
paramGrad = grad_log_post(param)
noise_term = jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape))
param = param + dt*paramGrad + noise_term
return key, param
@partial(jit, static_argnums=(1,2,))
def ula_sampler(key, grad_log_post, num_samples, dt, x_0):
def ula_step(carry, iter_num):
key, param = carry
key, param = ula_kernel(key, param, grad_log_post, dt)
return (key, param), param
carry = (key, x_0)
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples))
return samples
```

If we add a `print`

function in `ula_step`

above, it will only be called the first time it is called, which is when `ula_sampler`

is compiled. This is because printing is a side effect, and compiled JAX functions are pure.

# Basic progress bar

As a workaround, the JAX team has added the `host_callback`

module (which is still experimental, so things may change). This module defines functions that allow you to call Python functions from within a JAX function. Here’s how you would use the `id_tap`

function to create a progress bar (from this discussion):

```
from jax.experimental import host_callback
def _print_consumer(arg, transform):
iter_num, num_samples = arg
print(f"Iteration {iter_num:,} / {num_samples:,}")
@jit
def progress_bar(arg, result):
"""
Print progress of a scan/loop only if the iteration number is a multiple of the print_rate
Usage: `carry = progress_bar((iter_num + 1, num_samples, print_rate), carry)`
Pass in `iter_num + 1` so that counting starts at 1 and ends at `num_samples`
"""
iter_num, num_samples, print_rate = arg
result = lax.cond(
iter_num % print_rate==0,
lambda _: host_callback.id_tap(_print_consumer, (iter_num, num_samples), result=result),
lambda _: result,
operand=None)
return result
```

The `id_tap`

function behaves like the identity function, so calling `host_callback.id_tap(_print_consumer, (iter_num, num_samples), result=result)`

will simply return `result`

. However while doing this, it will also call the function `_print_consumer((iter_num, num_samples))`

which we’ve defined to print the iteration number.

You need to pass an argument in this way because you need to include a data dependency to make sure that the print function gets called at the correct time. This is linked to the fact that computations in JAX are run only when needed. So you need to pass in a variable that changes throughout the algorithm such as the PRNG key at that iteration.

Also note also that the `_print_consumer`

function takes in `arg`

(which holds the current iteration number as well as the total number of iterations) and `transform`

. This `transform`

argument isn’t used here, but apparently should be included in the consumer for id_tap (namely: the Python function that gets called).

Here’s how you would use the progress bar in the ULA sampler:

```
def ula_step(carry, iter_num):
key, param = carry
key = progress_bar((iter_num + 1, num_samples, print_rate), key)
key, param = ula_kernel(key, param, grad_log_post, dt)
return (key, param), param
```

We passed the `key`

into the progress bar which comes out unchanged. We also set the print rate to be 10% of the number of samples. Note that this would also work for `lax.fori_loop`

except that the first argument of `ula_step`

would be the current iteration number.

### Put it in a decorator

We can make this even easier to use by putting the progress bar in a decorator. Note that the decorator takes in `num_samples`

as an argument.

```
def progress_bar_scan(num_samples):
def _progress_bar_scan(func):
print_rate = int(num_samples/10)
def wrapper_progress_bar(carry, iter_num):
iter_num = progress_bar((iter_num + 1, num_samples, print_rate), iter_num)
return func(carry, iter_num)
return wrapper_progress_bar
return _progress_bar_scan
```

Remember that writing a decorator with arguments means writing a function that returns a decorator (which itself is a function that returns a modified version of the main function you care about). See this StackOverflow question about this.

Putting it all together, the result is very easy to use:

```
@partial(jit, static_argnums=(1,2,3))
def ula_sampler_pbar(key, grad_log_post, num_samples, dt, x_0):
"ULA sampler with progress bar"
@progress_bar_scan(num_samples)
def ula_step(carry, iter_num):
key, param = carry
key, param = ula_kernel(key, param, grad_log_post, dt)
return (key, param), param
carry = (key, x_0)
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples))
return samples
```

Now that we have a progress bar, we might also want to know when the function is compiling (which is especially useful when it takes a while to compile). Here we can use the fact that the `print`

function only gets called during compilation. We can add `print("Compiling..")`

at the beginning of `ula_sampler_pbar`

and add `print("Running:")`

at the end. Both of these will then only display when the function is first run. You can find the code for this sampler here.

# tqdm progress bar

We’ll now use the same ideas to build a fancier progress bar: namely one that uses tqdm. We’ll need to use `host_callback.id_tap`

to define a `tqdm`

progress bar and then call `tqdm.update`

regularly to update it. We’ll also need to close the progress bar once we’re finished or else `tqdm`

will act weirdly. To do with we’ll define a decorator that takes in arguments just like we did in the case of the simple progress bar.

This decorator defines the tqdm progress bar at the first iteration, updates it every `print_rate`

number of iterations, and finally closes it at the end. You can optionally pass in a message to add at the beginning of the progress bar.

There are details to make sure the progress bar acts correctly in corner cases, such as if `num_samples`

is less than 20, or if it’s not a multiple of 20. Note also that tqdm is closed at the last iteration only *after* the parameter update is done.

```
def progress_bar_scan(num_samples, message=None):
"Progress bar for a JAX scan"
if message is None:
message = f"Running for {num_samples:,} iterations"
tqdm_bars = {}
if num_samples > 20:
print_rate = int(num_samples / 20)
else:
print_rate = 1 # if you run the sampler for less than 20 iterations
remainder = num_samples % print_rate
def _define_tqdm(arg, transform):
tqdm_bars[0] = tqdm(range(num_samples))
tqdm_bars[0].set_description(message, refresh=False)
def _update_tqdm(arg, transform):
tqdm_bars[0].update(arg)
def _update_progress_bar(iter_num):
"Updates tqdm progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: host_callback.id_tap(_define_tqdm, None, result=iter_num),
lambda _: iter_num,
operand=None,
)
_ = lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) & (iter_num != num_samples-remainder),
lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=iter_num),
lambda _: iter_num,
operand=None,
)
_ = lax.cond(
# update tqdm by `remainder`
iter_num == num_samples-remainder,
lambda _: host_callback.id_tap(_update_tqdm, remainder, result=iter_num),
lambda _: iter_num,
operand=None,
)
def _close_tqdm(arg, transform):
tqdm_bars[0].close()
def close_tqdm(result, iter_num):
return lax.cond(
iter_num == num_samples-1,
lambda _: host_callback.id_tap(_close_tqdm, None, result=result),
lambda _: result,
operand=None,
)
def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Note that `body_fun` must either be looping over `np.arange(num_samples)`,
or be looping over a tuple who's first element is `np.arange(num_samples)`
This means that `iter_num` is the current iteration number
"""
def wrapper_progress_bar(carry, x):
if type(x) is tuple:
iter_num, *_ = x
else:
iter_num = x
_update_progress_bar(iter_num)
result = func(carry, x)
return close_tqdm(result, iter_num)
return wrapper_progress_bar
return _progress_bar_scan
```

Although this progress bar is more complicated than the previous one, you use it in exactly the same way. You simply add the decorator to the step function used in `lax.scan`

with the number of samples as argument (and optionally the messsage to print at the beginning of the progress bar).

```
@partial(jit, static_argnums=(1,2))
def ula_sampler_pbar(key, grad_log_post, num_samples, dt, x_0):
"ULA sampler with progress bar"
@progress_bar_scan(num_samples)
def ula_step(carry, iter_num):
key, param = carry
key, param = ula_kernel(key, param, grad_log_post, dt)
return (key, param), param
carry = (key, x_0)
_, samples = lax.scan(ula_step, carry, jnp.arange(num_samples))
return samples
```

### Conclusion

So we’ve built two progress bars: a basic version and a nicer version that uses tqdm. The code for these are on these two gists: here and here.