Skip to content

Commit eeb4c86

Browse files
authored
feat: use Mooncake's copy utilities (#809)
* feat: use Mooncake's copy utilities * Fix compat * Adapt to latest Mooncake
1 parent 0124a0e commit eeb4c86

File tree

5 files changed

+17
-12
lines changed

5 files changed

+17
-12
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
## [0.7.1]
1111

12+
### Feat
13+
14+
- Use Mooncake's internal copy utilities ([#809])
15+
1216
### Fixed
1317

1418
- Take `absstep` into account for FiniteDiff ([#812])
@@ -42,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4246

4347
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812
4448
[#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810
49+
[#809]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/809
4550
[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799
4651
[#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795
4752
[#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ JET = "0.9"
7474
JLArrays = "0.2.0"
7575
JuliaFormatter = "1,2"
7676
LinearAlgebra = "1"
77-
Mooncake = "0.4.88"
77+
Mooncake = "0.4.122"
7878
Pkg = "1"
7979
PolyesterForwardDiff = "0.1.2"
8080
Random = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@ using Mooncake:
1010
tangent_type,
1111
value_and_gradient!!,
1212
value_and_pullback!!,
13-
zero_tangent
13+
zero_tangent,
14+
_copy_output,
15+
_copy_to_output!!
1416

1517
DI.check_available(::AutoMooncake) = true
1618

17-
copyto!!(dst::Number, src::Number) = convert(typeof(dst), src)
18-
copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeof(dst), src)
19-
2019
get_config(::AutoMooncake{Nothing}) = Config()
2120
get_config(backend::AutoMooncake{<:Config}) = backend.config
2221

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ function DI.value_and_pullback(
3030
) where {F,Y,C}
3131
DI.check_prep(f, prep, backend, x, ty, contexts...)
3232
dy = only(ty)
33-
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
33+
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
3434
new_y, (_, new_dx) = value_and_pullback!!(
3535
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
3636
)
37-
return new_y, (mycopy(new_dx),)
37+
return new_y, (_copy_output(new_dx),)
3838
end
3939

4040
function DI.value_and_pullback(
@@ -47,11 +47,12 @@ function DI.value_and_pullback(
4747
) where {F,Y,C}
4848
DI.check_prep(f, prep, backend, x, ty, contexts...)
4949
ys_and_tx = map(ty) do dy
50-
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
50+
dy_righttype =
51+
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
5152
y, (_, new_dx) = value_and_pullback!!(
5253
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
5354
)
54-
y, mycopy(new_dx)
55+
y, _copy_output(new_dx)
5556
end
5657
y = first(ys_and_tx[1])
5758
tx = last.(ys_and_tx)
@@ -126,7 +127,7 @@ function DI.value_and_gradient(
126127
) where {F,C}
127128
DI.check_prep(f, prep, backend, x, contexts...)
128129
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
129-
return y, mycopy(new_grad)
130+
return y, _copy_output(new_grad)
130131
end
131132

132133
function DI.value_and_gradient!(

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function DI.value_and_pullback(
5858
map(DI.unwrap, contexts)...,
5959
)
6060
copyto!(y, y_after)
61-
return y, (mycopy(dx),)
61+
return y, (_copy_output(dx),)
6262
end
6363

6464
function DI.value_and_pullback(
@@ -83,7 +83,7 @@ function DI.value_and_pullback(
8383
map(DI.unwrap, contexts)...,
8484
)
8585
copyto!(y, y_after)
86-
mycopy(dx)
86+
_copy_output(dx)
8787
end
8888
return y, tx
8989
end

0 commit comments

Comments
 (0)