-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Poc elastic training #1310
base: main
Are you sure you want to change the base?
Poc elastic training #1310
Conversation
506da52
to
0c975ce
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few initial comments
MaxText/train.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move it to a separate elastic_train.py file similar to sft_trainer.py: https://github.com/AI-Hypercomputer/maxtext/blob/sft-maxtext/MaxText/sft_trainer.py
MaxText/train.py
Outdated
@elasticutils.timeit | ||
def reshard_fn(config: pyconfig.HyperParameters): | ||
"""Reshard function.""" | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add num of reshards instead of infinite loop?
0c975ce
to
28044f7
Compare
Adding ElasticUtils to config Added elasticutils and gkeutils Added a watchdog/timebomb to each step Completely working test run. Further test runs and optimizations to follow Fixed a bug if DATA_LOSS occurs during a save added timeit to reshard_fn Updated the watchdog to repeatably stack trace every timeout intervals. send a fatal log if the failures > max failures in slice_down() instead of in the training loop in order to fail correctly if there is a reshard/failure loop within the reshard handler Updated elasticutils and added a fake elasticutils Working host memory offloading. Added a max reshard retry count Updated elasticutils to how it will be structured in pathwaysutils Delete checkpoint if we are rewinding behind when it was saved. Fixed a bug in put_array. Added max_reshard_failure_count. Added save nbytes log Adding elastic trainer
28044f7
to
9fdab26
Compare
…ynchronized checkpointing might
… the number of values returned from reshard_fn
…ext manager of the old mesh when resharding up
Description
Start with a short description of what the PR does and how this is a change from
the past.
The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):