Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve setall() inference #120

Merged
merged 2 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 54 additions & 32 deletions src/getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ getall(obj, optic::ComposedFunction) = _getall(obj, decompose(optic))
function setall(obj, optic::ComposedFunction, vs)
optics = decompose(optic)
N = length(optics)
vss = to_nested_shape(vs, Val(getall_lengths(obj, optics)), Val(N))
lengths = getall_lengths(obj, optics)

total_length = _val(nestedsum(lengths))
length(vs) == total_length || throw(DimensionMismatch("tried to assign $(length(vs)) elements to $total_length destinations"))

vss = to_nested_shape(vs, lengths, Val(N))
_setall(obj, optics, vss)
end

Expand Down Expand Up @@ -140,46 +145,63 @@ _staticlength(x::AbstractVector) = length(x)

getall_lengths(obj, optics::Tuple{Any}) = _staticlength(getall(obj, only(optics)))
for N in [2:10; :(<: Any)]
@eval function getall_lengths(obj, optics::NTuple{$N,Any})
# convert to Tuple: vectors cannot be put into Val
map(getall(obj, last(optics)) |> Tuple) do o
@eval getall_lengths(obj, optics::NTuple{$N,Any}) =
map(getall(obj, last(optics))) do o
getall_lengths(o, Base.front(optics))
end
end
end

_val(N::Int) = N
_val(::Val{N}) where {N} = N

nestedsum(ls::Union{Int,Val}) = _val(ls)
nestedsum(ls::Tuple) = sum(nestedsum, ls; init=0)

# to_nested_shape() definition uses both @eval and @generated
#
# @eval is needed because the code for different recursion depths should be different for inference,
# not the same method with different parameters.
#
# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible.
#
# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example.
# That's why it's safe to make it @generated.
to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs)
_valadd(::Val{N}, ::Val{M}) where {N,M} = Val(N+M)
_valadd(n, m) = _val(n) + _val(m)

# nestedsum(): compute the sum of all values in a nested tuple/vector of int/val(int)
nestedsum(ls::Union{Int,Val}) = ls
nestedsum(ls::Tuple) = _valadd(nestedsum(first(ls)), nestedsum(Base.tail(ls)))
nestedsum(ls::Tuple{}) = Val(0)
nestedsum(ls::Vector) = sum(_val ∘ nestedsum, ls)

# splitelems() - split values provided to setall() into two parts: the first N elements, and the rest
# should always be type-stable
# if more collections should be supported, maybe add a fallback method that materializes to vectors; but is it actually needed?
splitelems(vs::NTuple{M,Any}, ::Val{N}) where {N,M} =
ntuple(j -> vs[j], Val(N)), ntuple(j -> vs[N+j], Val(M-N))
splitelems(vs::Tuple, N) =
map(i -> vs[i], 1:N), map(i -> vs[i], N+1:length(vs))
# staticarrays can be sliced into compile-time length slices for further efficiency, but this is still type-stable
splitelems(vs::AbstractVector, N) =
(@view vs[1:_val(N)]), (@view vs[_val(N)+1:end])

_sliceview(v::AbstractVector, i::AbstractVector) = view(v, i)
_sliceview(v::Tuple, i::AbstractVector) = collect(Iterators.map(i -> v[i], i)) # should be regular map(), but it exceed the recursion depth heuristic

# to_nested_shape(): convert a flat tuple/vector of values (as provided to setall) into a nested structure of tuples/vectors following the shape (ls)
# shape is always a (nested) tuple or vector with int or val(int) values, it is generated by getall_lengths()
# values can be any collection passed to setall, here we support tuples and abstractvectors
to_nested_shape(vs, LS, ::Val{1}) = (@assert length(vs) == _val(LS); vs)

for i in 2:10
@eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS}
vi = 1
subs = map(LS) do lss
n = nestedsum(lss)
elems = map(vi:vi+n-1) do j
:( vs[$j] )
end
res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) )
vi += n
@eval to_nested_shape(vs, ls::Tuple{}, ::Val{$i}) = ()

@eval function to_nested_shape(vs, ls::Tuple, ::Val{$i})
lss = first(ls)
n = nestedsum(lss)
elems, elemstail = splitelems(vs, n)
reshead = to_nested_shape(elems, lss, $(Val(i - 1)))
restail = to_nested_shape(elemstail, Base.tail(ls), $(Val(i)))
return (reshead, restail...)
end

@eval function to_nested_shape(vs, ls::Vector, ::Val{$i})
vi = Ref(1)
map(ls) do lss
n = nestedsum(lss) |> _val
elems = _sliceview(vs, vi[]:vi[]+n-1)
res = to_nested_shape(elems, lss, $(Val(i - 1)))
vi[] += n
res
end
total_n = nestedsum(LS)
quote
length(vs) == $total_n || throw(DimensionMismatch("tried to assign $(length(vs)) elements to $($total_n) destinations"))
($(subs...),)
end
end
end
24 changes: 18 additions & 6 deletions test/test_getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,32 @@ end
@test setall(obj.c[1], Elements(), (5,)) === SVector(5)
@test setall(obj.c[1], Elements(), [5, 6]) === SVector(5, 6)
@test setall(obj.c[1], Elements(), [5]) === SVector(5)
@testset for o in (
@testset for (i,o) in (
(@optic _.c |> Elements() |> Elements()),
(@optic _.c |> Elements() |> Elements() |> _ + 1),
)
@test setall(obj, o, getall(obj, o)) === obj
) |> enumerate
@test (@inferred setall(obj, o, getall(obj, o))) === obj
@test setall(obj, o, collect(getall(obj, o))) === obj
if VERSION ≥ v"1.10" || i == 2
@test (@inferred setall(obj, o, Vector{Float64}(collect(getall(obj, o))))) == obj
@test (@inferred setall(obj, o, SVector(getall(obj, o)))) == obj
else
@test setall(obj, o, Vector{Float64}(collect(getall(obj, o)))) == obj
@test setall(obj, o, SVector(getall(obj, o))) == obj
end
end

obj = ([1, 2], 3:5, (6,))
@test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
# can this infer?..
@test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)

@test obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
@test obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), ntuple(identity, 6))
@test obj == @inferred setall(obj, @optic(_ |> identity |> Elements() |> Elements()), ntuple(identity, 6))
@test obj[1] == @inferred setall(obj[1], @optic(_ |> Elements() |> _ + 1), (2, 3))
# impossible to infer:
@test_broken ([1, 2], [3.0, 4.0, 5.0], ("6",)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), (1, 2, 3., 4., 5., "6"))
end

@testset "getall/setall consistency" begin
Expand Down