You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Forgive me if this is a known issue, or if it is desired behaviour! And thank you for the great work you do.
I get a MethodError on the the increment_deriv!() call when running the following model with ReverseDiff (and not with other Turing-supported backends). The example is not as minimal as it could be, which I hope you will forgive me for - if you need it to be simpler let me know, and I can try to construct one.
The following is an implementation of the PVL-delta model for the Iowa Gambling task. What is relevant is just that there is an "expected_value" which is dynamically updated based on incoming "inputs", and which then controls probabilities of selecting one out of four "actions". The model has four parameters, the interpretation of which is not so relevant here.
If the expected_value vector is a Vector{Real}, the error happens; if it is changed to a Vector{Float64}, all is fine.
I am a bit uncertain what causes the error to happen. As you can see, there is a fairly simple workaround, but I thought I would make you aware of the error in any case.
Let me know if there is anything else I can do!
using Turing, LogExpFunctions
import ReverseDiff
inputs = [
(3, 50.0),
(3, 60.0),
(3, -10.0),
(4, 50.0),
(3, 55.0),
(1, 100.0),
]
actions = [3, 3, 3, 4, 3, 1]
mutable struct PVLState
expected_value::Vector
end
@model function pvl_model(inputs, actions)
#Sample parameters
α ~ LogitNormal(0, 1) #learning rate
A ~ LogitNormal(0, 1) #reward sensitivity
β ~ LogNormal(0, 1) #inverse temperature
ω ~ LogNormal(0, 1) #loss aversion
parameters = (α, A, β, ω)
expected_value = zeros(Real, 4) #If zeros(Float64, 4) is used instead, it works
state = PVLState(expected_value)
for (input, action) in zip(inputs, actions)
i ~ to_submodel(single_trial(parameters, state, input, action))
end
end
@model function single_trial(parameters, state, input, action)
action_probability = pvl_delta(parameters, state, input)
action ~ action_probability
end
# PVL-Delta
function pvl_delta(parameters, state, input)
#Split input
deck, reward = input
#Split parameters
α, A, β, ω = parameters
#Get expected value
expected_value = state.expected_value
#Transform expected values to parameters
action_probabilities = softmax(expected_value * β)
#Update expected values for next trial
if reward >= 0
prediction_error = (reward^A) - expected_value[deck]
else
prediction_error = -ω * (abs(reward)^A) - expected_value[deck]
end
new_expected_value = [
expected_value[deck_idx] + α * prediction_error * (deck == deck_idx) for
deck_idx = 1:4
]
state.expected_value = new_expected_value
return Categorical(action_probabilities)
end
model = pvl_model(inputs, actions)
sample(model, NUTS(; adtype = AutoForwardDiff()), 1000)
sample(model, NUTS(; adtype = AutoReverseDiff()), 1000)
The text was updated successfully, but these errors were encountered:
Dear ReverseDiff,
Forgive me if this is a known issue, or if it is desired behaviour! And thank you for the great work you do.
I get a MethodError on the the increment_deriv!() call when running the following model with ReverseDiff (and not with other Turing-supported backends). The example is not as minimal as it could be, which I hope you will forgive me for - if you need it to be simpler let me know, and I can try to construct one.
The following is an implementation of the PVL-delta model for the Iowa Gambling task. What is relevant is just that there is an "expected_value" which is dynamically updated based on incoming "inputs", and which then controls probabilities of selecting one out of four "actions". The model has four parameters, the interpretation of which is not so relevant here.
If the expected_value vector is a Vector{Real}, the error happens; if it is changed to a Vector{Float64}, all is fine.
I am a bit uncertain what causes the error to happen. As you can see, there is a fairly simple workaround, but I thought I would make you aware of the error in any case.
Let me know if there is anything else I can do!
The text was updated successfully, but these errors were encountered: