Skip to content

Commit a04112c

Browse files
committed
correction size
1 parent bbe4117 commit a04112c

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

ot/bregman.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
358358
while (err > stopThr and cpt < numItermax):
359359
uprev = u
360360
vprev = v
361-
KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u)
362-
v = np.divide(b, KtransposeU)
363-
u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v)
361+
if nbb:
362+
KtransposeU = np.einsum('ij,i,k->jk',K,u)#np.dot(K.T, u)
363+
v = np.divide(b, KtransposeU)
364+
u = 1. / np.einsum('ij,jk->ik',Kp,v)#np.dot(Kp, v)
365+
else:
366+
KtransposeU = np.einsum('ij,i->j',K,u)#np.dot(K.T, u)
367+
v = np.divide(b, KtransposeU)
368+
u = 1. / np.einsum('ij,j->i',Kp,v)#np.dot(Kp, v)
364369

365370
if (np.any(KtransposeU == 0) or
366371
np.any(np.isnan(u)) or np.any(np.isnan(v)) or

0 commit comments

Comments
 (0)