Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ This new release adds support for sparse cost matrices in the exact EMD solver.
#### New features
- Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778)
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD)
- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j]

#### Closed issues
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
- Add test for build from source (PR #772, Issue #764)
- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)
Expand Down
29 changes: 24 additions & 5 deletions ot/bregman/_geomloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,32 @@ def get_sinkhorn_geomloss_lazytensor(
shape = (X_a.shape[0], X_b.shape[0])

def func(i, j, X_a, X_b, f, g, a, b, metric, blur):
X_a_i = X_a[i]
X_b_j = X_b[j]

if X_a_i.ndim == 1:
X_a_i = X_a_i[None, :]
if X_b_j.ndim == 1:
X_b_j = X_b_j[None, :]

if metric == "sqeuclidean":
C = dist(X_a[i], X_b[j], metric=metric) / 2
C = dist(X_a_i, X_b_j, metric=metric) / 2
else:
C = dist(X_a[i], X_b[j], metric=metric)
return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (
a[i, None] * b[None, j]
)
C = dist(X_a_i, X_b_j, metric=metric)

# Robust broadcasting using nx backend (handles both numpy and torch)
# For scalars, slice to keep 1D; for arrays, index directly
f_i = f[i : i + 1] if isinstance(i, int) else f[i]
g_j = g[j : j + 1] if isinstance(j, int) else g[j]
a_i = a[i : i + 1] if isinstance(i, int) else a[i]
b_j = b[j : j + 1] if isinstance(j, int) else b[j]

f_i = nx.reshape(f_i, (-1, 1))
g_j = nx.reshape(g_j, (1, -1))
a_i = nx.reshape(a_i, (-1, 1))
b_j = nx.reshape(b_j, (1, -1))

return nx.squeeze(nx.exp((f_i + g_j - C) / (blur**2)) * a_i * b_j)

T = LazyTensor(
shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur
Expand Down
67 changes: 46 additions & 21 deletions ot/lp/sparse_bipartitegraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@ namespace lemon {

mutable std::vector<std::vector<Arc>> _in_arcs; // _in_arcs[node] = incoming arc IDs
mutable bool _in_arcs_built;

// Position tracking for O(1) iteration
mutable std::vector<int64_t> _arc_to_out_pos; // _arc_to_out_pos[arc_id] = position in _arc_ids
mutable std::vector<int64_t> _arc_to_in_pos; // _arc_to_in_pos[arc_id] = position in _in_arcs[target]
mutable bool _position_maps_built;

SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {}
SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false), _position_maps_built(false) {}

void construct(int n1, int n2) {
_node_num = n1 + n2;
Expand All @@ -58,6 +63,9 @@ namespace lemon {
_arc_ids.clear();
_in_arcs.clear();
_in_arcs_built = false;
_arc_to_out_pos.clear();
_arc_to_in_pos.clear();
_position_maps_built = false;
}

void build_in_arcs() const {
Expand All @@ -72,6 +80,31 @@ namespace lemon {

_in_arcs_built = true;
}

void build_position_maps() const {
if (_position_maps_built) return;

_arc_to_out_pos.resize(_arc_num);
_arc_to_in_pos.resize(_arc_num);

// Build outgoing arc position map from CSR structure
for (int64_t pos = 0; pos < _arc_num; ++pos) {
Arc arc_id = _arc_ids[pos];
_arc_to_out_pos[arc_id] = pos;
}

// Build incoming arc position map
build_in_arcs();
for (Node node = 0; node < _node_num; ++node) {
const std::vector<Arc>& in = _in_arcs[node];
for (size_t pos = 0; pos < in.size(); ++pos) {
Arc arc_id = in[pos];
_arc_to_in_pos[arc_id] = pos;
}
}

_position_maps_built = true;
}

public:

Expand Down Expand Up @@ -212,18 +245,14 @@ namespace lemon {

void nextOut(Arc& arc) const {
if (arc < 0) return;


build_position_maps();

int64_t pos = _arc_to_out_pos[arc];
Node src = _arc_sources[arc];
int64_t start = _row_ptr[src];
int64_t end = _row_ptr[src + 1];

for (int64_t i = start; i < end; ++i) {
if (_arc_ids[i] == arc) {
arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1);
return;
}
}
arc = -1;

arc = (pos + 1 < end) ? _arc_ids[pos + 1] : Arc(-1);
}

void firstIn(Arc& arc, const Node& node) const {
Expand All @@ -240,18 +269,14 @@ namespace lemon {

void nextIn(Arc& arc) const {
if (arc < 0) return;


build_position_maps();

int64_t pos = _arc_to_in_pos[arc];
Node tgt = _arc_targets[arc];
const std::vector<Arc>& in = _in_arcs[tgt];

// Find current arc in the list and return next one
for (size_t i = 0; i < in.size(); ++i) {
if (in[i] == arc) {
arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1);
return;
}
}
arc = -1;

arc = (pos + 1 < in.size()) ? in[pos + 1] : Arc(-1);
}
};

Expand Down
Loading