@@ -233,6 +233,8 @@ ArrayPartitionStyle(::Val{N}) where N = ArrayPartitionStyle{Broadcast.DefaultArr
233233function Broadcast. BroadcastStyle (:: ArrayPartitionStyle{AStyle} , :: ArrayPartitionStyle{BStyle} ) where {AStyle, BStyle}
234234 ArrayPartitionStyle (Broadcast. BroadcastStyle (AStyle (), BStyle ()))
235235end
236+ Broadcast. BroadcastStyle (:: ArrayPartitionStyle , :: Broadcast.DefaultArrayStyle{0} ) = Broadcast. DefaultArrayStyle {1} ()
237+ Broadcast. BroadcastStyle (:: ArrayPartitionStyle , :: Broadcast.DefaultArrayStyle{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
236238
237239combine_styles (args:: Tuple{} ) = Broadcast. DefaultArrayStyle {0} ()
238240combine_styles (args:: Tuple{Any} ) = Broadcast. result_style (Broadcast. BroadcastStyle (args[1 ]))
252254 ArrayPartition (f, N)
253255end
254256
255- @inline function Base. copyto! (dest:: ArrayPartition , bc:: Broadcast.Broadcasted )
257+ @inline function Base. copyto! (dest:: ArrayPartition , bc:: Broadcast.Broadcasted{ArrayPartitionStyle{Style}} ) where Style
256258 N = npartitions (dest, bc)
257259 for i in 1 : N
258260 copyto! (dest. x[i], unpack (bc, i))
@@ -293,3 +295,10 @@ common_number(a, b) =
293295 (b == 0 ? a :
294296 (a == b ? a :
295297 throw (DimensionMismatch (" number of partitions must be equal" ))))
298+
299+ # # Linear Algebra
300+
301+ ArrayInterface. zeromatrix (A:: ArrayPartition ) = ArrayInterface. zeromatrix (reduce (vcat,vec .(A. x)))
302+ LinearAlgebra. ldiv! (A:: LinearAlgebra.LU ,b:: ArrayPartition ) = ldiv! (A,Array (b))
303+ LinearAlgebra. ldiv! (A:: LinearAlgebra.QR ,b:: ArrayPartition ) = ldiv! (A,Array (b))
304+ LinearAlgebra. ldiv! (A:: LinearAlgebra.SVD ,b:: ArrayPartition ) = ldiv! (A,Array (b))
0 commit comments