Sinkhorn Algorithm

Sinkhorn 解决最优传输问题,即把一个概率分布以最小代价转换成另外一个分布。本文是关于 Sinkhorn 相关网络资料的整理。

Sinkhorn 算法

1 介绍

Sinkhorn 解决最优传输问题,即把一个概率分布以最小代价转换成另外一个分布。举例说明:

  • 现有 5 种小吃:merveilleux、eclair、chocolate mousse、bavarois、carrot cake,每种小吃数量如下表:

    小吃 merveilleux eclair chocolate mousse bavarois carrot cake
    数量 4 2 6 4 4
  • 要将这些小吃分享给 8 个人,Bernard、Jan、Willem、Hilde、Steffie、Marlies、Tim、Wouter,每个人对小吃的需求量如下表

    名字 Bernard Jan Willem Hilde Steffie Marlies Tim Wouter
    需求量 3 3 3 4 2 2 2 1
  • 每个人对各小吃的喜爱程度如下表所示,其中 2-2 分表示非常不喜欢,22 分表示非常喜欢:

    merveilleux eclair chocolate mousse bavarois carrot cake
    Bernard 2 2 1 0 0
    Jan 0 -2 -2 -2 2
    Willem 1 2 2 2 -1
    Hilde 2 1 0 1 -1
    Steffie 0.5 2 2 1 0
    Marlies 0 1 1 1 -1
    Tim -2 2 2 1 1
    Wouter 2 1 2 1 -1
  • 任务是将这些小吃分发给这 8 个人,同时使得大家满意度最高。

2 问题的数学表达

  • r\boldsymbol{r} 为每个人对小吃的需求量,r=[3,3,3,4,2,2,2,1]T\boldsymbol{r}=[3,3,3,4,2,2,2,1]^T,设 r\boldsymbol{r} 的维度为 nn,此处 n=8n=8
  • c\boldsymbol{c} 为每种小吃的数量,c=[4,2,6,4,4]T\boldsymbol{c}=[4,2,6,4,4]^T,设 c\boldsymbol{c} 的维度为 mm,此处 m=5m=5

一般地,r\boldsymbol{r}c\boldsymbol{c}(即从一个概率分布转为另一个概率分布的两个概率分布)表示边缘分布,因此 r\boldsymbol{r}c\boldsymbol{c} 中元素和为 1\boldsymbol{1},这需要在后续过程中加以处理。

  • 定义

    U(r,c)={PR+n×mP1m=r,PT1n=c}U(\pmb{r}, \pmb{c}) = \{ \pmb{P} \in \mathbb{R}_{+}^{n \times m} | \pmb{P} \pmb{1}_m = \pmb{r}, \pmb{P}^T \pmb{1}_n = \pmb{c} \}

    U(r,c)U(\boldsymbol{r}, \boldsymbol{c}) 包含所有可能的小吃分配方案,注意,此处每份小吃可以随意切割进行分配。

  • 每个用户对各小吃的喜爱程度存储在矩阵 MRn×m\boldsymbol{M} \in \mathbb{R}^{n \times m} 中(M\boldsymbol{M} 在一些文献中也被称为代价矩阵,本例中将矩阵 M\boldsymbol{M} 中的元素取负即可得到其对应的代价矩阵)。

最终问题可以表示为:

dM(r,c)=minPU(r,c)PijMijd_{\pmb{M}}(\pmb{r}, \pmb{c}) = \min_{\pmb{P} \in U(\pmb{r}, \pmb{c})} \sum P_{ij} M_{ij}

其中 dM(r,c)d_{\boldsymbol{M}} (\boldsymbol{r}, \boldsymbol{c}) 又被称为 Wasserstein 距离。

上述问题还可添加正则项,使得问题的描述更加合理,即:

dMλ(r,c)=minPU(r,c)PijMij+1λ(PijlogPij)d_{\pmb{M}}^\lambda (\pmb{r}, \pmb{c}) = \min_{\pmb{P} \in U(\pmb{r}, \pmb{c})} \sum P_{ij} M_{ij} + \frac{1}{\lambda} (-P_{ij} \log P_{ij})

dMλ(r,c)d_{\boldsymbol{M}}^\lambda (\boldsymbol{r}, \boldsymbol{c}) 又被称为 Sinkhorn 距离。

3 Sinkhorn 算法

上述问题最优解可表示为:

Pij=αiβjeλMijP_{ij} = \alpha_i \beta_j e^{-\lambda M_{ij}}

其中 αi\alpha_iβj\beta_j 为待求解的常数,具体算法为:

1.given: M,r,c,λ2.initialize: Pλ=eλM3.repeat1. scale the rows such that the rows sums match r2. scale the columns such that the column sums match c4.until convergence\begin{aligned} 1&. \mathrm{\pmb{given:}} \ M, \pmb{r}, \pmb{c}, \lambda \\ 2&. \mathrm{\pmb{initialize:}} \ P_{\lambda} = e^{-\lambda M} \\ 3&. \mathrm{\pmb{repeat}} \\ &| 1.\ \mathrm{scale\ the\ rows\ such\ that\ the\ rows\ sums\ match}\ r \\ &| 2.\ \mathrm{scale\ the\ columns\ such\ that\ the\ column\ sums\ match}\ c \\ 4&. \mathrm{\pmb{until}\ convergence} \end{aligned}

4 Python 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def compute_optimal_transport(M, r, c, lam, epsilon=1e-8):
"""
Computes the optimal transport matrix and Slinkhorn distance using the
Sinkhorn-Knopp algorithm
Inputs:
- M : cost matrix (n x m)
- r : vector of marginals (n, )
- c : vector of marginals (m, )
- lam : strength of the entropic regularization
- epsilon : convergence parameter
Outputs:
- P : optimal transport matrix (n x m)
- dist : Sinkhorn distance
"""
n, m = M.shape
P = np.exp(- lam * M)
P /= P.sum()
u = np.zeros(n)
# normalize this matrix
while np.max(np.abs(u - P.sum(1))) > epsilon:
u = P.sum(1)
P *= (r / u).reshape((-1, 1)) # 行归 r 化,注意python中*号含义
P *= (c / P.sum(0)).reshape((1, -1)) # 列归 c 化
return P, np.sum(P * M)

参考