Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update logic for checking TPUs availability #6767

Merged
merged 3 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io
import os
import re
import time
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
Expand All @@ -23,11 +24,11 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.apply_func import apply_to_collection

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand All @@ -39,8 +40,7 @@
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
from omegaconf import DictConfig, ListConfig
from omegaconf import DictConfig, ListConfig, OmegaConf


class TPUSpawnPlugin(DDPSpawnPlugin):
Expand Down Expand Up @@ -118,6 +118,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)

if self.global_rank == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great !

time.sleep(2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naa.


self.barrier("end-process")

def __save_end_of_training_weights(self, model: LightningModule) -> None:
Expand Down
19 changes: 3 additions & 16 deletions pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@
import traceback
from multiprocessing import Process, Queue

import torch.multiprocessing as mp

from pytorch_lightning.utilities.imports import _XLA_AVAILABLE

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 25
Expand Down Expand Up @@ -64,23 +61,13 @@ class XLADeviceUtils:
@pl_multi_process
def _is_device_tpu() -> bool:
"""
Check if device is TPU
Check if TPU devices are available

Return:
A boolean value indicating if the xla device is a TPU device or not
A boolean value indicating if TPU devices are available
"""

def _fn(_: int, mp_queue):
try:
device = xm.xla_device()
mp_queue.put(device.type == 'xla')
except Exception:
mp_queue.put(False)

smp = mp.get_context("spawn")
queue = smp.SimpleQueue()
xmp.spawn(_fn, args=(queue, ), nprocs=1)
return queue.get()
return len(xm.get_xla_supported_devices("TPU")) > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice !


@staticmethod
def xla_available() -> bool:
Expand Down