Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf ·...

27
Gradient-based Hyperparameter Optimization with Reversible Learning Dougal Maclaurin, David Duvenaud, Ryan Adams

Transcript of Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf ·...

Page 1: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Gradient-based HyperparameterOptimization with Reversible Learning

Dougal Maclaurin, David Duvenaud, Ryan Adams

Page 2: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Motivation

• Hyperparameters are everywhere• sometimes hidden!

• Gradient-free optimization is hard• Validation loss is a function of hyperparameters• Why not take gradients?

Page 3: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing optimization

xfinal = SGD (xinit , learn rate,momentum,∇Loss(x, reg,Data))

Weight 1W

eight 2

Initial weights

Meta-iteration 1

Page 4: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing optimization

xfinal = SGD (xinit , learn rate,momentum,∇Loss(x, reg,Data))

Weight 1W

eight 2

Initial weights

Meta-iteration 1

Weight 1W

eight 2

Initial weights

Meta-iteration 1

Meta-iteration 2

Page 5: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing optimization

xfinal = SGD (xinit , learn rate,momentum,∇Loss(x, reg,Data))

Weight 1W

eight 2

Initial weights

Meta-iteration 1

Weight 1W

eight 2

Initial weights

Meta-iteration 1

Meta-iteration 2

Weight 1W

eight 2

Initial weights

Meta-iteration 1

Meta-iteration 2

Meta-iteration 3

Page 6: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

A pretty scary function to differentiate

J = Loss (Dval ,SGD (xinit , α, β,∇Loss(Dtrain,x, reg)))

Stochastic Gradient Descent

1: input: initial x1, decay β, learning rate α, reg-ularization params θ, loss function L(x,θ, t)

2: initialize v1 = 03: for t = 1 to T do4: gt = ∇xL(xt ,θ, t) . evaluate gradient5: vt+1 = βtvt − (1− βt)gt . update velocity6: xt+1 = xt + αtvt . update position7: output trained parameters xT

• Each gradient evaluation in SGD requires forwardand backprop through model

• Entire learning procedure looks like a 1000-layerdeep net

Page 7: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Autograd: Automatic Differentiation

• github.com/HIPS/autograd

• Works with (almost) arbitrary Python/Numpy code• Can take gradients of gradients (of gradients...)

Page 8: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Autograd Example

import autograd.numpy as npimport matplotlib.pyplot as pltfrom autograd import grad

# Taylor approximation to sin functiondef fun(x):

curr = xans = currfor i in xrange(1000):

curr = - curr * x**2 / ((2*i+3)*(2*i+2))ans = ans + currif np.abs(curr) < 0.2: break

return ans

d_fun = grad(fun)dd_fun = grad(d_fun)

x = np.linspace(-10, 10, 100)plt.plot(x, map(fun, x),

x, map(d_fun, x),x, map(dd_fun, x))

Page 9: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Most Numpy functions implemented

Complex Array Misc Linear Stats& Fourier Algebra

imag atleast_1d logsumexp inv stdconjugate atleast_2d where norm meanangle atleast_3d einsum det varreal_if_close full sort eigh prodreal repeat partition solve sumfabs split clip trace cumsumfft concatenate outer diagfftshift roll dot trilfft2 transpose tensordot triuifftn reshape rot90ifftshift squeezeifft2 ravelifft expand_dims

Page 10: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Technical Challenge: Memory

• Reverse-mode differentiation needs access to entirelearning trajectory

• i.e. 107 parameters × 105 training iterations• Only need access in reverse order...• Could we recompute the learning trajectory

backwards by running reverse SGD?

Page 11: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

SGD with momentum is reversible

Forward update rule:

xt+1 ← xt + αvt

vt+1 ← βvt −∇L (xt+1)

Reverse update rule:

vt ← (vt+1 +∇L (xt+1)) /β

xt ← xt+1 − αvt

Page 12: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Reverse-mode differentiation of SGD

Stochastic Gradient Descent1: input: initial x1, decays β, learning rates α, loss

function L(x, θ, t)

2: initialize v1 = 0

3: for t = 1 to T do4: gt = ∇xL(xt , θ, t) . evaluate gradient5: vt+1 = βt vt − (1− βt )gt . update velocity

6: xt+1 = xt + αt vt . update position7: output trained parameters xT

Reverse-Mode Gradient of SGD1: input: xT , vT , β, α, train loss L(x, θ, t), loss f (x)

2: initialize dv = 0, dθ = 0, dαt = 0, dβ = 0

3: initialize dx = ∇xf (xT )

4: for t = T counting down to 1 do5: dαt = dxTvt

6: xt−1 = xt − αt vt . downdate position

7: gt = ∇xL(xt , θ, t) . evaluate gradient

8: vt−1 = [vt + (1− βt )gt ]/βt . downdate velocity

9: dv = dv + αt dx

10: dβt = dvT(vt + gt )

11: dx = dx− (1− βt )dv∇x∇xL(xt , θ, t)

12: dθ = dθ − (1− βt )dv∇θ∇xL(xt , θ, t)

13: dv = βt dv

14: output gradient of f (xT ) w.r.t x1, v1, β, α and θ

• Outputs gradients with respect to all hypers.• Reversing SGD avoids storing learning trajectory

Page 13: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Naive reversal

Weight 1W

eight 2

Page 14: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Naive reversal ... Fails!

Weight 1W

eight 2

Page 15: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

A closer look at reverse SGD

Forward update rule:

xt+1 ← xt + αvt

vt+1 ← βvt −∇L (xt+1)

Destroys log2 β bits per parameter per iteration

Reverse update rule:

vt ← (vt+1 +∇L (xt+1)) /β

xt ← xt+1 − αvt

Needs log2 β bits per parameter per iteration

Page 16: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

How to store the lost bits?

• Switch to fixed-precision for exact addition• Express β as a rational number• push/pop remainders from an information buffer

def rational_multiply(x, n, d, bitbuffer):bitbuffer.push(x % d, d)x /= dx *= nx += bitbuffer.pop(n)return x

• 200X memory savings when β = 0.9• Now we have scalable gradients of hypers, only twice

as slow as original!

Page 17: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Part 2: A Garden of Delights

Page 18: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Learning rate gradients

∂Loss (Dval ,xinit , α, β,Dtrain, reg)∂α

0 20 40 60 80 100Schedule index

0

1

2

3

4

5

6

7

Learn

ing

rate

Layer 1

Layer 2

Layer 3

Layer 4

0 20 40 60 80 100Schedule index

0

Learn

ing

rate

rad

ien

t

Page 19: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimized learning rates

• Used SGD to optimize SGD• 4-layer NN on MNIST• Top layer learns early on; slowdown at end

0 20 40 60 80 100Schedule index

0

1

2

3

4

5

6

7

Learn

ing

rate

Layer 1

Layer 2

Layer 3

Layer 4

Page 20: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing initialization scales

∂Loss (Dval ,xinit , α, β,Dtrain, reg)∂xinit

Biases Weights

0 10 20 30 40 50Meta iteration

0.0

0.1

0.2

0.3

0.4

0.5

0.6

0.7

0.8

Init

ial

scale

Layer 1

Layer 2

Layer 3

Layer 4

0 10 20 30 40 50Meta iteration

0.00

1=p784

0.10

1=p50

0.20

0.25

Page 21: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing regularization

∂Loss (Dval ,xinit , α, β,Dtrain, reg)∂reg

Optimized L2 hypers for each weight in logistic regression:

0

0.03

Page 22: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing architecture

• Architecture = tying weights or setting to them zero• i.e. convnets, recurrent nets, multi-task• Trying be enforced by L2 regularization• L2 regularization is differentiable

Rot

ated

Orig

inal

Page 23: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing regularization

Matrices enforce weight sharing between tasks

Input Middle Output Train Testweights weights weights error error

Separatenet-works

0.61 1.34

Tiedweights 0.90 1.25

Learnedsharing 0.60 1.13

Page 24: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Optimizing training data

∂Loss (Dval ,xinit , α, β,Dtrain, reg)∂Dtrain

• Training set of size 10 with fixed labels on MNIST• Started from blank images

-0.30

0

0.33

Page 25: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Limitations: Chaotic learning dynamics

1.0 1.5 2.00

Tra

inin

g l

oss

1.0 1.5 2.0

0

Gra

die

nt

Learning rate

Page 26: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Summary

• Can compute gradients of learning procedures• Reversing learning saves memory• Can optimize thousands of hyperparameters

Thanks!

Page 27: Gradient-based Hyperparameter Optimization with Reversible ...duvenaud/talks/hypergrad-talk.pdf · 5: vt+1 = t t (1 t)gt.update velocity 6: xt+1 = t + tvt.update position 7: output

Summary

• Can compute gradients of learning procedures• Reversing learning saves memory• Can optimize thousands of hyperparameters

Thanks!