-
Notifications
You must be signed in to change notification settings - Fork 23.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ReduceLROnPlateau with a naive Backtracking #2478
Comments
@soumith: I created a subclass to do this as follows. It works as I described above.
|
Hi, I am the puller of #1370
@soumith any idea? Edit: I believe that optim.state_dict does not contain all parameter of a model (like BN's running average mean&std). Therefore we still need to have access to nn.Module even without this issue. |
@Taha-Bahadori Would you plz make it a PR? You would see what happens by running this snippet. import torch
m = torch.nn.Linear(1,2)
optim = torch.optim.Adam(m.parameters())
state_dict = m.state_dict()
print(state_dict)
m.state_dict()['weight'][0]=1000
print(state_dict) |
@Jiaming-Liu I think there should be a mistake in your code snippet. Here is an example that shows the above saving and loading import torch
import torch.nn as nn
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.mem = nn.Parameter(torch.zeros(1))
m = M()
print m
print m.state_dict()
# Saving the current state_dict
sd = m.state_dict()
# Changing the value of the parameter
m.mem.data = torch.ones(1)
print m.state_dict()
# Now loading the original zero parameter back
m.load_state_dict(sd)
print m.state_dict() |
@Taha-Bahadori I think this snippet would be closer to the real use-case. Note that in-place tensor operation is used in As well, your import torch
net = torch.nn.Linear(1,2)
optim = torch.optim.Adam(net.parameters())
state_dict = net.state_dict()
print(state_dict)
x = torch.FloatTensor([[1],[2]])
x = torch.autograd.Variable(x)
y = torch.FloatTensor([[0,1],[2,3]])
y = torch.autograd.Variable(y)
loss = torch.nn.functional.mse_loss(net(x),y)
loss.backward()
optim.step() # Changing the value of the parameter
net.load_state_dict(state_dict)
print(net.state_dict()) |
…740f8f (pytorch#32125) Summary: Pull Request resolved: pytorch#32125 Previous import was 57ebc587fcf3913b4be93653b0dd58c686447298 Included changes: - **[65020daa](onnx/onnx@65020daa)**: better error message for undefined inputs (pytorch#2540) <Yuxin Wu> - **[8afff0e9](onnx/onnx@8afff0e9)**: bump ORT version (pytorch#2538) <Lu Fang> - **[3d9ca57e](onnx/onnx@3d9ca57e)**: fix name of directory (pytorch#2537) <Prasanth Pulavarthi> - **[df8fa2c9](onnx/onnx@df8fa2c9)**: Repository guidelines (pytorch#2539) <Prasanth Pulavarthi> - **[49cc2f02](onnx/onnx@49cc2f02)**: Update CircleCI job to use Python3.6 (pytorch#2527) <bddppq> - **[25ff79a4](onnx/onnx@25ff79a4)**: Fix wrong model version, it's not 12 (the onnx_opset_version()), not 11 (the opset version of the latest stable), but 10 (pytorch#2478) <daquexian> - **[7cebaed5](onnx/onnx@7cebaed5)**: Fix Windows py3.5 CI (pytorch#2529) <bddppq> - **[eddae00e](onnx/onnx@eddae00e)**: Correct the order of arguments of InferShapes (pytorch#2500) <Shinichiro Hamaji> - **[41b5afe6](onnx/onnx@41b5afe6)**: Include <ostream> in common/status.h (pytorch#2519) <Casey Carter> - **[423f1977](onnx/onnx@423f1977)**: add 8 bit support to maxpool op (pytorch#2510) <Ashwini Khade> - **[78593c2f](onnx/onnx@78593c2f)**: add 8 bit support to reducemin and reducemax ops (pytorch#2516) <Ashwini Khade> Test Plan: cont build Differential Revision: D19380034 fbshipit-source-id: 1ea677eed6779d2b3f8e4683225ba856c68159cd
…740f8f (#32125) Summary: Pull Request resolved: #32125 Previous import was 57ebc587fcf3913b4be93653b0dd58c686447298 Included changes: - **[65020daa](onnx/onnx@65020daa)**: better error message for undefined inputs (#2540) <Yuxin Wu> - **[8afff0e9](onnx/onnx@8afff0e9)**: bump ORT version (#2538) <Lu Fang> - **[3d9ca57e](onnx/onnx@3d9ca57e)**: fix name of directory (#2537) <Prasanth Pulavarthi> - **[df8fa2c9](onnx/onnx@df8fa2c9)**: Repository guidelines (#2539) <Prasanth Pulavarthi> - **[49cc2f02](onnx/onnx@49cc2f02)**: Update CircleCI job to use Python3.6 (#2527) <bddppq> - **[25ff79a4](onnx/onnx@25ff79a4)**: Fix wrong model version, it's not 12 (the onnx_opset_version()), not 11 (the opset version of the latest stable), but 10 (#2478) <daquexian> - **[7cebaed5](onnx/onnx@7cebaed5)**: Fix Windows py3.5 CI (#2529) <bddppq> - **[eddae00e](onnx/onnx@eddae00e)**: Correct the order of arguments of InferShapes (#2500) <Shinichiro Hamaji> - **[41b5afe6](onnx/onnx@41b5afe6)**: Include <ostream> in common/status.h (#2519) <Casey Carter> - **[423f1977](onnx/onnx@423f1977)**: add 8 bit support to maxpool op (#2510) <Ashwini Khade> - **[78593c2f](onnx/onnx@78593c2f)**: add 8 bit support to reducemin and reducemax ops (#2516) <Ashwini Khade> Test Plan: cont build Reviewed By: benoitsteiner Differential Revision: D19380034 fbshipit-source-id: ddce8450864a611773b2a32e2f0254c9bb6b6906
…740f8f (pytorch#32125) Summary: Pull Request resolved: pytorch#32125 Previous import was 57ebc587fcf3913b4be93653b0dd58c686447298 Included changes: - **[65020daa](onnx/onnx@65020daa)**: better error message for undefined inputs (pytorch#2540) <Yuxin Wu> - **[8afff0e9](onnx/onnx@8afff0e9)**: bump ORT version (pytorch#2538) <Lu Fang> - **[3d9ca57e](onnx/onnx@3d9ca57e)**: fix name of directory (pytorch#2537) <Prasanth Pulavarthi> - **[df8fa2c9](onnx/onnx@df8fa2c9)**: Repository guidelines (pytorch#2539) <Prasanth Pulavarthi> - **[49cc2f02](onnx/onnx@49cc2f02)**: Update CircleCI job to use Python3.6 (pytorch#2527) <bddppq> - **[25ff79a4](onnx/onnx@25ff79a4)**: Fix wrong model version, it's not 12 (the onnx_opset_version()), not 11 (the opset version of the latest stable), but 10 (pytorch#2478) <daquexian> - **[7cebaed5](onnx/onnx@7cebaed5)**: Fix Windows py3.5 CI (pytorch#2529) <bddppq> - **[eddae00e](onnx/onnx@eddae00e)**: Correct the order of arguments of InferShapes (pytorch#2500) <Shinichiro Hamaji> - **[41b5afe6](onnx/onnx@41b5afe6)**: Include <ostream> in common/status.h (pytorch#2519) <Casey Carter> - **[423f1977](onnx/onnx@423f1977)**: add 8 bit support to maxpool op (pytorch#2510) <Ashwini Khade> - **[78593c2f](onnx/onnx@78593c2f)**: add 8 bit support to reducemin and reducemax ops (pytorch#2516) <Ashwini Khade> Test Plan: cont build Reviewed By: benoitsteiner Differential Revision: D19380034 fbshipit-source-id: ddce8450864a611773b2a32e2f0254c9bb6b6906
@vincentqb can you make a call on this? |
Given the cost associated to backtracking, this mechanism should be left in the control of the user. The scheduler could have flag that indicates its status though, so the user can save/load based on that flag, see comment. |
This is really a nice issue, so, what is the final decision of "needs research"? |
Is it possible to implement a simple backtracking for the
ReduceLROnPlateau
module?That is, store the best model coefficients and reload it upon rate reduction.
In my experiments, this helps speed up learning, though it might be expensive for very large models.
cc @vincentqb
The text was updated successfully, but these errors were encountered: