From 975ca6ca2c6bf62f00f7e12f943068ada8f5a4cb Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 1 Dec 2025 17:07:33 +0100 Subject: [PATCH 1/5] Replaced coo_matrix with coo_array better compatability and added test to test coo_array functionnality --- ot/backend.py | 19 +++++++++---------- test/test_ot.py | 9 +++++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 7ca505c0f..4bfe5b8f8 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -94,7 +94,7 @@ import scipy import scipy.linalg import scipy.special as special -from scipy.sparse import coo_matrix, csr_matrix, issparse +from scipy.sparse import coo_array, coo_matrix, csr_matrix, issparse DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH" DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX" @@ -802,9 +802,9 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): r""" Creates a sparse tensor in COOrdinate format. - This function follows the api from :any:`scipy.sparse.coo_matrix` + This function follows the api from :any:`scipy.sparse.coo_array` - See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html """ raise NotImplementedError() @@ -1354,9 +1354,9 @@ def randperm(self, size, type_as=None): def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: - return coo_matrix((data, (rows, cols)), shape=shape) + return coo_array((data, (rows, cols)), shape=shape) else: - return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) + return coo_array((data, (rows, cols)), shape=shape, dtype=type_as.dtype) def issparse(self, a): return issparse(a) @@ -1385,8 +1385,9 @@ def todense(self, a): def sparse_coo_data(self, a): # Convert to COO format if needed - if not isinstance(a, coo_matrix): - a_coo = coo_matrix(a) + if not isinstance(a, (coo_array, coo_matrix)): + # Try to convert to coo_array (prefer modern API) + a_coo = coo_array(a) else: a_coo = a @@ -1815,9 +1816,7 @@ def sparse_coo_data(self, a): # JAX doesn't support sparse matrices, so this shouldn't be called # But if it is, convert the dense array to sparse using scipy a_np = self.to_numpy(a) - from scipy.sparse import coo_matrix - - a_coo = coo_matrix(a_np) + a_coo = coo_array(a_np) return a_coo.row, a_coo.col, a_coo.data, a_coo.shape def where(self, condition, x=None, y=None): diff --git a/test/test_ot.py b/test/test_ot.py index d762a03e1..7d4f459ea 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -992,6 +992,15 @@ def test_emd_sparse_vs_dense(nx): b, nx.to_numpy(nx.sum(G_sparse_dense, 0)), rtol=1e-5, atol=1e-7 ) + # Test coo_array element-wise multiplication (only works with coo_array, not coo_matrix) + if nx.__name__ == "numpy": + # This tests that we're using coo_array which supports element-wise operations + M_sparse_np = M_sparse + G_sparse_np = G_sparse + loss_sparse = np.sum(G_sparse_np * M_sparse_np) + # Verify the loss calculation is reasonable + assert loss_sparse >= 0, "Sparse loss should be non-negative" + def test_emd2_sparse_vs_dense(nx): """Test that sparse and dense emd2 solvers produce identical costs. From d889ac97884bf712b162f97bfcafd2960afa5d64 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 1 Dec 2025 19:45:39 +0100 Subject: [PATCH 2/5] Updated release file --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 4d73da648..daa424c68 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver. #### New features - Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778) +- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD) #### Closed issues - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) From 5ee1a4d9e767abb68791b9e21531122c189ca443 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Tue, 2 Dec 2025 10:29:24 +0100 Subject: [PATCH 3/5] Replaced some more coo_matrix calls --- ot/backend.py | 11 +++++------ ot/plot.py | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 4bfe5b8f8..6b03f5cd1 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -94,7 +94,7 @@ import scipy import scipy.linalg import scipy.special as special -from scipy.sparse import coo_array, coo_matrix, csr_matrix, issparse +from scipy.sparse import coo_array, csr_matrix, issparse DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH" DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX" @@ -1384,9 +1384,8 @@ def todense(self, a): return a def sparse_coo_data(self, a): - # Convert to COO format if needed - if not isinstance(a, (coo_array, coo_matrix)): - # Try to convert to coo_array (prefer modern API) + # Convert to COO array format if needed + if not isinstance(a, coo_array): a_coo = coo_array(a) else: a_coo = a @@ -2803,10 +2802,10 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): rows = self.from_numpy(rows) cols = self.from_numpy(cols) if type_as is None: - return cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=shape) + return cupyx.scipy.sparse.coo_array((data, (rows, cols)), shape=shape) else: with cp.cuda.Device(type_as.device): - return cupyx.scipy.sparse.coo_matrix( + return cupyx.scipy.sparse.coo_array( (data, (rows, cols)), shape=shape, dtype=type_as.dtype ) diff --git a/ot/plot.py b/ot/plot.py index e3091ac8a..efc08e7cb 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -15,6 +15,8 @@ import numpy as np import matplotlib.pylab as pl from matplotlib import gridspec +from . import backend +from scipy.sparse import issparse, coo_array def plot1D_mat( @@ -232,8 +234,6 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): parameters given to the plot functions (default color is black if nothing given) """ - from . import backend - from scipy.sparse import issparse, coo_matrix if ("color" not in kwargs) and ("c" not in kwargs): kwargs["color"] = "k" @@ -258,7 +258,7 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): # Not a backend array, check if scipy.sparse is_sparse = issparse(G) if is_sparse: - G_coo = G if isinstance(G, coo_matrix) else G.tocoo() + G_coo = G if isinstance(G, coo_array) else G.tocoo() rows, cols, data = G_coo.row, G_coo.col, G_coo.data if is_sparse: From eddc7c688c9e9557128e7b9e16b99f34a51668d4 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Fri, 12 Dec 2025 15:56:08 +0100 Subject: [PATCH 4/5] =?UTF-8?q?Fix=20O(n=C2=B3)=20performance=20issue=20in?= =?UTF-8?q?=20sparse=20bipartite=20graph=20arc=20iteration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added position tracking maps (_arc_to_out_pos, _arc_to_in_pos) for O(1) arc lookups - Modified nextOut() and nextIn() to use position maps instead of linear search --- ot/lp/sparse_bipartitegraph.h | 67 ++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/ot/lp/sparse_bipartitegraph.h b/ot/lp/sparse_bipartitegraph.h index 7ba13b41a..10e35d85d 100644 --- a/ot/lp/sparse_bipartitegraph.h +++ b/ot/lp/sparse_bipartitegraph.h @@ -43,8 +43,13 @@ namespace lemon { mutable std::vector> _in_arcs; // _in_arcs[node] = incoming arc IDs mutable bool _in_arcs_built; + + // Position tracking for O(1) iteration + mutable std::vector _arc_to_out_pos; // _arc_to_out_pos[arc_id] = position in _arc_ids + mutable std::vector _arc_to_in_pos; // _arc_to_in_pos[arc_id] = position in _in_arcs[target] + mutable bool _position_maps_built; - SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {} + SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false), _position_maps_built(false) {} void construct(int n1, int n2) { _node_num = n1 + n2; @@ -58,6 +63,9 @@ namespace lemon { _arc_ids.clear(); _in_arcs.clear(); _in_arcs_built = false; + _arc_to_out_pos.clear(); + _arc_to_in_pos.clear(); + _position_maps_built = false; } void build_in_arcs() const { @@ -72,6 +80,31 @@ namespace lemon { _in_arcs_built = true; } + + void build_position_maps() const { + if (_position_maps_built) return; + + _arc_to_out_pos.resize(_arc_num); + _arc_to_in_pos.resize(_arc_num); + + // Build outgoing arc position map from CSR structure + for (int64_t pos = 0; pos < _arc_num; ++pos) { + Arc arc_id = _arc_ids[pos]; + _arc_to_out_pos[arc_id] = pos; + } + + // Build incoming arc position map + build_in_arcs(); + for (Node node = 0; node < _node_num; ++node) { + const std::vector& in = _in_arcs[node]; + for (size_t pos = 0; pos < in.size(); ++pos) { + Arc arc_id = in[pos]; + _arc_to_in_pos[arc_id] = pos; + } + } + + _position_maps_built = true; + } public: @@ -212,18 +245,14 @@ namespace lemon { void nextOut(Arc& arc) const { if (arc < 0) return; - + + build_position_maps(); + + int64_t pos = _arc_to_out_pos[arc]; Node src = _arc_sources[arc]; - int64_t start = _row_ptr[src]; int64_t end = _row_ptr[src + 1]; - - for (int64_t i = start; i < end; ++i) { - if (_arc_ids[i] == arc) { - arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1); - return; - } - } - arc = -1; + + arc = (pos + 1 < end) ? _arc_ids[pos + 1] : Arc(-1); } void firstIn(Arc& arc, const Node& node) const { @@ -240,18 +269,14 @@ namespace lemon { void nextIn(Arc& arc) const { if (arc < 0) return; - + + build_position_maps(); + + int64_t pos = _arc_to_in_pos[arc]; Node tgt = _arc_targets[arc]; const std::vector& in = _in_arcs[tgt]; - - // Find current arc in the list and return next one - for (size_t i = 0; i < in.size(); ++i) { - if (in[i] == arc) { - arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1); - return; - } - } - arc = -1; + + arc = (pos + 1 < in.size()) ? in[pos + 1] : Arc(-1); } }; From 184acd3d0c889a55d8422db84786f855f05e0a3e Mon Sep 17 00:00:00 2001 From: nathanneike Date: Fri, 12 Dec 2025 16:00:33 +0100 Subject: [PATCH 5/5] Added changes to release --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index f39dfd299..2409a9078 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,6 +11,7 @@ This new release adds support for sparse cost matrices in the exact EMD solver. #### Closed issues - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) +- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration - Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) - Add test for build from source (PR #772, Issue #764) - Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)