-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update logic for checking TPUs availability #6767
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import io | ||
import os | ||
import re | ||
import time | ||
from typing import Any, Dict, Iterable, List, Optional, Union | ||
|
||
import torch | ||
|
@@ -23,11 +24,11 @@ | |
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin | ||
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle | ||
from pytorch_lightning.trainer.states import TrainerState | ||
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE | ||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn | ||
from pytorch_lightning.utilities.apply_func import apply_to_collection | ||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.seed import seed_everything | ||
from pytorch_lightning.utilities.apply_func import apply_to_collection | ||
|
||
if _TPU_AVAILABLE: | ||
import torch_xla.core.xla_model as xm | ||
|
@@ -39,8 +40,7 @@ | |
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 | ||
|
||
if _OMEGACONF_AVAILABLE: | ||
from omegaconf import OmegaConf | ||
from omegaconf import DictConfig, ListConfig | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
|
||
|
||
class TPUSpawnPlugin(DDPSpawnPlugin): | ||
|
@@ -118,6 +118,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: | |
self.__save_end_of_training_weights(self.lightning_module) | ||
self.transfer_distrib_spawn_state_on_fit_end(results) | ||
|
||
if self.global_rank == 0: | ||
time.sleep(2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason why 2? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naa. |
||
|
||
self.barrier("end-process") | ||
|
||
def __save_end_of_training_weights(self, model: LightningModule) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,13 +17,10 @@ | |
import traceback | ||
from multiprocessing import Process, Queue | ||
|
||
import torch.multiprocessing as mp | ||
|
||
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE | ||
|
||
if _XLA_AVAILABLE: | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
|
||
#: define waiting time got checking TPU available in sec | ||
TPU_CHECK_TIMEOUT = 25 | ||
|
@@ -64,23 +61,13 @@ class XLADeviceUtils: | |
@pl_multi_process | ||
def _is_device_tpu() -> bool: | ||
""" | ||
Check if device is TPU | ||
Check if TPU devices are available | ||
|
||
Return: | ||
A boolean value indicating if the xla device is a TPU device or not | ||
A boolean value indicating if TPU devices are available | ||
""" | ||
|
||
def _fn(_: int, mp_queue): | ||
try: | ||
device = xm.xla_device() | ||
mp_queue.put(device.type == 'xla') | ||
except Exception: | ||
mp_queue.put(False) | ||
|
||
smp = mp.get_context("spawn") | ||
queue = smp.SimpleQueue() | ||
xmp.spawn(_fn, args=(queue, ), nprocs=1) | ||
return queue.get() | ||
return len(xm.get_xla_supported_devices("TPU")) > 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really nice ! |
||
|
||
@staticmethod | ||
def xla_available() -> bool: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great !