Skip to content

Commit b6f5bf2

Browse files
committed
adding NamedArrayPartition and tests
1 parent d869b10 commit b6f5bf2

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

src/named_array_partition.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
NamedArrayPartition(; kwargs...)
3+
NamedArrayPartition(x::NamedTuple)
4+
5+
Similar to an `ArrayPartition` but the individual arrays can be accessed via the
6+
constructor-specified names. However, unlike `ArrayPartition`, each individual array
7+
must have the same element type.
8+
"""
9+
struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T}
10+
array_partition::A
11+
names_to_indices::NT
12+
end
13+
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
14+
function NamedArrayPartition(x::NamedTuple)
15+
names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x)))
16+
17+
# enforce homogeneity of eltypes
18+
@assert all(eltype.(values(x)) .== eltype(first(x)))
19+
T = eltype(first(x))
20+
S = typeof(values(x))
21+
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
22+
end
23+
24+
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
25+
# fields except through `getfield` and accessor functions.
26+
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
27+
28+
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
29+
30+
Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} =
31+
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
32+
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors
33+
34+
35+
Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
36+
Base.getproperty(x::NamedArrayPartition, s::Symbol) =
37+
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
38+
39+
# this enables x.s = some_array.
40+
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
41+
index = getproperty(getfield(x, :names_to_indices), s)
42+
ArrayPartition(x).x[index] .= v
43+
end
44+
45+
# print out NamedArrayPartition as a NamedTuple
46+
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
47+
Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) =
48+
show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))
49+
50+
Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
51+
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
52+
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)
53+
54+
Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
55+
Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
56+
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
57+
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))
58+
59+
Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} =
60+
NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
61+
62+
# broadcasting
63+
Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}()
64+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},
65+
::Type{ElType}) where {ElType}
66+
x = find_NamedArrayPartition(bc)
67+
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
68+
end
69+
70+
# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
71+
Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) =
72+
Broadcast.DefaultArrayStyle{1}()
73+
74+
# hook into ArrayPartition broadcasting routines
75+
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
76+
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) =
77+
Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
78+
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)
79+
80+
Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} =
81+
NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
82+
83+
@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function =
84+
NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices)
85+
86+
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
87+
N = npartitions(bc)
88+
@inline function f(i)
89+
copy(unpack(bc, i))
90+
end
91+
x = find_NamedArrayPartition(bc)
92+
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
93+
end
94+
95+
@inline function Base.copyto!(dest::NamedArrayPartition,
96+
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
97+
N = npartitions(dest, bc)
98+
@inline function f(i)
99+
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
100+
end
101+
ntuple(f, Val(N))
102+
return dest
103+
end
104+
105+
# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
106+
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
107+
find_NamedArrayPartition(args::Tuple) =
108+
find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args))
109+
find_NamedArrayPartition(x) = x
110+
find_NamedArrayPartition(::Tuple{}) = nothing
111+
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
112+
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)
113+
114+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
@testset "NamedArrayPartition tests" begin
2+
x = NamedArrayPartition(a = ones(10), b = rand(20))
3+
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition
4+
@test typeof(x.^2) <: NamedArrayPartition
5+
@test x.a ones(10)
6+
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
7+
@test all(x .== x[1:end])
8+
y = copy(x)
9+
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
10+
@test typeof(zero(x)) <: NamedArrayPartition
11+
@test (y .*= 2).a[1] 2 # test in-place bcast
12+
13+
@test length(Array(x))==30
14+
@test typeof(Array(x)) <: Array
15+
@test propertynames(x) == (:a, :b)
16+
17+
x = NamedArrayPartition(a = ones(1), b = 2*ones(1))
18+
@test Base.summary(x) == string(typeof(x), " with arrays:")
19+
@test (@capture_out Base.show(stdout, MIME"text/plain"(), x)) == "(a = [1.0], b = [2.0])"
20+
21+
using StructArrays
22+
using StaticArrays: SVector
23+
x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))),
24+
b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2))))
25+
@test typeof(x.a) <: StructVector{<:SVector{2}}
26+
@test typeof(x.b) <: StructArray{<:SVector{2}, 2}
27+
@test typeof((x->x[1]).(x)) <: NamedArrayPartition
28+
@test typeof(map(x->x[1], x)) <: NamedArrayPartition
29+
end
30+
31+
# x = NamedArrayPartition(a = ones(10), b = rand(20))
32+
# x_ap = ArrayPartition(x)
33+
# @btime @. x_ap * x_ap; # 498.836 ns (5 allocations: 2.77 KiB)
34+
# @btime @. x * x; # 2.032 μs (5 allocations: 2.84 KiB) - 5x slower than ArrayPartition

0 commit comments

Comments
 (0)