Skip to content

Commit 7b2613c

Browse files
refactor: add Base.view for AbstractVectorOfArray
1 parent 22a81f2 commit 7b2613c

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

src/vector_of_array.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,13 @@ function Base.append!(VA::AbstractVectorOfArray{T, N},
445445
end
446446

447447
# AbstractArray methods
448+
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
449+
@inline
450+
J = map(i->Base.unalias(A,i), to_indices(A, I))
451+
@boundscheck checkbounds(A, J...)
452+
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
453+
end
454+
Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing
448455
Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N
449456
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
450457
if checkbounds(Bool, VA.u, last(idx))
@@ -456,6 +463,9 @@ function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
456463
end
457464
return false
458465
end
466+
function Base.checkbounds(VA::AbstractVectorOfArray, idx...)
467+
checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx))
468+
end
459469

460470
# Operations
461471
function Base.isapprox(A::AbstractVectorOfArray,
@@ -502,8 +512,12 @@ end
502512
# Tools for creating similar objects
503513
Base.eltype(::VectorOfArray{T}) where {T} = T
504514
# TODO: Is there a better way to do this?
505-
@inline function Base.similar(VA::VectorOfArray, args...)
506-
return Base.similar(ones(eltype(VA)), args...)
515+
@inline function Base.similar(VA::AbstractVectorOfArray, args...)
516+
if args[end] isa Type
517+
return Base.similar(eltype(VA)[], args..., size(VA))
518+
else
519+
return Base.similar(eltype(VA)[], args...)
520+
end
507521
end
508522
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
509523
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])

0 commit comments

Comments
 (0)