Skip to content

Commit

Permalink
remove simple .data from torch/nn
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#31482

Test Plan: Imported from OSS

Differential Revision: D19303185

Pulled By: albanD

fbshipit-source-id: 610eae096bab24a7b9f651b9af2e3ecd19df55b0
  • Loading branch information
albanD authored and facebook-github-bot committed Jan 14, 2020
1 parent 2e77629 commit 2e38958
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
14 changes: 8 additions & 6 deletions torch/autograd/gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from itertools import product
import warnings


def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.data.zero_()
x.grad.zero_()
elif isinstance(x, container_abcs.Iterable):
for elem in x:
zero_gradients(elem)
Expand Down Expand Up @@ -63,8 +62,6 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):

# TODO: compare structure
for x_tensor, d_tensor in zip(x_tensors, j_tensors):
# need data here to get around the version check because without .data,
# the following code updates version but doesn't change content
if x_tensor.is_sparse:
def get_stride(size):
dim = len(size)
Expand All @@ -78,9 +75,12 @@ def get_stride(size):
x_nnz = x_tensor._nnz()
x_size = list(x_tensor.size())
x_indices = x_tensor._indices().t()
x_values = x_tensor._values().data
x_values = x_tensor._values()
x_stride = get_stride(x_size)

# Use .data here to get around the version check
x_values = x_values.data

for i in range(x_nnz):
x_value = x_values[i]
for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
Expand All @@ -95,10 +95,11 @@ def get_stride(size):
r = (outb - outa) / (2 * eps)
d_tensor[d_idx] = r.detach().reshape(-1)
elif x_tensor.layout == torch._mkldnn:
# Use .data here to get around the version check
x_tensor = x_tensor.data
if len(input) != 1:
raise ValueError('gradcheck currently only supports functions with 1 input, but got: ',
len(input))
x_tensor = x_tensor.data
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
# this is really inefficient, but without indexing implemented, there's
# not really a better way than converting back and forth
Expand All @@ -116,6 +117,7 @@ def get_stride(size):
r = (outb - outa) / (2 * eps)
d_tensor[d_idx] = r.detach().reshape(-1)
else:
# Use .data here to get around the version check
x_tensor = x_tensor.data
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
orig = x_tensor[x_idx].item()
Expand Down
5 changes: 3 additions & 2 deletions torch/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def _unravel_index(index, shape):
rtol, atol, list(index), actual[index].item(), expected[index].item(),
count - 1, 100 * count / actual.numel()))


def make_non_contiguous(tensor):
if tensor.numel() <= 1: # can't make non-contiguous
return tensor.clone()
Expand All @@ -83,7 +82,9 @@ def make_non_contiguous(tensor):
input = input.narrow(i, bounds, tensor.size(i))

input.copy_(tensor)
return input

# Use .data here to hide the view relation between input and other temporary Tensors
return input.data


def get_all_dtypes():
Expand Down

0 comments on commit 2e38958

Please sign in to comment.