@@ -919,7 +919,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
919919 return geometricBar (weights , UKv )
920920
921921
922- def convolutional_barycenter2d (A , reg , weights = None , numItermax = 10000 , stopThr = 1e-9 , verbose = False , log = False ):
922+ def convolutional_barycenter2d (A , reg , weights = None , numItermax = 10000 , stopThr = 1e-9 , stabThr = 1e-30 , verbose = False , log = False ):
923923 """Compute the entropic regularized wasserstein barycenter of distributions A
924924 where A is a collection of 2D images.
925925
@@ -948,6 +948,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
948948 Max number of iterations
949949 stopThr : float, optional
950950 Stop threshol on error (>0)
951+ stabThr : float, optional
952+ Stabilization threshold to avoid numerical precision issue
951953 verbose : bool, optional
952954 Print information along iterations
953955 log : bool, optional
@@ -983,7 +985,6 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
983985 b = np .zeros_like (A [0 , :, :])
984986 U = np .ones_like (A )
985987 KV = np .ones_like (A )
986- threshold = 1e-30 # in order to avoids numerical precision issues
987988
988989 cpt = 0
989990 err = 1
@@ -993,7 +994,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
993994 [Y , X ] = np .meshgrid (t , t )
994995 xi1 = np .exp (- (X - Y )** 2 / reg )
995996
996- def K (x ):
997+ def K (x ):
997998 return np .dot (np .dot (xi1 , x ), xi1 )
998999
9991000 while (err > stopThr and cpt < numItermax ):
@@ -1003,11 +1004,11 @@ def K(x):
10031004
10041005 b = np .zeros_like (A [0 , :, :])
10051006 for r in range (A .shape [0 ]):
1006- KV [r , :, :] = K (A [r , :, :] / np .maximum (threshold , K (U [r , :, :])))
1007- b += weights [r ] * np .log (np .maximum (threshold , U [r , :, :] * KV [r , :, :]))
1007+ KV [r , :, :] = K (A [r , :, :] / np .maximum (stabThr , K (U [r , :, :])))
1008+ b += weights [r ] * np .log (np .maximum (stabThr , U [r , :, :] * KV [r , :, :]))
10081009 b = np .exp (b )
10091010 for r in range (A .shape [0 ]):
1010- U [r , :, :] = b / np .maximum (threshold , KV [r , :, :])
1011+ U [r , :, :] = b / np .maximum (stabThr , KV [r , :, :])
10111012
10121013 if cpt % 10 == 1 :
10131014 err = np .sum (np .abs (bold - b ))
0 commit comments