From 2921b97fd8e106b73295b8dd3b4ae93b4dbce315 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Wed, 5 Mar 2025 04:54:57 +0000 Subject: [PATCH] Fix Tokenizer issue for long running checkpointing config --- MaxText/metric_logger.py | 4 +- benchmarks/maxtext_trillium_model_configs.py | 20 +++- benchmarks/maxtext_xpk_runner.py | 4 +- .../recipes/pw_mcjax_benchmark_recipe.py | 94 ++++++++++++++++--- 4 files changed, 104 insertions(+), 18 deletions(-) diff --git a/MaxText/metric_logger.py b/MaxText/metric_logger.py index 9be1e5eb7..b3f02b40d 100644 --- a/MaxText/metric_logger.py +++ b/MaxText/metric_logger.py @@ -70,8 +70,8 @@ def write_metrics(self, running_gcs_metrics, metrics, step, is_training=True): if self.config.enable_tensorboard: self.write_metrics_to_tensorboard(metrics_to_write, steps_to_write, is_training) - if self.config.metrics_file: - self.write_metrics_locally(metrics_to_write, steps_to_write) + # if self.config.metrics_file: + # self.write_metrics_locally(metrics_to_write, steps_to_write) if self.config.gcs_metrics and jax.process_index() == 0: running_gcs_metrics = self.write_metrics_for_gcs(metrics_to_write, steps_to_write, running_gcs_metrics, is_training) diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 064ba7a01..3ca6706c8 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -31,7 +31,7 @@ "checkpoint_storage_use_zarr3": False, "enable_pathways_goodput": True, "enable_single_controller": True, - "metrics_file": "metrics.txt", + # "metrics_file": "metrics.txt", "goodput_upload_interval_seconds": 30, } @@ -1022,7 +1022,8 @@ def _add_to_model_dictionary( "enable_checkpointing": True, "async_checkpointing": True, "checkpoint_period": 20, - "enable_checkpoint_cloud_logger": True, + # "enable_checkpoint_cloud_logger": True, + "enable_checkpoint_cloud_logger": False, "sa_block_q": 2048, "sa_block_kv": 2048, "sa_block_kv_compute": 2048, @@ -1036,6 +1037,7 @@ def _add_to_model_dictionary( "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 5, + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", }, xla_flags=( xla_flags_library.DENSE_VMEM_LIMIT_FLAG @@ -1081,10 +1083,12 @@ def _add_to_model_dictionary( "use_iota_embed": True, "dataset_path": "gs://max-datasets-rogue", "dataset_type": "synthetic", - "enable_checkpointing": True, + # "enable_checkpointing": True, + "enable_checkpointing": False, "async_checkpointing": True, "checkpoint_period": 20, - "enable_checkpoint_cloud_logger": True, + # "enable_checkpoint_cloud_logger": True, + "enable_checkpoint_cloud_logger": False, "sa_block_q": 2048, "sa_block_kv": 2048, "sa_block_kv_compute": 2048, @@ -1098,6 +1102,14 @@ def _add_to_model_dictionary( "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 5, + "tokenizer_path": "assets/tokenizer_llama3.tiktoken", + }, + pathways_tuning_params={ + "enable_pathways_goodput": False, + "monitor_goodput": False, + "enable_tensorboard": False, + "use_vertex_tensorboard": False, + "enable_checkpoint_cloud_logger": False, }, xla_flags=( xla_flags_library.DENSE_VMEM_LIMIT_FLAG diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 046d93072..b0be401c7 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -59,6 +59,7 @@ class PathwaysConfig: server_flags: str = '' proxy_flags: str = '' worker_flags: str = '' + pathways_gcs_location: str = '' # TODO(@vbarr): Split out parameters related to XPK workload and a General workload @@ -79,6 +80,7 @@ class WorkloadConfig: xpk_path: str = '~/xpk' pathways_config: PathwaysConfig = None run_name: str = None + use_gcsfuse: bool = False @dataclasses.dataclass @@ -513,7 +515,7 @@ def _get_pathways_specific_flags(wl_config: WorkloadConfig): f' {proxy_server_image_flag} ' f' {remote_python_sidecar_image_flag} ' f' --termination-grace-period-seconds=300 ' - f' --pathways-gcs-location={wl_config.base_output_directory} ' + f' --pathways-gcs-location={pw_config.pathways_gcs_location if pw_config.pathways_gcs_location is not None else wl_config.base_output_directory} ' f' --custom-pathways-server-args="{server_flags}" ' f' --custom-pathways-proxy-server-args="{proxy_flags}" ' f' --custom-pathways-worker-args="{worker_flags}" ' diff --git a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py index 611536d8a..1ae8a9ede 100644 --- a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py +++ b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py @@ -41,11 +41,67 @@ # Other parameters (MUST BE SET BY USER) XPK_PATH = "../xpk" # We're running this script from the maxtext directory USER = os.environ["USER"] + +MAX_RESTARTS = 1_000 +BENCHMARK_STEPS = 10_000 +# USE_GCSFUSE = True +USE_GCSFUSE = False + +num_slices_list = [ + 48 +] + BASE_OUTPUT_DIRECTORY = ( f"gs://{USER}-{PROJECT}-{COUNTRY}/pw_mcjax_benchmarking/" ) +GCSFUSE_BASE_OUTPUT_DIRECTORY = "" + +if USE_GCSFUSE: + GCSFUSE_BASE_OUTPUT_DIRECTORY = BASE_OUTPUT_DIRECTORY + BASE_OUTPUT_DIRECTORY = ( + f"/tmp/gscfuse/pw_mcjax_benchmarking/" + ) + +################################################################################ + +PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/sujinesh/unsanitized_proxy_server:latest" +SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/sujinesh/unsanitized_server:latest" +RUNNER = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest" +RUNNER = "gcr.io/cloud-tpu-multipod-dev/sujinesh_latest:latest" + +# Sustained Capacity Cluster +CLUSTER = "bodaborg-v6e-256-dnd-yucmhab" +PROJECT = "tpu-prod-env-one-vm" +ZONE = "us-east5-b" +COUNTRY = "us" +DEVICE_TYPE = "v6e-256" + +# Debug Cluster +CLUSTER = "bodaborg-v6e-16-debug" +PROJECT = "tpu-prod-env-one-vm" +ZONE = "us-east5-b" +COUNTRY = "us" +DEVICE_TYPE = "v6e-16" + +# High scale cluster +CLUSTER = "bodaborg-v6e-256-ts" +PROJECT = "tpu-prod-env-multipod" +ZONE = "us-west1-c" +COUNTRY = "us" +DEVICE_TYPE = "v6e-256" + +BASE_OUTPUT_DIRECTORY = ( + f"gs://trillium-scale-tests-q1-25-west/pw_mcjax_benchmarking/{USER}/suspend_resume_test/" +) +GCSFUSE_BASE_OUTPUT_DIRECTORY = "" -BENCHMARK_STEPS = 20 +if USE_GCSFUSE: + GCSFUSE_BASE_OUTPUT_DIRECTORY = BASE_OUTPUT_DIRECTORY + BASE_OUTPUT_DIRECTORY = ( + f"/tmp/gscfuse/pw_mcjax_benchmarking/" + ) + +################################################################################ def main() -> int: @@ -71,12 +127,16 @@ def main() -> int: # model_configs.llama3_1_70b_8192, # model_configs.llama3_1_405b_8192_fsdp_dcn, # model_configs.llama2_70b_4096_real_data_long_run, + model_configs.llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds, + # model_configs.llama3_1_70b_8192_iter_synth_data_and_checkpointing, ], "pathways": [ - model_configs.llama3_1_8b_8192, + # model_configs.llama3_1_8b_8192, # model_configs.llama3_1_70b_8192, # model_configs.llama3_1_405b_8192_fsdp_dcn, # model_configs.llama2_70b_4096_real_data_long_run, + # model_configs.llama3_1_70b_8192_iter_real_data_and_checkpointing_tfds, + # model_configs.llama3_1_70b_8192_iter_synth_data_and_checkpointing, ] } pathways_config = mxr.PathwaysConfig( @@ -85,13 +145,14 @@ def main() -> int: runner_image=RUNNER, # User can add additional flags here. - server_flags="", - proxy_flags="", - worker_flags="", + server_flags="--enable_metrics_collection=true", + proxy_flags="--enable_metrics_collection=true", + worker_flags="--enable_metrics_collection=true", + + pathways_gcs_location=( + GCSFUSE_BASE_OUTPUT_DIRECTORY if USE_GCSFUSE else None + ), ) - num_slices_list = [ - 2 - ] xpk_workload_cmds = [] xpk_workload_names = [] @@ -109,6 +170,13 @@ def main() -> int: model.tuning_params["vertex_tensorboard_project"] = PROJECT model.tuning_params["vertex_tensorboard_region"] = REGION + # model.tuning_params["monitor_goodput"] = False + # model.tuning_params["enable_tensorboard"] = False + # model.tuning_params["use_vertex_tensorboard"] = False + + if USE_GCSFUSE: + model.tuning_params["dataset_path"] = "/tmp/dataset" + # Run workloads in the following slice configurations for num_slices in num_slices_list: wl_config = mxr.WorkloadConfig( @@ -117,14 +185,15 @@ def main() -> int: device_type=cluster_config.device_type, base_output_directory=BASE_OUTPUT_DIRECTORY + f"{infra}_{num_slices}_slice_{DEVICE_TYPE}_{model.model_name}/", - max_restarts=0, + max_restarts=MAX_RESTARTS, libtpu_type=None, libtpu_nightly_version="", base_docker_image=RUNNER if infra == "mcjax" else None, pathways_config=pathways_config if infra == "pathways" else None, xpk_path=XPK_PATH, num_steps=BENCHMARK_STEPS, - priority="low", + priority="high", + use_gcsfuse=USE_GCSFUSE, ) command, name = mxr.generate_xpk_workload_cmd( cluster_config=cluster_config, wl_config=wl_config @@ -140,7 +209,10 @@ def main() -> int: xpk_workload_names, xpk_workload_cmds ): timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - print(f"[{timestamp}] Running workload: {xpk_workload_name} with command: {xpk_workload_cmd}") + print( + f"[{timestamp}] Running workload: {xpk_workload_name} with command:" + f" {xpk_workload_cmd}" + ) return_code = mxr.run_command_with_updates( xpk_workload_cmd, xpk_workload_name )