Skip to content

Commit 58bb091

Browse files
linear index array partitions
1 parent 3583451 commit 58bb091

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

src/array_partition.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,28 @@ Base.:*(A::ArrayPartition, B::Number) = ArrayPartition((x .* B for x in A.x)...)
1919
Base.:/(A::ArrayPartition, B::Number) = ArrayPartition((x ./ B for x in A.x)...)
2020
Base.:\(A::Number, B::ArrayPartition) = ArrayPartition((x ./ A for x in B.x)...)
2121

22-
Base.getindex( A::ArrayPartition, i::Int) = ArrayPartition((x[i] for x in A.x)...)
23-
Base.setindex!(A::ArrayPartition, v, i::Int) = ArrayPartition((x[i]=v for x in A.x)...)
24-
Base.getindex( A::ArrayPartition, i::Int...) = ArrayPartition((x[i...] for x in A.x)...)
25-
Base.setindex!(A::ArrayPartition, v, i::Int...) = ArrayPartition((x[i...]=v for x in A.x)...)
22+
@inline function Base.getindex( A::ArrayPartition,i::Int)
23+
@boundscheck i > length(A) && throw(BoundsError("Index out of bounds"))
24+
@inbounds for j in 1:length(A.x)
25+
i -= length(A.x[j])
26+
if i <= 0
27+
return A.x[j][length(A.x[j])+i]
28+
end
29+
end
30+
end
31+
Base.getindex( A::ArrayPartition,::Colon) = [A[i] for i in 1:length(A)]
32+
@inline function Base.setindex!(A::ArrayPartition, v, i::Int)
33+
@boundscheck i > length(A) && throw(BoundsError("Index out of bounds"))
34+
@inbounds for j in 1:length(A.x)
35+
i -= length(A.x[j])
36+
if i <= 0
37+
A.x[j][length(A.x[j])+i] = v
38+
break
39+
end
40+
end
41+
end
42+
Base.getindex( A::ArrayPartition, i::Int...) = A.x[i[1]][Base.tail(i)...]
43+
Base.setindex!(A::ArrayPartition, v, i::Int...) = A.x[i[1]][Base.tail(i)...]=v
2644

2745
function recursivecopy!(A::ArrayPartition,B::ArrayPartition)
2846
for (a,b) in zip(A.x,B.x)
@@ -38,6 +56,7 @@ Base.start(A::ArrayPartition) = chain(A.x...)
3856
Base.next(iter::ArrayPartition,state) = next(state,state)
3957
Base.done(iter::ArrayPartition,state) = done(state,state)
4058

41-
Base.length(A::ArrayPartition) = ((length(x) for x in A.x)...)
59+
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
60+
Base.size(A::ArrayPartition) = (length(A),)
4261
Base.indices(A::ArrayPartition) = ((indices(x) for x in A.x)...)
4362
Base.eachindex(A::ArrayPartition) = ((indices(x) for x in A.x)...)

test/partitions_test.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ using RecursiveArrayTools, Base.Test
22

33
A = (rand(5),rand(5))
44
p = ArrayPartition(A)
5+
@test (p.x[1][1],p.x[2][1]) == (p[1],p[6])
56

6-
@test (p[1].x[1],p[1].x[2]) == (p.x[1][1],p.x[2][1])
77
p2 = similar(p)
88
p2[1] = 1
99
@test p2.x[1] != p.x[1]
1010

1111
C = rand(10)
12-
p3 = similar(p,indices(C))
13-
@test length(p3.x[1]) == length(p3.x[2]) == 10
12+
p3 = similar(p,indices(p))
13+
@test length(p3.x[1]) == length(p3.x[2]) == 5
1414
@test length(p.x) == length(p2.x) == length(p3.x) == 2

0 commit comments

Comments
 (0)