Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

step_prefix cannot contain _ -- Checkpoint manager does not recognized multiple _. #1499

Open
scott-yj-yang opened this issue Jan 14, 2025 · 2 comments
Assignees
Labels
checkpoint type:bug Something isn't working

Comments

@scott-yj-yang
Copy link

scott-yj-yang commented Jan 14, 2025

Bug Description:

When I created a checkpoint manager option like the following,

options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
with ocp.CheckpointManager(
    ".../model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
    options=options,
) as mngr:
    mngr.restore(0)

with my directory looks like this

Image

it gives me an value error of the following when instantiating the manager object.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 5
      1 import orbax.checkpoint as ocp
      4 options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
----> 5 with ocp.CheckpointManager(
      6     "/root/vast/scott-yang/track-mjx/model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
      7     options=options,
      8 ) as mngr:
      9     mngr.restore(0)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:685, in CheckpointManager.__init__(self, directory, checkpointers, options, metadata, item_names, item_handlers, logger, handler_registry)
    675   self._cleanup_tmp_directories()
    677 self._step_name_format = (
    678     self._options.step_name_format
    679     or step_lib.standard_name_format(
   (...)
    682     )
    683 )
--> 685 self._checkpoints = self._load_checkpoint_infos()
    687 self._metadata_checkpointer = Checkpointer(
    688     JsonCheckpointHandler(
    689         multiprocessing_options=self._multiprocessing_options
   (...)
    694     temporary_path_class=self._options.temporary_path_class,
    695 )
    696 if self._options.read_only and not self._metadata_path().exists():

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:1431, in CheckpointManager._load_checkpoint_infos(self)
   1423 """Loads a list of CheckpointInfo for existing checkpoints.
   1424 
   1425 If none are present, returns empty list.
   (...)
   1428   a list of CheckpointInfo, sorted by increasing step.
   1429 """
   1430 start = time.time()
-> 1431 steps = utils.checkpoint_steps(
   1432     self.directory, self._options.single_host_load_and_broadcast
   1433 )
   1434 steps.sort()  # Prefer in-place sort.
   1436 if not steps:

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:698, in checkpoint_steps(checkpoint_dir, single_host_load_and_broadcast)
    696   padded_step_list = multihost.broadcast_one_to_all(padded_step_list)
    697   return [step for step in padded_step_list if step >= 0]
--> 698 return _checkpoint_steps(checkpoint_dir)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:682, in checkpoint_steps.<locals>._checkpoint_steps(path)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
--> 682   return [
    683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:683, in <listcomp>(.0)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
    682   return [
--> 683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:645, in step_from_checkpoint_name(name)
    643 elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
    644   return int(tmp_match.group(1))
--> 645 raise ValueError(f'Unrecognized name format: {name}.')

ValueError: Unrecognized name format: ppo_networks_1024000.

Specifically, when I check the step.py

def step_from_checkpoint_name(name: str) -> int:
"""Returns the step from a checkpoint name. Also works for tmp checkpoints."""
if name.isdigit():
return int(name)
elif name.split('_')[-1].isdigit():
split = name.split('_')
if len(split) == 2 and split[0]:
return int(split[-1])
elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
return int(tmp_match.group(1))
raise ValueError(f'Unrecognized name format: {name}.')
it assumes that after the split by _, there are only two members. An input validation of the prefix is needed.

@niketkumar
Copy link
Collaborator

Thank you for reporting this issue. We are working on the fix.

I hope as a work around, you are fine with renaming the prefix to something like pponetworks?

@niketkumar niketkumar self-assigned this Feb 6, 2025
@scott-yj-yang
Copy link
Author

Thank you for your reply. Yes, I am currently naming as PPONetworks to work around this.

@selamw1 selamw1 added type:bug Something isn't working checkpoint labels Feb 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpoint type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants