A Distributed Method for Fitting Laplacian Regularized ...boyd/papers/pdf/strat_models_talk.pdf ·...

40
A Distributed Method for Fitting Laplacian Regularized Stratified Models Jonathan Tuck Shane Barratt Stephen Boyd International Conference on Statistical Optimization and Learning Beijing Jiatong University, June 19 2019

Transcript of A Distributed Method for Fitting Laplacian Regularized ...boyd/papers/pdf/strat_models_talk.pdf ·...

  • A Distributed Method for FittingLaplacian Regularized Stratified Models

    Jonathan Tuck Shane Barratt Stephen Boyd

    International Conference on Statistical Optimization and LearningBeijing Jiatong University, June 19 2019

  • Outline

    Stratified models

    Data models

    Regularization graphs

    Solution method

    Examples

    Conclusions

    Stratified models 2

  • Data records

    I Data records have form (z , x , y) ∈ Z × X × YI z ∈ Z = {1, . . . ,K} are categorical features (we’ll stratify over)I x ∈ X are the other featuresI y ∈ Y is the label/dependent variable

    Stratified models 3

  • Stratified model

    I Fit a model to data (z , x , y)

    I Stratified model: fit different model for each value of z

    I θk is model parameter for z = k

    I θ = (θ1, . . . , θK ) ∈ Θ ⊆ RKn parameterizes the stratified modelI Old idea [Kernan 99]

    I Example: stratified regression model ŷ = xT θz

    Stratified models 4

  • Example

    0.0 0.2 0.4 0.6 0.8 1.0

    10

    5

    0

    5

    10z=1z=-1

    ŷ = θ1 + θ2x + θ3z , z ∈ {−1, 1}

    Stratified models 5

  • Example

    0.0 0.2 0.4 0.6 0.8 1.0

    10

    5

    0

    5

    10z=1z=-1

    ŷ =

    {θ−1,1 + θ−1,2x z = −1θ1,1 + θ1,2x z = 1

    Stratified models 6

  • Stratified model

    I Stratified model is simple function of x (e.g., linear), arbitrary function of zI Examples: separate models for

    – z ∈ {Male, Female}– z ∈ {Monday, . . . , Sunday}

    I If K is large, might not have enough training data to fit θkI Extreme case: no training data for some values of z

    I We’ll add regularization so nearby θk ’s are close

    Stratified models 7

  • Stratified model with Laplacian regularization

    I Choose θ = (θ1, . . . , θK ) to minimize

    N∑i=1

    l(θzi , xi , yi ) +K∑

    k=1

    r(θk) +K∑

    u,v=1

    Wuv‖θu − θv‖2

    I l is loss function, r is (local) regularization

    I Last term is Laplacian regularization

    I Wuv ≥ 0 are edge weights of graph with node set ZI Convex problem when l , r convex in θ

    Stratified models 8

  • Stratified model with Laplacian regularization

    I Graph encodes prior that nearby values of z should have similar models

    I Can be used to capture periodicities, other structureI Examples:

    – θmale and θfemale should be close– θjan should be close to θfeb and θdec

    I Model for each value of z ‘borrows strength’ from its neighbors

    I Works even when there’s no data for some values of z

    I As Wuv → 0, get traditional (unregularized) stratified modelI As Wuv →∞, get common model (θ1 = · · · = θK )

    Stratified models 9

  • Related work

    I Network lasso [Hallac 15])

    I Pliable lasso [Tibshirani 17]

    I Varying-coefficient models [Hastie 93, Fan 08]

    I Multi-task learning [Caruana 97]

    Stratified models 10

  • Outline

    Stratified models

    Data models

    Regularization graphs

    Solution method

    Examples

    Conclusions

    Data models 11

  • Point estimate: Predict y

    I Regression: X = Rn, Y = R– l(θ, x , y) = p(xT θ − y), p is penalty function– ŷ = xT θz

    I Classification: X = Rn, Y = {−1, 1}– l(θ, x , y) = p(yxT θ)– ŷ = sign(xT θz)

    I M-class classification: X ∈ Rn, Y = {1, . . . ,M}– l(θ, x , y) = py (x

    T θ), θ ∈ Rn×M , py is penalty function for class y– ŷ = argmax(xT θz)

    Data models 12

  • Conditional distribution estimate: Predict prob(y | x , z)

    I Multinomial logistic regression: X = Rn, Y = {1, . . . ,M}I Conditional distribution:

    prob(y | x , z) = exp(xT θz)y∑M

    j=1 exp(xT θz)j

    , y = 1, . . . ,M

    I Loss function (convex in θ):

    l(θ, x , y) = log

    M∑j=1

    exp(xT θ)j

    − (xT θ)yData models 13

  • Distribution estimate: Predict p(y | z)

    I Gaussian distribution: Y = Rm

    I Density:

    p(y | z) = (2π)−m/2 det(Σ)−1/2 exp(−1/2(y − µ)TΣ−1(y − µ))

    I Use natural parameter θ = (S , ν) = (Σ−1,Σ−1µ) (so Σ = S−1, µ = S−1ν)

    I Loss function (convex in θ):

    l(θ, y) = − log detS + yTSy − 2yTν + νTS−1ν

    Data models 14

  • Outline

    Stratified models

    Data models

    Regularization graphs

    Solution method

    Examples

    Conclusions

    Regularization graphs 15

  • Path graph

    0 1 2 3 4 5

    I Models time, distance, . . .

    I Yields time-varying, distance-varying models

    Regularization graphs 16

  • Cycle graph

    M

    TuW

    Th

    F

    SaSu

    I Yields diurnal, weekly, seasonally-varying models

    Regularization graphs 17

  • Tree graph

    CEO

    CTO

    SVP1

    COO

    SVP2

    VP1 VP2

    SVP3

    VP3

    I Yields hierarchical models

    Regularization graphs 18

  • Grid graph

    0,0

    0,1

    0,2

    0,3

    1,0

    1,1

    1,2

    1,3

    2,0

    2,1

    2,2

    2,3

    3,0

    3,1

    3,2

    3,3

    I Yields (2D) space-varying models

    Regularization graphs 19

  • Products of graphs

    M,1

    F,1

    M,2

    F,2

    · · ·

    · · ·

    M,99

    F,99

    M,100

    F,100

    I Z = {Male, Female} × {1, 2, . . . , 99, 100} (sex × age)I Yields sex, age-varying models

    Regularization graphs 20

  • Outline

    Stratified models

    Data models

    Regularization graphs

    Solution method

    Examples

    Conclusions

    Solution method 21

  • Fitting problem

    I To fit stratified model, minimize `(θ) + r(θ) + L(θ)I `(θ) =

    ∑Kk=1 `k(θk) is loss, `k(θk) =

    ∑i :zi=k

    l(θk , xi , yi )

    I r(θ) =∑K

    k=1 r(θk) is (local) regularization

    I L(θ) =∑K

    u,v=1Wuv‖θu − θv‖2 is Laplacian regularization

    I `, r are separable in θkI L is quadratic, separable in components of θk

    I We’ll use operator splitting method (ADMM)

    Solution method 22

  • Reformulation

    I Replicate variables:minimize `(θ) + r(θ̃) + L(θ̂)subject to θ = θ̃ = θ̂

    I Augmented Lagrangian

    L(θ, θ̃, θ̂, u, ũ) = `(θ) + r(θ̃) + L(θ̂) + 12t‖θ − θ̂ + u‖22 +

    1

    2t‖θ̃ − θ̂ + ũ‖22

    I u, ũ dual variables for the two constraints, t > 0 is algorithm parameter

    Solution method 23

  • ADMM

    I Augmented Lagrangian

    L(θ, θ̃, θ̂, u, ũ) = `(θ) + r(θ̃) + L(θ̂) + 12t‖θ − θ̂ + u‖22 +

    1

    2t‖θ̃ − θ̂ + ũ‖22

    I ADMM: for i = 1, 2, . . .

    θi+1, θ̃i+1 := argminθ,θ̃

    L(θ, θ̃, θ̂i , ui , ũi )

    θ̂i+1 := argminθ̂

    L(θi , θ̃i , θ̂, ui , ũi )

    ui+1 := ui + θi+1 − θ̂i+1

    ũi+1 := ũi + θ̃i+1 − θ̂i+1

    Solution method 24

  • ADMM

    I First step can be expressed as

    θi+1k = proxt`k (θ̂ik − uik), θ̃i+1k = proxtr (θ̂

    ik − ũik), k = 1, . . . ,K

    I proximal operator of tg is

    proxtg (ν) = argminθ

    (tg(θ) + (1/2)‖θ − ν‖22

    )I Can evaluate these 2K proximal operators in parallel

    Solution method 25

  • Loss proximal operators

    I Often has closed form expression or efficient implementationI Square loss: `k(θk) = ‖Xkθk − yk‖22

    – proxt`k (ν) = (I + tXTk Xk)

    −1(ν − tXTk yk)– Cache factorization or warm-start CG

    I Logistic loss: `k(θk) =∑

    i log(1 + exp(−ykiθTk xki ))– Use L-BFGS to evaluate proxt`k (ν)– Warm-start

    I Many others . . .

    Solution method 26

  • Regularizer proximal operators

    I Often has closed form expression or efficient implementation

    Regularizer r(θk) proxtr (ν)

    Sum of squares (`2) (λ/2)‖θk‖22 ν/(tλ+ 1)

    `1 norm λ‖θk‖1 (ν − tλ)+ − (−ν − tλ)+Nonnegative I+(θk) (θk)+

    I λ > 0 local regularization parameter

    Solution method 27

  • Laplacian proximal operator

    I Separable across each component of θkI To find each component (θk)i , need to solve a Laplacian system

    I Many efficient ways to solve, e.g., diagonally preconditioned CG

    I These n systems can be solved in parallel

    Solution method 28

  • Software implementation

    I Available at www.github.com/cvxgrp/strat_models

    I numpy, scipy for matrix operations

    I networkx for handling graphs and graph operations

    I torch for L-BFGS and GPU computation

    I multiprocessing for parallelism

    I model.fit(X,Y,Z,G) (writes θk on graph nodes)

    Solution method 29

    www.github.com/cvxgrp/strat_modelsnumpyscipynetworkxtorchmultiprocessingmodel.fit(X,Y,Z,G)

  • Outline

    Stratified models

    Data models

    Regularization graphs

    Solution method

    Examples

    Conclusions

    Examples 30

  • House price prediction

    I N ≈ 22000 (z , x , y) from King County, WAI Split into train/test ≈ 16200/5400I z = (latitude bin, longitude bin)

    I x ∈ R10 = features of house, y = log of house priceI Graph is 50× 50 grid with all edge weights same; K = 2500I Stratified ridge regression model with two hyperparameters

    (one for local regularization, one for Laplacian regularization)

    Examples 31

  • House price prediction: Results

    I Compare stratifed, common, and random forest with 50 trees

    Model Parameters RMS test error

    Stratified 25000 0.181Common 10 0.316Random forest 985888 0.184

    Examples 32

  • House price prediction: Parameters

    bedrooms bathrooms sqft living sqft lot floors

    waterfront condition grade yr built intercept

    Examples 33

  • Chicago crime prediction

    I N ≈ 7 000 000 (z , y) pairs for 2017, 2018I Train on 2017, test on 2018

    I y = number of crimes

    I z = (location bin,week of year, day of week, hour of day); K ≈ 3 500 000I Graph is Cartesian product of grid, three cycles; four graph edge weights

    I Stratified Poisson model with four hyperparameters(one for each graph edge weight)

    Examples 34

  • Chicago crime prediction

    I Compare three models, average negative log likelihood on test data

    Model Train Test

    Separate 0.068 0.740Stratified 0.221 0.234Common 0.279 0.278

    Examples 35

  • Chicago crime prediction

    Crime rate as a function of latitude/longitude

    -87.85 -87.77 -87.69 -87.61 -87.52

    42.02

    41.93

    41.83

    41.74

    41.64

    0.02

    0.04

    0.06

    0.08

    0.10

    0.12

    Examples 36

  • Chicago crime prediction

    Crime rate as a function of week of year

    Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec

    0.036

    0.038

    0.040

    0.042

    0.044

    Examples 37

  • Chicago crime prediction

    Crime rate as a function of hour of week

    M Tu W Th F Sa Su

    0.020

    0.025

    0.030

    0.035

    0.040

    0.045

    0.050

    0.055

    Examples 38

  • Outline

    Stratified models

    Data models

    Regularization graphs

    Solution method

    Examples

    Conclusions

    Conclusions 39

  • Conclusions

    I Stratified models combine

    – simple dependence on some features (x)– complex dependence on others (z)

    I Often interpretable

    I Laplacian regularization encodes prior on values of z , so models can borrowstrength from their neighbors

    I Effective method to build time-varying, space-varying, seasonally-varying models

    I Efficient, distributed ADMM-based implementation for large-scale data

    Conclusions 40

    Stratified modelsData modelsRegularization graphsSolution methodExamplesConclusions