Skip to content

Commit 351594c

Browse files
fix broadcast allocations
1 parent 53bcecb commit 351594c

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/array_partition.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,15 @@ ArrayPartitionStyle(::S, ::Val{N}) where {S,N} = ArrayPartitionStyle(S(Val(N)))
230230
ArrayPartitionStyle(::Val{N}) where N = ArrayPartitionStyle{Broadcast.DefaultArrayStyle{N}}()
231231

232232
# promotion rules
233-
function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
233+
@inline function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
234234
ArrayPartitionStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
235235
end
236-
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{0}) = Broadcast.DefaultArrayStyle{1}()
236+
Broadcast.BroadcastStyle(::ArrayPartitionStyle{Style}, ::Broadcast.DefaultArrayStyle{0}) where Style = ArrayPartitionStyle{Style}()
237237
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
238238

239239
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
240-
combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
241-
combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
240+
@inline combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
241+
@inline combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
242242
@inline combine_styles(args::Tuple) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), combine_styles(Base.tail(args)))
243243

244244
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T,S}}) where {T, S}

test/partitions_test.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,17 @@ for sizes in S
119119
@test all([x[i] == y[i] for i in eachindex(x)])
120120
@test all([x[i] == y_array[i] for i in eachindex(x)])
121121
end
122+
123+
# Non-allocating broadcast
124+
xce0 = ArrayPartition(zeros(2),[0.])
125+
xcde0 = copy(xce0)
126+
function foo(y, x)
127+
y .= y .+ x
128+
end
129+
foo(xcde0, xce0)
130+
@test 0 == @allocated foo(xcde0, xce0)
131+
function foo(y, x)
132+
y .= y .+ 2 .* x
133+
end
134+
foo(xcde0, xce0)
135+
@test 0 == @allocated foo(xcde0, xce0)

0 commit comments

Comments
 (0)