Skip to content

Commit

Permalink
Implement backend-agnostic rpc._wait_all_workers() utility (pytorch#3…
Browse files Browse the repository at this point in the history
…2190)

Summary:
Pull Request resolved: pytorch#32190

We need a backend-agnostic mechanism to do barrier-like operation before locally destroy RRef context and shutdown RPC Agent.

- Sort worker names.
- Elect the first name as the leader in the ordered worker names.
- Followers reports therir intent to synchronize to the leader.
- Leader also reports to itself, when `_wait_all_workers()` called.
- If all workers report their intent to proceed, leader send the command to every one to proceed.
ghstack-source-id: 96693296

Test Plan:
# Unit tests

```
buck test mode/dev-nosan //caffe2/test:rpc_fork

buck-out/gen/caffe2/test/rpc_fork\#binary.par -r test_wait_all_workers
buck-out/gen/caffe2/test/rpc_fork\#binary.par -r test_rref_leak
```

```
buck test mode/dev-nosan //caffe2/test:rpc_spawn

buck-out/gen/caffe2/test/rpc_spawn\#binary.par -r test_wait_all_workers
buck-out/gen/caffe2/test/rpc_spawn\#binary.par -r test_rref_leak
```

```
buck test mode/dev-nosan //caffe2/test:rpc_fork_thrift

buck-out/gen/caffe2/test/rpc_fork\#binary.par -r test_wait_all_workers
buck-out/gen/caffe2/test/rpc_fork_thrift\#binary.par -r test_worker_id
```

# Stress runs
```
buck test mode/dev-nosan //caffe2/test:rpc_fork_thrift -- test_stress_light_rpc --stress-runs 10
```

```
buck test mode/dev-nosan //caffe2/test:rpc_spawn_thrift -- test_stress_light_rpc --stress-runs 10
```

```
buck test mode/dev-nosan //caffe2/test:rpc_fork_thrift -- test_stress_heavy_rpc --stress-runs 10
```

```
buck test mode/dev-nosan //caffe2/test:rpc_spawn_thrift -- test_stress_heavy_rpc --stress-runs 10
```

Differential Revision: D19399908

fbshipit-source-id: 1dee607cd49adafe88534621a1c85e2736e2f595
  • Loading branch information
xush6528 authored and ttumiel committed Mar 4, 2020
1 parent 18f99d2 commit 6cc6c9e
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 78 deletions.
2 changes: 1 addition & 1 deletion test/dist_autograd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,7 @@ def test_clean_context_during_backward(self):
# receive gradients from the node that received an error (and as a
# result it didn't execute the rest of the graph).
dist.barrier()
rpc.shutdown()
rpc.shutdown(graceful=False)
sys.exit(0)

@classmethod
Expand Down
56 changes: 0 additions & 56 deletions test/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import threading
import time
from functools import partial, wraps

Expand All @@ -26,30 +25,6 @@ def __init__(self, *args, **kwargs):
INIT_METHOD_TEMPLATE = "file://{file_name}"


MASTER_RANK = 0
_ALL_NODE_NAMES = set()
_DONE_NODE_NAMES = set()
_TERMINATION_SIGNAL = threading.Event()


def on_master_follower_report_done(worker_name):
assert (
worker_name in _ALL_NODE_NAMES
), "{worker_name} is not expected by master.".format(worker_name=worker_name)
assert (
worker_name not in _DONE_NODE_NAMES
), "{worker_name} report done twice.".format(worker_name=worker_name)
_DONE_NODE_NAMES.add(worker_name)
if _ALL_NODE_NAMES != _DONE_NODE_NAMES:
return
set_termination_signal()


def set_termination_signal():
assert not _TERMINATION_SIGNAL.is_set(), "Termination signal got set twice."
_TERMINATION_SIGNAL.set()


def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
"""
We use this decorator for setting up and tearing down state since
Expand Down Expand Up @@ -97,37 +72,6 @@ def new_test_method(self, *arg, **kwargs):
return_value = old_test_method(self, *arg, **kwargs)

if setup_rpc:
if clean_shutdown:
# Follower reports done.
if self.rank == MASTER_RANK:
on_master_follower_report_done("worker{}".format(MASTER_RANK))
else:
rpc.rpc_async(
"worker{}".format(MASTER_RANK),
on_master_follower_report_done,
args=("worker{}".format(self.rank),),
)

# Master waits for followers to report done.
# Follower waits for master's termination command.
_TERMINATION_SIGNAL.wait()
if self.rank == MASTER_RANK:
# Master sends termination command.
futs = []
for dst_rank in range(self.world_size):
# torch.distributed.rpc module does not support sending to self.
if dst_rank == MASTER_RANK:
continue
dst_name = "worker{}".format(dst_rank)
fut = rpc.rpc_async(dst_name, set_termination_signal, args=())
futs.append(fut)
for fut in futs:
assert fut.wait() is None, "Sending termination signal failed."

# Close RPC. Need to do this even if we don't have a clean shutdown
# since we need to shutdown the RPC agent. If we don't shutdown the
# RPC agent, tests would fail since RPC agent threads, locks and
# condition variables are not properly terminated.
rpc.shutdown(graceful=clean_shutdown)

return return_value
Expand Down
78 changes: 65 additions & 13 deletions test/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ def decorator(old_func):
DONE_FUTURE = concurrent.futures.Future()


class StubRpcAgent:
def __init__(self, world_size):
self.world_size = world_size

def get_worker_infos(self):
return {
rpc.WorkerInfo(
name="worker{}".format(rank),
id=rank,
) for rank in range(self.world_size)
}


def _stub_construct_rpc_backend_options_handler(
**kwargs
):
Expand All @@ -42,7 +55,7 @@ def _stub_construct_rpc_backend_options_handler(
def _stub_start_rpc_backend_handler(
store, name, rank, world_size, rpc_backend_options
):
return mock.Mock() # RpcAgent.
return StubRpcAgent(world_size=world_size)


def set_value(value):
Expand Down Expand Up @@ -361,7 +374,6 @@ def test_duplicate_name(self):
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
rpc.shutdown()

@dist_init(setup_rpc=False)
def test_reinit(self):
Expand Down Expand Up @@ -496,8 +508,51 @@ def test_shutdown(self):
args=(torch.ones(n, n), torch.ones(n, n)),
)

# it's safe to call shutdown() multiple times
rpc.shutdown()
def test_wait_all_workers(self):
rpc.init_rpc(
name="worker%d" % self.rank,
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)

# worker0 drives and waits for worker1 and worker2
# throughout the test.
if self.rank == 0:
self.assertTrue(self.world_size >= 3)

num_repeat = 30

# Phase 1: Only worker1 has workload.
dst = "worker1"
futs = []
for _ in range(num_repeat):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)

for fut in futs:
fut.wait()
self.assertEqual(fut.wait(), 0)

# Phase 2: Only worker2 has workload.
# If join is not correctly implemented,
# worker2 should be closed by now.
dst = "worker2"
futs = []
for _ in range(num_repeat):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)

for fut in futs:
fut.wait()
self.assertEqual(fut.wait(), 0)

# worker0 calls this at the end after waiting for RPC responses.
# worker1/2 calls this immediately and has some works after it.
# worker3 calls this immediately and has no more work.
rpc.api._wait_all_workers()
rpc.shutdown(graceful=False)

@dist_init
def test_expected_src(self):
Expand Down Expand Up @@ -768,10 +823,10 @@ def test_asymmetric_load_with_join(self):
assert self.world_size >= 3

num_repeat = 100
futs = []

# Phase 1: Only worker1 has workload.
dst = "worker1"
futs = []
for _ in range(num_repeat):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)
Expand All @@ -784,6 +839,7 @@ def test_asymmetric_load_with_join(self):
# If join is not correctly implemented,
# worker2 should be closed by now.
dst = "worker2"
futs = []
for _ in range(num_repeat):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)
Expand Down Expand Up @@ -1324,9 +1380,7 @@ def test_local_shutdown(self):
# without sending any messages.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
Expand Down Expand Up @@ -1396,9 +1450,7 @@ def test_local_shutdown_with_rpc(self):
# test that we can start RPC, send RPCs, and then run local shutdown.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
Expand Down Expand Up @@ -1426,15 +1478,15 @@ def test_wait_all_workers_and_shutdown(self):
# multiple times.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name],
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options
)
from torch.distributed.rpc.api import _wait_all_workers
# intentional call to internal _wait_all_workers.
_wait_all_workers()
rpc.shutdown()
rpc.shutdown(graceful=False)

@dist_init(setup_rpc=False)
def test_get_rpc_timeout(self):
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/distributed/rpc/rpc_agent.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ class TORCH_API RpcAgent {

protected:
const WorkerInfo workerInfo_;
const std::string workerName_;
const std::unique_ptr<RequestCallback> cb_;
std::atomic<std::chrono::milliseconds> rpcTimeout_;

Expand Down
Loading

0 comments on commit 6cc6c9e

Please sign in to comment.