diff --git a/Project.toml b/Project.toml index 3a2a4a3..e274282 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Static" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" authors = ["chriselrod", "ChrisRackauckas", "Tokazama"] -version = "0.7.6" +version = "0.7.7" [deps] IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" diff --git a/src/Static.jl b/src/Static.jl index 5a557f7..84f38fe 100644 --- a/src/Static.jl +++ b/src/Static.jl @@ -437,15 +437,6 @@ Base.real(@nospecialize(x::StaticNumber)) = x Base.real(@nospecialize(T::Type{<:StaticNumber})) = eltype(T) Base.imag(@nospecialize(x::StaticNumber)) = zero(x) -""" - field_type(::Type{T}, f) - -Functionally equivalent to `fieldtype(T, f)` except `f` may be a static type. -""" -@inline field_type(T::Type, f::Union{Int, Symbol}) = fieldtype(T, f) -@inline field_type(::Type{T}, ::StaticInt{N}) where {T, N} = fieldtype(T, N) -@inline field_type(::Type{T}, ::StaticSymbol{S}) where {T, S} = fieldtype(T, S) - Base.rad2deg(::StaticFloat64{M}) where {M} = StaticFloat64(rad2deg(M)) Base.deg2rad(::StaticFloat64{M}) where {M} = StaticFloat64(deg2rad(M)) @generated Base.cbrt(::StaticFloat64{M}) where {M} = StaticFloat64(cbrt(M)) @@ -939,4 +930,39 @@ function Base.show(io::IO, m::MIME"text/plain", @nospecialize(x::NDIndex)) show(io, m, Tuple(x)) end +# field and property accessors +""" + field_type(::Type{T}, f) + +Functionally equivalent to `fieldtype(T, f)` except `f` may be a static type. +""" +@inline field_type(T::Type, f::Union{Int, Symbol}) = fieldtype(T, f) +@inline field_type(::Type{T}, ::StaticInt{N}) where {T, N} = fieldtype(T, N) +@inline field_type(::Type{T}, ::StaticSymbol{S}) where {T, S} = fieldtype(T, S) + +function (::Base.Fix2{typeof(getfield), <:Union{StaticSymbol{f}, StaticInt{f}}})(x) where {f + } + getfield(x, f) +end +function (::Base.Fix2{typeof(fieldtype), <:Union{StaticSymbol{f}, StaticInt{f}}})(x) where { + f + } + fieldtype(x, f) +end + +Base.getproperty(x, ::StaticSymbol{S}) where {S} = getproperty(x, S) +Base.setproperty!(x, ::StaticSymbol{S}, v) where {S} = setproperty!(x, S, v) +Base.hasproperty(x, ::StaticSymbol{S}) where {S} = hasproperty(x, S) + +Base.getindex(nt::NamedTuple, ::StaticSymbol{S}) where {S} = getfield(nt, S) +function Base.getindex(nt::NamedTuple, idxs::Tuple{<:StaticSymbol, Vararg{<:StaticSymbol}}) + NamedTuple{known(idxs)}(nt) +end +function Base.setindex(nt::NamedTuple, v, ::StaticSymbol{S}) where {S} + merge(nt, NamedTuple{(S,)}((v,))) +end +function Base.setindex(nt::NamedTuple, vs, idxs::Tuple{Vararg{<:StaticSymbol}}) + merge(nt, NamedTuple{known(idxs)}((vs...,))) +end + end diff --git a/test/runtests.jl b/test/runtests.jl index 502ee35..346b828 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,22 @@ Aqua.test_all(Static) @test @inferred(StaticSymbol(x, y, z)) === static(:xy1) @test @inferred(static(nothing)) === nothing @test_throws ErrorException static([]) + nt = (x = 1, y = 2) + @test nt[static(:x)] === 1 + @test hasproperty(nt, static(:x)) + @test Base.setindex((a = 1, b = 2, c = 3), 4, static(:b)) == + (a = 1, b = 4, c = 3) + @test Base.setindex((a = 1, b = 2, c = 3), (4, 5), (static(:b), static(:d))) == + (a = 1, b = 4, c = 3, d = 5) + @test getindex((a = 1, b = 2, c = 3), (static(:b), static(:c))) == + (b = 2, c = 3) + mutable struct Foo + x::Int + end + f = Foo(3) + @test getproperty(f, static(:x)) === 3 + setproperty!(f, static(:x), 4) + @test Base.Fix2(getfield, static(:x))(f) === 4 end @testset "StaticInt" begin @@ -314,6 +330,7 @@ end @test @inferred(Static.permute(x, y)) === y @test @inferred(Static.eachop(getindex, x)) === x + @test Base.Fix2(fieldtype, static(:x))(typeof((x = 1, y = 2))) <: Int @test Static.field_type(typeof((x = 1, y = 2)), :x) <: Int @test Static.field_type(typeof((x = 1, y = 2)), static(:x)) <: Int function get_tuple_add(::Type{T}, ::Type{X}, dim::StaticInt) where {T, X}