From d318e2ca906044269dc50a55af922afbc88c1f0c Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Wed, 28 Oct 2020 00:21:56 +0100 Subject: [PATCH] Remove constructor hijacking --- src/read.jl | 33 +++++++++++++-------------------- test/readwrite.jl | 10 ++++------ 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/read.jl b/src/read.jl index 0c86c57..47a66bd 100644 --- a/src/read.jl +++ b/src/read.jl @@ -1,53 +1,46 @@ - -# Convenience function to allow for things like Array(tp) or CuArray(tp) -# Not sure if this counts as type piracy... -(::Type{T})(p::TensorProto) where T = array(p) |> T -(::Type{Ref{T}})(p::TensorProto) where T = array(p) |> T |> Ref - - """ - array(p::TensorProto) + array(p::TensorProto, wrap=Array) -Return `p` as an reshaped and reinterpreted array. +Return `p` as an `Array` of the correct type. Second argument can be used to change type of the returned array """ -function array(p::TensorProto) +function array(p::TensorProto, wrap=Array) # Copy pasted from jl # Can probably be cleaned up a bit # TODO: Add missing datatypes... if p.data_type === TensorProto_DataType.INT64 if isdefined(p, :int64_data) && !isempty(p.int64_data) - return reshape(reinterpret(Int64, p.int64_data), reverse(p.dims)...) + return reshape(reinterpret(Int64, p.int64_data), reverse(p.dims)...) |> wrap end - return reshape(reinterpret(Int64, p.raw_data), reverse(p.dims)...) + return reshape(reinterpret(Int64, p.raw_data), reverse(p.dims)...) |> wrap end if p.data_type === TensorProto_DataType.INT32 if isdefined(p, :int32_data) && !isempty(p.int32_data) - return reshape(p.int32_data , reverse(p.dims)...) + return reshape(p.int32_data , reverse(p.dims)...) |> wrap end - return reshape(reinterpret(Int32, p.raw_data), reverse(p.dims)...) + return reshape(reinterpret(Int32, p.raw_data), reverse(p.dims)...) |> wrap end if p.data_type === TensorProto_DataType.INT8 - return reshape(reinterpret(Int8, p.raw_data), reverse(p.dims)...) + return reshape(reinterpret(Int8, p.raw_data), reverse(p.dims)...) |> wrap end if p.data_type === TensorProto_DataType.DOUBLE if isdefined(p, :double_data) && !isempty(p.double_data) - return reshape(p.double_data , reverse(p.dims)...) + return reshape(p.double_data , reverse(p.dims)...) |> wrap end - return reshape(reinterpret(Float64, p.raw_data), reverse(p.dims)...) + return reshape(reinterpret(Float64, p.raw_data), reverse(p.dims)...) |> wrap end if p.data_type === TensorProto_DataType.FLOAT if isdefined(p,:float_data) && !isempty(p.float_data) - return reshape(reinterpret(Float32, p.float_data), reverse(p.dims)...) + return reshape(reinterpret(Float32, p.float_data), reverse(p.dims)...) |> wrap end - return reshape(reinterpret(Float32, p.raw_data), reverse(p.dims)...) + return reshape(reinterpret(Float32, p.raw_data), reverse(p.dims)...) |> wrap end if p.data_type === TensorProto_DataType.FLOAT16 - return reshape(reinterpret(Float16, p.raw_data), reverse(p.dims)...) + return reshape(reinterpret(Float16, p.raw_data), reverse(p.dims)...) |> wrap end end diff --git a/test/readwrite.jl b/test/readwrite.jl index 80024bf..a1b6e12 100644 --- a/test/readwrite.jl +++ b/test/readwrite.jl @@ -7,7 +7,7 @@ end @testset "TensorProto" begin - import BaseOnnx: TensorProto + import BaseOnnx: TensorProto, array @testset "Tensor type $T size $s" for T in (Int8, Int32, Int64, Float16, Float32, Float64), s in ((1,), (1, 2), @@ -15,10 +15,8 @@ (1, 2, 3, 4), (1, 2, 3, 4, 5)) exp = reshape(collect(T, 1:prod(s)), s...) - @test TensorProto(exp) |> serdeser |> Array == exp - + @test TensorProto(exp) |> serdeser |> array == exp end - end @testset "ValueInfo" begin @@ -41,7 +39,7 @@ end @testset "Attribute" begin - import BaseOnnx: AttributeProto, TensorProto, attribute + import BaseOnnx: AttributeProto, TensorProto, attribute, array @testset "Attribute type $(first(p))" for p in ( :Int64 => 12, @@ -62,7 +60,7 @@ @testset "Attribute type TensorProto" begin # TensorProto has undef fields which mess up straigh comparison arr = collect(1:4) - @test AttributeProto(:ff => TensorProto(arr)) |> serdeser |> attribute |> last |> Array == arr + @test AttributeProto(:ff => TensorProto(arr)) |> serdeser |> attribute |> last |> array == arr end @testset "Attribute Dict" begin