diff --git a/test/dist_autograd_test.py b/test/dist_autograd_test.py index 0b0c1c4f2124b3..3cce604127586c 100644 --- a/test/dist_autograd_test.py +++ b/test/dist_autograd_test.py @@ -1372,7 +1372,7 @@ def test_clean_context_during_backward(self): # receive gradients from the node that received an error (and as a # result it didn't execute the rest of the graph). dist.barrier() - rpc.shutdown() + rpc.shutdown(graceful=False) sys.exit(0) @classmethod diff --git a/test/dist_utils.py b/test/dist_utils.py index 243a6ba1f722f5..a7efe4395be632 100644 --- a/test/dist_utils.py +++ b/test/dist_utils.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import threading import time from functools import partial, wraps @@ -26,30 +25,6 @@ def __init__(self, *args, **kwargs): INIT_METHOD_TEMPLATE = "file://{file_name}" -MASTER_RANK = 0 -_ALL_NODE_NAMES = set() -_DONE_NODE_NAMES = set() -_TERMINATION_SIGNAL = threading.Event() - - -def on_master_follower_report_done(worker_name): - assert ( - worker_name in _ALL_NODE_NAMES - ), "{worker_name} is not expected by master.".format(worker_name=worker_name) - assert ( - worker_name not in _DONE_NODE_NAMES - ), "{worker_name} report done twice.".format(worker_name=worker_name) - _DONE_NODE_NAMES.add(worker_name) - if _ALL_NODE_NAMES != _DONE_NODE_NAMES: - return - set_termination_signal() - - -def set_termination_signal(): - assert not _TERMINATION_SIGNAL.is_set(), "Termination signal got set twice." - _TERMINATION_SIGNAL.set() - - def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True): """ We use this decorator for setting up and tearing down state since @@ -97,37 +72,6 @@ def new_test_method(self, *arg, **kwargs): return_value = old_test_method(self, *arg, **kwargs) if setup_rpc: - if clean_shutdown: - # Follower reports done. - if self.rank == MASTER_RANK: - on_master_follower_report_done("worker{}".format(MASTER_RANK)) - else: - rpc.rpc_async( - "worker{}".format(MASTER_RANK), - on_master_follower_report_done, - args=("worker{}".format(self.rank),), - ) - - # Master waits for followers to report done. - # Follower waits for master's termination command. - _TERMINATION_SIGNAL.wait() - if self.rank == MASTER_RANK: - # Master sends termination command. - futs = [] - for dst_rank in range(self.world_size): - # torch.distributed.rpc module does not support sending to self. - if dst_rank == MASTER_RANK: - continue - dst_name = "worker{}".format(dst_rank) - fut = rpc.rpc_async(dst_name, set_termination_signal, args=()) - futs.append(fut) - for fut in futs: - assert fut.wait() is None, "Sending termination signal failed." - - # Close RPC. Need to do this even if we don't have a clean shutdown - # since we need to shutdown the RPC agent. If we don't shutdown the - # RPC agent, tests would fail since RPC agent threads, locks and - # condition variables are not properly terminated. rpc.shutdown(graceful=clean_shutdown) return return_value diff --git a/test/rpc_test.py b/test/rpc_test.py index 906d3d36b2d5f7..8b1a25339b27cb 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -33,6 +33,19 @@ def decorator(old_func): DONE_FUTURE = concurrent.futures.Future() +class StubRpcAgent: + def __init__(self, world_size): + self.world_size = world_size + + def get_worker_infos(self): + return { + rpc.WorkerInfo( + name="worker{}".format(rank), + id=rank, + ) for rank in range(self.world_size) + } + + def _stub_construct_rpc_backend_options_handler( **kwargs ): @@ -42,7 +55,7 @@ def _stub_construct_rpc_backend_options_handler( def _stub_start_rpc_backend_handler( store, name, rank, world_size, rpc_backend_options ): - return mock.Mock() # RpcAgent. + return StubRpcAgent(world_size=world_size) def set_value(value): @@ -361,7 +374,6 @@ def test_duplicate_name(self): world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) - rpc.shutdown() @dist_init(setup_rpc=False) def test_reinit(self): @@ -496,8 +508,51 @@ def test_shutdown(self): args=(torch.ones(n, n), torch.ones(n, n)), ) - # it's safe to call shutdown() multiple times - rpc.shutdown() + def test_wait_all_workers(self): + rpc.init_rpc( + name="worker%d" % self.rank, + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + # worker0 drives and waits for worker1 and worker2 + # throughout the test. + if self.rank == 0: + self.assertTrue(self.world_size >= 3) + + num_repeat = 30 + + # Phase 1: Only worker1 has workload. + dst = "worker1" + futs = [] + for _ in range(num_repeat): + fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + futs.append(fut) + + for fut in futs: + fut.wait() + self.assertEqual(fut.wait(), 0) + + # Phase 2: Only worker2 has workload. + # If join is not correctly implemented, + # worker2 should be closed by now. + dst = "worker2" + futs = [] + for _ in range(num_repeat): + fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + futs.append(fut) + + for fut in futs: + fut.wait() + self.assertEqual(fut.wait(), 0) + + # worker0 calls this at the end after waiting for RPC responses. + # worker1/2 calls this immediately and has some works after it. + # worker3 calls this immediately and has no more work. + rpc.api._wait_all_workers() + rpc.shutdown(graceful=False) @dist_init def test_expected_src(self): @@ -768,10 +823,10 @@ def test_asymmetric_load_with_join(self): assert self.world_size >= 3 num_repeat = 100 - futs = [] # Phase 1: Only worker1 has workload. dst = "worker1" + futs = [] for _ in range(num_repeat): fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) futs.append(fut) @@ -784,6 +839,7 @@ def test_asymmetric_load_with_join(self): # If join is not correctly implemented, # worker2 should be closed by now. dst = "worker2" + futs = [] for _ in range(num_repeat): fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) futs.append(fut) @@ -1324,9 +1380,7 @@ def test_local_shutdown(self): # without sending any messages. rpc.init_rpc( name="worker%d" % self.rank, - backend=rpc.backend_registry.BackendType[ - dist_utils.TEST_CONFIG.rpc_backend_name - ], + backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, @@ -1396,9 +1450,7 @@ def test_local_shutdown_with_rpc(self): # test that we can start RPC, send RPCs, and then run local shutdown. rpc.init_rpc( name="worker%d" % self.rank, - backend=rpc.backend_registry.BackendType[ - dist_utils.TEST_CONFIG.rpc_backend_name - ], + backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, @@ -1426,7 +1478,7 @@ def test_wait_all_workers_and_shutdown(self): # multiple times. rpc.init_rpc( name="worker%d" % self.rank, - backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name], + backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options @@ -1434,7 +1486,7 @@ def test_wait_all_workers_and_shutdown(self): from torch.distributed.rpc.api import _wait_all_workers # intentional call to internal _wait_all_workers. _wait_all_workers() - rpc.shutdown() + rpc.shutdown(graceful=False) @dist_init(setup_rpc=False) def test_get_rpc_timeout(self): diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index e276783607c4d8..a6b87748b26d10 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -146,7 +146,6 @@ class TORCH_API RpcAgent { protected: const WorkerInfo workerInfo_; - const std::string workerName_; const std::unique_ptr cb_; std::atomic rpcTimeout_; diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 7d923c7ee71b39..382180b4f7af0a 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -7,6 +7,7 @@ _invoke_remote_python_udf, _invoke_rpc_builtin, _invoke_rpc_python_udf, + _set_rpc_timeout, _start_rpc_agent, backend_registry, ) @@ -18,12 +19,17 @@ ) import contextlib +from datetime import timedelta import functools import numbers import sys +import logging +import threading import torch import torch.distributed as dist +logging.basicConfig() +logger = logging.getLogger(__name__) _agent = None # NB: Ignoring RRef leaks during shutdown. Without this, applications have to @@ -63,6 +69,42 @@ def wrapper(*args, **kwargs): return wrapper +# States used by `def _wait_all_workers()`. +# `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer. +_ALL_WORKER_NAMES = None +# `_SHUTDOWN_INTENT_WORKER_NAMES` is an empty set at beginning. +# It's only used by leader worker. Leader worker is elected as the first +# worker in a sorted worker name list. +# Whenever there is a worker showing shutdown intention to the leader, by +# calling _wait_all_workers()`, the leader adds this worker's name to the set. +# The leader also adds itself's name to the set on calling +# `_wait_all_workers()`. We need this because, we confine `_wait_all_workers()` +# to be called only once, by examing if leader's name has been added to the set. +_SHUTDOWN_INTENT_WORKER_NAMES = set() +# Once `_SHUTDOWN_INTENT_WORKER_NAMES == _ALL_WORKER_NAMES`, +# we flip `_SHUTDOWN_PROCEED_SIGNAL` on the leader, and leader will send RPCs +# to follower workers to flip their `_SHUTDOWN_PROCEED_SIGNAL`s. +_SHUTDOWN_PROCEED_SIGNAL = threading.Event() + + +def _on_leader_follower_report_shutdown_intent(worker_name): + assert ( + worker_name in _ALL_WORKER_NAMES + ), "{worker_name} is not expected by leader.".format(worker_name=worker_name) + assert ( + worker_name not in _SHUTDOWN_INTENT_WORKER_NAMES + ), "{worker_name} reported intent twice. ".format(worker_name=worker_name) + _SHUTDOWN_INTENT_WORKER_NAMES.add(worker_name) + if _ALL_WORKER_NAMES == _SHUTDOWN_INTENT_WORKER_NAMES: + _set_proceed_shutdown_signal() + + +def _set_proceed_shutdown_signal(): + assert not _SHUTDOWN_PROCEED_SIGNAL.is_set(), "Termination signal got set twice." + _SHUTDOWN_PROCEED_SIGNAL.set() + + +@_require_initialized def _wait_all_workers(): r""" Block until all local and remote RPC processes reach this method and wait @@ -71,11 +113,55 @@ def _wait_all_workers(): terminate the RPC framework, and there is no guarantee that the RPC framework will work after this method returns. """ - global _agent + assert ( + _ALL_WORKER_NAMES is not None + ), "`_ALL_WORKER_NAMES` is not initialized for `def _wait_all_workers`." + leader_worker_name = sorted(_ALL_WORKER_NAMES)[0] + + self_worker_name = _agent.get_worker_info().name + assert ( + self_worker_name not in _SHUTDOWN_INTENT_WORKER_NAMES + ), "Can not call `_wait_all_workers()` twice." + + is_leader_worker = leader_worker_name == self_worker_name + + # Phase 1: Followers send intents. + # All followers report intents to the leader. + if is_leader_worker: + _on_leader_follower_report_shutdown_intent(self_worker_name) + else: + rpc_sync( + leader_worker_name, + _on_leader_follower_report_shutdown_intent, + args=(self_worker_name,), + ) + + _SHUTDOWN_PROCEED_SIGNAL.wait() + + # Phase 2: Leader asks followers to proceed. + # Leader's signal is the first to be unblocked, + # after receiving all followers' intents. + if is_leader_worker: + # The leader sends out proceeed signals to all followers. + timeout = timedelta(seconds=5) + _set_rpc_timeout(timeout) + worker_name_to_response_future_dict = dict() + for follower_worker_name in _ALL_WORKER_NAMES - {leader_worker_name}: + fut = rpc_async(follower_worker_name, _set_proceed_shutdown_signal, args=()) + worker_name_to_response_future_dict[follower_worker_name] = fut + for follower_worker_name, fut in worker_name_to_response_future_dict.items(): + try: + fut.wait() + except RuntimeError as ex: + logger.error( + "{worker_name} failed to respond to 'Shutdown Proceed.' request in {timeout}".format( + worker_name=follower_worker_name, + timeout=timeout, + ) + ) - if _agent: - _agent.join() +@_require_initialized def shutdown(graceful=True): r""" Perform a shutdown of the RPC agent, and then destroy the RPC agent. This @@ -118,18 +204,25 @@ def shutdown(graceful=True): >>> rpc.shutdown() """ global _agent - if _agent: - if graceful: - _wait_all_workers() + + if graceful: + _wait_all_workers() + _agent.join() + try: + # This raises a `TORCH_CHECK()` exception on RRef leak detected. _destroy_rref_context(_ignore_rref_leak) + finally: _agent.shutdown() # clean up python rpc handler in shutdown(), see comments in # PythonRpcHandler::cleanup(), call it in python API because the # cleanup() function has python dependency, it assumes python - # interpreter exists + # interpreter exists. + # No matter if RRef leak exception is raised, this clean-up code + # must run to avoid destruction segfault in Python 3.5. _cleanup_python_rpc_handler() _agent = None + # TODO: add a context manager to wrap _init_rpc_backend and shutdown def _init_rpc_backend( backend=backend_registry.BackendType.PROCESS_GROUP, @@ -159,6 +252,11 @@ def _init_rpc_backend( world_size=world_size, rpc_backend_options=rpc_backend_options, ) + + worker_infos = _agent.get_worker_infos() + global _ALL_WORKER_NAMES + _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} + _start_rpc_agent(_agent)