A PyTorch implementation of the generalized Newton's method, first proposed in [1].
The generalized Newton's method is a learning rate scheduler that uses second-order derivative data.
As a concrete example, suppose that our objective is to minimize the loss function
Towards the objective of minimizing
This choice of
More theory and implementation notes can be found in this blog post.
Currently only the "exact version" of the method is implemented. A future version will implement the "approximate version" of the method as well. The difference between the two versions is that the "approximate version" trades off accuracy for efficiency, since it does not materialize the required Hessian-vector products.
git clone https://github.com/dscamiss/generalized-newtons-method
pip install generalized-newtons-method
import generalized_newtons_method as gen
model = MyModel()
criterion = MyLossCriterion()
- Call
make_gen_optimizer()
to make a wrapped version of your desired optimizer:
optimizer = gen.make_gen_optimizer(torch.optim.AdamW, model.parameters())
- Create the learning rate scheduler:
lr_min, lr_max = 0.0, 1e-3 # Clamp learning rate between `lr_min` and `lr_max`
scheduler = gen.ExactGen(optimizer, model, criterion, lr_min, lr_max)
- Run standard training loop:
for x, y in dataloader:
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
scheduler.step(x, y) # <-- Note additional arguments
optimizer.step()
- Add test cases to verify second-order coefficients
- Add "approximate version"
- Add shallow CNN training example
[1] Zi Bu and Shiyun Xu, Automatic gradient descent with generalized Newton’s method, arXiv:2407.02772