SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING · SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING...
Transcript of SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING · SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING...
PAVEL IZMAILOV, WESLEY MADDOX, POLINA KIRICHENKO, TIMUR GARIPOV, DMITRY VETROV, ANDREW GORDON WILSON
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
�1
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
WHY BAYESIAN INFERENCE?‣ Combining models for better predictions 📊
‣ Uncertainty representation (crucial for decision making) 🤷
‣ Interpretably incorporate prior knowledge and domain expertise #
�2
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
WHY BAYESIAN INFERENCE?‣ Combining models for better predictions 📊
‣ Uncertainty representation (crucial for decision making) 🤷
‣ Interpretably incorporate prior knowledge and domain expertise #
‣ Challenging for Deep NNs due to high dimensional weight spaces 😩
�3
WHY NOT?
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
SUBSPACE INFERENCEA modular approach:
‣ Design subspace
‣ Approximate posterior over parameters in the subspace
‣ Sample from approximate posterior for Bayesian model averaging
�4
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
SUBSPACE INFERENCEA modular approach:
‣ Design subspace
‣ Approximate posterior over parameters in the subspace
‣ Sample from approximate posterior for Bayesian model averaging
We can approximate posterior of 36 million dimensional WideResNet in 5D subspace and get state-of-the-art results!
�5
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
SUBSPACE‣ Choose shift and basis vectors
‣ Define subspace
‣ Likelihood
�6
w {d1, . . . , dK}
S = {w |w = w + t1d1 + . . . + tkdK
Pt
}
p(D | t) = pM(D |w = w + Pt) .
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
INFERENCE‣ Approximate inference over parameters
‣ MCMC, Variational Inference, Normalizing Flows, …
‣ Bayesian model averaging at test time:
�7
p(D* |D) =1J
J
∑i=1
pM(D* | w = w + Pti), ti ∼ q(t |D)
t
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
TEMPERING POSTERIOR‣ In the subspace model # parameters << # data points
‣ ~5-10 parameters, ~50K data points
‣ Posterior over is extremely concentrated
‣ To address this issue, we utilize the tempered posterior:
‣ T can be learned by cross-validation
‣ Heuristic:
�8
pT(t |D) ∝ p(D | t)1/T
likelihood
p(t)⏟prior
t
T =# data points# parameters
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
SUBSPACE CHOICE
We want a subspace that
‣ Contains diverse models
‣ Cheap to construct
�9
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
RANDOM SUBSPACE‣ Directions
‣ Use pre-trained solution as shift
‣ Subspace
�10
�3 �2 �1 0 1 2 3�3
�2
�1
0
1
2
3
Posterior log-densityESS, Random Subspace
�0.0029
�0.012
�0.015
�0.019
�0.025
�0.032
�0.042
�0.055
< �0.055
Predictive DistributionESS, Random Subspace
d1, …, dK ∼ N(0, Ip)
w
S = {w |w = w + Pt}
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
PCA OF THE SGD TRAJECTORY‣ Run SGD with high constant learning rate from a pre-trained solution
‣ Collect snapshots of weights
‣ Use SWA solution as shift
‣ — first PCA components of vectors
�11
�0.04 �0.02 0.00 0.02 0.04
�0.8
�0.6
�0.4
�0.2
0.0
0.2
0.4
0.6
0.8
Posterior log-densityESS, PCA Subspace
�0.0028
�0.015
�0.025
�0.043
�0.076
�0.14
�0.24
�0.44
< �0.44
Predictive DistributionESS, PCA Subspace
w − wi
w =1T ∑
i
wi
{d1, …, dK} K
wi
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
CURVE SUBSPACE‣ Garipov et al. 2018 proposed a method to find 2D subspaces
containing a path of low loss between weights of two independently trained neural networks
�12
�5 0 5 10 15 20 25 30
�3
�2
�1
0
1
2
Posterior log-densityESS, Curve Subspace
�0.003
�0.02
�0.047
�0.12
�0.3
�0.76
�1.9
�5
< �5
Predictive DistributionESS, Curve Subspace
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
SUBSPACE COMPARISON
�13
�0.04 �0.02 0.00 0.02 0.04
�0.8
�0.6
�0.4
�0.2
0.0
0.2
0.4
0.6
0.8
Posterior log-densityESS, PCA Subspace
�0.0028
�0.015
�0.025
�0.043
�0.076
�0.14
�0.24
�0.44
< �0.44
�3 �2 �1 0 1 2 3�3
�2
�1
0
1
2
3
Posterior log-densityESS, Random Subspace
�0.0029
�0.012
�0.015
�0.019
�0.025
�0.032
�0.042
�0.055
< �0.055
�5 0 5 10 15 20 25 30
�3
�2
�1
0
1
2
Posterior log-densityESS, Curve Subspace
�0.003
�0.02
�0.047
�0.12
�0.3
�0.76
�1.9
�5
< �5
Predictive DistributionESS, Curve Subspace
Predictive DistributionESS, PCA Subspace
Predictive DistributionESS, Random Subspace
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
SUBSPACE COMPARISON ON PRERESNET-164, CIFAR-100
�14
�2.0 �1.5 �1.0 �0.5 0.0 0.5 1.0 1.5 2.0�2.0
�1.5
�1.0
�0.5
0.0
0.5
1.0
1.5
2.0
Curve SubspacePosterior log-density
ESS
�0.4
�0.65
�0.86
�1.2
�1.9
�3.2
�5.6
�10
< �10
�80 �60 �40 �20 0 20 40 60 80
�80
�60
�40
�20
0
20
40
60
80
PCA SubspacePosterior log-density
ESS SWAG 3� region VI 3� region
�0.51
�0.8
�1.1
�1.8
�3.2
�6.1
�12
�25
< �25
�60 �40 �20 0 20 40 60
�60
�40
�20
0
20
40
60
Random SubspacePosterior log-density
ESS VI 3� region
�0.51
�0.8
�1.1
�1.8
�3.2
�6.1
�12
�25
< �25
SGD Random PCA Curve
NLL 0.946 ± 0.001 0.686 ± 0.005 0.665 ± 0.004 0.646
Accuracy (%) 78.50 ± 0.32 80.17 ±0.03 80.54 ± 0.13 81.28
SUBSPACE INFERENCE FOR BAYESIAN DEEP LEARNING
TAKEAWAYS‣ We can apply standard approximate inference methods in subspaces of
parameter space
‣ More diverse subspaces => better performance: Curve Subspace > PCA Subspace > Random Subspace
‣ Subspace Inference in the PCA subspace is competitive with SWAG (Maddox et al., 2019), MC-Dropout (Gal & Ghahramani, 2016) and Temperature Scaling (Guo et al., 2017) on image classification and UCI regression
�15