-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpatch_for_mlperf_trining_v1.1_by_samsung.patch
35 lines (32 loc) · 1.36 KB
/
patch_for_mlperf_trining_v1.1_by_samsung.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
From 15632b0762b52aa12e175045096bd7f6ea756e3d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=EB=AA=85=EC=9A=B0=20=EA=B9=80?= <[email protected]>
Date: Fri, 22 Oct 2021 18:16:53 +0900
Subject: [PATCH] Modify detach_() to detach() to enable Pytorch DDP bucket
view
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Signed-off-by: 명우 김 <[email protected]>
---
apex/amp/_process_optimizer.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/apex/amp/_process_optimizer.py b/apex/amp/_process_optimizer.py
index 471289b..b6369d5 100644
--- a/apex/amp/_process_optimizer.py
+++ b/apex/amp/_process_optimizer.py
@@ -370,11 +370,11 @@ def _process_optimizer(optimizer, properties):
# Zero the model grads.
for param in stash.all_fp16_params:
if param.grad is not None:
- param.grad.detach_()
+ param.grad.detach()
param.grad.zero_()
for param in stash.all_fp32_from_fp32_params:
if param.grad is not None:
- param.grad.detach_()
+ param.grad.detach()
param.grad.zero_()
# Clear the master grads that are independent of model grads
for param in self._amp_stash.all_fp32_from_fp16_params:
--
2.17.1