Skip to content

Commit c0c959d

Browse files
committed
speedup einsum constraint violation
1 parent 5cd6c0a commit c0c959d

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

ot/bregman.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
350350
np.exp(K, out=K)
351351

352352
# print(np.min(K))
353-
tmp = np.empty(K.shape, dtype=M.dtype)
354353
tmp2 = np.empty(b.shape, dtype=M.dtype)
355354

356355
Kp = (1 / a).reshape(-1, 1) * K
@@ -379,11 +378,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
379378
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
380379
np.sum((v - vprev)**2) / np.sum((v)**2)
381380
else:
382-
np.multiply(u.reshape(-1, 1), K, out=tmp)
383-
np.multiply(tmp, v.reshape(1, -1), out=tmp)
384-
np.sum(tmp, axis=0, out=tmp2)
385-
tmp2 -= b
386-
err = np.linalg.norm(tmp2)**2
381+
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
382+
np.einsum('i,ij,j->j',u,K,v,out=tmp2)
383+
err = np.linalg.norm(tmp2-b)**2 # violation of marginal
387384
if log:
388385
log['err'].append(err)
389386

0 commit comments

Comments
 (0)