Jan-Christian (JC) Hütter jchuetter.web@gmail

40
Jan-Christian (JC) Hütter [email protected]

Transcript of Jan-Christian (JC) Hütter jchuetter.web@gmail

Page 1: Jan-Christian (JC) Hütter jchuetter.web@gmail

Jan-Christian (JC) Hü[email protected]

Page 2: Jan-Christian (JC) Hütter jchuetter.web@gmail

2 / 29

• Common statistical distances insensitive to geometry of underlying space

• Monge problem: For 𝑃, 𝑄 probability distributions on ℝ𝑑, search for transport map 𝑇0: ℝ𝑑 → ℝ𝑑

that solves

min න 𝑇 𝑥 − 𝑥 22 ⅆ𝑃 𝑥 ∶ 𝑇#𝑃 = 𝑄 , Monge

where 𝑇#𝑃 is push-forward of 𝑃, distribution of 𝑌 = 𝑇 𝑋 , 𝑋 ∼ 𝑃.

• Focus on square Euclidean cost 𝑇 𝑥 − 𝑥 22, but other cost functions also of interest

𝐷 𝑃 𝑄 = න log𝑝 𝑥

𝑞 𝑥𝑝 𝑥 ⅆ𝑥 = ∞

𝑇𝑉 𝑃, 𝑄 =1

2න 𝑝 𝑥 − 𝑞 𝑥 ⅆ𝑥 = 1

𝑃

𝑎

𝑄

𝑏

𝑇

𝑊22 𝑃, 𝑄 = 𝑎 − 𝑏 2

2 Wasserstein distance

𝑊22 𝑃, 𝑄 =

Page 3: Jan-Christian (JC) Hütter jchuetter.web@gmail

3 / 29

Distance betweenprobability measures(𝑊2)

Uncoupled functionestimation (𝑇0)

Bag-of-words models(Rolet, Cuturi, Peyré, 2016)

Multi-label prediction(Frogner et al., 2015)

Color transfer(Rabin, Delon, Gousseau, 2010)

Domain adaptation(Courty, Flamary, Tuia, 2017) Trajectory inference in

scRNA-Seq(Schiebinger, Shu,Tabaka, et al., 2019)

Wasserstein GAN(Arjovsky, Chintala, Bottou, 2017)

Page 4: Jan-Christian (JC) Hütter jchuetter.web@gmail

4 / 29

• Reminder:

min𝑇#𝑃=𝑄

න 𝑇 𝑥 − 𝑥 22 ⅆ𝑃(𝑥) (Monge)

• Constraint 𝑇#𝑃 = 𝑄 highly non-linear, asymmetric

• Relaxation: Look for a transport plan 𝛾0 that solves

min𝛾∈Γ 𝑃,𝑄

න 𝑦 − 𝑥 22 ⅆ𝛾 𝑥, 𝑦 , (Kantorovich)

where Γ 𝑃, 𝑄 is the set of probability measures on ℝ𝑑×𝑑 with marginals 𝑃 and 𝑄,

නⅆ𝛾 . , 𝑦 = 𝑃 , නⅆ𝛾 𝑥, . = 𝑄

Independent

Deterministic

Page 5: Jan-Christian (JC) Hütter jchuetter.web@gmail

5 / 29

Theorem (Brenier)

• If 𝑃 absolutely continuous measure, then 𝛾0 is unique and concentrated on graph of function 𝑇0 that solves (Monge).Moreover, 𝑇0 = ∇𝑓0, the gradient of a convex function 𝑓0 ∶ ℝ𝑑 → ℝ∪ +∞ .

• Conversely, given 𝑃 and 𝑇0 = ∇𝑓0 𝑥 for 𝑓0 convex, 𝑇0 solves (Monge) for 𝑃 and 𝑄 = 𝑇0 #𝑃.

𝑇0 ∈ arg min𝑇#𝑃=𝑄

න 𝑇 𝑥 − 𝑥 22 ⅆ𝑃 𝑥 Monge

𝛾0 ∈ arg min𝛾∈Γ 𝑃,𝑄

න 𝑦 − 𝑥 22 ⅆ𝛾 𝑥, 𝑦 Kantorovich

Page 6: Jan-Christian (JC) Hütter jchuetter.web@gmail
Page 7: Jan-Christian (JC) Hütter jchuetter.web@gmail

7 / 29

• Sample access: observe 𝑋1, … , 𝑋𝑛 ∼ 𝑃, 𝑌1, … , 𝑌𝑚 ∼ 𝑄 i.i.d.

• Empirical distributions

𝑃 =1

𝑛

𝑖=1

𝑛

𝛿𝑋𝑖 ,𝑄 =

1

𝑚

𝑖=1

𝑚

𝛿𝑌𝑖

• Can directly compute on empirical distributions

𝑊22 = 𝑊2

2 𝑃, 𝑄 = min𝛾∈Γ 𝑃, 𝑄

න 𝑦 − 𝑥 22 d𝛾 𝑥, 𝑦

• Fast algorithms available: linear programming,entropic regularization

• Is this a good idea?

• No out-of-sample prediction, but could do barycentric projection and 1-nearest-neighbor

• Statistical guarantees?

Page 8: Jan-Christian (JC) Hütter jchuetter.web@gmail

8 / 29

1. How well does 𝑃 approximate 𝑃 in 𝑊2:

𝑊22 𝑃, 𝑃 ≲ ?

2. How well can I estimate 𝑊2(𝑃, 𝑄):

𝑊22 𝑃, 𝑄 −𝑊2

2 𝑃, 𝑄 ≲ ?

𝑊2 𝑃, 𝑄 −𝑊2 𝑃, 𝑄 ≲ ?

By triangle inequality, 𝑊2𝑃, 𝑄 −𝑊2 𝑃, 𝑄 ≤ 𝑊2

𝑃, 𝑃 +𝑊2𝑄,𝑄 , so can get 2

from 1 for plug-in estimator

3. How well can I estimate 𝑇0 or 𝛾0? For example,

න 𝑇 𝑥 − 𝑇0 𝑥2

2ⅆ𝑃(𝑥) ≲ ?

Page 9: Jan-Christian (JC) Hütter jchuetter.web@gmail

9 / 29

Theorem (Dudley, 1969; Weed, Bach, 2017)

𝑃 ∼ Unif 0,1 𝑑 , 𝑋1, … , 𝑋𝑛 ∼ 𝑃 i.i.d., then

𝑊22 𝑃, 𝑃 ≳ 𝑛−2/𝑑 almost surely.

Proof.

• Set 𝜖 = 1

100𝑛−1/𝑑. By a volume argument, 𝑋𝑖 + 𝜖 −1,1 𝑑 does not cover more than

half of 0,1 𝑑. Denote that half by 𝐴• 𝐴 needs to get transported to 𝑃, so each point travels at least 𝜖.• Hence, for any coupling 𝛾,

𝑊22 𝑃, 𝑃 ≥ න 𝑦 − 𝑥 2

2 ⅆ𝛾 𝑥, 𝑦 ≥1

2𝜖2 =

1

2 100 2𝑛−

2𝑑

Page 10: Jan-Christian (JC) Hütter jchuetter.web@gmail

10 / 29

For the rest of the talk, assume 𝑃, 𝑄 compactly supported, ⅆ ≥ 5.

Theorem (Dudley, 1969; Fournier, Guillin, 2014; Weed, Bach, 2017)If 𝑃 absolutely continuous, then

𝔼 𝑊22 𝑃, 𝑃 ≍ 𝑛−

2𝑑.

• Also holds true when 𝑃 supported on 𝑘-dimensional compact manifold (Weed, Bach, 2017):

𝔼 𝑊22 𝑃, 𝑃 ≍ 𝑛−

2𝑘 .

• Two-sample case (Chizat et al., 2020; Manole, Niles-Weed, 2021):

𝔼 𝑊22 𝑃, 𝑄 −𝑊2

2 𝑃, 𝑄 ≍ 𝑛−2𝑑

• Corollary

𝔼 𝑊2𝑃, 𝑄 −𝑊2 𝑃, 𝑄 ≲ 𝑛−

1𝑑 ∧

1

𝑊2 𝑃, 𝑄𝑛−

2𝑑

• Lower bounds (under some assumptions, Niles-Weed, Rigollet, 2020):

𝔼 𝑊22 −𝑊2

2 𝑃, 𝑄 ≳ (𝑛 log 𝑛)−2𝑑

Page 11: Jan-Christian (JC) Hütter jchuetter.web@gmail

11 / 29

• What if distributions are full dimensional, but transport is not?

• Kolouri et al, 2019; Paty, Cuturi, 2019; Niles-Weed, Rigollet, 2019; Lin et al., 2021:

𝑃𝑊2,𝑘2 𝑃, 𝑄 = sup 𝑊2

2 Π𝐸 #𝑃, Π𝐸 #𝑄 ∶ 𝐸 𝑘−dim subspace ,

where Π𝐸 denotes the projection onto subspace 𝐸

• Projection Robust Wasserstein distance/Wasserstein projection pursuit

• Could also average instead of taking supremum, 1D case then known as sliced Wasserstein distance

• Recover lower-dimensional rates

𝔼 𝑃𝑊2,𝑘2 𝑃, 𝑃 ≲ 𝑛−

2𝑘

• Computationally challenging, but amenable to Riemannian optimization (Lin et al., 2020)

Page 12: Jan-Christian (JC) Hütter jchuetter.web@gmail
Page 13: Jan-Christian (JC) Hütter jchuetter.web@gmail

13 / 29

• Empirical distribution 𝑃 is poor approximation to 𝑃

• For M > 0, 𝛼 > 1, assume• 𝑃 admits 𝐶𝛼−1 smooth densities (also denoted by 𝑃)

• Densities upper and lower bounded, 𝑀−1 ≤ 𝑃 𝑥 ≤ 𝑀

• Some additional assumptions on support

• Niles-Weed, Berthet, 2019: For wavelet density estimators ෨𝑃,

𝔼 𝑊22 𝑃, ෨𝑃 ≲ 𝑛−

2𝛼2𝛼−2+𝑑

• Under additional assumptions, can obtain similar results for 𝑊22( ෨𝑃, ෨𝑄) and ෨𝑇 𝑥 − 𝑥

2

2ⅆ𝑃 𝑥

with ෨𝑇 transport map from ෨𝑃 to ෨𝑄 (Manole et al., 2021; Deb et al., 2021)

• Shortcomings: to compute Wasserstein distance or transport map, still need to sample a lot!

Page 14: Jan-Christian (JC) Hütter jchuetter.web@gmail

14 / 29

Theorem (H., Rigollet, 2021)

If

• 𝑇0 = ∇𝑓0 with M−1I ≼ ∇2𝑓0 ≼ 𝑀𝐼 (𝑓0 strongly convex, i.e., 𝑓0 −𝑀−1 . 22 convex)

• 𝑇0 is 𝐶𝛼-smooth on hypercube containing supp(𝑃), 𝛼 > 1,

• 𝑄 = 𝑇0 #𝑃

Then, there exists 𝑇 such that

𝔼න 𝑇 𝑥 − 𝑇0 𝑥2

2ⅆ𝑃(𝑥) ≲ 𝑛−

2𝛼2𝛼−2+𝑑 log 𝑛 2

• Minimax optimal up to logs for classes mimicking the above assumptions

Page 15: Jan-Christian (JC) Hütter jchuetter.web@gmail

15 / 29

2𝛼/(2𝛼 − 2 + ⅆ)

• Change of variables:d𝑄

d𝑃𝑦 ≍

1

det 𝜕𝑇 𝑇−1 𝑦, 𝑦 ∈ 𝑇(supp 𝑃 )

• KL between “signals” is governed by 𝜕𝑇, while estimation target is 𝑇

• Compare white noise model: estimation rate 𝑛−2 𝛽−𝑘 / 2𝛽+𝑑 for estimation of 𝑘thderivative of 𝐶𝛽 function

• Set 𝛽 = 𝛼 − 1, 𝑘 = −1, estimation of anti-derivative of 𝐶𝛼−1 function

Page 16: Jan-Christian (JC) Hütter jchuetter.web@gmail

16 / 29

• 𝑃 ∈ 𝐶𝛼−1 Ω𝑃 , 𝑄 ∈ 𝐶𝛼−1 Ω𝑄 for 𝛼 > 1

• 𝑃, 𝑄 upper and lower bounded on Ω𝑃, Ω𝑄

• ΩP, Ω𝑄 convex with 𝐶2-boundary and uniformly convex

Then, 𝑇0 ∈ 𝐶𝛼

Page 17: Jan-Christian (JC) Hütter jchuetter.web@gmail

17 / 29

• For a moment, think about regression problem with data (𝑋𝑖 , 𝑌𝑖) 𝑖=1,…,𝑛 and regression function 𝑇 instead

• Risk function 𝒮(𝑇) = 𝔼 𝑇 𝑋 − 𝑌 22

• Empirical risk: መ𝒮 𝑇 =1

𝑛σ𝑖=1𝑛 𝑇 𝑋𝑖 − 𝑌𝑖 2

2

𝑇0 = arg min𝑇

𝒮 𝑇 , 𝑇 = arg min𝑇

መ𝒮 𝑇

• Bound excess risk

• Control via empirical process theory, e.g. Dudley integral and covering numbers based on complexity of function class 𝑇

• Stability/margin condition

𝔼 𝑇 𝑋 − 𝑇0 𝑋2

2≲ 𝒮 𝑇 − 𝒮 𝑇0 = 𝒮 𝑇 − መ𝒮 𝑇 + መ𝒮 𝑇 − መ𝒮 𝑇0

≤0

+ መ𝒮 𝑇0 − 𝒮 𝑇0

≤ 𝒮 𝑇 − መ𝒮 𝑇 + መ𝒮 𝑇0 − 𝒮 𝑇0≤ 2 sup

𝑇|𝒮 𝑇 − መ𝒮 𝑇 |

Page 18: Jan-Christian (JC) Hütter jchuetter.web@gmail

18 / 29

• Find suitable risk function to apply empirical process theory

• Rewrite objective: expanding squares yields

min𝛾∈Γ 𝑃,𝑄

න 𝑦 − 𝑥 22 d𝛾 𝑥, 𝑦 ⟺ max

𝛾∈Γ 𝑃,𝑄න 𝑥, 𝑦 d𝛾 𝑥, 𝑦

Page 19: Jan-Christian (JC) Hütter jchuetter.web@gmail

19 / 29

• Strong duality: for optimal 𝛾0, 𝑓0, 𝑔0,

න 𝑥, 𝑦 d𝛾0 𝑥, 𝑦 = න 𝑓0 𝑥 + 𝑔0 𝑦 d𝛾0 𝑥, 𝑦 = න𝑓0 𝑥 d𝑃 𝑥 + න𝑔0 𝑦 d𝑄 𝑦

• Semi-dual: Optimizing 𝑔 for given 𝑓,

𝑔𝑓 𝑦 = sup𝑧∈ℝ𝑑

𝑧, 𝑦 − 𝑓(𝑧) = 𝑓∗ 𝑦 , convex conjugate

Semi-dual

min𝑓

𝒮 𝑓 = න𝑓 𝑥 d𝑃 𝑥 + න𝑓∗ 𝑦 d𝑄 𝑦

Primal

𝛾0 ∈ arg maxන 𝑥, 𝑦 d𝛾 𝑥, 𝑦

s.t. 𝛾 ∈ Γ 𝑃, 𝑄

Dual

𝑓0, 𝑔0 ∈ arg minන𝑓 𝑥 d𝑃 𝑥 + න𝑔 𝑦 d𝑄(𝑦)

s.t. 𝑓 𝑥 + 𝑔 𝑦 ≥ 𝑥, 𝑦 , 𝑥, 𝑦 ∈ ℝ𝑑

Non-linear, no constraints!

න𝑓∗∗ 𝑥 d𝑃 𝑥 + න𝑓∗ 𝑦 d𝑄 𝑦 , 𝑇0 = ∇𝑓0

Page 20: Jan-Christian (JC) Hütter jchuetter.web@gmail

20 / 29

• Replace 𝔼 ⤳1

𝑛σ𝑖=1𝑛 ,

𝒮 𝑓 = 𝑓 𝑥 ⅆ𝑃 𝑥 + ∗𝑓 𝑦 ⅆ𝑄 𝑦 ⤳ መ𝒮 𝑓 =1

𝑛𝑖=1𝑛

𝑓 𝑋𝑖 +1𝑛σ𝑖=1𝑛 𝑓∗ 𝑌𝑖

• With hypothesis space ℋ = 𝑉𝐽 wavelet cutoff

𝑓0 = arg min𝑓∈ℋ

𝒮 𝑓 , መ𝑓 = arg min𝑓∈ℋ

መ𝑆 𝑓 , T = ∇ መ𝑓

𝒮 መ𝑓 − 𝒮 𝑓0 ≲ sup𝑓∈ℋ

𝒮 𝑓 − መ𝒮 𝑓 generalization error

• To obtain upper bounds:1. Bound generalization error

2. Relate excess risk to ⅆ 𝑇, 𝑇0 : stability/margin condition

ⅆ 𝑇, 𝑇0 ≲

Page 21: Jan-Christian (JC) Hütter jchuetter.web@gmail

21 / 29

Proposition (H., Rigollet)

Assume 𝑓, 𝑓∗, 𝑓0 differentiable, 𝑓 convex and ∇2𝑓 ≼ 𝑀𝐼 (M-Lipschitz). Then,

න ∇𝑓 𝑥 − ∇𝑓0 𝑥 22 ⅆ𝑃 𝑥 ≲ 𝑀 𝒮 𝑓 − 𝒮 𝑓0 ,

where 𝒮 𝑓 = න𝑓 𝑥 ⅆ𝑃 𝑥 + න𝑓∗ 𝑦 ⅆ𝑄 𝑦 .

• Dualization of argument by Ambrosio (Gigli, 2011)

• If 𝑓 additionally 𝑀-strongly convex, change roles to obtain

න ∇𝑓 𝑥 − ∇𝑓0 𝑥 22 ⅆ𝑃 𝑥 ≍ 𝒮 𝑓 − 𝒮 𝑓0

Page 22: Jan-Christian (JC) Hütter jchuetter.web@gmail

22 / 29

• Manole et al. 2021: Rates for smooth densitiesIf

• 𝑔0 𝑀-strongly convex and ∇𝑔0 𝑀-Lipschitz

• ෨𝑄 any measure

• ሚ𝑓 Kantorovich potential for 𝑃, ෨𝑄

Then,

න ∇ ሚ𝑓 𝑥 − ∇𝑓0 𝑥2

2ⅆ𝑃 𝑥 ≲ 𝑊2

2 𝑃, ෨𝑄 −𝑊22 𝑃, 𝑄 − න𝑔0 𝑦 ⅆ ෨𝑄 − 𝑄 𝑦

• Used to show plug-in rates for transport map estimation

• Two-sample case less general, needs regularity of candidate potential

Page 23: Jan-Christian (JC) Hütter jchuetter.web@gmail

23 / 29

Ground truth𝑇0 𝑥 = 𝑥, identity map

Baseline estimator 𝑇𝑒𝑚𝑝

𝑇0 𝑥 = 𝑥, identity map

Wavelet estimator 𝑇𝑤𝑎𝑣Discretize hyperbox around distribution, compute 𝑓∗ with Linear-Time Legendre Transform, then interpolate. This is slow.

Heuristic kernel estimator 𝑇𝑘𝑒𝑟Solve OT between 𝑃, 𝑄 to get matching 𝜋 ∶ 𝑁 → 𝑁 , compute kernel ridge estimator with input data 𝑋𝑖 , 𝑌𝜋 𝑖 𝑖

Page 24: Jan-Christian (JC) Hütter jchuetter.web@gmail

24 / 29

Measure performance by

MSE𝑛 𝑇 =1

𝑛σ𝑖=1𝑛 𝑇 𝑋𝑖 − 𝑇0 𝑋𝑖 2

2

Page 25: Jan-Christian (JC) Hütter jchuetter.web@gmail

25 / 29

• Can mimic rates from classical non-parametric estimation for OT

• Smoothness can beat curse of dimensionality

• Stability estimates were useful

• Tricky: Get computationally efficient estimators

Page 26: Jan-Christian (JC) Hütter jchuetter.web@gmail

26 / 29

• Entropic regularization well-known for providing fast algorithms for OT

• Genevay et al. 2019, Mena, Weed, 2019:For compactly supported measures, can show that optimal potentials 𝑓𝜖 , 𝑔𝜖 ∈ 𝐻𝛼 𝑤𝑖𝑡ℎ

𝑓𝜖 𝐻𝛼 , 𝑔𝜖 𝐻𝛼 ≲ 1 + 𝜖−1/ 𝛼−1

• Obtain plug-in rates:

𝔼 𝑊𝜖2 𝑃, 𝑄 −𝑊𝜖

2 𝑃, 𝑄 ≲1

𝑛1 + 𝜖−𝑑/2

• Approximation guarantees𝑊𝜖

2 𝑃, 𝑄 −𝑊22 𝑃, 𝑄 ≲ 𝜖 log 𝜖−1

• Chizat et al., 2020: Improved approximation guarantees of order 𝜖2 𝐼(𝑃, 𝑄)

• Still geometrically bad estimation rates for 𝑊22

Primal

𝑊𝜖2 𝑃, 𝑄 = min

𝛾න 𝑦 − 𝑥 2

2 ⅆ𝛾 𝑥, 𝑦 + 𝜖𝐷𝐾𝐿 𝛾 ‖ 𝑃 ⊗ 𝑄

s.t. 𝛾 ∈ Γ 𝑃, 𝑄 ,

𝐷𝐾𝐿 𝛾 ‖ 𝑃 ⊗ 𝑄 = න log𝛾 𝑥, 𝑦

𝑃 𝑥 𝑄 𝑦ⅆ𝛾 𝑥, 𝑦

Dual

𝑓𝜖 , 𝑔𝜖 ∈ max𝑓,𝑔

න𝑓 𝑥 ⅆ𝑃 𝑥 + න𝑔 𝑦 ⅆ𝑄 𝑦 + 𝜖

−𝜖නexp𝑓 𝑥 + 𝑔 𝑦 − 𝑦 − 𝑥 2

2

𝜖ⅆ𝑃 𝑥 ⅆ𝑄 𝑦

Optimality condition

𝑓𝜖 𝑥 = −𝜖 logනexp𝑔 𝑦 − 𝑦 − 𝑥 2

2

𝜖ⅆ𝑄 𝑦

Page 27: Jan-Christian (JC) Hütter jchuetter.web@gmail

27 / 29

• (Pooladian, Niles-Weed, 09/27/2021)

• Estimate transport map as barycentric projection of entropic plan, can write in terms of dual potentials 𝑓𝜖 , 𝑔𝜖 :

𝑇𝜖 𝑋𝑖 =𝑦 ⅆො𝛾𝜖 𝑋𝑖 , 𝑦

ⅆ ො𝛾𝜖 𝑋𝑖, 𝑦, can extend to: 𝑇𝜖 𝑥 =

𝑦 exp𝑔𝜖 𝑦 − 𝑥 − 𝑦 2

2

𝜖ⅆ 𝑄 𝑦

exp𝑔𝜖 𝑦 − 𝑥 − 𝑦 2

2

𝜖ⅆ 𝑄 𝑦

• Under assumptions (including 1𝑀𝐼 ≼ ∇2f ≼ 𝑀𝐼), for 1 < 𝛼 ≤ 3 and ⅆ′ = 2

𝑑

2,

𝔼න 𝑇𝜖 𝑥 − 𝑇0 𝑥2

2ⅆ𝑃(𝑥) ≲ 𝑛

−𝛼+1

2𝛼+2+2𝑑′ log 𝑛

• Compare to information-theoretic limit 𝑛−2𝛼

2𝛼−2+𝑑

• Goes through the same approximation steps as before for estimation of 𝑊22, so won’t beat

geometrically bad dependence as ⅆ → ∞

Page 28: Jan-Christian (JC) Hütter jchuetter.web@gmail

28 / 29

• Reproducing kernel Hilbert space

• Hilbert space ℋ of functions on ℝ𝑑

• kernel 𝑘 ∶ ℝ𝑑 × ℝ𝑑 → ℝ such that𝑓 𝑥 = 𝑘 𝑥, . , 𝑓 ℋ for 𝑓 ∈ ℋ

• e.g. Sobolev space 𝐻𝛼 if 𝛼 > ⅆ/2

• Advantages• interpolation problems finite dimensional

(representer theorem)• statistical guarantees

• Vacher et al, 2021: Under additional assumptions, if 𝑃, 𝑄 have 𝐻𝛼-densities with 𝛼 > 3ⅆ, obtain estimator 𝑊2

2 𝑃, 𝑄 in time O(𝑛2) with

𝔼 𝑊22 𝑃, 𝑄 −𝑊2

2 𝑃, 𝑄 ≲ 𝑛−12 log 𝑛

• Compare to 𝑛−2𝛼

2𝛼−2+𝑑 ≈ 𝑛−6

7 information-theoretic rate

maxන𝑓 𝑥 ⅆ 𝑃 𝑥 + න𝑔 𝑦 ⅆ 𝑄(𝑦)

s.t. 𝑓 𝑥 + 𝑔 𝑦 ≤ 𝑥 − 𝑦 22, 𝑥, 𝑦 ∈ ℝ𝑑

min 𝑓,1

𝑛

𝑖=1

𝑛

𝑘 𝑋𝑖 , .

+ 𝑔,1

𝑛

𝑗=1

𝑛

𝑘 𝑌𝑗 , .

s.𝑡. 𝑓, 𝑔 ∈ ℋ,

෨𝑋𝑖 − ෨𝑌𝑖 2

2− 𝑓 ෨𝑋𝑖 − 𝑔 ෨𝑌𝑗 =

𝑘 ෨𝑋𝑖 , . 𝑘 ෨𝑌𝑗 , . , 𝐴 𝑘 ෨𝑋𝑖 , . 𝑘 ෨𝑌𝑗 , . , 𝑖, 𝑗 ∈ ℓ

for positive operator 𝐴 and additional sample ෨𝑋𝑖 , ෨𝑌𝑗

Page 29: Jan-Christian (JC) Hütter jchuetter.web@gmail
Page 30: Jan-Christian (JC) Hütter jchuetter.web@gmail

30 / 29

• Forrow, H., et al., 2019:Gain statistical efficiency from sparse Wasserstein barycenters

ො𝜌𝑘 = arg min{𝑊22 𝑃, 𝜌 +𝑊2

2 𝜌, 𝑄 ∶

𝜌 supported on 𝑘 points}

• Generalization bounds

sup𝜌𝔼 𝑊2

2 𝜌, 𝑃 −𝑊22 𝜌, 𝑃 ≲

𝑘3 log 𝑘

𝑛

• But no approximation guarantees

Page 31: Jan-Christian (JC) Hütter jchuetter.web@gmail

31 / 29

• Paty, d’Aspremont, Cuturi, 2020:Alternative way of ensuring strong convexity:

መ𝑓𝑆𝑆𝑁𝐵 = arg min{𝑊22 ∇𝑓# 𝑃, 𝑄 ∶ 𝑓 strongly convex and ∇𝑓 Lipschitz}

• Leads to mixed quadratically constrained quadratic program/Wasserstein problem

Page 32: Jan-Christian (JC) Hütter jchuetter.web@gmail

32 / 29

• Recall semi-dual:

𝒮 𝑓 = න𝑓 𝑥 ⅆ𝑃 𝑥 + න𝑓∗ 𝑦 ⅆ𝑄(𝑦)

𝑓∗ 𝑦 = sup𝑥

𝑥, 𝑦 − 𝑓 𝑥

= ∇𝑓∗ 𝑦 , 𝑦 − 𝑓 ∇𝑓∗ 𝑦

≥ 𝑦, ∇𝑔 𝑦 − 𝑓 ∇𝑔 𝑦 ∀𝑔

• Makkuva et al., 2020: Replace 𝑓, 𝑔 by input-convex neural networks, 𝐼𝐶𝑁𝑁:

minf∈𝐼𝐶𝑁𝑁

max𝑔∈𝐼𝐶𝑁𝑁

න𝑓 𝑥 ⅆ 𝑃 𝑥 + න 𝑦, ∇𝑔 𝑦 − 𝑓 ∇g y ⅆ 𝑄(𝑦)

Input convex neural network:𝑧𝑙+1 = 𝜎𝑙 𝑊𝑙𝑧𝑙 + 𝐴𝑙𝑥 + 𝑏𝑙 ,

𝑊𝑙 ≥ 0, for 𝑙 = 0,… , 𝐿 − 1𝑓 𝑥 = 𝑧𝐿

Page 33: Jan-Christian (JC) Hütter jchuetter.web@gmail

33 / 29

• OT can be very useful, but also very noisy

• Compromise: projection robust Wasserstein distance, entropic OT

• Regularize: recent progress in regularized estimators with statistical and computational guarantees

• Improve:

• Computationally efficient estimator matching information-theoretic rates

• Rates for plug-in transport map estimators with smoothing

Thanks!

Page 34: Jan-Christian (JC) Hütter jchuetter.web@gmail

34 / 29

• Chizat, Lenaic, Pierre Roussillon, Flavien Léger, François-Xavier Vialard, and Gabriel Peyré. “Faster Wasserstein Distance Estimation with the Sinkhorn Divergence.” Advances in Neural Information Processing Systems 33 (2020).

• Courty, Nicolas, Rémi Flamary, Devis Tuia, and Alain Rakotomamonjy. “Optimal Transport for Domain Adaptation.” IEEE Transactions on Pattern Analysis and Machine Intelligence 39, no. 9 (2016): 1853–65.

• Deb, Nabarun, Promit Ghosal, and Bodhisattva Sen. “Rates of Estimation of Optimal Transport Maps Using Plug-in Estimators via Barycentric Projections.” ArXiv:2107.01718 [Math, Stat], July 4, 2021. http://arxiv.org/abs/2107.01718.

• Dudley, Richard Mansfield. “The Speed of Mean Glivenko-Cantelli Convergence.” The Annals of Mathematical Statistics 40, no. 1 (1969): 40–50.

Page 35: Jan-Christian (JC) Hütter jchuetter.web@gmail

35 / 29

• Arjovsky, Martin, Soumith Chintala, and Léon Bottou. “Wasserstein Generative Adversarial Networks.” In International Conference on Machine Learning, 214–23. PMLR, 2017.

• Forrow, Aden, Jan-Christian Hütter, Mor Nitzan, Philippe Rigollet, Geoffrey Schiebinger, and Jonathan Weed. “Statistical Optimal Transport via Factored Couplings.” In Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics, 2454–65. PMLR, 2019. https://proceedings.mlr.press/v89/forrow19a.html.

• Fournier, Nicolas, and Arnaud Guillin. “On the Rate of Convergence in Wasserstein Distance of the Empirical Measure.” Probability Theory and Related Fields 162, no. 3 (2015): 707–38.

• Frogner, Charlie, Chiyuan Zhang, Hossein Mobahi, Mauricio Araya-Polo, and Tomaso Poggio. “Learning with a Wasserstein Loss.” ArXiv Preprint ArXiv:1506.05439, 2015.

• Genevay, Aude, Lénaïc Chizat, Francis Bach, Marco Cuturi, and Gabriel Peyré. “Sample Complexity of SinkhornDivergences.” In The 22nd International Conference on Artificial Intelligence and Statistics, 1574–83. PMLR, 2019. https://proceedings.mlr.press/v89/genevay19a.html.

• Hütter, Jan-Christian, and Philippe Rigollet. “Minimax Estimation of Smooth Optimal Transport Maps.” The Annals of Statistics 49, no. 2 (April 2021): 1166–94. https://doi.org/10.1214/20-AOS1997.

Page 36: Jan-Christian (JC) Hütter jchuetter.web@gmail

36 / 29

• Kolouri, Soheil, Kimia Nadjahi, Umut Simsekli, Roland Badeau, and Gustavo Rohde. “Generalized Sliced Wasserstein Distances.” In Advances in Neural Information Processing Systems, Vol. 32. Curran Associates, Inc., 2019. https://proceedings.neurips.cc/paper/2019/hash/f0935e4cd5920aa6c7c996a5ee53a70f-Abstract.html.

• Li, Wenbo, and Ricardo H. Nochetto. “Quantitative Stability and Error Estimates for Optimal Transport Plans.” IMA Journal of Numerical Analysis 41, no. 3 (2021): 1941–65.

• Lin, Tianyi, Chenyou Fan, Nhat Ho, Marco Cuturi, and Michael I. Jordan. “Projection Robust Wasserstein Distance and Riemannian Optimization.” ArXiv Preprint ArXiv:2006.07458, 2020.

• Lin, Tianyi, Zeyu Zheng, Elynn Chen, Marco Cuturi, and Michael Jordan. “On Projection Robust Optimal Transport: Sample Complexity and Model Misspecification.” In International Conference on Artificial Intelligence and Statistics, 262–70. PMLR, 2021.

• Makkuva, Ashok, Amirhossein Taghvaei, Sewoong Oh, and Jason Lee. “Optimal Transport Mapping via Input Convex Neural Networks.” In Proceedings of the 37th International Conference on Machine Learning, 6672–81. PMLR, 2020. https://proceedings.mlr.press/v119/makkuva20a.html.

Page 37: Jan-Christian (JC) Hütter jchuetter.web@gmail

37 / 29

• Manole, Tudor, Sivaraman Balakrishnan, Jonathan Niles-Weed, and Larry Wasserman. “Plugin Estimation of Smooth Optimal Transport Maps.” ArXiv Preprint ArXiv:2107.12364, 2021.

• Manole, Tudor, and Jonathan Niles-Weed. “Sharp Convergence Rates for Empirical Optimal Transport with Smooth Costs.” ArXiv Preprint ArXiv:2106.13181, 2021.

• Mena, Gonzalo, and Jonathan Weed. “Statistical Bounds for Entropic Optimal Transport: Sample Complexity and the Central Limit Theorem.” ArXiv:1905.11882 [Cs, Math, Stat], May 30, 2019. http://arxiv.org/abs/1905.11882.

• Niles-Weed, Jonathan, and Philippe Rigollet. “Estimation of Wasserstein Distances in the Spiked Transport Model.” ArXiv Preprint ArXiv:1909.07513, 2019.

• Paty, François-Pierre, Alexandre d’Aspremont, and Marco Cuturi. “Regularity as Regularization: Smooth and Strongly Convex Brenier Potentials in Optimal Transport.” In International Conference on Artificial Intelligence and Statistics, 1222–32. PMLR, 2020.

Page 38: Jan-Christian (JC) Hütter jchuetter.web@gmail

38 / 29

• Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. “Entropic Estimation of Optimal Transport Maps.” ArXiv:2109.12004 [Math, Stat], September 24, 2021. http://arxiv.org/abs/2109.12004.

• Rabin, Julien, Julie Delon, and Yann Gousseau. “Regularization of Transportation Maps for Color and Contrast Transfer.” In 2010 IEEE International Conference on Image Processing, 1933–36. IEEE, 2010.

• Rolet, Antoine, Marco Cuturi, and Gabriel Peyré. “Fast Dictionary Learning with a Smoothed Wasserstein Loss.” In Artificial Intelligence and Statistics, 630–38. PMLR, 2016.

• Schiebinger, Geoffrey, Jian Shu, Marcin Tabaka, Brian Cleary, Vidya Subramanian, Aryeh Solomon, Siyan Liu, Stacie Lin, Peter Berube, and Lia Lee. “Reconstruction of Developmental Landscapes by Optimal-Transport Analysis of Single-Cell Gene Expression Sheds Light on Cellular Reprogramming.” BioRxiv, 2017, 191056.

• Vacher, Adrien, Boris Muzellec, Alessandro Rudi, Francis Bach, and Francois-Xavier Vialard. “A Dimension-Free Computational Upper-Bound for Smooth Optimal Transport Estimation.” ArXiv Preprint ArXiv:2101.05380, 2021.

• Weed, Jonathan, and Francis Bach. “Sharp Asymptotic and Finite-Sample Rates of Convergence of Empirical Measures in Wasserstein Distance.” Bernoulli 25, no. 4A (2019): 2620–48.

• Weed, Jonathan, and Quentin Berthet. “Estimation of Smooth Densities in Wasserstein Distance.” In Conference on Learning Theory, 3118–19. PMLR, 2019.

Page 39: Jan-Christian (JC) Hütter jchuetter.web@gmail
Page 40: Jan-Christian (JC) Hütter jchuetter.web@gmail

40 / 29

convex conjugate: 𝑓∗ 𝑦 = sup𝑧∈ℝ𝑑 𝑦, 𝑧 − 𝑓(𝑧)

𝑓∗ 𝑦 = 𝑦, 𝑥 − 𝑓 𝑥 ⇔ 𝑦 = ∇𝑓(𝑥)

𝑓 convex, ∇𝑓 𝐿−Lipschitz: 𝑓∗ 𝑦 ≥ 𝑓∗ 𝑥 + ⟨∇𝑓∗ 𝑥 , 𝑦 − 𝑥⟩

+1

2𝐿𝑦 − 𝑥 2

2

∇𝑓∗ = ∇𝑓 −1

න𝑓 𝑥 d𝑃 𝑥 + න𝑓∗ 𝑦 d𝑄 𝑦

= න 𝑓 𝑥 + 𝑓∗ ∇𝑓0 𝑥 d𝑃 𝑥 , 𝑄 = 𝑇0 #𝑃 = ∇𝑓0 #𝑃

≥ න 𝑓 𝑥 + 𝑓∗ ∇𝑓 𝑥

= 𝑥,∇𝑓 𝑥 − 𝑓 𝑥

+ ∇𝑓∗ ∇𝑓 𝑥 , ∇𝑓0 𝑥 − ∇𝑓 𝑥 +1

2𝐿∇𝑓0 𝑥 − ∇𝑓 𝑥 2

2 d𝑃 𝑥

= න 𝑥, ∇𝑓0 𝑥 +1

2𝐿∇𝑓0 𝑥 − ∇𝑓 𝑥 2

2 d𝑃 𝑥

= න𝑓0 𝑥 d𝑃 𝑥 + න𝑓0∗ 𝑦 d𝑄 𝑦 +

1

2𝐿න ∇𝑓0 𝑥 − ∇𝑓 𝑥 2

2 d𝑃 𝑥

𝑥