Skip to content

Commit 35cd234

Browse files
committed
fix examples
1 parent e8099a1 commit 35cd234

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

docs/src/examples/custom-relu.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ function ChainRulesCore.rrule(::typeof(matrix_relu), y::Matrix{T}) where {T}
4242
function pullback_matrix_relu(dl_dx)
4343
## some value from the backpropagation (e.g., loss) is denoted by `l`
4444
## so `dl_dy` is the derivative of `l` wrt `y`
45-
x = model[:x] # load decision variable `x` into scope
46-
dl_dy = zeros(T, size(dl_dx))
47-
dl_dq = zeros(T, size(dl_dx))
45+
x = model[:x]::Matrix{JuMP.VariableRef} # load decision variable `x` into scope
46+
dl_dy = zeros(T, size(x))
47+
dl_dq = zeros(T, size(x))
4848
## set sensitivities
4949
MOI.set.(model, DiffOpt.ReverseVariablePrimal(), x[:], dl_dx[:])
5050
## compute grad

docs/src/examples/polyhedral_project.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ function ChainRulesCore.rrule(
7575
model = direct_model(DiffOpt.diff_optimizer(Ipopt.Optimizer))
7676
xv = polytope(y; model = model)
7777
function pullback_matrix_projection(dl_dx)
78-
layer_size, batch_size = size(dl_dx)
7978
dl_dx = ChainRulesCore.unthunk(dl_dx)
8079
## `dl_dy` is the derivative of `l` wrt `y`
81-
x = model[:x]
80+
x = model[:x]::Matrix{JuMP.VariableRef}
81+
layer_size, batch_size = size(x)
8282
## grad wrt input parameters
83-
dl_dy = zeros(size(dl_dx))
83+
dl_dy = zeros(size(x))
8484
## grad wrt layer parameters
8585
dl_dw = zero.(polytope.w)
8686
dl_db = zero(polytope.b)

0 commit comments

Comments
 (0)