Skip to content

Commit

Permalink
Add No Mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Jun 21, 2018
1 parent 13de723 commit c99fca5
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 16 deletions.
57 changes: 42 additions & 15 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,50 @@ void BroadcastOpHandle::RunImpl() {
});
}

this->RunAndRecordEvent([&] {
{
platform::NCCLGroupGuard guard;
for (auto &call : broadcast_calls) {
call();
// FIXME(zcd): a temporary fix for some language model that has sparse
// parameter.
bool use_mutex = true;
if (in_var->IsType<paddle::framework::SelectedRows>()) {
use_mutex = false;
}
if (use_mutex) {
this->RunAndRecordEvent([&] {
{
platform::NCCLGroupGuard guard;
for (auto &call : broadcast_calls) {
call();
}
}
}

if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
});
if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
});
} else {
this->RunAndRecordEventNoMutex([&] {
{
platform::NCCLGroupGuard guard;
for (auto &call : broadcast_calls) {
call();
}
}

if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
});
}

#else
PADDLE_THROW("CUDA is not enabled.");
#endif
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,29 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#endif
}

void OpHandleBase::RunAndRecordEventNoMutex(
const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function<void()> method = callback;

for (auto &p : dev_ctxes_) {
method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.second)
->RecordEventNoMutex(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
};
}
method();
} else {
#endif
callback();
#ifdef PADDLE_WITH_CUDA
}
#endif
}

void OpHandleBase::RunAndRecordEvent(platform::Place p,
const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class OpHandleBase {
protected:
void RunAndRecordEvent(const std::function<void()> &callback);

// FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
void RunAndRecordEventNoMutex(const std::function<void()> &callback);

void RunAndRecordEvent(platform::Place p,
const std::function<void()> &callback);

Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/reduce_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ void ReduceOpHandle::RunImpl() {
}

if (pre_in_var->IsType<framework::SelectedRows>()) {
this->RunAndRecordEvent([&] {
// FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
this->RunAndRecordEventNoMutex([&] {
std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}

// FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
template <typename Callback>
void RecordEventNoMutex(cudaEvent_t ev, Callback callback) {
callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}

private:
CUDAPlace place_;

Expand Down

0 comments on commit c99fca5

Please sign in to comment.