Skip to content

Commit 85e93de

Browse files
authored
Merge pull request #69 from SciML/static
Added Static integers.
2 parents c83328b + 581b18a commit 85e93de

File tree

6 files changed

+183
-38
lines changed

6 files changed

+183
-38
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Requires = "0.5, 1.0"
1212
julia = "1.2"
1313

1414
[extras]
15+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1516
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1617
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1718
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
@@ -21,4 +22,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2223

2324
[targets]
24-
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random"]
25+
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "Aqua"]

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ Otherwise, returns `nothing`. For example, `known_step(UnitRange{Int})` returns
134134
If `length` of an instance of type `T` is known at compile time, return it.
135135
Otherwise, return `nothing`.
136136

137+
## Static(N::Int)
138+
139+
Creates a static integer with value known at compile time. It is a number,
140+
supporting basic arithmetic. Many operations with two `Static` integers
141+
will produce another `Static` integer. If one of the arguments to a
142+
function call isn't static (e.g., `Static(4) + 3`) then the `Static`
143+
number will promote to a dynamic value.
144+
137145
# List of things to add
138146

139147
- https://github.com/JuliaLang/julia/issues/22216

src/ArrayInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ function __init__()
699699
end
700700
end
701701

702+
include("static.jl")
702703
include("ranges.jl")
703704

704705
end

src/ranges.jl

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
4343
# add methods to support ArrayInterface
4444

4545
_get(x) = x
46-
_get(::Val{V}) where {V} = V
46+
_get(::Static{V}) where {V} = V
47+
_get(::Type{Static{V}}) where {V} = V
4748
_convert(::Type{T}, x) where {T} = convert(T, x)
4849
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))
4950

@@ -56,7 +57,7 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
5657
from other valid indices. Therefore, users should not expect the same checks are used
5758
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
5859
"""
59-
struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
60+
struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T}
6061
start::F
6162
stop::L
6263

@@ -79,28 +80,26 @@ struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
7980

8081
function OptionallyStaticUnitRange(x::AbstractRange)
8182
if step(x) == 1
82-
fst = known_first(x)
83-
fst = fst === nothing ? first(x) : Val(fst)
84-
lst = known_last(x)
85-
lst = lst === nothing ? last(x) : Val(lst)
83+
fst = static_first(x)
84+
lst = static_last(x)
8685
return OptionallyStaticUnitRange(fst, lst)
8786
else
8887
throw(ArgumentError("step must be 1, got $(step(r))"))
8988
end
9089
end
9190
end
9291

93-
Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
94-
Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start
92+
Base.:(:)(L::Integer, ::Static{U}) where {U} = OptionallyStaticUnitRange(L, Static(U))
93+
Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(L), U)
94+
Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U))
9595

96+
Base.first(r::OptionallyStaticUnitRange) = r.start
9697
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
98+
Base.last(r::OptionallyStaticUnitRange) = r.stop
9799

98-
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L
99-
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop
100-
101-
known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F
100+
known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F
102101
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
103-
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L
102+
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L
104103

105104
function Base.isempty(r::OptionallyStaticUnitRange)
106105
if known_first(r) === oneunit(eltype(r))
@@ -141,10 +140,20 @@ end
141140
return convert(eltype(r), val)
142141
end
143142

144-
_try_static(x, y) = Val(x)
145-
_try_static(::Nothing, y) = Val(y)
146-
_try_static(x, ::Nothing) = Val(x)
147-
_try_static(::Nothing, ::Nothing) = nothing
143+
@inline _try_static(::Static{N}, ::Static{N}) where {N} = Static{N}()
144+
@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()"
145+
function _try_static(::Static{N}, x) where {N}
146+
@assert N == x "Unequal Indices: Static{$N}() != x == $x"
147+
Static{N}()
148+
end
149+
function _try_static(x, ::Static{N}) where {N}
150+
@assert N == x "Unequal Indices: x == $x != Static{$N}()"
151+
Static{N}()
152+
end
153+
function _try_static(x, y)
154+
@assert x == y "Unequal Indicess: x == $x != $y == y"
155+
x
156+
end
148157

149158
###
150159
### length
@@ -193,7 +202,7 @@ specified then indices for visiting each index of `x` is returned.
193202
"""
194203
@inline function indices(x)
195204
inds = eachindex(x)
196-
if inds isa AbstractUnitRange{<:Integer}
205+
if inds isa AbstractUnitRange && eltype(inds) <: Integer
197206
return Base.Slice(OptionallyStaticUnitRange(inds))
198207
else
199208
return inds
@@ -202,30 +211,24 @@ end
202211

203212
function indices(x::Tuple)
204213
inds = map(eachindex, x)
205-
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
206214
return reduce(_pick_range, inds)
207215
end
208216

209-
indices(x, d) = indices(axes(x, d))
217+
@inline indices(x, d) = indices(axes(x, d))
210218

211-
@inline function indices(x::NTuple{N,<:Any}, dim) where {N}
219+
@inline function indices(x::Tuple{Vararg{Any,N}}, dim) where {N}
212220
inds = map(x_i -> indices(x_i, dim), x)
213-
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
214221
return reduce(_pick_range, inds)
215222
end
216223

217-
@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N}
224+
@inline function indices(x::Tuple{Vararg{Any,N}}, dim::Tuple{Vararg{Any,N}}) where {N}
218225
inds = map(indices, x, dim)
219-
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
220226
return reduce(_pick_range, inds)
221227
end
222228

223229
@inline function _pick_range(x, y)
224-
fst = _try_static(known_first(x), known_first(y))
225-
fst = fst === nothing ? first(x) : fst
226-
227-
lst = _try_static(known_last(x), known_last(y))
228-
lst = lst === nothing ? last(x) : lst
230+
fst = _try_static(static_first(x), static_first(y))
231+
lst = _try_static(static_last(x), static_last(y))
229232
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
230233
end
231234

src/static.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
2+
"""
3+
A statically sized `Int`.
4+
Use `Static(N)` instead of `Val(N)` when you want it to behave like a number.
5+
"""
6+
struct Static{N} <: Integer
7+
Static{N}() where {N} = new{N::Int}()
8+
end
9+
Base.@pure Static(N::Int) = Static{N}()
10+
Static(N::Integer) = Static(convert(Int, N))
11+
Static(::Static{N}) where {N} = Static{N}()
12+
Static(::Val{N}) where {N} = Static{N}()
13+
Base.Val(::Static{N}) where {N} = Val{N}()
14+
Base.convert(::Type{T}, ::Static{N}) where {T<:Number,N} = convert(T, N)
15+
Base.convert(::Type{Static{N}}, ::Static{N}) where {N} = Static{N}()
16+
17+
Base.promote_rule(::Type{<:Static}, ::Type{T}) where {T <: AbstractIrrational} = promote_rule(Int, T)
18+
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T <: AbstractIrrational} = promote_rule(T, Int)
19+
for (S,T) [(:Complex,:Real), (:Rational, :Integer), (:(Base.TwicePrecision),:Any)]
20+
@eval Base.promote_rule(::Type{$S{T}}, ::Type{<:Static}) where {T <: $T} = promote_rule($S{T}, Int)
21+
end
22+
Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:Static}) = Union{Nothing, Missing, Int}
23+
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int)
24+
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Nothing} = promote_rule(T, Int)
25+
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Missing} = promote_rule(T, Int)
26+
for T [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any]
27+
# let S = :Any
28+
@eval begin
29+
Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: Static} = promote_rule(Int, $T)
30+
Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: Static} = promote_rule($T, Int)
31+
end
32+
end
33+
Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int
34+
Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N
35+
36+
Base.iszero(::Static{0}) = true
37+
Base.iszero(::Static) = false
38+
Base.isone(::Static{1}) = true
39+
Base.isone(::Static) = false
40+
41+
for T = [:Real, :Rational, :Integer]
42+
@eval begin
43+
@inline Base.:(+)(i::$T, ::Static{0}) = i
44+
@inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M
45+
@inline Base.:(+)(::Static{0}, i::$T) = i
46+
@inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i
47+
@inline Base.:(-)(i::$T, ::Static{0}) = i
48+
@inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M
49+
@inline Base.:(*)(i::$T, ::Static{0}) = Static{0}()
50+
@inline Base.:(*)(i::$T, ::Static{1}) = i
51+
@inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M
52+
@inline Base.:(*)(::Static{0}, i::$T) = Static{0}()
53+
@inline Base.:(*)(::Static{1}, i::$T) = i
54+
@inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i
55+
end
56+
end
57+
@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}()
58+
@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}()
59+
@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}()
60+
61+
@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}()
62+
63+
@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}()
64+
@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}()
65+
@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}()
66+
@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}()
67+
@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}()
68+
@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}()
69+
@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}()
70+
@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}()
71+
for f [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :()]
72+
@eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N)))
73+
end
74+
for f [:(==), :(!=), :(<), :(), :(>), :()]
75+
@eval begin
76+
@inline Base.$f(::Static{M}, ::Static{N}) where {M,N} = $f(M, N)
77+
@inline Base.$f(::Static{M}, x::Int) where {M} = $f(M, x)
78+
@inline Base.$f(x::Int, ::Static{M}) where {M} = $f(x, M)
79+
end
80+
end
81+
82+
@inline function maybe_static(f::F, g::G, x) where {F, G}
83+
L = f(x)
84+
isnothing(L) ? g(x) : Static(L)
85+
end
86+
@inline static_length(x) = maybe_static(known_length, length, x)
87+
@inline static_first(x) = maybe_static(known_first, first, x)
88+
@inline static_last(x) = maybe_static(known_last, last, x)
89+
@inline static_step(x) = maybe_static(known_step, step, x)
90+

test/runtests.jl

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
using ArrayInterface, Test
22
using Base: setindex
3-
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance
3+
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, Static
44
@test ArrayInterface.ismutable(rand(3))
55

6+
using Aqua
7+
Aqua.test_all(ArrayInterface)
8+
69
using StaticArrays
710
x = @SVector [1,2,3]
811
@test ArrayInterface.ismutable(x) == false
@@ -220,12 +223,51 @@ end
220223
end
221224

222225
@testset "indices" begin
223-
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)))) == 1:6
224-
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
225-
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2
226-
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2
227-
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(2, 3)), 1)) == 1:2
228-
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), 1)
229-
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), (1, 2))
226+
A23 = ones(2,3); SA23 = @SMatrix ones(2,3);
227+
A32 = ones(3,2); SA32 = @SMatrix ones(3,2);
228+
@test @inferred(ArrayInterface.indices((A23, A32))) == 1:6
229+
@test @inferred(ArrayInterface.indices((SA23, A32))) == 1:6
230+
@test @inferred(ArrayInterface.indices((A23, SA32))) == 1:6
231+
@test @inferred(ArrayInterface.indices((SA23, SA32))) == 1:6
232+
@test @inferred(ArrayInterface.indices(A23)) == 1:6
233+
@test @inferred(ArrayInterface.indices(SA23)) == 1:6
234+
@test @inferred(ArrayInterface.indices(A23, 1)) == 1:2
235+
@test @inferred(ArrayInterface.indices(SA23, Static(1))) === Base.Slice(Static(1):Static(2))
236+
@test @inferred(ArrayInterface.indices((A23, A32), (1, 2))) == 1:2
237+
@test @inferred(ArrayInterface.indices((SA23, A32), (Static(1), 2))) === Base.Slice(Static(1):Static(2))
238+
@test @inferred(ArrayInterface.indices((A23, SA32), (1, Static(2)))) === Base.Slice(Static(1):Static(2))
239+
@test @inferred(ArrayInterface.indices((SA23, SA32), (Static(1), Static(2)))) === Base.Slice(Static(1):Static(2))
240+
@test @inferred(ArrayInterface.indices((A23, A23), 1)) == 1:2
241+
@test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
242+
@test @inferred(ArrayInterface.indices((SA23, A23), Static(1))) === Base.Slice(Static(1):Static(2))
243+
@test @inferred(ArrayInterface.indices((A23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
244+
@test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
245+
@test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), 1)
246+
@test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), (1, 2))
247+
@test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), Static(1))
248+
@test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), (Static(1), 2))
249+
@test_throws AssertionError ArrayInterface.indices((SA23, SA23), (Static(1), Static(2)))
250+
end
251+
252+
@testset "Static" begin
253+
@test iszero(Static(0))
254+
@test !iszero(Static(1))
255+
# test for ambiguities and correctness
256+
for i [Static(0), Static(1), Static(2), 3]
257+
for j [Static(0), Static(1), Static(2), 3]
258+
i === j === 3 && continue
259+
for f [+, -, *, ÷, %, <<, >>, >>>, &, |, , ==, , ]
260+
(iszero(j) && ((f === ÷) || (f === %))) && continue # integer division error
261+
@test convert(Int, @inferred(f(i,j))) == f(convert(Int, i), convert(Int, j))
262+
end
263+
end
264+
i == 3 && break
265+
for f [+, -, *, /, ÷, %, ==, , ]
266+
x = f(convert(Int, i), 1.4)
267+
y = f(1.4, convert(Int, i))
268+
@test convert(typeof(x), @inferred(f(i, 1.4))) === x
269+
@test convert(typeof(y), @inferred(f(1.4, i))) === y # if f is division and i === Static(0), returns `NaN`; hence use of ==== in check.
270+
end
271+
end
230272
end
231273

0 commit comments

Comments
 (0)