From 9e759cfd65d77ef66983fda78f905bc1c43661b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Riedemann?= <38795484+longemen3000@users.noreply.github.com> Date: Tue, 13 May 2025 22:55:48 -0400 Subject: [PATCH 1/4] improve performance of hessians with static arrays --- ext/ForwardDiffStaticArraysExt.jl | 10 ++++++++++ test/HessianTest.jl | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/ext/ForwardDiffStaticArraysExt.jl b/ext/ForwardDiffStaticArraysExt.jl index 63f841db..8226aea3 100644 --- a/ext/ForwardDiffStaticArraysExt.jl +++ b/ext/ForwardDiffStaticArraysExt.jl @@ -81,6 +81,16 @@ end end end +@generated function extract_jacobian(::Type{T}, ydual::Partials{M}, x::S) where {M, T, S<:StaticArray} + N = length(x) + result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...) + return quote + $(Expr(:meta, :inline)) + V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N)) + return V($result) + end +end + @inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) return extract_jacobian(T, static_dual_eval(T, f, x), x) diff --git a/test/HessianTest.jl b/test/HessianTest.jl index 4c667e5e..b826c2ff 100644 --- a/test/HessianTest.jl +++ b/test/HessianTest.jl @@ -163,4 +163,12 @@ end @test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) ≈ [2 6 10; 6 10 14; 10 14 18] end +@testset "allocation-free hessian with StaticArrays" begin + #https://github.com/JuliaDiff/ForwardDiff.jl/issues/720 + g = r -> (r[1]^2 - 3) * (r[2]^2 - 2) + x = SA_F32[0.5, 2.7] + hres = DiffResults.HessianResult(x) + @test @allocated(ForwardDiff.hessian!(hres, g, x))) == 0 +end + end # module From c715da36834fa992a621fd0fa774e65b48a55959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Riedemann?= <38795484+longemen3000@users.noreply.github.com> Date: Tue, 13 May 2025 23:02:53 -0400 Subject: [PATCH 2/4] ForwardDiffStaticArraysExt.jl: import Partials --- ext/ForwardDiffStaticArraysExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ForwardDiffStaticArraysExt.jl b/ext/ForwardDiffStaticArraysExt.jl index 8226aea3..fe65eb33 100644 --- a/ext/ForwardDiffStaticArraysExt.jl +++ b/ext/ForwardDiffStaticArraysExt.jl @@ -3,7 +3,7 @@ module ForwardDiffStaticArraysExt using ForwardDiff, StaticArrays using ForwardDiff.LinearAlgebra using ForwardDiff.DiffResults -using ForwardDiff: Dual, partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk, +using ForwardDiff: Dual, partials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk, gradient, hessian, jacobian, gradient!, hessian!, jacobian!, extract_gradient!, extract_jacobian!, extract_value!, vector_mode_gradient, vector_mode_gradient!, From dd44ec49391bf638c19ee6a4e5ff3c5cb6c6fa9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Riedemann?= <38795484+longemen3000@users.noreply.github.com> Date: Tue, 13 May 2025 23:11:11 -0400 Subject: [PATCH 3/4] Update HessianTest.jl --- test/HessianTest.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/HessianTest.jl b/test/HessianTest.jl index b826c2ff..051377fe 100644 --- a/test/HessianTest.jl +++ b/test/HessianTest.jl @@ -168,7 +168,7 @@ end g = r -> (r[1]^2 - 3) * (r[2]^2 - 2) x = SA_F32[0.5, 2.7] hres = DiffResults.HessianResult(x) - @test @allocated(ForwardDiff.hessian!(hres, g, x))) == 0 + @test @allocated(ForwardDiff.hessian!(hres, g, x)) == 0 end end # module From ab86fe861561caad3f288c8eeebe8d78a81021d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Riedemann?= <38795484+longemen3000@users.noreply.github.com> Date: Tue, 13 May 2025 23:21:11 -0400 Subject: [PATCH 4/4] Update HessianTest.jl --- test/HessianTest.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/HessianTest.jl b/test/HessianTest.jl index 051377fe..9f12f91c 100644 --- a/test/HessianTest.jl +++ b/test/HessianTest.jl @@ -168,6 +168,7 @@ end g = r -> (r[1]^2 - 3) * (r[2]^2 - 2) x = SA_F32[0.5, 2.7] hres = DiffResults.HessianResult(x) + ForwardDiff.hessian!(hres, g, x) @test @allocated(ForwardDiff.hessian!(hres, g, x)) == 0 end