@@ -493,7 +493,7 @@ class label
493493
494494 # pairwise distance
495495 self .cost_ = dist (Xs , Xt , metric = self .metric )
496- self .cost_ = cost_normalization (self .cost_ , self .norm )
496+ self .cost_ , self . norm_cost_ = cost_normalization (self .cost_ , self .norm , return_value = True )
497497
498498 if (ys is not None ) and (yt is not None ):
499499
@@ -1055,13 +1055,18 @@ class SinkhornTransport(BaseTransport):
10551055 The ground metric for the Wasserstein problem
10561056 norm : string, optional (default=None)
10571057 If given, normalize the ground metric to avoid numerical errors that
1058- can occur with large metric values.
1058+ can occur with large metric values. Accepted values are 'median',
1059+ 'max', 'log' and 'loglog'.
10591060 distribution_estimation : callable, optional (defaults to the uniform)
10601061 The kind of distribution estimation to employ
1061- out_of_sample_map : string, optional (default="ferradans ")
1062+ out_of_sample_map : string, optional (default="continuous ")
10621063 The kind of out of sample mapping to apply to transport samples
10631064 from a domain into another one. Currently the only possible option is
1064- "ferradans" which uses the method proposed in :ref:`[6] <references-sinkhorntransport>`.
1065+ "ferradans" which uses the nearest neighbor method proposed in :ref:`[6]
1066+ <references-sinkhorntransport>` while "continuous" use the out of sample
1067+ method from :ref:`[66]
1068+ <references-sinkhorntransport>` and :ref:`[19]
1069+ <references-sinkhorntransport>`.
10651070 limit_max: float, optional (default=np.infty)
10661071 Controls the semi supervised mode. Transport between labeled source
10671072 and target samples of different classes will exhibit an cost defined
@@ -1089,13 +1094,26 @@ class SinkhornTransport(BaseTransport):
10891094 .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
10901095 Regularized discrete optimal transport. SIAM Journal on Imaging
10911096 Sciences, 7(3), 1853-1882.
1097+
1098+ .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.
1099+ & Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
1100+ International Conference on Learning Representation (2018)
1101+
1102+ .. [66] Pooladian, Aram-Alexandre, and Jonathan Niles-Weed. "Entropic
1103+ estimation of optimal transport maps." arXiv preprint
1104+ arXiv:2109.12004 (2021).
1105+
10921106 """
10931107
1094- def __init__ (self , reg_e = 1. , method = "sinkhorn " , max_iter = 1000 ,
1108+ def __init__ (self , reg_e = 1. , method = "sinkhorn_log " , max_iter = 1000 ,
10951109 tol = 10e-9 , verbose = False , log = False ,
10961110 metric = "sqeuclidean" , norm = None ,
10971111 distribution_estimation = distribution_estimation_uniform ,
1098- out_of_sample_map = 'ferradans' , limit_max = np .infty ):
1112+ out_of_sample_map = 'continuous' , limit_max = np .infty ):
1113+
1114+ if out_of_sample_map not in ['ferradans' , 'continuous' ]:
1115+ raise ValueError ('Unknown out_of_sample_map method' )
1116+
10991117 self .reg_e = reg_e
11001118 self .method = method
11011119 self .max_iter = max_iter
@@ -1135,6 +1153,12 @@ class label
11351153
11361154 super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
11371155
1156+ if self .out_of_sample_map == 'continuous' :
1157+ self .log = True
1158+ if not self .method == 'sinkhorn_log' :
1159+ self .method = 'sinkhorn_log'
1160+ warnings .warn ("The method has been set to 'sinkhorn_log' as it is the only method available for out_of_sample_map='continuous'" )
1161+
11381162 # coupling estimation
11391163 returned_ = sinkhorn (
11401164 a = self .mu_s , b = self .mu_t , M = self .cost_ , reg = self .reg_e ,
@@ -1150,6 +1174,120 @@ class label
11501174
11511175 return self
11521176
1177+ def transform (self , Xs = None , ys = None , Xt = None , yt = None , batch_size = 128 ):
1178+ r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}`
1179+
1180+ Parameters
1181+ ----------
1182+ Xs : array-like, shape (n_source_samples, n_features)
1183+ The source input samples.
1184+ ys : array-like, shape (n_source_samples,)
1185+ The class labels for source samples
1186+ Xt : array-like, shape (n_target_samples, n_features)
1187+ The target input samples.
1188+ yt : array-like, shape (n_target_samples,)
1189+ The class labels for target. If some target samples are unlabelled, fill the
1190+ :math:`\mathbf{y_t}`'s elements with -1.
1191+
1192+ Warning: Note that, due to this convention -1 cannot be used as a
1193+ class label
1194+ batch_size : int, optional (default=128)
1195+ The batch size for out of sample inverse transform
1196+
1197+ Returns
1198+ -------
1199+ transp_Xs : array-like, shape (n_source_samples, n_features)
1200+ The transport source samples.
1201+ """
1202+ nx = self .nx
1203+
1204+ if self .out_of_sample_map == 'ferradans' :
1205+ return super (SinkhornTransport , self ).transform (Xs , ys , Xt , yt , batch_size )
1206+
1207+ else : # self.out_of_sample_map == 'continuous':
1208+
1209+ # check the necessary inputs parameters are here
1210+ g = self .log_ ['log_v' ]
1211+
1212+ indices = nx .arange (Xs .shape [0 ])
1213+ batch_ind = [
1214+ indices [i :i + batch_size ]
1215+ for i in range (0 , len (indices ), batch_size )]
1216+
1217+ transp_Xs = []
1218+ for bi in batch_ind :
1219+ # get the nearest neighbor in the source domain
1220+ M = dist (Xs [bi ], self .xt_ , metric = self .metric )
1221+
1222+ M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1223+
1224+ K = nx .exp (- M / self .reg_e + g [None , :])
1225+
1226+ transp_Xs_ = nx .dot (K , self .xt_ ) / nx .sum (K , axis = 1 )[:, None ]
1227+
1228+ transp_Xs .append (transp_Xs_ )
1229+
1230+ transp_Xs = nx .concatenate (transp_Xs , axis = 0 )
1231+
1232+ return transp_Xs
1233+
1234+ def inverse_transform (self , Xs = None , ys = None , Xt = None , yt = None , batch_size = 128 ):
1235+ r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}`
1236+
1237+ Parameters
1238+ ----------
1239+ Xs : array-like, shape (n_source_samples, n_features)
1240+ The source input samples.
1241+ ys : array-like, shape (n_source_samples,)
1242+ The class labels for source samples
1243+ Xt : array-like, shape (n_target_samples, n_features)
1244+ The target input samples.
1245+ yt : array-like, shape (n_target_samples,)
1246+ The class labels for target. If some target samples are unlabelled, fill the
1247+ :math:`\mathbf{y_t}`'s elements with -1.
1248+
1249+ Warning: Note that, due to this convention -1 cannot be used as a
1250+ class label
1251+ batch_size : int, optional (default=128)
1252+ The batch size for out of sample inverse transform
1253+
1254+ Returns
1255+ -------
1256+ transp_Xt : array-like, shape (n_source_samples, n_features)
1257+ The transport target samples.
1258+ """
1259+
1260+ nx = self .nx
1261+
1262+ if self .out_of_sample_map == 'ferradans' :
1263+ return super (SinkhornTransport , self ).inverse_transform (Xs , ys , Xt , yt , batch_size )
1264+
1265+ else : # self.out_of_sample_map == 'continuous':
1266+
1267+ f = self .log_ ['log_u' ]
1268+
1269+ indices = nx .arange (Xt .shape [0 ])
1270+ batch_ind = [
1271+ indices [i :i + batch_size ]
1272+ for i in range (0 , len (indices ), batch_size
1273+ )]
1274+
1275+ transp_Xt = []
1276+ for bi in batch_ind :
1277+
1278+ M = dist (Xt [bi ], self .xs_ , metric = self .metric )
1279+ M = cost_normalization (M , self .norm , value = self .norm_cost_ )
1280+
1281+ K = nx .exp (- M / self .reg_e + f [None , :])
1282+
1283+ transp_Xt_ = nx .dot (K , self .xs_ ) / nx .sum (K , axis = 1 )[:, None ]
1284+
1285+ transp_Xt .append (transp_Xt_ )
1286+
1287+ transp_Xt = nx .concatenate (transp_Xt , axis = 0 )
1288+
1289+ return transp_Xt
1290+
11531291
11541292class EMDTransport (BaseTransport ):
11551293
0 commit comments