Skip to content

Commit 299ce1e

Browse files
committed
add static size information for hvcat and hvncat
This significantly improves the performance of multi-dimensional array creation via scalar numbers using methods of [a b; c d], [a b;; c d] and typed T[a b; c d], T[a b;; c d] For small numeric array creation(length <= 16), manual loop unroll is used to further minimize the overhead, and it now has zero overhead and is as fast as the array initialization method.
1 parent d0a521f commit 299ce1e

File tree

3 files changed

+314
-6
lines changed

3 files changed

+314
-6
lines changed

base/abstractarray.jl

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,6 +2255,77 @@ end
22552255

22562256
typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as...) where T = typed_hvncat(T, rows_to_dimshape(rows), true, as...)
22572257

2258+
# A fast version of hvcat for the case where we have static size information of xs
2259+
# and the number of rows is known at compile time -- we can eliminate all the runtime
2260+
# size checks. For cases that static size information is not beneficial, we fall back to
2261+
# the general hvcat/typed_hvcat methods.
2262+
@generated function typed_hvcat_static(::Type{T}, ::Val{rows}, xs::Number...) where {T<:Number, rows}
2263+
nr = length(rows)
2264+
nc = rows[1]
2265+
for i = 2:nr
2266+
if nc != rows[i]
2267+
return quote
2268+
msg = "row " * string($i) * " has mismatched number of columns (expected " * string($nc) * ", got " * string($rows[$i]) * ")"
2269+
throw(DimensionMismatch(msg))
2270+
end
2271+
end
2272+
end
2273+
2274+
len = length(xs)
2275+
if nr*nc != len
2276+
return quote
2277+
msg = "argument count " * string($len) * " does not match specified shape " * string(($nr, $nc))
2278+
throw(ArgumentError(msg))
2279+
end
2280+
end
2281+
2282+
if len <= 16
2283+
# For small array construction, manually unroll the loop for better performance
2284+
assigns = Expr[]
2285+
k = 1
2286+
for i in 1:nr
2287+
for j in 1:nc
2288+
ex = :(a[$i, $j] = xs[$k])
2289+
push!(assigns, ex)
2290+
k += 1
2291+
end
2292+
end
2293+
2294+
return quote
2295+
a = Matrix{$T}(undef, $nr, $nc)
2296+
$(assigns...)
2297+
return a
2298+
end
2299+
end
2300+
2301+
quote
2302+
a = Matrix{$T}(undef, $nr, $nc)
2303+
k = 1
2304+
@inbounds for i in 1:$nr
2305+
for j in 1:$nc
2306+
a[i,j] = xs[k]
2307+
k += 1
2308+
end
2309+
end
2310+
a
2311+
end
2312+
end
2313+
@inline function hvcat_static(::Val{rows}, x::T, xs::Vararg{T}) where {rows, T<:Number}
2314+
typed_hvcat_static(T, Val{rows}(), x, xs...)
2315+
end
2316+
@inline function hvcat_static(::Val{rows}, xs::Number...) where {rows}
2317+
typed_hvcat_static(promote_typeof(xs...), Val{rows}(), xs...)
2318+
end
2319+
@inline function typed_hvcat_static(::Type{T}, ::Val{rows}, xs...) where {T, rows}
2320+
# fallback to the general case
2321+
typed_hvcat(T, rows, xs...)
2322+
end
2323+
@inline function hvcat_static(::Val{rows}, xs...) where {rows}
2324+
# fallback to the general case
2325+
hvcat(rows, xs...)
2326+
end
2327+
2328+
22582329
## N-dimensional concatenation ##
22592330

22602331
"""
@@ -2750,6 +2821,94 @@ end
27502821
Ai
27512822
end
27522823

2824+
# Static version of hvncat for better performance with scalar numbers
2825+
# See the comments for hvcat_static for more details.
2826+
@generated function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs::Number...) where {T<:Number, dims, row_first}
2827+
for d in dims
2828+
if d <= 0
2829+
return quote
2830+
throw(ArgumentError("`dims` argument must contain positive integers"))
2831+
end
2832+
end
2833+
end
2834+
2835+
N = length(dims)
2836+
lengtha = prod(dims)
2837+
lengthx = length(xs)
2838+
if lengtha != lengthx
2839+
return quote
2840+
msg = "argument count does not match specified shape (expected " * string($lengtha) * ", got " * string($lengthx) * ")"
2841+
throw(ArgumentError(msg))
2842+
end
2843+
end
2844+
2845+
if lengthx <= 16
2846+
# For small array construction, manually unroll the loop
2847+
assigns = Expr[]
2848+
nr, nc = dims[1], dims[2]
2849+
na = if N > 2
2850+
n = 1
2851+
for d in 3:N
2852+
n *= dims[d]
2853+
end
2854+
n
2855+
else
2856+
1
2857+
end
2858+
nrc = nr * nc
2859+
2860+
if row_first
2861+
k = 1
2862+
for d in 1:na
2863+
dd = nrc * (d - 1)
2864+
for i in 1:nr
2865+
Ai = dd + i
2866+
for j in 1:nc
2867+
ex = :(A[$Ai] = xs[$k])
2868+
push!(assigns, ex)
2869+
k += 1
2870+
Ai += nr
2871+
end
2872+
end
2873+
end
2874+
else
2875+
k = 1
2876+
for i in 1:lengtha
2877+
ex = :(A[$i] = xs[$k])
2878+
push!(assigns, ex)
2879+
k += 1
2880+
end
2881+
end
2882+
2883+
return quote
2884+
A = Array{$T, $N}(undef, $dims...)
2885+
$(assigns...)
2886+
return A
2887+
end
2888+
end
2889+
2890+
# For larger arrays, use the regular loop
2891+
quote
2892+
A = Array{$T, $N}(undef, $dims...)
2893+
hvncat_fill!(A, $row_first, xs)
2894+
return A
2895+
end
2896+
end
2897+
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, x::T, xs::Vararg{T}) where {dims, row_first, T<:Number}
2898+
typed_hvncat_static(T, Val{dims}(), Val{row_first}(), x, xs...)
2899+
end
2900+
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs::Number...) where {dims, row_first}
2901+
typed_hvncat_static(promote_typeof(xs...), Val{dims}(), Val{row_first}(), xs...)
2902+
end
2903+
@inline function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs...) where {T, dims, row_first}
2904+
# fallback to the general case
2905+
typed_hvncat(T, dims, row_first, xs...)
2906+
end
2907+
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs...) where {dims, row_first}
2908+
# fallback to the general case
2909+
hvncat(dims, row_first, xs...)
2910+
end
2911+
27532912
"""
27542913
stack(iter; [dims])
27552914

src/julia-syntax.scm

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,7 +2254,7 @@
22542254

22552255
(define (expand-vcat e
22562256
(vcat '((top vcat)))
2257-
(hvcat '((top hvcat)))
2257+
(hvcat '((top hvcat_static)))
22582258
(hvcat_rows '((top hvcat_rows))))
22592259
(let ((a (cdr e)))
22602260
(if (any assignment? a)
@@ -2276,11 +2276,13 @@
22762276
(if (any (lambda (row) (any vararg? row)) rows)
22772277
`(call ,@hvcat_rows ,@(map (lambda (x) `(tuple ,@x)) rows))
22782278
`(call ,@hvcat
2279-
(tuple ,@(map length rows))
2279+
(new (curly (top Val) (call (core tuple) ,@(map length rows))))
22802280
,@(apply append rows))))
22812281
`(call ,@vcat ,@a))))))
22822282

2283-
(define (expand-ncat e (hvncat '((top hvncat))))
2283+
(define (expand-ncat e
2284+
(hvncat '((top hvncat)))
2285+
(hvncat_static '((top hvncat_static))))
22842286
(define (is-row a) (and (pair? a)
22852287
(or (eq? (car a) 'row)
22862288
(eq? (car a) 'nrow))))
@@ -2384,7 +2386,7 @@
23842386
(let ((shape (get-shape a is-row-first d)))
23852387
(if (is-balanced shape)
23862388
(let ((dims `(tuple ,@(reverse (get-dims a is-row-first d)))))
2387-
`(call ,@hvncat ,dims ,(tf is-row-first) ,@aflat))
2389+
`(call ,@hvncat_static (new (curly (top Val) ,dims)) (new (curly (top Val) ,(tf is-row-first))) ,@aflat))
23882390
`(call ,@hvncat ,(tuplize shape) ,(tf is-row-first) ,@aflat))))))))
23892391

23902392
(define (maybe-ssavalue lhss x in-lhs?)
@@ -2899,13 +2901,13 @@
28992901
(lambda (e)
29002902
(let ((t (cadr e))
29012903
(e (cdr e)))
2902-
(expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat) ,t) `((top typed_hvcat_rows) ,t))))
2904+
(expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat_static) ,t) `((top typed_hvcat_rows) ,t))))
29032905

29042906
'typed_ncat
29052907
(lambda (e)
29062908
(let ((t (cadr e))
29072909
(e (cdr e)))
2908-
(expand-ncat e `((top typed_hvncat) ,t))))
2910+
(expand-ncat e `((top typed_hvncat) ,t) `((top typed_hvncat_static) ,t))))
29092911

29102912
'|'| (lambda (e) (expand-forms `(call |'| ,(cadr e))))
29112913

test/abstractarray.jl

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,6 +1745,153 @@ using Base: typed_hvncat
17451745
@test [["A";"B"];;"C";"D"] == ["A" "C"; "B" "D"]
17461746
end
17471747

1748+
@testset "array construction using numbers" begin
1749+
# test array construction using hvcat [x x; y y] and hvncat [x x;;; y y]
1750+
1751+
@testset "hvcat array construction" begin
1752+
function test_hvcat(x1, x2, x3, x4)
1753+
# small arrays are constructed differently (manually unrolled)
1754+
A = [x1 x2; x3 x4]
1755+
@test A[1,1] == x1
1756+
@test A[1,2] == x2
1757+
@test A[2,1] == x3
1758+
@test A[2,2] == x4
1759+
1760+
AT = Float64[x1 x2; x3 x4]
1761+
@test AT == A
1762+
1763+
# large arrays are constructed using a loop
1764+
A = [x1 x2 x3 x4; x2 x3 x4 x1; x3 x4 x1 x2; x4 x1 x2 x3; x1 x2 x3 x4]
1765+
@test A[1,1] == x1
1766+
@test A[1,2] == x2
1767+
@test A[1,3] == x3
1768+
@test A[1,4] == x4
1769+
@test A[2,1] == x2
1770+
@test A[2,2] == x3
1771+
@test A[2,3] == x4
1772+
@test A[2,4] == x1
1773+
@test A[3,1] == x3
1774+
@test A[3,2] == x4
1775+
@test A[3,3] == x1
1776+
@test A[3,4] == x2
1777+
@test A[4,1] == x4
1778+
@test A[4,2] == x1
1779+
@test A[4,3] == x2
1780+
@test A[4,4] == x3
1781+
@test A[5,1] == x1
1782+
@test A[5,2] == x2
1783+
@test A[5,3] == x3
1784+
@test A[5,4] == x4
1785+
AT = Float64[x1 x2 x3 x4; x2 x3 x4 x1; x3 x4 x1 x2; x4 x1 x2 x3; x1 x2 x3 x4]
1786+
@test AT == A
1787+
end
1788+
1789+
test_hvcat(1, 2, 3, 4)
1790+
test_hvcat(1.0, 2.0, 3.0, 4.0)
1791+
test_hvcat(1.0, 2, 3.0, 4)
1792+
end
1793+
1794+
@testset "hvncat array construction" begin
1795+
# Test hvncat with dims as Int and Tuple, row_first true/false
1796+
# Testing 3D and 4D outputs, same and mixed element types, and different sizes
1797+
1798+
function test_hvncat(x1, x2, x3, x4)
1799+
# 3D arrays with dims as Int (row-first by default)
1800+
A = [x1 x2;;; x3 x4] # 1x2x2 Array (row-first)
1801+
@test size(A) == (1, 2, 2)
1802+
@test A[1,1,1] == x1
1803+
@test A[1,1,2] == x3
1804+
@test A[1,2,1] == x2
1805+
@test A[1,2,2] == x4
1806+
1807+
AT = Float64[x1 x2;;; x3 x4]
1808+
@test AT == A
1809+
1810+
A = [x1 x2; x3 x4;;; x2 x3; x4 x1;;; x3 x4; x1 x2;;; x4 x1; x2 x3;;; x1 x2; x3 x4]
1811+
@test size(A) == (2, 2, 5)
1812+
1813+
@test A[:, :, 1] == [x1 x2; x3 x4]
1814+
@test A[:, :, 2] == [x2 x3; x4 x1]
1815+
@test A[:, :, 3] == [x3 x4; x1 x2]
1816+
@test A[:, :, 4] == [x4 x1; x2 x3]
1817+
@test A[:, :, 5] == [x1 x2; x3 x4]
1818+
1819+
AT = Float64[x1 x2; x3 x4;;; x2 x3; x4 x1;;; x3 x4; x1 x2;;; x4 x1; x2 x3;;; x1 x2; x3 x4]
1820+
@test AT == A
1821+
end
1822+
1823+
test_hvncat(1, 2, 3, 4)
1824+
test_hvncat(1.0, 2.0, 3.0, 4.0)
1825+
test_hvncat(1.0, 2, 3.0, 4)
1826+
end
1827+
1828+
@testset "hvcat vs hvcat_static" begin
1829+
# number cases generate the same result
1830+
@test Base.hvcat_static(Val{(2,2)}(), 1, 2, 3, 4) == hvcat((2, 2), 1, 2, 3, 4)
1831+
@test Base.hvcat_static(Val{(2,2)}(), 1, 2, 3.0, 4.0) == hvcat((2, 2), 1, 2, 3.0, 4.0)
1832+
@test Base.typed_hvcat_static(Float64, Val{(2,2)}(), 1, 2, 3, 4) == Base.typed_hvcat(Float64, (2, 2), 1, 2, 3, 4)
1833+
@test Base.typed_hvcat_static(Float64, Val{(2,2)}(), 1, 2, 3.0, 4.0) == Base.typed_hvcat(Float64, (2, 2), 1, 2, 3.0, 4.0)
1834+
1835+
# non-number cases will be fallbacks to hvcat
1836+
@test Base.hvcat_static(Val{(2,2)}(), "a", "b", "c", "d") == hvcat((2, 2), "a", "b", "c", "d")
1837+
@test Base.typed_hvcat_static(String, Val{(2,2)}(), "a", "b", "c", "d") == Base.typed_hvcat(String, (2, 2), "a", "b", "c", "d")
1838+
1839+
# non-scalar cases will be fallbacks to hvcat
1840+
@test Base.hvcat_static(Val{(2,2)}(), [1 2], [2 2], [3 3], [4 4]) == hvcat((2, 2), [1 2], [2 2], [3 3], [4 4])
1841+
@test Base.typed_hvcat_static(Float64, Val{(2,2)}(), [1 2], [2 2], [3 3], [4 4]) == Base.typed_hvcat(Float64, (2, 2), [1 2], [2 2], [3 3], [4 4])
1842+
1843+
@test_throws DimensionMismatch hvcat((2,4), 2, 3, 4, 5)
1844+
@test_throws DimensionMismatch Base.hvcat_static(Val{(2,4)}(), 2, 3, 4, 5)
1845+
end
1846+
1847+
@testset "hvncat vs hvncat_static" begin
1848+
# basic test
1849+
x = rand(8)
1850+
1851+
A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{true}(), x...)
1852+
B = hvncat((2, 2, 2), true, x...)
1853+
@test A == B
1854+
1855+
A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{false}(), x...)
1856+
B = hvncat((2, 2, 2), false, x...)
1857+
@test A == B
1858+
1859+
# test different eltypes
1860+
x, y = rand(4), rand(1:10, 4)
1861+
A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{true}(), x..., y...)
1862+
B = hvncat((2, 2, 2), true, x..., y...)
1863+
@test A == B
1864+
1865+
A = Base.hvncat_static(Val{(2, 2, 2)}(), Val{false}(), x..., y...)
1866+
B = hvncat((2, 2, 2), false, x..., y...)
1867+
@test A == B
1868+
1869+
# test large array
1870+
x = rand(24)
1871+
A = Base.hvncat_static(Val{(2, 3, 4)}(), Val{true}(), x...)
1872+
B = hvncat((2, 3, 4), true, x...)
1873+
@test A == B
1874+
1875+
A = Base.hvncat_static(Val{(2, 3, 4)}(), Val{false}(), x...)
1876+
B = hvncat((2, 3, 4), false, x...)
1877+
@test A == B
1878+
end
1879+
1880+
@testset "Static method doesn't apply to non-number types" begin
1881+
# StaticArrays "abused" the syntax SA[1 2; 3 4] to create a static array
1882+
# but unfortunately SA isn't a number-like type.
1883+
# We need to ensure that our typed_hvcat_static specialization doesn't affect its usage
1884+
# and returns the same result as the non-static version.
1885+
1886+
struct FOO end
1887+
Base.typed_hvcat(::Type{FOO}, dims::Dims, xs::Number...) = "Foo typed_hvcat"
1888+
Base.typed_hvncat(::Type{FOO}, dims::Dims, row_first::Bool, xs::Number...) = "Foo typed_hvncat"
1889+
1890+
@test FOO[1 2; 3 4] == "Foo typed_hvcat"
1891+
@test FOO[1 2;;; 3 4] == "Foo typed_hvncat"
1892+
end
1893+
end
1894+
17481895
@testset "stack" begin
17491896
# Basics
17501897
for args in ([[1, 2]], [1:2, 3:4], [[1 2; 3 4], [5 6; 7 8]],

0 commit comments

Comments
 (0)