@@ -143,3 +143,78 @@ def test_gromov_entropic_barycenter():
143143 'kl_loss' , 2e-3 ,
144144 max_iter = 100 , tol = 1e-3 )
145145 np .testing .assert_allclose (Cb2 .shape , (n_samples , n_samples ))
146+
147+ def test_fgw ():
148+ n_samples = 50 # nb samples
149+
150+ mu_s = np .array ([0 , 0 ])
151+ cov_s = np .array ([[1 , 0 ], [0 , 1 ]])
152+
153+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu_s , cov_s )
154+
155+ xt = xs [::- 1 ].copy ()
156+
157+ ys = np .random .randn (xs .shape [0 ],2 )
158+ yt = ys [::- 1 ].copy ()
159+
160+ p = ot .unif (n_samples )
161+ q = ot .unif (n_samples )
162+
163+ C1 = ot .dist (xs , xs )
164+ C2 = ot .dist (xt , xt )
165+
166+ C1 /= C1 .max ()
167+ C2 /= C2 .max ()
168+
169+ M = ot .dist (ys ,yt )
170+ M /= M .max ()
171+
172+ G = ot .gromov .fused_gromov_wasserstein (M ,C1 , C2 , p , q , 'square_loss' ,alpha = 0.5 )
173+
174+ # check constratints
175+ np .testing .assert_allclose (
176+ p , G .sum (1 ), atol = 1e-04 ) # cf convergence fgw
177+ np .testing .assert_allclose (
178+ q , G .sum (0 ), atol = 1e-04 ) # cf convergence fgw
179+
180+
181+ def test_fgw_barycenter ():
182+
183+ ns = 50
184+ nt = 60
185+
186+ Xs , ys = ot .datasets .make_data_classif ('3gauss' , ns )
187+ Xt , yt = ot .datasets .make_data_classif ('3gauss2' , nt )
188+
189+ ys = np .random .randn (Xs .shape [0 ],2 )
190+ yt = np .random .randn (Xt .shape [0 ],2 )
191+
192+ C1 = ot .dist (Xs )
193+ C2 = ot .dist (Xt )
194+
195+ n_samples = 3
196+ X ,C ,log = ot .gromov .fgw_barycenters (n_samples ,[ys ,yt ] ,[C1 , C2 ],[ot .unif (ns ), ot .unif (nt )],[.5 , .5 ],0.5 ,
197+ fixed_structure = False ,fixed_features = False ,
198+ p = ot .unif (n_samples ),loss_fun = 'square_loss' ,
199+ max_iter = 100 , tol = 1e-3 )
200+ np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
201+ np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
202+
203+ xalea = np .random .randn (n_samples , 2 )
204+ init_C = ot .dist (xalea , xalea )
205+
206+ X ,C ,log = ot .gromov .fgw_barycenters (n_samples ,[ys ,yt ] ,[C1 , C2 ],ps = [ot .unif (ns ), ot .unif (nt )],lambdas = [.5 , .5 ],alpha = 0.5 ,
207+ fixed_structure = True ,init_C = init_C ,fixed_features = False ,
208+ p = ot .unif (n_samples ),loss_fun = 'square_loss' ,
209+ max_iter = 100 , tol = 1e-3 )
210+ np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
211+ np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
212+
213+ init_X = np .random .randn (n_samples ,ys .shape [1 ])
214+
215+ X ,C ,log = ot .gromov .fgw_barycenters (n_samples ,[ys ,yt ] ,[C1 , C2 ],[ot .unif (ns ), ot .unif (nt )],[.5 , .5 ],0.5 ,
216+ fixed_structure = False ,fixed_features = True , init_X = init_X ,
217+ p = ot .unif (n_samples ),loss_fun = 'square_loss' ,
218+ max_iter = 100 , tol = 1e-3 )
219+ np .testing .assert_allclose (C .shape , (n_samples , n_samples ))
220+ np .testing .assert_allclose (X .shape , (n_samples , ys .shape [1 ]))
0 commit comments