@@ -78,16 +78,16 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
7878
7979 if loss_fun == 'square_loss' :
8080 def f1 (a ):
81- return (a ** 2 )
81+ return (a ** 2 )
8282
8383 def f2 (b ):
84- return (b ** 2 )
84+ return (b ** 2 )
8585
8686 def h1 (a ):
8787 return a
8888
8989 def h2 (b ):
90- return 2 * b
90+ return 2 * b
9191 elif loss_fun == 'kl_loss' :
9292 def f1 (a ):
9393 return a * np .log (a + 1e-15 ) - a
@@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
269269 return np .exp (np .divide (tmpsum , ppt ))
270270
271271
272- def gromov_wasserstein (C1 , C2 , p , q , loss_fun , log = False ,amijo = False , ** kwargs ):
272+ def gromov_wasserstein (C1 , C2 , p , q , loss_fun , log = False , amijo = False , ** kwargs ):
273273 """
274274 Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
275275
@@ -344,13 +344,14 @@ def df(G):
344344 return gwggrad (constC , hC1 , hC2 , G )
345345
346346 if log :
347- res , log = cg (p , q , 0 , 1 , f , df , G0 ,log = True ,amijo = amijo ,C1 = C1 ,C2 = C2 ,constC = constC , ** kwargs )
347+ res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , amijo = amijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
348348 log ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
349349 return res , log
350350 else :
351- return cg (p , q , 0 , 1 , f , df , G0 ,amijo = amijo , ** kwargs )
351+ return cg (p , q , 0 , 1 , f , df , G0 , amijo = amijo , ** kwargs )
352352
353- def fused_gromov_wasserstein (M ,C1 ,C2 ,p ,q ,loss_fun = 'square_loss' ,alpha = 0.5 ,amijo = False ,** kwargs ):
353+
354+ def fused_gromov_wasserstein (M , C1 , C2 , p , q , loss_fun = 'square_loss' , alpha = 0.5 , amijo = False , ** kwargs ):
354355 """
355356 Computes the FGW distance between two graphs see [3]
356357 .. math::
@@ -376,7 +377,7 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
376377 q : ndarray, shape (nt,)
377378 distribution in the target space
378379 loss_fun : string,optionnal
379- loss function used for the solver
380+ loss function used for the solver
380381 max_iter : int, optional
381382 Max number of iterations
382383 tol : float, optional
@@ -404,19 +405,20 @@ def fused_gromov_wasserstein(M,C1,C2,p,q,loss_fun='square_loss',alpha=0.5,amijo=
404405 International Conference on Machine Learning (ICML). 2019.
405406 """
406407
407- constC ,hC1 ,hC2 = init_matrix (C1 ,C2 ,p , q , loss_fun )
408-
409- G0 = p [:,None ]* q [None ,:]
410-
408+ constC , hC1 , hC2 = init_matrix (C1 , C2 , p , q , loss_fun )
409+
410+ G0 = p [:, None ] * q [None , :]
411+
411412 def f (G ):
412- return gwloss (constC ,hC1 ,hC2 ,G )
413+ return gwloss (constC , hC1 , hC2 , G )
414+
413415 def df (G ):
414- return gwggrad (constC ,hC1 ,hC2 ,G )
415-
416- return cg (p ,q , M , alpha ,f , df ,G0 ,amijo = amijo ,C1 = C1 ,C2 = C2 ,constC = constC ,** kwargs )
416+ return gwggrad (constC , hC1 , hC2 , G )
417+
418+ return cg (p , q , M , alpha , f , df , G0 , amijo = amijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
417419
418420
419- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False ,amijo = False , ** kwargs ):
421+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , log = False , amijo = False , ** kwargs ):
420422 """
421423 Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
422424
@@ -485,7 +487,7 @@ def f(G):
485487
486488 def df (G ):
487489 return gwggrad (constC , hC1 , hC2 , G )
488- res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True ,amijo = amijo ,C1 = C1 ,C2 = C2 ,constC = constC , ** kwargs )
490+ res , log = cg (p , q , 0 , 1 , f , df , G0 , log = True , amijo = amijo , C1 = C1 , C2 = C2 , constC = constC , ** kwargs )
489491 log ['gw_dist' ] = gwloss (constC , hC1 , hC2 , res )
490492 log ['T' ] = res
491493 if log :
@@ -883,14 +885,14 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
883885
884886 return C
885887
886- def fgw_barycenters ( N , Ys , Cs , ps , lambdas , alpha , fixed_structure = False , fixed_features = False ,
887- p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-9 ,
888- verbose = False , log = True , init_C = None , init_X = None ):
889-
888+
889+ def fgw_barycenters ( N , Ys , Cs , ps , lambdas , alpha , fixed_structure = False , fixed_features = False ,
890+ p = None , loss_fun = 'square_loss' , max_iter = 100 , tol = 1e-9 ,
891+ verbose = False , log = True , init_C = None , init_X = None ):
890892 """
891893 Compute the fgw barycenter as presented eq (5) in [3].
892894 ----------
893- N : integer
895+ N : integer
894896 Desired number of samples of the target barycenter
895897 Ys: list of ndarray, each element has shape (ns,d)
896898 Features of all samples
@@ -906,9 +908,9 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
906908 Wether to fix the structure of the barycenter during the updates
907909 fixed_features : bool
908910 Wether to fix the feature of the barycenter during the updates
909- init_C : ndarray, shape (N,N), optional
911+ init_C : ndarray, shape (N,N), optional
910912 initialization for the barycenters' structure matrix. If not set random init
911- init_X : ndarray, shape (N,d), optional
913+ init_X : ndarray, shape (N,d), optional
912914 initialization for the barycenters' features. If not set random init
913915 Returns
914916 ----------
@@ -926,14 +928,14 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
926928 "Optimal Transport for structured data with application on graphs"
927929 International Conference on Machine Learning (ICML). 2019.
928930 """
929-
931+
930932 class UndefinedParameter (Exception ):
931933 pass
932-
934+
933935 S = len (Cs )
934- d = Ys [0 ].shape [1 ] # dimension on the node features
936+ d = Ys [0 ].shape [1 ] # dimension on the node features
935937 if p is None :
936- p = np .ones (N )/ N
938+ p = np .ones (N ) / N
937939
938940 Cs = [np .asarray (Cs [s ], dtype = np .float64 ) for s in range (S )]
939941 Ys = [np .asarray (Ys [s ], dtype = np .float64 ) for s in range (S )]
@@ -944,7 +946,7 @@ class UndefinedParameter(Exception):
944946 if init_C is None :
945947 raise UndefinedParameter ('If C is fixed it must be initialized' )
946948 else :
947- C = init_C
949+ C = init_C
948950 else :
949951 if init_C is None :
950952 xalea = np .random .randn (N , 2 )
@@ -954,67 +956,67 @@ class UndefinedParameter(Exception):
954956
955957 if fixed_features :
956958 if init_X is None :
957- raise UndefinedParameter ('If X is fixed it must be initialized' )
958- else :
959- X = init_X
959+ raise UndefinedParameter ('If X is fixed it must be initialized' )
960+ else :
961+ X = init_X
960962 else :
961- if init_X is None :
962- X = np .zeros ((N ,d ))
963+ if init_X is None :
964+ X = np .zeros ((N , d ))
963965 else :
964966 X = init_X
965-
966- T = [np .outer (p ,q ) for q in ps ]
967+
968+ T = [np .outer (p , q ) for q in ps ]
967969
968970 # X is N,d
969971 # Ys is ns,d
970- Ms = [np .asarray (dist (X ,Ys [s ]), dtype = np .float64 ) for s in range (len (Ys ))]
972+ Ms = [np .asarray (dist (X , Ys [s ]), dtype = np .float64 ) for s in range (len (Ys ))]
971973 # Ms is N,ns
972974
973975 cpt = 0
974976 err_feature = 1
975977 err_structure = 1
976978
977979 if log :
978- log_ = {}
979- log_ ['err_feature' ]= []
980- log_ ['err_structure' ]= []
981- log_ ['Ts_iter' ]= []
980+ log_ = {}
981+ log_ ['err_feature' ] = []
982+ log_ ['err_structure' ] = []
983+ log_ ['Ts_iter' ] = []
982984
983985 while ((err_feature > tol or err_structure > tol ) and cpt < max_iter ):
984986 Cprev = C
985987 Xprev = X
986988
987989 if not fixed_features :
988- Ys_temp = [y .T for y in Ys ]
989- X = update_feature_matrix (lambdas ,Ys_temp ,T , p ).T
990+ Ys_temp = [y .T for y in Ys ]
991+ X = update_feature_matrix (lambdas , Ys_temp , T , p ).T
990992
991993 # X must be N,d
992994 # Ys must be ns,d
993- Ms = [np .asarray (dist (X ,Ys [s ]), dtype = np .float64 ) for s in range (len (Ys ))]
995+ Ms = [np .asarray (dist (X , Ys [s ]), dtype = np .float64 ) for s in range (len (Ys ))]
994996
995997 if not fixed_structure :
996998 if loss_fun == 'square_loss' :
997999 # T must be ns,N
9981000 # Cs must be ns,ns
9991001 # p must be N,1
1000- T_temp = [t .T for t in T ]
1002+ T_temp = [t .T for t in T ]
10011003 C = update_sructure_matrix (p , lambdas , T_temp , Cs )
10021004
10031005 # Ys must be d,ns
10041006 # Ts must be N,ns
10051007 # p must be N,1
10061008 # Ms is N,ns
1007- # C is N,N
1009+ # C is N,N
10081010 # Cs is ns,ns
10091011 # p is N,1
10101012 # ps is ns,1
1011-
1012- T = [fused_gromov_wasserstein ((1 - alpha )* Ms [s ],C ,Cs [s ],p ,ps [s ],loss_fun ,alpha ,numItermax = max_iter , stopThr = 1e-5 , verbose = verbose ) for s in range (S )]
10131013
1014- # T is N,ns
1014+ T = [fused_gromov_wasserstein ((1 - alpha ) * Ms [s ], C , Cs [s ], p , ps [s ], loss_fun , alpha , numItermax = max_iter , stopThr = 1e-5 , verbose = verbose ) for s in range (S )]
1015+
1016+ # T is N,ns
10151017
10161018 log_ ['Ts_iter' ].append (T )
1017- err_feature = np .linalg .norm (X - Xprev .reshape (N ,d ))
1019+ err_feature = np .linalg .norm (X - Xprev .reshape (N , d ))
10181020 err_structure = np .linalg .norm (C - Cprev )
10191021
10201022 if log :
@@ -1029,11 +1031,11 @@ class UndefinedParameter(Exception):
10291031 print ('{:5d}|{:8e}|' .format (cpt , err_feature ))
10301032
10311033 cpt += 1
1032- log_ ['T' ]= T # from target to Ys
1033- log_ ['p' ]= p
1034- log_ ['Ms' ]= Ms # Ms are N,ns
1034+ log_ ['T' ] = T # from target to Ys
1035+ log_ ['p' ] = p
1036+ log_ ['Ms' ] = Ms # Ms are N,ns
10351037
1036- return X ,C , log_
1038+ return X , C , log_
10371039
10381040
10391041def update_sructure_matrix (p , lambdas , T , Cs ):
@@ -1060,8 +1062,8 @@ def update_sructure_matrix(p, lambdas, T, Cs):
10601062
10611063 return np .divide (tmpsum , ppt )
10621064
1063- def update_feature_matrix ( lambdas , Ys , Ts , p ):
1064-
1065+
1066+ def update_feature_matrix ( lambdas , Ys , Ts , p ):
10651067 """
10661068 Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
10671069 calculated at each iteration
@@ -1078,7 +1080,7 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
10781080 Returns
10791081 ----------
10801082 X : ndarray, shape (d,N)
1081-
1083+
10821084 References
10831085 ----------
10841086 .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\' e}mi, Tavenard Romain
@@ -1087,10 +1089,8 @@ def update_feature_matrix(lambdas,Ys,Ts,p):
10871089 International Conference on Machine Learning (ICML). 2019.
10881090 """
10891091
1090- p = np .diag (np .array (1 / p ).reshape (- 1 ,))
1092+ p = np .diag (np .array (1 / p ).reshape (- 1 ,))
10911093
1092- tmpsum = sum ([lambdas [s ] * np .dot (Ys [s ],Ts [s ].T ).dot (p ) for s in range (len (Ts ))])
1094+ tmpsum = sum ([lambdas [s ] * np .dot (Ys [s ], Ts [s ].T ).dot (p ) for s in range (len (Ts ))])
10931095
10941096 return tmpsum
1095-
1096-
0 commit comments