88from numpy .testing .utils import assert_allclose , assert_equal
99
1010import ot
11- from ot .datasets import get_data_classif
11+ from ot .datasets import make_data_classif
1212from ot .utils import unif
1313
1414
@@ -19,8 +19,8 @@ def test_sinkhorn_lpl1_transport_class():
1919 ns = 150
2020 nt = 200
2121
22- Xs , ys = get_data_classif ('3gauss' , ns )
23- Xt , yt = get_data_classif ('3gauss2' , nt )
22+ Xs , ys = make_data_classif ('3gauss' , ns )
23+ Xt , yt = make_data_classif ('3gauss2' , nt )
2424
2525 otda = ot .da .SinkhornLpl1Transport ()
2626
@@ -45,7 +45,7 @@ def test_sinkhorn_lpl1_transport_class():
4545 transp_Xs = otda .transform (Xs = Xs )
4646 assert_equal (transp_Xs .shape , Xs .shape )
4747
48- Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
48+ Xs_new , _ = make_data_classif ('3gauss' , ns + 1 )
4949 transp_Xs_new = otda .transform (Xs_new )
5050
5151 # check that the oos method is working
@@ -55,7 +55,7 @@ def test_sinkhorn_lpl1_transport_class():
5555 transp_Xt = otda .inverse_transform (Xt = Xt )
5656 assert_equal (transp_Xt .shape , Xt .shape )
5757
58- Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
58+ Xt_new , _ = make_data_classif ('3gauss2' , nt + 1 )
5959 transp_Xt_new = otda .inverse_transform (Xt = Xt_new )
6060
6161 # check that the oos method is working
@@ -92,8 +92,8 @@ def test_sinkhorn_l1l2_transport_class():
9292 ns = 150
9393 nt = 200
9494
95- Xs , ys = get_data_classif ('3gauss' , ns )
96- Xt , yt = get_data_classif ('3gauss2' , nt )
95+ Xs , ys = make_data_classif ('3gauss' , ns )
96+ Xt , yt = make_data_classif ('3gauss2' , nt )
9797
9898 otda = ot .da .SinkhornL1l2Transport ()
9999
@@ -119,7 +119,7 @@ def test_sinkhorn_l1l2_transport_class():
119119 transp_Xs = otda .transform (Xs = Xs )
120120 assert_equal (transp_Xs .shape , Xs .shape )
121121
122- Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
122+ Xs_new , _ = make_data_classif ('3gauss' , ns + 1 )
123123 transp_Xs_new = otda .transform (Xs_new )
124124
125125 # check that the oos method is working
@@ -129,7 +129,7 @@ def test_sinkhorn_l1l2_transport_class():
129129 transp_Xt = otda .inverse_transform (Xt = Xt )
130130 assert_equal (transp_Xt .shape , Xt .shape )
131131
132- Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
132+ Xt_new , _ = make_data_classif ('3gauss2' , nt + 1 )
133133 transp_Xt_new = otda .inverse_transform (Xt = Xt_new )
134134
135135 # check that the oos method is working
@@ -173,8 +173,8 @@ def test_sinkhorn_transport_class():
173173 ns = 150
174174 nt = 200
175175
176- Xs , ys = get_data_classif ('3gauss' , ns )
177- Xt , yt = get_data_classif ('3gauss2' , nt )
176+ Xs , ys = make_data_classif ('3gauss' , ns )
177+ Xt , yt = make_data_classif ('3gauss2' , nt )
178178
179179 otda = ot .da .SinkhornTransport ()
180180
@@ -200,7 +200,7 @@ def test_sinkhorn_transport_class():
200200 transp_Xs = otda .transform (Xs = Xs )
201201 assert_equal (transp_Xs .shape , Xs .shape )
202202
203- Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
203+ Xs_new , _ = make_data_classif ('3gauss' , ns + 1 )
204204 transp_Xs_new = otda .transform (Xs_new )
205205
206206 # check that the oos method is working
@@ -210,7 +210,7 @@ def test_sinkhorn_transport_class():
210210 transp_Xt = otda .inverse_transform (Xt = Xt )
211211 assert_equal (transp_Xt .shape , Xt .shape )
212212
213- Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
213+ Xt_new , _ = make_data_classif ('3gauss2' , nt + 1 )
214214 transp_Xt_new = otda .inverse_transform (Xt = Xt_new )
215215
216216 # check that the oos method is working
@@ -252,8 +252,8 @@ def test_emd_transport_class():
252252 ns = 150
253253 nt = 200
254254
255- Xs , ys = get_data_classif ('3gauss' , ns )
256- Xt , yt = get_data_classif ('3gauss2' , nt )
255+ Xs , ys = make_data_classif ('3gauss' , ns )
256+ Xt , yt = make_data_classif ('3gauss2' , nt )
257257
258258 otda = ot .da .EMDTransport ()
259259
@@ -278,7 +278,7 @@ def test_emd_transport_class():
278278 transp_Xs = otda .transform (Xs = Xs )
279279 assert_equal (transp_Xs .shape , Xs .shape )
280280
281- Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
281+ Xs_new , _ = make_data_classif ('3gauss' , ns + 1 )
282282 transp_Xs_new = otda .transform (Xs_new )
283283
284284 # check that the oos method is working
@@ -288,7 +288,7 @@ def test_emd_transport_class():
288288 transp_Xt = otda .inverse_transform (Xt = Xt )
289289 assert_equal (transp_Xt .shape , Xt .shape )
290290
291- Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
291+ Xt_new , _ = make_data_classif ('3gauss2' , nt + 1 )
292292 transp_Xt_new = otda .inverse_transform (Xt = Xt_new )
293293
294294 # check that the oos method is working
@@ -329,9 +329,9 @@ def test_mapping_transport_class():
329329 ns = 60
330330 nt = 120
331331
332- Xs , ys = get_data_classif ('3gauss' , ns )
333- Xt , yt = get_data_classif ('3gauss2' , nt )
334- Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
332+ Xs , ys = make_data_classif ('3gauss' , ns )
333+ Xt , yt = make_data_classif ('3gauss2' , nt )
334+ Xs_new , _ = make_data_classif ('3gauss' , ns + 1 )
335335
336336 ##########################################################################
337337 # kernel == linear mapping tests
@@ -449,8 +449,8 @@ def test_linear_mapping():
449449 ns = 150
450450 nt = 200
451451
452- Xs , ys = get_data_classif ('3gauss' , ns )
453- Xt , yt = get_data_classif ('3gauss2' , nt )
452+ Xs , ys = make_data_classif ('3gauss' , ns )
453+ Xt , yt = make_data_classif ('3gauss2' , nt )
454454
455455 A , b = ot .da .OT_mapping_linear (Xs , Xt )
456456
@@ -467,8 +467,8 @@ def test_linear_mapping_class():
467467 ns = 150
468468 nt = 200
469469
470- Xs , ys = get_data_classif ('3gauss' , ns )
471- Xt , yt = get_data_classif ('3gauss2' , nt )
470+ Xs , ys = make_data_classif ('3gauss' , ns )
471+ Xt , yt = make_data_classif ('3gauss2' , nt )
472472
473473 otmap = ot .da .LinearTransport ()
474474
@@ -491,8 +491,8 @@ def test_otda():
491491 n_samples = 150 # nb samples
492492 np .random .seed (0 )
493493
494- xs , ys = ot .datasets .get_data_classif ('3gauss' , n_samples )
495- xt , yt = ot .datasets .get_data_classif ('3gauss2' , n_samples )
494+ xs , ys = ot .datasets .make_data_classif ('3gauss' , n_samples )
495+ xt , yt = ot .datasets .make_data_classif ('3gauss2' , n_samples )
496496
497497 a , b = ot .unif (n_samples ), ot .unif (n_samples )
498498
0 commit comments