Skip to content

Commit 6a8efa6

Browse files
authored
Merge pull request JuliaArrays#156 from ranocha/hr/fix_contiguous_batch_size
fix `contiguous_batch_size` for reshaped views
2 parents 9789fa4 + ada0cf2 commit 6a8efa6

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/stridelayout.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ function _contiguous_batch_size(::StaticInt{D}, ::R) where {D,R<:Tuple}
269269
return nothing
270270
end
271271
end
272+
_contiguous_batch_size(::StaticInt{-1}, ::R) where {R<:Tuple} = -One()
272273

273274
contiguous_batch_size(::Type{Array{T,N}}) where {T,N} = Zero()
274275
contiguous_batch_size(::Type{BitArray{N}}) where {N} = Zero()

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,12 @@ using OffsetArrays
398398
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) === ArrayInterface.StaticInt(-1)
399399
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1)
400400
@test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(0)
401+
let u_base = randn(10, 10)
402+
u_view = view(u_base, 3, :)
403+
u_reshaped_view = reshape(u_view, 1, size(u_base, 2))
404+
@test @inferred(contiguous_batch_size(u_view)) === ArrayInterface.StaticInt(-1)
405+
@test @inferred(contiguous_batch_size(u_reshaped_view)) === ArrayInterface.StaticInt(-1)
406+
end
401407

402408
@test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == (1, 2, 3)
403409
@test @inferred(stride_rank(A)) == (1,2,3)

0 commit comments

Comments
 (0)