Skip to content

Commit

Permalink
fix bug for eager mode distributed training (#41841)
Browse files Browse the repository at this point in the history
  • Loading branch information
lilong12 authored Apr 18, 2022
1 parent f3531c7 commit 34f30f7
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 56 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {

void ProcessGroup::Task::Synchronize() {}

ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size), gid_(gid) {
ProcessGroup::ProcessGroup(int rank, int size, const platform::Place& place,
int gid)
: rank_(rank), size_(size), place_(place), gid_(gid) {
if (gid != IGNORE_ID) {
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class ProcessGroup {
bool is_completed_ = false;
};

explicit ProcessGroup(int rank, int size, int gid);
explicit ProcessGroup(int rank, int size, const platform::Place& place,
int gid);
virtual ~ProcessGroup() {}

int GetRank() const { return rank_; }
Expand Down Expand Up @@ -145,6 +146,7 @@ class ProcessGroup {
protected:
const int rank_;
const int size_;
const platform::Place place_;
const int gid_;
};

Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupGloo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ ProcessGroupGloo::GlooTask::GlooTask(

ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<distributed::Store>& store, int rank, int world_size,
int gid, const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size, gid),
const platform::Place& place, int gid,
const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size, place, gid),
_tag(0),
_store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/collective/ProcessGroupGloo.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class ProcessGroupGloo : public ProcessGroup {

explicit ProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, int rank,
int world_size, int gid, std::shared_ptr<GlooOptions> options);
int world_size, const platform::Place& place, int gid,
std::shared_ptr<GlooOptions> options);

~ProcessGroupGloo() = default;

Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupHCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/distributed/collective/HCCLTools.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
Expand Down Expand Up @@ -97,8 +98,11 @@ bool ProcessGroupHCCL::HCCLTask::Wait(std::chrono::milliseconds timeout) {
void ProcessGroupHCCL::HCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupHCCL::ProcessGroupHCCL(const std::shared_ptr<Store>& store,
int rank, int size, int gid)
: ProcessGroup(rank, size, gid), store_(store) {}
int rank, int size,
const platform::Place& place, int gid)
: ProcessGroup(rank, size, place, gid), store_(store) {
platform::SetNPUDeviceId(place_.device);
}

void ProcessGroupHCCL::BroadcastUniqueHCCLID(
std::vector<HcclRootInfo>& hccl_ids) { // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/ProcessGroupHCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ProcessGroupHCCL : public ProcessGroup {
};

ProcessGroupHCCL(const std::shared_ptr<Store>& store, int rank, int size,
int gid);
const platform::Place& place, int gid);

const std::string GetBackendName() const override {
return std::string(HCCL_BACKEND_NAME);
Expand Down
20 changes: 9 additions & 11 deletions paddle/fluid/distributed/collective/ProcessGroupHeter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,11 @@ bool ProcessGroupHeter::HeterTask::Wait(std::chrono::milliseconds timeout) {
return true;
}

ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store,
int rank, int size, int gid,
int local_rank, int local_size,
int gloo_rank, int gloo_size,
bool with_switch,
std::string switch_endpoint)
: ProcessGroup(rank, size, gid),
ProcessGroupHeter::ProcessGroupHeter(
const std::shared_ptr<Store>& store, int rank, int size,
const platform::Place& place, int gid, int local_rank, int local_size,
int gloo_rank, int gloo_size, bool with_switch, std::string switch_endpoint)
: ProcessGroup(rank, size, place, gid),
store_(store),
local_rank_(local_rank),
local_size_(local_size),
Expand All @@ -60,19 +58,19 @@ ProcessGroupHeter::ProcessGroupHeter(const std::shared_ptr<Store>& store,
switch_endpoint_(switch_endpoint) {
#if defined(PADDLE_WITH_NCCL)
inner_pg_ = std::make_shared<ProcessGroupNCCL>(store, local_rank, local_size,
IGNORE_ID);
place_, IGNORE_ID);
#elif defined(PADDLE_WITH_ASCEND_CL)
inner_pg_ = std::make_shared<ProcessGroupHCCL>(store, local_rank, local_size,
IGNORE_ID);
place_, IGNORE_ID);
#else
PADDLE_THROW(platform::errors::Fatal(
"ProcessGroupHeter only supports NCCL and HCCL now.");
#endif
if (local_rank_ == 0 && !with_switch_) {
auto opts = ProcessGroupGloo::GlooOptions::create();
opts->device = ProcessGroupGloo::createDefaultDevice();
inter_pg_ = std::make_shared<ProcessGroupGloo>(store, gloo_rank_,
gloo_size_, IGNORE_ID, opts);
inter_pg_ = std::make_shared<ProcessGroupGloo>(
store, gloo_rank_, gloo_size_, place_, IGNORE_ID, opts);
}
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/collective/ProcessGroupHeter.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ class ProcessGroupHeter : public ProcessGroup {
};

ProcessGroupHeter(const std::shared_ptr<Store>& store, int rank, int size,
int gid, int local_rank, int local_size, int gloo_rank,
int gloo_size, bool with_switch,
std::string switch_endpoints);
const platform::Place& place, int gid, int local_rank,
int local_size, int gloo_rank, int gloo_size,
bool with_switch, std::string switch_endpoints);

const std::string GetBackendName() const override {
return std::string(HETER_BACKEND_NAME);
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
Expand Down Expand Up @@ -103,8 +104,11 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int size, int gid)
: ProcessGroup(rank, size, gid), store_(store) {}
int rank, int size,
const platform::Place& place, int gid)
: ProcessGroup(rank, size, place, gid), store_(store) {
platform::SetDeviceId(place_.device);
}

void ProcessGroupNCCL::BroadcastUniqueNCCLID(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class ProcessGroupNCCL : public ProcessGroup {
};

ProcessGroupNCCL(const std::shared_ptr<Store>& store, int rank, int size,
int gid);
const platform::Place& place, int gid);

const std::string GetBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
Expand Down
48 changes: 22 additions & 26 deletions paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,49 +241,42 @@ void BindDistributed(py::module *m) {
std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int,
int>(),
const platform::CUDAPlace &, int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("group_id") = 0, py::call_guard<py::gil_scoped_release>());
py::arg("place"), py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif

#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
py::class_<distributed::ProcessGroupHeter,
std::shared_ptr<distributed::ProcessGroupHeter>>(
*m, "ProcessGroupHeter", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int, int,
int, int, int, int, bool, std::string>(),
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int,
#if defined(PADDLE_WITH_ASCEND_CL)
const platform::NPUPlace &,
#else
const platform::CUDAPlace &,
#endif
int, int, int, int, int, bool, std::string>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("gid") = 0, py::arg("local_rank") = 0,
py::arg("place"), py::arg("gid") = 0, py::arg("local_rank") = 0,
py::arg("local_size") = 1, py::arg("gloo_rank") = 0,
py::arg("gloo_size") = 1, py::arg("with_switch") = false,
py::arg("switch_endpoint") = "",
py::call_guard<py::gil_scoped_release>());
#endif
#endif

#if defined(PADDLE_WITH_ASCEND_CL)
py::class_<distributed::ProcessGroupHCCL,
std::shared_ptr<distributed::ProcessGroupHCCL>>(
*m, "ProcessGroupHCCL", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int,
int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("group_id") = 0, py::call_guard<py::gil_scoped_release>());

#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_ASCEND_CL))
py::class_<distributed::ProcessGroupHeter,
std::shared_ptr<distributed::ProcessGroupHeter>>(
*m, "ProcessGroupHeter", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int, int,
int, int, int, int, bool, std::string>(),
const platform::NPUPlace &, int>(),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("gid") = 0, py::arg("local_rank") = 0,
py::arg("local_size") = 1, py::arg("gloo_rank") = 0,
py::arg("gloo_rank") = 1, py::arg("with_switch") = false,
py::arg("switch_endpoint") = "",
py::arg("place"), py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif

#endif

py::class_<distributed::ProcessGroup::Task,
Expand All @@ -299,10 +292,12 @@ void BindDistributed(py::module *m) {
py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
*m, "ProcessGroupGloo", ProcessGroup)
.def(py::init<const std::shared_ptr<paddle::distributed::Store> &, int,
int, int, std::shared_ptr<GlooOptions> &>(),
int, const platform::CPUPlace &, int,
std::shared_ptr<GlooOptions> &>(),
py::call_guard<py::gil_scoped_release>())
.def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store,
int rank, int world_size, int gid) {
int rank, int world_size,
const platform::CPUPlace &place, int gid) {
auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
Expand All @@ -312,10 +307,11 @@ void BindDistributed(py::module *m) {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
return std::make_shared<ProcessGroupGloo>(store, rank, world_size,
gid, opts);
place, gid, opts);
}),
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::arg("group_id") = 0, py::call_guard<py::gil_scoped_release>())
py::arg("place"), py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device",
&ProcessGroupGloo::createDefaultDevice);
#endif
Expand Down
16 changes: 13 additions & 3 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,23 @@ def _new_process_group_impl(backend,
pg_options,
group_id=0):
pg = None
genv = _get_global_env()
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo":
pg = core.ProcessGroupGloo(store, rank, world_size, group_id)
place = core.CPUPlace()
pg = core.ProcessGroupGloo(store, rank, world_size, place, group_id)
elif backend == "nccl":
pg = core.ProcessGroupNCCL(store, rank, world_size, group_id)
place = core.CUDAPlace(genv.device_id)
pg = core.ProcessGroupNCCL(store, rank, world_size, place, group_id)
elif backend == "hccl":
pg = core.ProcessGroupHCCL(store, rank, world_size, group_id)
place = core.NPUPlace(genv.device_id)
pg = core.ProcessGroupHCCL(store, rank, world_size, place, group_id)
elif backend == "heter":
place = None
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
elif core.is_compiled_with_npu():
place = core.NPUPlace(genv.device_id)
cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
assert cluster_id >= 0, "please set the CLUSTER_ID variable."
cluster_size = os.getenv("CLUSTER_SIZE", None)
Expand All @@ -253,6 +262,7 @@ def _new_process_group_impl(backend,
store,
rank=global_rank,
world_size=global_world_size,
place=place,
gid=0,
local_rank=rank,
local_size=world_size,
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/process_group_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_create_process_group_gloo(self):
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6272, is_master,
nranks, datetime.timedelta(0))
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks)
place = paddle.fluid.core.CPUPlace()
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks, place)

# test allreduce sum
# rank 0
Expand Down

0 comments on commit 34f30f7

Please sign in to comment.