diff --git a/src/Nonlinear/ReverseAD/utils.jl b/src/Nonlinear/ReverseAD/utils.jl index 231eb173b8..1b92fcd33b 100644 --- a/src/Nonlinear/ReverseAD/utils.jl +++ b/src/Nonlinear/ReverseAD/utils.jl @@ -164,9 +164,43 @@ function _UnsafeLowerTriangularMatrixView(x::Vector{Float64}, N::Int) return _UnsafeLowerTriangularMatrixView(N, pointer(x)) end +""" + _reinterpret_unsafe(::Type{T}, x::Vector{R}) where {T,R} + +Return an `_UnsafeVectorView` that act as a vector of element type +`T` over the same bytes as `x`. Note that if `length(x) * sizeof(R)` is not +a multiple of `sizeof(T)`, the last bits will be ignored. This is a key +difference with `reinterpret` which errors in that case. + +Given a vector of `Float64` of length equal to the maximum number of nodes of a +set of expressions time the maximum chunk size, this function is used to +reinterpret it as a vector of `ForwardDiff.Partials{N,T}` where `N` is the +chunk size of one of the expressions of the set. In that case, we know that +the vector has enough bytes and we don't care about the leftover bytes at the +end. + +## Examples + +```jldoctest +julia> import MathOptInterface as MOI + +julia> x = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] +3-element Vector{Tuple{Int64, Int64, Int64}}: + (1, 2, 3) + (4, 5, 6) + (7, 8, 9) + +julia> MOI.Nonlinear.ReverseAD._reinterpret_unsafe(NTuple{2,Int}, x) +4-element MathOptInterface.Nonlinear.ReverseAD._UnsafeVectorView{Tuple{Int64, Int64}}: + (1, 2) + (3, 4) + (5, 6) + (7, 8) +``` +""" function _reinterpret_unsafe(::Type{T}, x::Vector{R}) where {T,R} - # how many T's fit into x? @assert isbitstype(T) && isbitstype(R) + # how many T's fit into x? len = length(x) * sizeof(R) p = reinterpret(Ptr{T}, pointer(x)) return _UnsafeVectorView(0, div(len, sizeof(T)), p)