diff --git a/src/read.jl b/src/read.jl index 0c86c57..3368156 100644 --- a/src/read.jl +++ b/src/read.jl @@ -1,8 +1,8 @@ # Convenience function to allow for things like Array(tp) or CuArray(tp) # Not sure if this counts as type piracy... -(::Type{T})(p::TensorProto) where T = array(p) |> T -(::Type{Ref{T}})(p::TensorProto) where T = array(p) |> T |> Ref +(::Type{T})(p::TensorProto) where T <: AbstractArray = array(p) |> T +(::Type{Ref{T}})(p::TensorProto) where T <: AbstractArray = array(p) |> T |> Ref """ @@ -67,9 +67,9 @@ function attribute(p::AttributeProto) if (p._type != 0) field = [:f, :i, :s, :t, :g, :floats, :ints, :strings, :tensors, :graphs][p._type] if field === :s - return Symbol(p.name) => String(getfield(p, field)) + return Symbol(p.name) => String(copy(getfield(p, field))) elseif field === :strings - return Symbol(p.name) => String.(getfield(p, field)) + return Symbol(p.name) => String.(copy.(getfield(p, field))) end return Symbol(p.name) => getfield(p, field) end