Skip to content

dscamiss/generalized-newtons-method

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

generalized-newtons-method

License PyTorch Python Python Python Build codecov

A PyTorch implementation of the generalized Newton's method, first proposed in [1].

Brief background

The generalized Newton's method is a learning rate scheduler that uses second-order derivative data.

As a concrete example, suppose that our objective is to minimize the loss function $L: \Theta \to \mathbf{R}$ using vanilla SGD and a static learning rate $\alpha$. One gradient descent iteration is $\theta_{t+1} \leftarrow \theta_t - \alpha \nabla_\theta L(\theta_t)$. For this iteration, introduce the "loss per learning rate" function

$$g(\alpha) = L(\theta_t - \alpha \nabla_\theta L(\theta_t)).$$

Towards the objective of minimizing $L$, we can attempt to choose $\alpha$ such that $g$ is (approximately) minimized. Provided that $g$ is well-approximated near the origin by its second-order Taylor polynomial, and that this polynomial is strictly convex, the generalized Newton's method chooses

$$\alpha_t = \frac{d_\theta L(\theta_t) \cdot \nabla_\theta L(\theta_t)}{d_\theta^2 L(\theta_t) \cdot (\nabla_\theta L(\theta_t), \nabla_\theta L(\theta_t))}.$$

This choice of $\alpha_t$ minimizes the second-order Taylor polynomial, and therefore approximately minimizes $g$.

More theory and implementation notes can be found in this blog post.

Caveats

Currently only the "exact version" of the method is implemented. A future version will implement the "approximate version" of the method as well. The difference between the two versions is that the "approximate version" trades off accuracy for efficiency, since it does not materialize the required Hessian-vector products.

Installation

git clone https://github.com/dscamiss/generalized-newtons-method
pip install generalized-newtons-method

Usage

Setup

import generalized_newtons_method as gen
model = MyModel()
criterion = MyLossCriterion()
  • Call make_gen_optimizer() to make a wrapped version of your desired optimizer:
optimizer = gen.make_gen_optimizer(torch.optim.AdamW, model.parameters())
  • Create the learning rate scheduler:
lr_min, lr_max = 0.0, 1e-3  # Clamp learning rate between `lr_min` and `lr_max`
scheduler = gen.ExactGen(optimizer, model, criterion, lr_min, lr_max)

Training

  • Run standard training loop:
for x, y in dataloader:
    optimizer.zero_grad()
    loss = criterion(model(x), y)
    loss.backward()
    scheduler.step(x, y)  # <-- Note additional arguments
    optimizer.step()

TODO

  • Add test cases to verify second-order coefficients
  • Add "approximate version"
  • Add shallow CNN training example

References

[1] Zi Bu and Shiyun Xu, Automatic gradient descent with generalized Newton’s method, arXiv:2407.02772