Skip to content

Commit 3aee908

Browse files
committed
pep8
1 parent 060d904 commit 3aee908

File tree

2 files changed

+54
-52
lines changed

2 files changed

+54
-52
lines changed

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
845845
reg : float
846846
Regularization term >0
847847
weights : np.ndarray (n,)
848-
Weights of each histogram i_i on the simplex
848+
Weights of each histogram i_i on the simplex
849849
numItermax : int, optional
850850
Max number of iterations
851851
stopThr : float, optional

ot/lp/cvx.py

Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515
import cvxopt
1616
from cvxopt import solvers, matrix, sparse, spmatrix
1717
except ImportError:
18-
cvxopt=False
18+
cvxopt = False
19+
1920

2021
def scipy_sparse_to_spmatrix(A):
2122
"""Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
2223
coo = A.tocoo()
2324
SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
2425
return SP
2526

26-
def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-point'):
27+
28+
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
2729
"""Compute the entropic regularized wasserstein barycenter of distributions A
2830
2931
The function solves the following optimization problem [16]:
@@ -36,7 +38,7 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
3638
- :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
3739
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
3840
39-
The linear program is solved using the default cvxopt solver if installed.
41+
The linear program is solved using the default cvxopt solver if installed.
4042
If cvxopt is not installed it uses the lp solver from scipy.optimize.
4143
4244
Parameters
@@ -48,13 +50,13 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
4850
reg : float
4951
Regularization term >0
5052
weights : np.ndarray (n,)
51-
Weights of each histogram i_i on the simplex
53+
Weights of each histogram i_i on the simplex
5254
verbose : bool, optional
5355
Print information along iterations
5456
log : bool, optional
5557
record log if True
5658
solver : string, optional
57-
the solver used, default 'interior-point' use the lp solver from
59+
the solver used, default 'interior-point' use the lp solver from
5860
scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
5961
6062
Returns
@@ -78,61 +80,61 @@ def barycenter(A, M, weights=None, verbose=False, log=False,solver='interior-poi
7880
weights = np.ones(A.shape[1]) / A.shape[1]
7981
else:
8082
assert(len(weights) == A.shape[1])
81-
82-
n_distributions=A.shape[1]
83-
n=A.shape[0]
84-
85-
n2=n*n
86-
c=np.zeros((0))
87-
b_eq1=np.zeros((0))
83+
84+
n_distributions = A.shape[1]
85+
n = A.shape[0]
86+
87+
n2 = n * n
88+
c = np.zeros((0))
89+
b_eq1 = np.zeros((0))
8890
for i in range(n_distributions):
89-
c=np.concatenate((c,M.ravel()*weights[i]))
90-
b_eq1=np.concatenate((b_eq1,A[:,i]))
91-
c=np.concatenate((c,np.zeros(n)))
92-
93-
lst_idiag1=[sps.kron(sps.eye(n),np.ones((1,n))) for i in range(n_distributions)]
91+
c = np.concatenate((c, M.ravel() * weights[i]))
92+
b_eq1 = np.concatenate((b_eq1, A[:, i]))
93+
c = np.concatenate((c, np.zeros(n)))
94+
95+
lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
9496
# row constraints
95-
A_eq1=sps.hstack((sps.block_diag(lst_idiag1),sps.coo_matrix((n_distributions*n,n))))
96-
97+
A_eq1 = sps.hstack((sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n))))
98+
9799
# columns constraints
98-
lst_idiag2=[]
99-
lst_eye=[]
100+
lst_idiag2 = []
101+
lst_eye = []
100102
for i in range(n_distributions):
101-
if i==0:
102-
lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n)))
103+
if i == 0:
104+
lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
103105
lst_eye.append(-sps.eye(n))
104106
else:
105-
lst_idiag2.append(sps.kron(np.ones((1,n)),sps.eye(n-1,n)))
106-
lst_eye.append(-sps.eye(n-1,n))
107-
108-
A_eq2=sps.hstack((sps.block_diag(lst_idiag2),sps.vstack(lst_eye)))
109-
b_eq2=np.zeros((A_eq2.shape[0]))
110-
107+
lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
108+
lst_eye.append(-sps.eye(n - 1, n))
109+
110+
A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
111+
b_eq2 = np.zeros((A_eq2.shape[0]))
112+
111113
# full problem
112-
A_eq=sps.vstack((A_eq1,A_eq2))
113-
b_eq=np.concatenate((b_eq1,b_eq2))
114-
115-
if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point
116-
114+
A_eq = sps.vstack((A_eq1, A_eq2))
115+
b_eq = np.concatenate((b_eq1, b_eq2))
116+
117+
if not cvxopt or solver in ['interior-point']: # cvxopt not installed or simplex/interior point
118+
117119
if solver is None:
118-
solver='interior-point'
119-
120-
options={'sparse':True,'disp': verbose}
121-
sol=sp.optimize.linprog(c,A_eq=A_eq,b_eq=b_eq,method=solver,options=options)
122-
x=sol.x
123-
b=x[-n:]
124-
120+
solver = 'interior-point'
121+
122+
options = {'sparse': True, 'disp': verbose}
123+
sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options)
124+
x = sol.x
125+
b = x[-n:]
126+
125127
else:
126-
127-
h=np.zeros((n_distributions*n2+n))
128-
G=-sps.eye(n_distributions*n2+n)
129-
130-
sol=solvers.lp(matrix(c),scipy_sparse_to_spmatrix(G),matrix(h),A=scipy_sparse_to_spmatrix(A_eq),b=matrix(b_eq),solver=solver)
131-
132-
x=np.array(sol['x'])
133-
b=x[-n:].ravel()
134-
128+
129+
h = np.zeros((n_distributions * n2 + n))
130+
G = -sps.eye(n_distributions * n2 + n)
131+
132+
sol = solvers.lp(matrix(c), scipy_sparse_to_spmatrix(G), matrix(h), A=scipy_sparse_to_spmatrix(A_eq), b=matrix(b_eq), solver=solver)
133+
134+
x = np.array(sol['x'])
135+
b = x[-n:].ravel()
136+
135137
if log:
136138
return b, sol
137139
else:
138-
return b
140+
return b

0 commit comments

Comments
 (0)