Sinkhorn 解决最优传输问题,即把一个概率分布以最小代价转换成另外一个分布。本文是关于 Sinkhorn 相关网络资料的整理。
Sinkhorn 算法
1 介绍
Sinkhorn 解决最优传输问题,即把一个概率分布以最小代价转换成另外一个分布。举例说明:
现有 5 种小吃:merveilleux、eclair、chocolate mousse、bavarois、carrot cake,每种小吃数量如下表:
小吃
merveilleux
eclair
chocolate mousse
bavarois
carrot cake
数量
4
2
6
4
4
要将这些小吃分享给 8 个人,Bernard、Jan、Willem、Hilde、Steffie、Marlies、Tim、Wouter,每个人对小吃的需求量如下表
名字
Bernard
Jan
Willem
Hilde
Steffie
Marlies
Tim
Wouter
需求量
3
3
3
4
2
2
2
1
每个人对各小吃的喜爱程度如下表所示,其中 − 2 -2 − 2 分表示非常不喜欢,2 2 2 分表示非常喜欢:
merveilleux
eclair
chocolate mousse
bavarois
carrot cake
Bernard
2
2
1
0
0
Jan
0
-2
-2
-2
2
Willem
1
2
2
2
-1
Hilde
2
1
0
1
-1
Steffie
0.5
2
2
1
0
Marlies
0
1
1
1
-1
Tim
-2
2
2
1
1
Wouter
2
1
2
1
-1
任务是将这些小吃分发给这 8 个人,同时使得大家满意度最高。
2 问题的数学表达
设 r \boldsymbol{r} r 为每个人对小吃的需求量,r = [ 3 , 3 , 3 , 4 , 2 , 2 , 2 , 1 ] T \boldsymbol{r}=[3,3,3,4,2,2,2,1]^T r = [ 3 , 3 , 3 , 4 , 2 , 2 , 2 , 1 ] T ,设 r \boldsymbol{r} r 的维度为 n n n ,此处 n = 8 n=8 n = 8 ;
设 c \boldsymbol{c} c 为每种小吃的数量,c = [ 4 , 2 , 6 , 4 , 4 ] T \boldsymbol{c}=[4,2,6,4,4]^T c = [ 4 , 2 , 6 , 4 , 4 ] T ,设 c \boldsymbol{c} c 的维度为 m m m ,此处 m = 5 m=5 m = 5 ;
一般地,r \boldsymbol{r} r 和 c \boldsymbol{c} c (即从一个概率分布转为另一个概率分布的两个概率分布)表示边缘分布,因此 r \boldsymbol{r} r 和 c \boldsymbol{c} c 中元素和为 1 \boldsymbol{1} 1 ,这需要在后续过程中加以处理。
定义
U ( r , c ) = { P ∈ R + n × m ∣ P 1 m = r , P T 1 n = c } U(\pmb{r}, \pmb{c}) = \{ \pmb{P} \in \mathbb{R}_{+}^{n \times m} | \pmb{P} \pmb{1}_m = \pmb{r}, \pmb{P}^T \pmb{1}_n = \pmb{c} \}
U ( r r r , c c c ) = { P P P ∈ R + n × m ∣ P P P 1 1 1 m = r r r , P P P T 1 1 1 n = c c c }
U ( r , c ) U(\boldsymbol{r}, \boldsymbol{c}) U ( r , c ) 包含所有可能的小吃分配方案,注意,此处每份小吃可以随意切割进行分配。
每个用户对各小吃的喜爱程度存储在矩阵 M ∈ R n × m \boldsymbol{M} \in \mathbb{R}^{n \times m} M ∈ R n × m 中(M \boldsymbol{M} M 在一些文献中也被称为代价矩阵,本例中将矩阵 M \boldsymbol{M} M 中的元素取负即可得到其对应的代价矩阵)。
最终问题可以表示为:
d M ( r , c ) = min P ∈ U ( r , c ) ∑ P i j M i j d_{\pmb{M}}(\pmb{r}, \pmb{c}) = \min_{\pmb{P} \in U(\pmb{r}, \pmb{c})} \sum P_{ij} M_{ij}
d M M M ( r r r , c c c ) = P P P ∈ U ( r r r , c c c ) min ∑ P i j M i j
其中 d M ( r , c ) d_{\boldsymbol{M}} (\boldsymbol{r}, \boldsymbol{c}) d M ( r , c ) 又被称为 Wasserstein 距离。
上述问题还可添加正则项,使得问题的描述更加合理,即:
d M λ ( r , c ) = min P ∈ U ( r , c ) ∑ P i j M i j + 1 λ ( − P i j log P i j ) d_{\pmb{M}}^\lambda (\pmb{r}, \pmb{c}) = \min_{\pmb{P} \in U(\pmb{r}, \pmb{c})} \sum P_{ij} M_{ij} + \frac{1}{\lambda} (-P_{ij} \log P_{ij})
d M M M λ ( r r r , c c c ) = P P P ∈ U ( r r r , c c c ) min ∑ P i j M i j + λ 1 ( − P i j log P i j )
d M λ ( r , c ) d_{\boldsymbol{M}}^\lambda (\boldsymbol{r}, \boldsymbol{c}) d M λ ( r , c ) 又被称为 Sinkhorn 距离。
3 Sinkhorn 算法
上述问题最优解可表示为:
P i j = α i β j e − λ M i j P_{ij} = \alpha_i \beta_j e^{-\lambda M_{ij}}
P i j = α i β j e − λ M i j
其中 α i \alpha_i α i 和 β j \beta_j β j 为待求解的常数,具体算法为:
1 . g i v e n : M , r , c , λ 2 . i n i t i a l i z e : P λ = e − λ M 3 . r e p e a t ∣ 1. s c a l e t h e r o w s s u c h t h a t t h e r o w s s u m s m a t c h r ∣ 2. s c a l e t h e c o l u m n s s u c h t h a t t h e c o l u m n s u m s m a t c h c 4 . u n t i l c o n v e r g e n c e \begin{aligned}
1&. \mathrm{\pmb{given:}} \ M, \pmb{r}, \pmb{c}, \lambda \\
2&. \mathrm{\pmb{initialize:}} \ P_{\lambda} = e^{-\lambda M} \\
3&. \mathrm{\pmb{repeat}} \\
&| 1.\ \mathrm{scale\ the\ rows\ such\ that\ the\ rows\ sums\ match}\ r \\
&| 2.\ \mathrm{scale\ the\ columns\ such\ that\ the\ column\ sums\ match}\ c \\
4&. \mathrm{\pmb{until}\ convergence}
\end{aligned}
1 2 3 4 . g i v e n : g i v e n : g i v e n : M , r r r , c c c , λ . i n i t i a l i z e : ini t ia l i ze : i n i t i a l i z e : P λ = e − λ M . r e p e a t re p e a t r e p e a t ∣ 1 . s c a l e t h e r o w s s u c h t h a t t h e r o w s s u m s m a t c h r ∣ 2 . s c a l e t h e c o l u m n s s u c h t h a t t h e c o l u m n s u m s m a t c h c . u n t i l u n t i l u n t i l c o n v e r g e n c e
4 Python 实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def compute_optimal_transport (M, r, c, lam, epsilon=1e-8 ): """ Computes the optimal transport matrix and Slinkhorn distance using the Sinkhorn-Knopp algorithm Inputs: - M : cost matrix (n x m) - r : vector of marginals (n, ) - c : vector of marginals (m, ) - lam : strength of the entropic regularization - epsilon : convergence parameter Outputs: - P : optimal transport matrix (n x m) - dist : Sinkhorn distance """ n, m = M.shape P = np.exp(- lam * M) P /= P.sum () u = np.zeros(n) while np.max (np.abs (u - P.sum (1 ))) > epsilon: u = P.sum (1 ) P *= (r / u).reshape((-1 , 1 )) P *= (c / P.sum (0 )).reshape((1 , -1 )) return P, np.sum (P * M)
参考