Skip to content

Commit 84d9be5

Browse files
Merge pull request #499 from ChrisRackauckas-Claude/arraypart_zero
Fix mapreduce type-stability on Julia 1.10 using @generated functions
2 parents b1a3980 + b9c30c8 commit 84d9be5

File tree

4 files changed

+56
-6
lines changed

4 files changed

+56
-6
lines changed

src/array_partition.jl

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,40 @@ Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
171171
## Iterable Collection Constructs
172172

173173
Base.map(f, A::ArrayPartition) = ArrayPartition(map(x -> map(f, x), A.x))
174-
function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T}
175-
mapreduce(f, op, (i for i in A); kwargs...)
174+
# Use @generated function for type stability on Julia 1.10
175+
# The generated approach avoids type inference issues with kwargs in older Julia versions
176+
@generated function _mapreduce_impl(f, op, A::ArrayPartition{T, S}) where {T, S}
177+
N = length(S.parameters)
178+
if N == 1
179+
return :(mapreduce(f, op, A.x[1]))
180+
else
181+
expr = :(mapreduce(f, op, A.x[$N]))
182+
for i in (N - 1):-1:1
183+
expr = :(op(mapreduce(f, op, A.x[$i]), $expr))
184+
end
185+
return expr
186+
end
187+
end
188+
@generated function _mapreduce_impl_init(f, op, A::ArrayPartition{T, S}, init) where {T, S}
189+
N = length(S.parameters)
190+
if N == 1
191+
return :(mapreduce(f, op, A.x[1]))
192+
else
193+
expr = :(mapreduce(f, op, A.x[$N]))
194+
for i in (N - 1):-1:1
195+
expr = :(op(mapreduce(f, op, A.x[$i]), $expr))
196+
end
197+
# Apply init only at the outermost reduction
198+
return :(op(init, $expr))
199+
end
200+
end
201+
@inline function Base.mapreduce(f, op, A::ArrayPartition;
202+
init = Base._InitialValue(), kwargs...)
203+
if init isa Base._InitialValue
204+
_mapreduce_impl(f, op, A)
205+
else
206+
_mapreduce_impl_init(f, op, A, init)
207+
end
176208
end
177209
Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x))
178210
Base.any(f, A::ArrayPartition) = any((any(f, x) for x in A.x))
@@ -442,7 +474,10 @@ end
442474

443475
## Linear Algebra
444476

445-
ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(A))
477+
function ArrayInterface.zeromatrix(A::ArrayPartition)
478+
x = reduce(vcat,vec.(A.x))
479+
x .* x' .* false
480+
end
446481

447482
function __get_subtypes_in_module(
448483
mod, supertype; include_supertype = true, all = false, except = [])

src/named_array_partition.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ end
149149
return dest
150150
end
151151

152+
#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
153+
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
154+
B = ArrayPartition(A)
155+
x = reduce(vcat,vec.(B.x))
156+
x .* x' .* false
157+
end
158+
152159
# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
153160
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
154161
function find_NamedArrayPartition(args::Tuple)

test/gpu/arraypartition_gpu.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, CUDA, Test, Adapt
1+
using RecursiveArrayTools, ArrayInterface, CUDA, Adapt, Test
22
CUDA.allowscalar(false)
33

44
# Test indexing with colon
@@ -40,4 +40,9 @@ for i in 1:length(part_a.x)
4040
sub_b = part_b.x[i]
4141
@test sub_a == sub_b # Test for value equality in sub-arrays
4242
@test typeof(sub_a) === typeof(sub_b) # Test type equality
43-
end
43+
end
44+
45+
x = ArrayPartition((CUDA.zeros(2),CUDA.zeros(2)))
46+
@test ArrayInterface.zeromatrix(x) isa CuMatrix
47+
@test size(ArrayInterface.zeromatrix(x)) == (4,4)
48+
@test maximum(abs, x) == 0f0

test/named_array_partition_tests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, Test
1+
using RecursiveArrayTools, ArrayInterface, Test
22

33
@testset "NamedArrayPartition tests" begin
44
x = NamedArrayPartition(a = ones(10), b = rand(20))
@@ -9,10 +9,13 @@ using RecursiveArrayTools, Test
99
@test x.a ones(10)
1010
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
1111
@test all(x .== x[1:end])
12+
@test ArrayInterface.zeromatrix(x) isa Matrix
13+
@test size(ArrayInterface.zeromatrix(x)) == (30,30)
1214
y = copy(x)
1315
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
1416
@test typeof(zero(x)) <: NamedArrayPartition
1517
@test (y .*= 2).a[1] 2 # test in-place bcast
18+
1619

1720
@test length(Array(x)) == 30
1821
@test typeof(Array(x)) <: Array

0 commit comments

Comments
 (0)