diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 14f9f109a2134..a746eb673cbf0 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -506,8 +506,9 @@ def convert_sync_batchnorm(cls, module, process_group=None): module.track_running_stats, process_group) if module.affine: - module_output.weight.data = module.weight.data.clone(memory_format=torch.preserve_format).detach() - module_output.bias.data = module.bias.data.clone(memory_format=torch.preserve_format).detach() + with torch.no_grad(): + module_output.weight.copy_(module.weight) + module_output.bias.copy_(module.bias) # keep requires_grad unchanged module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index d834c4372c1d7..b9007237463d2 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -446,7 +446,7 @@ def extra_repr(self): size_str = 'x'.join(str(size) for size in p.size()) device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) parastr = 'Parameter containing: [{} of size {}{}]'.format( - torch.typename(p.data), size_str, device_str) + torch.typename(p), size_str, device_str) child_lines.append(' (' + str(k) + '): ' + parastr) tmpstr = '\n'.join(child_lines) return tmpstr @@ -586,7 +586,7 @@ def extra_repr(self): size_str = 'x'.join(str(size) for size in p.size()) device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) parastr = 'Parameter containing: [{} of size {}{}]'.format( - torch.typename(p.data), size_str, device_str) + torch.typename(p), size_str, device_str) child_lines.append(' (' + k + '): ' + parastr) tmpstr = '\n'.join(child_lines) return tmpstr diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index caf1d9f7a69b0..57d6c701c75e0 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -663,10 +663,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - destination[prefix + name] = param if keep_vars else param.data + destination[prefix + name] = param if keep_vars else param.detach() for name, buf in self._buffers.items(): if buf is not None: - destination[prefix + name] = buf if keep_vars else buf.data + destination[prefix + name] = buf if keep_vars else buf.detach() def state_dict(self, destination=None, prefix='', keep_vars=False): r"""Returns a dictionary containing a whole state of the module. @@ -745,7 +745,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) local_name_params = itertools.chain(self._parameters.items(), self._buffers.items()) - local_state = {k: v.data for k, v in local_name_params if v is not None} + local_state = {k: v for k, v in local_name_params if v is not None} for name, param in local_state.items(): key = prefix + name @@ -763,11 +763,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, .format(key, input_param.shape, param.shape)) continue - if isinstance(input_param, Parameter): - # backwards compatibility for serialized parameters - input_param = input_param.data try: - param.copy_(input_param) + with torch.no_grad(): + param.copy_(input_param) except Exception: error_msgs.append('While copying the parameter named "{}", ' 'whose dimensions in the model are {} and '