Skip to content

Commit f70aabf

Browse files
committed
pep8
1 parent 6484c9e commit f70aabf

File tree

2 files changed

+91
-92
lines changed

2 files changed

+91
-92
lines changed

ot/gromov.py

Lines changed: 62 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
7878

7979
if loss_fun == 'square_loss':
8080
def f1(a):
81-
return (a**2)
81+
return (a**2)
8282

8383
def f2(b):
84-
return (b**2)
84+
return (b**2)
8585

8686
def h1(a):
8787
return a
8888

8989
def h2(b):
90-
return 2*b
90+
return 2 * b
9191
elif loss_fun == 'kl_loss':
9292
def f1(a):
9393
return a * np.log(a + 1e-15) - a
@@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
269269
return np.exp(np.divide(tmpsum, ppt))
270270

271271

272-
def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
272+
def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
273273
"""
274274
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
275275
@@ -344,13 +344,14 @@ def df(G):
344344
return gwggrad(constC, hC1, hC2, G)
345345

346346
if log:
347-
res, log = cg(p, q, 0, 1, f, df, G0,log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
347+
res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
348348
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
349349
return res, log
350350
else:
351-
return cg(p, q, 0, 1, f, df, G0,amijo=amijo, **kwargs)
351+
return cg(p, q, 0, 1, f, df, G0, amijo=amijo, **kwargs)
352352

353-
def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=False,**kwargs):
353+
354+
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs):
354355
"""
355356
Computes the FGW distance between two graphs see [3]
356357
.. math::
@@ -376,7 +377,7 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
376377
q : ndarray, shape (nt,)
377378
distribution in the target space
378379
loss_fun : string,optionnal
379-
loss function used for the solver
380+
loss function used for the solver
380381
max_iter : int, optional
381382
Max number of iterations
382383
tol : float, optional
@@ -404,19 +405,20 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
404405
International Conference on Machine Learning (ICML). 2019.
405406
"""
406407

407-
constC,hC1,hC2=init_matrix(C1,C2,p,q,loss_fun)
408-
409-
G0=p[:,None]*q[None,:]
410-
408+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
409+
410+
G0 = p[:, None] * q[None, :]
411+
411412
def f(G):
412-
return gwloss(constC,hC1,hC2,G)
413+
return gwloss(constC, hC1, hC2, G)
414+
413415
def df(G):
414-
return gwggrad(constC,hC1,hC2,G)
415-
416-
return cg(p,q,M,alpha,f,df,G0,amijo=amijo,C1=C1,C2=C2,constC=constC,**kwargs)
416+
return gwggrad(constC, hC1, hC2, G)
417+
418+
return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
417419

418420

419-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False,amijo=False, **kwargs):
421+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
420422
"""
421423
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
422424
@@ -485,7 +487,7 @@ def f(G):
485487

486488
def df(G):
487489
return gwggrad(constC, hC1, hC2, G)
488-
res, log = cg(p, q, 0, 1, f, df, G0, log=True,amijo=amijo,C1=C1,C2=C2,constC=constC, **kwargs)
490+
res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
489491
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
490492
log['T'] = res
491493
if log:
@@ -883,14 +885,14 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
883885

884886
return C
885887

886-
def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_features=False,
887-
p=None,loss_fun='square_loss',max_iter=100, tol=1e-9,
888-
verbose=False,log=True,init_C=None,init_X=None):
889-
888+
889+
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
890+
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
891+
verbose=False, log=True, init_C=None, init_X=None):
890892
"""
891893
Compute the fgw barycenter as presented eq (5) in [3].
892894
----------
893-
N : integer
895+
N : integer
894896
Desired number of samples of the target barycenter
895897
Ys: list of ndarray, each element has shape (ns,d)
896898
Features of all samples
@@ -906,9 +908,9 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
906908
Wether to fix the structure of the barycenter during the updates
907909
fixed_features : bool
908910
Wether to fix the feature of the barycenter during the updates
909-
init_C : ndarray, shape (N,N), optional
911+
init_C : ndarray, shape (N,N), optional
910912
initialization for the barycenters' structure matrix. If not set random init
911-
init_X : ndarray, shape (N,d), optional
913+
init_X : ndarray, shape (N,d), optional
912914
initialization for the barycenters' features. If not set random init
913915
Returns
914916
----------
@@ -926,14 +928,14 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
926928
"Optimal Transport for structured data with application on graphs"
927929
International Conference on Machine Learning (ICML). 2019.
928930
"""
929-
931+
930932
class UndefinedParameter(Exception):
931933
pass
932-
934+
933935
S = len(Cs)
934-
d = Ys[0].shape[1] #dimension on the node features
936+
d = Ys[0].shape[1] # dimension on the node features
935937
if p is None:
936-
p = np.ones(N)/N
938+
p = np.ones(N) / N
937939

938940
Cs = [np.asarray(Cs[s], dtype=np.float64) for s in range(S)]
939941
Ys = [np.asarray(Ys[s], dtype=np.float64) for s in range(S)]
@@ -944,7 +946,7 @@ class UndefinedParameter(Exception):
944946
if init_C is None:
945947
raise UndefinedParameter('If C is fixed it must be initialized')
946948
else:
947-
C=init_C
949+
C = init_C
948950
else:
949951
if init_C is None:
950952
xalea = np.random.randn(N, 2)
@@ -954,67 +956,67 @@ class UndefinedParameter(Exception):
954956

955957
if fixed_features:
956958
if init_X is None:
957-
raise UndefinedParameter('If X is fixed it must be initialized')
958-
else :
959-
X= init_X
959+
raise UndefinedParameter('If X is fixed it must be initialized')
960+
else:
961+
X = init_X
960962
else:
961-
if init_X is None:
962-
X=np.zeros((N,d))
963+
if init_X is None:
964+
X = np.zeros((N, d))
963965
else:
964966
X = init_X
965-
966-
T=[np.outer(p,q) for q in ps]
967+
968+
T = [np.outer(p, q) for q in ps]
967969

968970
# X is N,d
969971
# Ys is ns,d
970-
Ms = [np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
972+
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
971973
# Ms is N,ns
972974

973975
cpt = 0
974976
err_feature = 1
975977
err_structure = 1
976978

977979
if log:
978-
log_={}
979-
log_['err_feature']=[]
980-
log_['err_structure']=[]
981-
log_['Ts_iter']=[]
980+
log_ = {}
981+
log_['err_feature'] = []
982+
log_['err_structure'] = []
983+
log_['Ts_iter'] = []
982984

983985
while((err_feature > tol or err_structure > tol) and cpt < max_iter):
984986
Cprev = C
985987
Xprev = X
986988

987989
if not fixed_features:
988-
Ys_temp=[y.T for y in Ys]
989-
X=update_feature_matrix(lambdas,Ys_temp,T,p).T
990+
Ys_temp = [y.T for y in Ys]
991+
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
990992

991993
# X must be N,d
992994
# Ys must be ns,d
993-
Ms=[np.asarray(dist(X,Ys[s]), dtype=np.float64) for s in range(len(Ys))]
995+
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
994996

995997
if not fixed_structure:
996998
if loss_fun == 'square_loss':
997999
# T must be ns,N
9981000
# Cs must be ns,ns
9991001
# p must be N,1
1000-
T_temp=[t.T for t in T]
1002+
T_temp = [t.T for t in T]
10011003
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
10021004

10031005
# Ys must be d,ns
10041006
# Ts must be N,ns
10051007
# p must be N,1
10061008
# Ms is N,ns
1007-
# C is N,N
1009+
# C is N,N
10081010
# Cs is ns,ns
10091011
# p is N,1
10101012
# ps is ns,1
1011-
1012-
T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
10131013

1014-
# T is N,ns
1014+
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
1015+
1016+
# T is N,ns
10151017

10161018
log_['Ts_iter'].append(T)
1017-
err_feature = np.linalg.norm(X - Xprev.reshape(N,d))
1019+
err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
10181020
err_structure = np.linalg.norm(C - Cprev)
10191021

10201022
if log:
@@ -1029,11 +1031,11 @@ class UndefinedParameter(Exception):
10291031
print('{:5d}|{:8e}|'.format(cpt, err_feature))
10301032

10311033
cpt += 1
1032-
log_['T']=T # from target to Ys
1033-
log_['p']=p
1034-
log_['Ms']=Ms #Ms are N,ns
1034+
log_['T'] = T # from target to Ys
1035+
log_['p'] = p
1036+
log_['Ms'] = Ms # Ms are N,ns
10351037

1036-
return X,C,log_
1038+
return X, C, log_
10371039

10381040

10391041
def update_sructure_matrix(p, lambdas, T, Cs):
@@ -1060,8 +1062,8 @@ def update_sructure_matrix(p, lambdas, T, Cs):
10601062

10611063
return np.divide(tmpsum, ppt)
10621064

1063-
def update_feature_matrix(lambdas,Ys,Ts,p):
1064-
1065+
1066+
def update_feature_matrix(lambdas, Ys, Ts, p):
10651067
"""
10661068
Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
10671069
calculated at each iteration
@@ -1078,7 +1080,7 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
10781080
Returns
10791081
----------
10801082
X : ndarray, shape (d,N)
1081-
1083+
10821084
References
10831085
----------
10841086
.. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
@@ -1087,10 +1089,8 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
10871089
International Conference on Machine Learning (ICML). 2019.
10881090
"""
10891091

1090-
p=np.diag(np.array(1/p).reshape(-1,))
1092+
p = np.diag(np.array(1 / p).reshape(-1,))
10911093

1092-
tmpsum = sum([lambdas[s] * np.dot(Ys[s],Ts[s].T).dot(p) for s in range(len(Ts))])
1094+
tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T).dot(p) for s in range(len(Ts))])
10931095

10941096
return tmpsum
1095-
1096-

ot/optim.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ def phi(alpha1):
7171

7272
return alpha, fc[0], phi1
7373

74-
def do_linesearch(cost,G,deltaG,Mi,f_val,
75-
amijo=False,C1=None,C2=None,reg=None,Gc=None,constC=None,M=None):
74+
75+
def do_linesearch(cost, G, deltaG, Mi, f_val,
76+
amijo=False, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
7677
"""
7778
Solve the linesearch in the FW iterations
7879
Parameters
@@ -119,22 +120,22 @@ def do_linesearch(cost,G,deltaG,Mi,f_val,
119120
"""
120121
if amijo:
121122
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
122-
else: # requires symetric matrices
123-
dot1=np.dot(C1,deltaG)
124-
dot12=dot1.dot(C2)
125-
a=-2*reg*np.sum(dot12*deltaG)
126-
b=np.sum((M+reg*constC)*deltaG)-2*reg*(np.sum(dot12*G)+np.sum(np.dot(C1,G).dot(C2)*deltaG))
127-
c=cost(G)
123+
else: # requires symetric matrices
124+
dot1 = np.dot(C1, deltaG)
125+
dot12 = dot1.dot(C2)
126+
a = -2 * reg * np.sum(dot12 * deltaG)
127+
b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
128+
c = cost(G)
129+
130+
alpha = solve_1d_linesearch_quad_funct(a, b, c)
131+
fc = None
132+
f_val = cost(G + alpha * deltaG)
128133

129-
alpha=solve_1d_linesearch_quad_funct(a,b,c)
130-
fc=None
131-
f_val=cost(G+alpha*deltaG)
132-
133-
return alpha,fc,f_val
134+
return alpha, fc, f_val
134135

135136

136137
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
137-
stopThr=1e-9, verbose=False, log=False,**kwargs):
138+
stopThr=1e-9, verbose=False, log=False, **kwargs):
138139
"""
139140
Solve the general regularized OT problem with conditional gradient
140141
@@ -240,7 +241,7 @@ def cost(G):
240241
deltaG = Gc - G
241242

242243
# line search
243-
alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,**kwargs)
244+
alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
244245

245246
G = G + alpha * deltaG
246247

@@ -403,11 +404,12 @@ def cost(G):
403404
else:
404405
return G
405406

406-
def solve_1d_linesearch_quad_funct(a,b,c):
407+
408+
def solve_1d_linesearch_quad_funct(a, b, c):
407409
"""
408-
Solve on 0,1 the following problem:
410+
Solve on 0,1 the following problem:
409411
.. math::
410-
\min f(x)=a*x^{2}+b*x+c
412+
\min f(x)=a*x^{2}+b*x+c
411413
412414
Parameters
413415
----------
@@ -416,22 +418,19 @@ def solve_1d_linesearch_quad_funct(a,b,c):
416418
417419
Returns
418420
-------
419-
x : float
421+
x : float
420422
The optimal value which leads to the minimal cost
421-
423+
422424
"""
423-
f0=c
424-
df0=b
425-
f1=a+f0+df0
425+
f0 = c
426+
df0 = b
427+
f1 = a + f0 + df0
426428

427-
if a>0: # convex
428-
minimum=min(1,max(0,-b/(2*a)))
429-
#print('entrelesdeux')
429+
if a > 0: # convex
430+
minimum = min(1, max(0, -b / (2 * a)))
430431
return minimum
431-
else: # non convexe donc sur les coins
432-
if f0>f1:
433-
#print('sur1 f(1)={}'.format(f(1)))
432+
else: # non convex
433+
if f0 > f1:
434434
return 1
435435
else:
436-
#print('sur0 f(0)={}'.format(f(0)))
437436
return 0

0 commit comments

Comments
 (0)