diff --git a/RELEASES.md b/RELEASES.md index c221bc0f8..2409a9078 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,9 +7,11 @@ This new release adds support for sparse cost matrices in the exact EMD solver. #### New features - Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778) - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD) +- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j] #### Closed issues - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) +- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration - Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) - Add test for build from source (PR #772, Issue #764) - Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783) diff --git a/ot/bregman/_geomloss.py b/ot/bregman/_geomloss.py index 1df423db2..f901663a6 100644 --- a/ot/bregman/_geomloss.py +++ b/ot/bregman/_geomloss.py @@ -54,13 +54,32 @@ def get_sinkhorn_geomloss_lazytensor( shape = (X_a.shape[0], X_b.shape[0]) def func(i, j, X_a, X_b, f, g, a, b, metric, blur): + X_a_i = X_a[i] + X_b_j = X_b[j] + + if X_a_i.ndim == 1: + X_a_i = X_a_i[None, :] + if X_b_j.ndim == 1: + X_b_j = X_b_j[None, :] + if metric == "sqeuclidean": - C = dist(X_a[i], X_b[j], metric=metric) / 2 + C = dist(X_a_i, X_b_j, metric=metric) / 2 else: - C = dist(X_a[i], X_b[j], metric=metric) - return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * ( - a[i, None] * b[None, j] - ) + C = dist(X_a_i, X_b_j, metric=metric) + + # Robust broadcasting using nx backend (handles both numpy and torch) + # For scalars, slice to keep 1D; for arrays, index directly + f_i = f[i : i + 1] if isinstance(i, int) else f[i] + g_j = g[j : j + 1] if isinstance(j, int) else g[j] + a_i = a[i : i + 1] if isinstance(i, int) else a[i] + b_j = b[j : j + 1] if isinstance(j, int) else b[j] + + f_i = nx.reshape(f_i, (-1, 1)) + g_j = nx.reshape(g_j, (1, -1)) + a_i = nx.reshape(a_i, (-1, 1)) + b_j = nx.reshape(b_j, (1, -1)) + + return nx.squeeze(nx.exp((f_i + g_j - C) / (blur**2)) * a_i * b_j) T = LazyTensor( shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur diff --git a/ot/lp/sparse_bipartitegraph.h b/ot/lp/sparse_bipartitegraph.h index 7ba13b41a..10e35d85d 100644 --- a/ot/lp/sparse_bipartitegraph.h +++ b/ot/lp/sparse_bipartitegraph.h @@ -43,8 +43,13 @@ namespace lemon { mutable std::vector> _in_arcs; // _in_arcs[node] = incoming arc IDs mutable bool _in_arcs_built; + + // Position tracking for O(1) iteration + mutable std::vector _arc_to_out_pos; // _arc_to_out_pos[arc_id] = position in _arc_ids + mutable std::vector _arc_to_in_pos; // _arc_to_in_pos[arc_id] = position in _in_arcs[target] + mutable bool _position_maps_built; - SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {} + SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false), _position_maps_built(false) {} void construct(int n1, int n2) { _node_num = n1 + n2; @@ -58,6 +63,9 @@ namespace lemon { _arc_ids.clear(); _in_arcs.clear(); _in_arcs_built = false; + _arc_to_out_pos.clear(); + _arc_to_in_pos.clear(); + _position_maps_built = false; } void build_in_arcs() const { @@ -72,6 +80,31 @@ namespace lemon { _in_arcs_built = true; } + + void build_position_maps() const { + if (_position_maps_built) return; + + _arc_to_out_pos.resize(_arc_num); + _arc_to_in_pos.resize(_arc_num); + + // Build outgoing arc position map from CSR structure + for (int64_t pos = 0; pos < _arc_num; ++pos) { + Arc arc_id = _arc_ids[pos]; + _arc_to_out_pos[arc_id] = pos; + } + + // Build incoming arc position map + build_in_arcs(); + for (Node node = 0; node < _node_num; ++node) { + const std::vector& in = _in_arcs[node]; + for (size_t pos = 0; pos < in.size(); ++pos) { + Arc arc_id = in[pos]; + _arc_to_in_pos[arc_id] = pos; + } + } + + _position_maps_built = true; + } public: @@ -212,18 +245,14 @@ namespace lemon { void nextOut(Arc& arc) const { if (arc < 0) return; - + + build_position_maps(); + + int64_t pos = _arc_to_out_pos[arc]; Node src = _arc_sources[arc]; - int64_t start = _row_ptr[src]; int64_t end = _row_ptr[src + 1]; - - for (int64_t i = start; i < end; ++i) { - if (_arc_ids[i] == arc) { - arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1); - return; - } - } - arc = -1; + + arc = (pos + 1 < end) ? _arc_ids[pos + 1] : Arc(-1); } void firstIn(Arc& arc, const Node& node) const { @@ -240,18 +269,14 @@ namespace lemon { void nextIn(Arc& arc) const { if (arc < 0) return; - + + build_position_maps(); + + int64_t pos = _arc_to_in_pos[arc]; Node tgt = _arc_targets[arc]; const std::vector& in = _in_arcs[tgt]; - - // Find current arc in the list and return next one - for (size_t i = 0; i < in.size(); ++i) { - if (in[i] == arc) { - arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1); - return; - } - } - arc = -1; + + arc = (pos + 1 < in.size()) ? in[pos + 1] : Arc(-1); } };