Skip to content

Commit fe4dc89

Browse files
committed
fix strides for reshaped views of abstract vectors, cf. #160
1 parent 6a8efa6 commit fe4dc89

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

src/stridelayout.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888
contiguous_axis(::Type{T}) -> StaticInt{N}
8989
9090
Returns the axis of an array of type `T` containing contiguous data.
91-
If no axis is contiguous, it returns `StaticInt{-1}`.
91+
If no axis is contiguous, it returns a `StaticInt{-1}`.
9292
If unknown, it returns `nothing`.
9393
"""
9494
contiguous_axis(x) = contiguous_axis(typeof(x))
@@ -297,7 +297,7 @@ contiguous_batch_size(::Type{<:Base.ReinterpretArray{T,N,S,A}}) where {T,N,S,A}
297297
"""
298298
is_column_major(A) -> True/False
299299
300-
Returns `Val{true}` if elements of `A` are stored in column major order. Otherwise returns `Val{false}`.
300+
Returns `True()` if elements of `A` are stored in column major order. Otherwise returns `False()`.
301301
"""
302302
is_column_major(A) = is_column_major(stride_rank(A), contiguous_batch_size(A))
303303
is_column_major(sr::Nothing, cbs) = False()
@@ -310,10 +310,11 @@ _is_column_major(sr::R, cbs::StaticInt) where {R} = False()
310310
_is_column_major(sr::R, cbs::Union{StaticInt{0},StaticInt{-1}}) where {R} = is_increasing(sr)
311311

312312
"""
313-
dense_dims(::Type{T}) -> NTuple{N,Bool}
313+
dense_dims(::Type{<:AbstractArray{N}}) -> NTuple{N,Bool}
314314
315315
Returns a tuple of indicators for whether each axis is dense.
316-
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
316+
An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)`
317+
where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
317318
"""
318319
dense_dims(x) = dense_dims(typeof(x))
319320
function dense_dims(::Type{T}) where {T}
@@ -359,7 +360,7 @@ end
359360
if VERSION v"1.6.0-DEV.1581"
360361
@inline function dense_dims(::Type{A}) where {NB, NA, B <: AbstractArray{<:Any,NB},A<: Base.ReinterpretArray{<:Any, NA, <:Any, B, true}}
361362
ddb = dense_dims(B)
362-
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
363+
IfElse.ifelse(Static.le(StaticInt(NB), StaticInt(NA)), (True(), ddb...), Base.tail(ddb))
363364
end
364365
end
365366

@@ -464,13 +465,16 @@ julia> A = rand(3,4);
464465
465466
julia> ArrayInterface.strides(A)
466467
(static(1), 3)
468+
```
467469
468470
Additionally, the behavior differs from `Base.strides` for adjoint vectors:
469471
472+
```julia
470473
julia> x = rand(5);
471474
472475
julia> ArrayInterface.strides(x')
473476
(static(1), static(1))
477+
```
474478
475479
This is to support the pattern of using just the first stride for linear indexing, `x[i]`,
476480
while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`.
@@ -485,6 +489,17 @@ function strides(x)
485489
return Base.strides(x)
486490
end
487491
end
492+
493+
# Fixes the example of https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
494+
# TODO: Should be generalized to reshaped arrays wrapping more general array types
495+
function strides(A::ReshapedArray{T,N,P}) where {T, N, P<:AbstractVector}
496+
if defines_strides(A)
497+
return size_to_strides(size(A), first(strides(parent(A))))
498+
else
499+
return Base.strides(A)
500+
end
501+
end
502+
488503
@inline bmap(f::F, t::Tuple{}, x::Number) where {F} = ()
489504
@inline bmap(f::F, t::Tuple{T}, x::Number) where {F, T} = (f(first(t),x), )
490505
@inline bmap(f::F, t::Tuple, x::Number) where {F} = (f(first(t),x), bmap(f, Base.tail(t), x)...)

test/runtests.jl

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,22 @@ end
207207
@testset "Range Interface" begin
208208
@testset "Range Constructors" begin
209209
@test @inferred(StaticInt(1):StaticInt(10)) == 1:10
210-
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
210+
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
211211
@test @inferred(1:StaticInt(2):StaticInt(10)) == 1:2:10
212212
@test @inferred(StaticInt(1):StaticInt(2):10) == 1:2:10
213-
@test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10
213+
@test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10
214214
@test @inferred(1:2:StaticInt(10)) == 1:2:10
215215
@test @inferred(1:StaticInt(2):10) == 1:2:10
216-
@test @inferred(StaticInt(1):2:10) == 1:2:10
217-
@test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10
216+
@test @inferred(StaticInt(1):2:10) == 1:2:10
217+
@test @inferred(StaticInt(1):UInt(10)) === StaticInt(1):10
218218
@test @inferred(UInt(1):StaticInt(1):StaticInt(10)) === 1:StaticInt(10)
219219
@test @inferred(ArrayInterface.OptionallyStaticUnitRange{Int,Int}(1:10)) == 1:10
220220
@test @inferred(ArrayInterface.OptionallyStaticUnitRange(1:10)) == 1:10
221221

222222
@inferred(ArrayInterface.OptionallyStaticUnitRange(1:10))
223223

224-
@test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10
225-
@test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10
224+
@test @inferred(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 1, UInt(10))) == StaticInt(1):1:10
225+
@test @inferred(ArrayInterface.OptionallyStaticStepRange(UInt(1), 1, StaticInt(10))) == StaticInt(1):1:10
226226
@test @inferred(ArrayInterface.OptionallyStaticStepRange(1:10)) == 1:1:10
227227

228228
@test_throws ArgumentError ArrayInterface.OptionallyStaticUnitRange(1:2:10)
@@ -331,7 +331,6 @@ using OffsetArrays
331331
@test @inferred(ArrayInterface.defines_strides(D1))
332332
@test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1)))
333333
@test @inferred(ArrayInterface.defines_strides(DenseWrapper{Int,2,Matrix{Int}}))
334-
335334
@test @inferred(device(A)) === ArrayInterface.CPUPointer()
336335
@test @inferred(device(B)) === ArrayInterface.CPUIndex()
337336
@test @inferred(device(-1:19)) === ArrayInterface.CPUIndex()
@@ -372,7 +371,7 @@ using OffsetArrays
372371
@test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
373372
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing
374373
@test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing
375-
374+
376375
@test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false)
377376
@test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false)
378377
@test @inferred(ArrayInterface.contiguous_axis_indicator(B)) == (true,false,false)
@@ -424,7 +423,7 @@ using OffsetArrays
424423
@test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing
425424
@test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing
426425

427-
426+
428427
#=
429428
@btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2))))
430429
0.047 ns (0 allocations: 0 bytes)
@@ -494,11 +493,11 @@ using OffsetArrays
494493
@test @inferred(ArrayInterface.defines_strides(C1))
495494
@test @inferred(ArrayInterface.defines_strides(C2))
496495
@test @inferred(ArrayInterface.defines_strides(C3))
497-
496+
498497
@test @inferred(device(C1)) === ArrayInterface.CPUPointer()
499498
@test @inferred(device(C2)) === ArrayInterface.CPUPointer()
500499
@test @inferred(device(C3)) === ArrayInterface.CPUPointer()
501-
500+
502501
@test @inferred(contiguous_batch_size(C1)) === ArrayInterface.StaticInt(0)
503502
@test @inferred(contiguous_batch_size(C2)) === ArrayInterface.StaticInt(0)
504503
@test @inferred(contiguous_batch_size(C3)) === ArrayInterface.StaticInt(0)
@@ -510,7 +509,7 @@ using OffsetArrays
510509
@test @inferred(contiguous_axis(C1)) === StaticInt(1)
511510
@test @inferred(contiguous_axis(C2)) === StaticInt(0)
512511
@test @inferred(contiguous_axis(C3)) === StaticInt(2)
513-
512+
514513
@test @inferred(ArrayInterface.contiguous_axis_indicator(C1)) == (true,false,false,false)
515514
@test @inferred(ArrayInterface.contiguous_axis_indicator(C2)) == (false,false)
516515
@test @inferred(ArrayInterface.contiguous_axis_indicator(C3)) == (false,true)
@@ -675,7 +674,7 @@ end
675674
colormat = reinterpret(reshape, Float64, colors)
676675
@test @inferred(ArrayInterface.strides(colormat)) === (StaticInt(1), StaticInt(3))
677676
@test @inferred(ArrayInterface.dense_dims(colormat)) === (True(),True())
678-
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),)
677+
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4))) === (True(),)
679678
@test @inferred(ArrayInterface.dense_dims(view(colormat,:,4:7))) === (True(),True())
680679
@test @inferred(ArrayInterface.dense_dims(view(colormat,2:3,:))) === (True(),False())
681680

@@ -702,7 +701,7 @@ end
702701
@test @inferred(ArrayInterface.strides(Ac2r)) === (StaticInt(1), StaticInt(2), 10)
703702
Ac2r_static = reinterpret(reshape, Float64, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
704703
@test @inferred(ArrayInterface.strides(Ac2r_static)) === (StaticInt(1), StaticInt(2), StaticInt(10))
705-
704+
706705
Ac2t = reinterpret(reshape, Tuple{Float64,Float64}, view(rand(ComplexF64, 5, 7), 2:4, 3:6));
707706
@test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5)
708707
Ac2t_static = reinterpret(reshape, Tuple{Float64,Float64}, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
@@ -711,6 +710,26 @@ end
711710
end
712711
end
713712

713+
@testset "Reshaped views" begin
714+
# See
715+
# https://github.com/JuliaArrays/ArrayInterface.jl/issues/160
716+
# https://github.com/JuliaArrays/ArrayInterface.jl/issues/157
717+
u_base = randn(10, 10)
718+
u_view = view(u_base, 3, :)
719+
u_reshaped_view1 = reshape(u_view, 1, :)
720+
u_reshaped_view2 = reshape(u_view, 2, :)
721+
722+
@test @inferred(ArrayInterface.defines_strides(u_base))
723+
@test @inferred(ArrayInterface.defines_strides(u_view))
724+
@test @inferred(ArrayInterface.defines_strides(u_reshaped_view1))
725+
@test @inferred(ArrayInterface.defines_strides(u_reshaped_view2))
726+
727+
@test @inferred(ArrayInterface.strides(u_base)) == (StaticInt(1), 10)
728+
@test @inferred(ArrayInterface.strides(u_view)) == (10,)
729+
@test @inferred(ArrayInterface.strides(u_reshaped_view1)) == (10, 10)
730+
@test @inferred(ArrayInterface.strides(u_reshaped_view2)) == (10, 20)
731+
end
732+
714733
@test ArrayInterface.can_avx(ArrayInterface.can_avx) == false
715734

716735
@testset "can_change_size" begin
@@ -842,6 +861,6 @@ end
842861
@test @inferred(is_lazy_conjugate(d)) == false
843862
e = permutedims(d)
844863
@test @inferred(is_lazy_conjugate(e)) == false
845-
864+
846865
@test @inferred(is_lazy_conjugate([1,2,3]')) == false # We don't care about conj on `<:Real`
847866
end

0 commit comments

Comments
 (0)