Skip to content

Commit 060d904

Browse files
committed
add cvx barycenter solver
1 parent be88177 commit 060d904

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

ot/lp/cvx.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
LP solvers for optimal transport using cvxopt
4+
"""
5+
6+
# Author: Remi Flamary <remi.flamary@unice.fr>
7+
#
8+
# License: MIT License
9+
10+
import numpy as np
11+
import scipy as sp
12+
import scipy.sparse as sps
13+
14+
try:
15+
import cvxopt
16+
from cvxopt import solvers, matrix, sparse, spmatrix
17+
except ImportError:
18+
cvxopt=False
19+
20+
def scipy_sparse_to_spmatrix(A):
21+
"""Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
22+
coo = A.tocoo()
23+
SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
24+
return SP
25+
26+
def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-point'):
27+
"""Compute the entropic regularized wasserstein barycenter of distributions A
28+
29+
The function solves the following optimization problem [16]:
30+
31+
.. math::
32+
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
33+
34+
where :
35+
36+
- :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
37+
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
38+
39+
The linear program is solved using the default cvxopt solver if installed.
40+
If cvxopt is not installed it uses the lp solver from scipy.optimize.
41+
42+
Parameters
43+
----------
44+
A : np.ndarray (d,n)
45+
n training distributions of size d
46+
M : np.ndarray (d,d)
47+
loss matrix for OT
48+
reg : float
49+
Regularization term >0
50+
weights : np.ndarray (n,)
51+
Weights of each histogram i_i on the simplex
52+
verbose : bool, optional
53+
Print information along iterations
54+
log : bool, optional
55+
record log if True
56+
solver : string, optional
57+
the solver used, default 'interior-point' use the lp solver from
58+
scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
59+
60+
Returns
61+
-------
62+
a : (d,) ndarray
63+
Wasserstein barycenter
64+
log : dict
65+
log dictionary return only if log==True in parameters
66+
67+
68+
References
69+
----------
70+
71+
.. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
72+
73+
74+
75+
"""
76+
77+
if weights is None:
78+
weights = np.ones(A.shape[1]) / A.shape[1]
79+
else:
80+
assert(len(weights) == A.shape[1])
81+
82+
n_distributions=A.shape[1]
83+
n=A.shape[0]
84+
85+
n2=n*n
86+
c=np.zeros((0))
87+
b_eq1=np.zeros((0))
88+
for i in range(n_distributions):
89+
c=np.concatenate((c,M.ravel()*weights[i]))
90+
b_eq1=np.concatenate((b_eq1,A[:,i]))
91+
c=np.concatenate((c,np.zeros(n)))
92+
93+
lst_idiag1=[sps.kron(sps.eye(n),np.ones((1,n))) for i in range(n_distributions)]
94+
# row constraints
95+
A_eq1=sps.hstack((sps.block_diag(lst_idiag1),sps.coo_matrix((n_distributions*n,n))))
96+
97+
# columns constraints
98+
lst_idiag2=[]
99+
lst_eye=[]
100+
for i in range(n_distributions):
101+
if i==0:
102+
lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n)))
103+
lst_eye.append(-sps.eye(n))
104+
else:
105+
lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n-1,n)))
106+
lst_eye.append(-sps.eye(n-1,n))
107+
108+
A_eq2=sps.hstack((sps.block_diag(lst_idiag2),sps.vstack(lst_eye)))
109+
b_eq2=np.zeros((A_eq2.shape[0]))
110+
111+
# full problem
112+
A_eq=sps.vstack((A_eq1,A_eq2))
113+
b_eq=np.concatenate((b_eq1,b_eq2))
114+
115+
if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point
116+
117+
if solver is None:
118+
solver='interior-point'
119+
120+
options={'sparse':True,'disp': verbose}
121+
sol=sp.optimize.linprog(c,A_eq=A_eq,b_eq=b_eq,method=solver,options=options)
122+
x=sol.x
123+
b=x[-n:]
124+
125+
else:
126+
127+
h=np.zeros((n_distributions*n2+n))
128+
G=-sps.eye(n_distributions*n2+n)
129+
130+
sol=solvers.lp(matrix(c),scipy_sparse_to_spmatrix(G),matrix(h),A=scipy_sparse_to_spmatrix(A_eq),b=matrix(b_eq),solver=solver)
131+
132+
x=np.array(sol['x'])
133+
b=x[-n:].ravel()
134+
135+
if log:
136+
return b, sol
137+
else:
138+
return b

0 commit comments

Comments
 (0)