1515 import cvxopt
1616 from cvxopt import solvers , matrix , sparse , spmatrix
1717except ImportError :
18- cvxopt = False
18+ cvxopt = False
19+
1920
2021def scipy_sparse_to_spmatrix (A ):
2122 """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
2223 coo = A .tocoo ()
2324 SP = spmatrix (coo .data .tolist (), coo .row .tolist (), coo .col .tolist (), size = A .shape )
2425 return SP
2526
26- def barycenter (A , M , weights = None , verbose = False , log = False ,solver = 'interior-point' ):
27+
28+ def barycenter (A , M , weights = None , verbose = False , log = False , solver = 'interior-point' ):
2729 """Compute the entropic regularized wasserstein barycenter of distributions A
2830
2931 The function solves the following optimization problem [16]:
@@ -36,7 +38,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
3638 - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
3739 - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
3840
39- The linear program is solved using the default cvxopt solver if installed.
41+ The linear program is solved using the default cvxopt solver if installed.
4042 If cvxopt is not installed it uses the lp solver from scipy.optimize.
4143
4244 Parameters
@@ -48,13 +50,13 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
4850 reg : float
4951 Regularization term >0
5052 weights : np.ndarray (n,)
51- Weights of each histogram i_i on the simplex
53+ Weights of each histogram i_i on the simplex
5254 verbose : bool, optional
5355 Print information along iterations
5456 log : bool, optional
5557 record log if True
5658 solver : string, optional
57- the solver used, default 'interior-point' use the lp solver from
59+ the solver used, default 'interior-point' use the lp solver from
5860 scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
5961
6062 Returns
@@ -78,61 +80,61 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
7880 weights = np .ones (A .shape [1 ]) / A .shape [1 ]
7981 else :
8082 assert (len (weights ) == A .shape [1 ])
81-
82- n_distributions = A .shape [1 ]
83- n = A .shape [0 ]
84-
85- n2 = n * n
86- c = np .zeros ((0 ))
87- b_eq1 = np .zeros ((0 ))
83+
84+ n_distributions = A .shape [1 ]
85+ n = A .shape [0 ]
86+
87+ n2 = n * n
88+ c = np .zeros ((0 ))
89+ b_eq1 = np .zeros ((0 ))
8890 for i in range (n_distributions ):
89- c = np .concatenate ((c ,M .ravel ()* weights [i ]))
90- b_eq1 = np .concatenate ((b_eq1 ,A [:,i ]))
91- c = np .concatenate ((c ,np .zeros (n )))
92-
93- lst_idiag1 = [sps .kron (sps .eye (n ),np .ones ((1 ,n ))) for i in range (n_distributions )]
91+ c = np .concatenate ((c , M .ravel () * weights [i ]))
92+ b_eq1 = np .concatenate ((b_eq1 , A [:, i ]))
93+ c = np .concatenate ((c , np .zeros (n )))
94+
95+ lst_idiag1 = [sps .kron (sps .eye (n ), np .ones ((1 , n ))) for i in range (n_distributions )]
9496 # row constraints
95- A_eq1 = sps .hstack ((sps .block_diag (lst_idiag1 ),sps .coo_matrix ((n_distributions * n , n ))))
96-
97+ A_eq1 = sps .hstack ((sps .block_diag (lst_idiag1 ), sps .coo_matrix ((n_distributions * n , n ))))
98+
9799 # columns constraints
98- lst_idiag2 = []
99- lst_eye = []
100+ lst_idiag2 = []
101+ lst_eye = []
100102 for i in range (n_distributions ):
101- if i == 0 :
102- lst_idiag2 .append (sps .kron (np .ones ((1 ,n )),sps .eye (n )))
103+ if i == 0 :
104+ lst_idiag2 .append (sps .kron (np .ones ((1 , n )), sps .eye (n )))
103105 lst_eye .append (- sps .eye (n ))
104106 else :
105- lst_idiag2 .append (sps .kron (np .ones ((1 ,n )),sps .eye (n - 1 , n )))
106- lst_eye .append (- sps .eye (n - 1 , n ))
107-
108- A_eq2 = sps .hstack ((sps .block_diag (lst_idiag2 ),sps .vstack (lst_eye )))
109- b_eq2 = np .zeros ((A_eq2 .shape [0 ]))
110-
107+ lst_idiag2 .append (sps .kron (np .ones ((1 , n )), sps .eye (n - 1 , n )))
108+ lst_eye .append (- sps .eye (n - 1 , n ))
109+
110+ A_eq2 = sps .hstack ((sps .block_diag (lst_idiag2 ), sps .vstack (lst_eye )))
111+ b_eq2 = np .zeros ((A_eq2 .shape [0 ]))
112+
111113 # full problem
112- A_eq = sps .vstack ((A_eq1 ,A_eq2 ))
113- b_eq = np .concatenate ((b_eq1 ,b_eq2 ))
114-
115- if not cvxopt or solver in ['interior-point' ]: # cvxopt not installed or simplex/interior point
116-
114+ A_eq = sps .vstack ((A_eq1 , A_eq2 ))
115+ b_eq = np .concatenate ((b_eq1 , b_eq2 ))
116+
117+ if not cvxopt or solver in ['interior-point' ]: # cvxopt not installed or simplex/interior point
118+
117119 if solver is None :
118- solver = 'interior-point'
119-
120- options = {'sparse' :True ,'disp' : verbose }
121- sol = sp .optimize .linprog (c ,A_eq = A_eq ,b_eq = b_eq ,method = solver ,options = options )
122- x = sol .x
123- b = x [- n :]
124-
120+ solver = 'interior-point'
121+
122+ options = {'sparse' : True , 'disp' : verbose }
123+ sol = sp .optimize .linprog (c , A_eq = A_eq , b_eq = b_eq , method = solver , options = options )
124+ x = sol .x
125+ b = x [- n :]
126+
125127 else :
126-
127- h = np .zeros ((n_distributions * n2 + n ))
128- G = - sps .eye (n_distributions * n2 + n )
129-
130- sol = solvers .lp (matrix (c ),scipy_sparse_to_spmatrix (G ),matrix (h ),A = scipy_sparse_to_spmatrix (A_eq ),b = matrix (b_eq ),solver = solver )
131-
132- x = np .array (sol ['x' ])
133- b = x [- n :].ravel ()
134-
128+
129+ h = np .zeros ((n_distributions * n2 + n ))
130+ G = - sps .eye (n_distributions * n2 + n )
131+
132+ sol = solvers .lp (matrix (c ), scipy_sparse_to_spmatrix (G ), matrix (h ), A = scipy_sparse_to_spmatrix (A_eq ), b = matrix (b_eq ), solver = solver )
133+
134+ x = np .array (sol ['x' ])
135+ b = x [- n :].ravel ()
136+
135137 if log :
136138 return b , sol
137139 else :
138- return b
140+ return b
0 commit comments