You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TL;DR What are the best practices for progress bars within shard_map? (I.e. monitoring fori_loops over multiple GPUs). Single device progress bars seems to work out of the box (see below), however, ordered effects within JAX seem to affect it in the multiple GPU case.
Hi all,
I have a quick question about progress bars within JAX. I've seen there's a bit of a discussion (e.g. #13126) regarding progress bars (but only for single devices). In #13126 there's an example gist of how to write a progress bar in JAX, however, it fails in the multi-device case and I was wondering if anyone knew of a fix?
I've attached a minimal reproducible example at the bottom of this post.
The error message states it fails due to Ordered effects, which was mentioned in #26087 (and fixed in #26275 at least for lax.custom_linear_solve). So, perhaps this has already been solve in 0.4.51?
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "main.py", line 158, in <module>
shard_loop(xs)
ValueError: The following ordered effects are not supported for more than 1 device: [<jax._src.debugging.OrderedDebugEffect object at 0x75381f5d5930>]
If I set ordered=False in progress, the script runs but no progress bar is shown and it just outputs the values as,
My initial attempt at a progress bar is to use io_callback or debug.callback and just print to the terminal (via the logging library), however, this seems to significantly slow down my code (so perhaps it's syncing with the CPU?).
What are the best practices for progress bars within JAX as it stands in 0.4.50? I only ask this question as I have jobs which require multiple GPUs (and can take a few days to run) and I'd like to have an estimate on the walltime during the job (if this is at all possible!).
Any help is appreciated!
Here's the minimal reproducible example,
"""
Module for the JAX progress bar.
Code modified from: https://gist.github.com/sharadmv/be24b3107cf9b8bf027ea8e2f177882e
"""
import abc
import threading
import types
from typing import Any, Callable, List, Optional, Set, Tuple, Type
import os
os.environ["XLA_FLAGS"] = (
'--xla_force_host_platform_device_count=4 '
)
import jax
from jax import lax, jit, effects_barrier, pure_callback
from jax import numpy as jnp
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
import rich.console
import rich.live
import rich.progress
import time
_ProgressCallable = Callable[[float, str, bool], Any]
class Progress(metaclass=abc.ABCMeta):
"""Abstract class for implementing a progress bar."""
def __init__(self, total: float, description: str, ordered: bool):
self._total = total
self._description = description
self._ordered = ordered
@abc.abstractmethod
def _start(self):
pass
@abc.abstractmethod
def _update(self, value: float):
pass
@abc.abstractmethod
def _stop(self):
pass
def start(self):
jax.debug.callback(self._start, ordered=self._ordered)
def update(self, value):
jax.debug.callback(self._update, value, ordered=self._ordered)
def stop(self):
jax.debug.callback(self._stop, ordered=self._ordered)
def __enter__(self):
self.start()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[types.TracebackType],
):
self.stop()
class RichProgress(Progress):
"""Progress bar implemented using the `rich` library."""
_live = rich.live.Live()
_active_progress: List[rich.progress.Progress] = []
def __init__(self, total: float, description: str, ordered: bool):
super().__init__(total=total, description=description, ordered=ordered)
self._progress = rich.progress.Progress(auto_refresh=False)
self._task = self._progress.add_task(self._description, total=total)
self._lock = threading.Lock()
def _start(self):
if not self._active_progress:
self._live.start()
self._active_progress.append(self._progress)
self._live.update(rich.console.Group(*self._active_progress))
def _update(self, value):
with self._lock:
self._progress.update(self._task, completed=value)
self._progress.refresh()
def _stop(self):
self._active_progress.pop()
self._live.update(rich.console.Group(*self._active_progress))
if not self._active_progress:
self._live.stop()
_progress_bars: Set[Tuple[int, _ProgressCallable]] = set()
def get_progress_bar() -> _ProgressCallable:
progress_bars = sorted(_progress_bars, key=lambda x: -x[0])
for _, progress_bar, in progress_bars:
return progress_bar
assert False, 'Should have hit default progress_bar'
def register_progress_bar(progress_bar, priority):
_progress_bars.add((priority, progress_bar))
return progress_bar
def _rich_progress(total: float, description: str, ordered: bool):
return RichProgress(total=total, description=description, ordered=ordered)
register_progress_bar(_rich_progress, -1)
def progress(total: float,
description: str = 'Progress:',
ordered: bool = True):
return get_progress_bar()(total, description, ordered)
@jit
def simple_loop(x):
# Create a single progress bar for 10 iterations.
total = 100_000
with progress(total=total, description="Loop progress", ordered=False) as pbar:
def loop_body(i, x):
# Update the progress bar to the current iteration count.
pbar.update(i)
return x+1
# Execute a loop from 0 to 10.
return lax.fori_loop(0, total, loop_body, x)
# 1 device fori_loop in JAX
out_1d = simple_loop(0)
out_1d = out_1d.block_until_ready()
print('output (1 device): ',out_1d)
# multiple devices fori_loop in JAX
mesh_shape = (4,1)
local_devices = jax.devices()
devices = mesh_utils.create_device_mesh(mesh_shape=mesh_shape,
devices=local_devices)
devices = devices.reshape(mesh_shape)
mesh = Mesh(devices=devices, axis_names=('batch','model'))
xs = jnp.arange(4)
# sharded function
shard_loop = shard_map(simple_loop,
mesh=mesh,
in_specs=(P('batch')),
out_specs=(P('batch')),
check_rep=False)
out_4d = shard_loop(xs) # run multi-device fori_loop here [ ERROR ]
out_4d = out_4d.block_until_ready()
print('output (4 devices): ',out_4d)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
TL;DR What are the best practices for progress bars within shard_map? (I.e. monitoring fori_loops over multiple GPUs). Single device progress bars seems to work out of the box (see below), however, ordered effects within JAX seem to affect it in the multiple GPU case.
Hi all,
I have a quick question about progress bars within JAX. I've seen there's a bit of a discussion (e.g. #13126) regarding progress bars (but only for single devices). In #13126 there's an example gist of how to write a progress bar in JAX, however, it fails in the multi-device case and I was wondering if anyone knew of a fix?
I've attached a minimal reproducible example at the bottom of this post.
The error message states it fails due to Ordered effects, which was mentioned in #26087 (and fixed in #26275 at least for
lax.custom_linear_solve
). So, perhaps this has already been solve in 0.4.51?If I set
ordered=False
inprogress
, the script runs but no progress bar is shown and it just outputs the values as,My initial attempt at a progress bar is to use
io_callback
ordebug.callback
and just print to the terminal (via thelogging
library), however, this seems to significantly slow down my code (so perhaps it's syncing with the CPU?).What are the best practices for progress bars within JAX as it stands in 0.4.50? I only ask this question as I have jobs which require multiple GPUs (and can take a few days to run) and I'd like to have an estimate on the walltime during the job (if this is at all possible!).
Any help is appreciated!
Here's the minimal reproducible example,
Beta Was this translation helpful? Give feedback.
All reactions