@@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
350350 np .exp (K , out = K )
351351
352352 # print(np.min(K))
353- tmp = np .empty (K .shape , dtype = M .dtype )
354353 tmp2 = np .empty (b .shape , dtype = M .dtype )
355354
356355 Kp = (1 / a ).reshape (- 1 , 1 ) * K
@@ -359,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
359358 while (err > stopThr and cpt < numItermax ):
360359 uprev = u
361360 vprev = v
361+
362362 KtransposeU = np .dot (K .T , u )
363363 v = np .divide (b , KtransposeU )
364364 u = 1. / np .dot (Kp , v )
@@ -379,11 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
379379 err = np .sum ((u - uprev )** 2 ) / np .sum ((u )** 2 ) + \
380380 np .sum ((v - vprev )** 2 ) / np .sum ((v )** 2 )
381381 else :
382- np .multiply (u .reshape (- 1 , 1 ), K , out = tmp )
383- np .multiply (tmp , v .reshape (1 , - 1 ), out = tmp )
384- np .sum (tmp , axis = 0 , out = tmp2 )
385- tmp2 -= b
386- err = np .linalg .norm (tmp2 )** 2
382+ # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
383+ np .einsum ('i,ij,j->j' , u , K , v , out = tmp2 )
384+ err = np .linalg .norm (tmp2 - b )** 2 # violation of marginal
387385 if log :
388386 log ['err' ].append (err )
389387
@@ -398,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
398396 log ['v' ] = v
399397
400398 if nbb : # return only loss
401- res = np .zeros ((nbb ))
402- for i in range (nbb ):
403- res [i ] = np .sum (
404- u [:, i ].reshape ((- 1 , 1 )) * K * v [:, i ].reshape ((1 , - 1 )) * M )
399+ res = np .einsum ('ik,ij,jk,ij->k' , u , K , v , M )
405400 if log :
406401 return res , log
407402 else :
0 commit comments