Skip to content

Commit

Permalink
remove simple .data from torch/nn
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#31482

Test Plan: Imported from OSS

Differential Revision: D19303243

Pulled By: albanD

fbshipit-source-id: 5afdfeb4b8382c09b9ec65acd545148ed76d4285
  • Loading branch information
albanD authored and ttumiel committed Mar 4, 2020
1 parent e904c08 commit 05916d6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
5 changes: 3 additions & 2 deletions torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,9 @@ def convert_sync_batchnorm(cls, module, process_group=None):
module.track_running_stats,
process_group)
if module.affine:
module_output.weight.data = module.weight.data.clone(memory_format=torch.preserve_format).detach()
module_output.bias.data = module.bias.data.clone(memory_format=torch.preserve_format).detach()
with torch.no_grad():
module_output.weight.copy_(module.weight)
module_output.bias.copy_(module.bias)
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def extra_repr(self):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
torch.typename(p), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
Expand Down Expand Up @@ -586,7 +586,7 @@ def extra_repr(self):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
torch.typename(p), size_str, device_str)
child_lines.append(' (' + k + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
12 changes: 5 additions & 7 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.data
destination[prefix + name] = buf if keep_vars else buf.detach()

def state_dict(self, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.
Expand Down Expand Up @@ -745,7 +745,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None}
local_state = {k: v for k, v in local_name_params if v is not None}

for name, param in local_state.items():
key = prefix + name
Expand All @@ -763,11 +763,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
.format(key, input_param.shape, param.shape))
continue

if isinstance(input_param, Parameter):
# backwards compatibility for serialized parameters
input_param = input_param.data
try:
param.copy_(input_param)
with torch.no_grad():
param.copy_(input_param)
except Exception:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
Expand Down

0 comments on commit 05916d6

Please sign in to comment.