Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove simple .data from torch/nn #31482

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,9 @@ def convert_sync_batchnorm(cls, module, process_group=None):
module.track_running_stats,
process_group)
if module.affine:
module_output.weight.data = module.weight.data.clone(memory_format=torch.preserve_format).detach()
module_output.bias.data = module.bias.data.clone(memory_format=torch.preserve_format).detach()
with torch.no_grad():
module_output.weight.copy_(module.weight)
module_output.bias.copy_(module.bias)
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def extra_repr(self):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
torch.typename(p), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
Expand Down Expand Up @@ -586,7 +586,7 @@ def extra_repr(self):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
parastr = 'Parameter containing: [{} of size {}{}]'.format(
torch.typename(p.data), size_str, device_str)
torch.typename(p), size_str, device_str)
child_lines.append(' (' + k + '): ' + parastr)
tmpstr = '\n'.join(child_lines)
return tmpstr
12 changes: 5 additions & 7 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.data
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.data
destination[prefix + name] = buf if keep_vars else buf.detach()

def state_dict(self, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.
Expand Down Expand Up @@ -745,7 +745,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v.data for k, v in local_name_params if v is not None}
local_state = {k: v for k, v in local_name_params if v is not None}

for name, param in local_state.items():
key = prefix + name
Expand All @@ -763,11 +763,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
.format(key, input_param.shape, param.shape))
continue

if isinstance(input_param, Parameter):
# backwards compatibility for serialized parameters
input_param = input_param.data
try:
param.copy_(input_param)
with torch.no_grad():
param.copy_(input_param)
except Exception:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
Expand Down