Skip to content

Commit

Permalink
Merge pull request #46 from rveltz/syntax
Browse files Browse the repository at this point in the history
update for new diffeq syntax
  • Loading branch information
rveltz authored Jan 24, 2018
2 parents 3c05d38 + 6c74dac commit d97a23d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ julia 0.6
BinDeps 0.4.3
Compat 0.17.0
Parameters 0.5.0
DiffEqBase 1.5.1
DiffEqBase 3.0.0
28 changes: 20 additions & 8 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
## Common Interface Solve Functions

type CommonFunction{F,P}
func::F
p::P
neq::Cint
end

function commonfun(t::T1,y::T2,yp::T3,comfun::CommonFunction) where {T1,T2,T3}
y_ = unsafe_wrap(Array,y,comfun.neq)
ydot_ = unsafe_wrap(Array,yp,comfun.neq)
comfun.func(ydot_,y_,comfun.p,t)
return Int32(0)
end

function solve{uType,tType,isinplace}(
prob::AbstractODEProblem{uType,tType,isinplace},
alg::LSODAAlgorithm,
Expand Down Expand Up @@ -84,14 +97,13 @@ function solve{uType,tType,isinplace}(

### Fix the more general function to Sundials allowed style
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
f! = (t,u,du,userdata) -> (du[:] = prob.f(t,u); nothing)
f! = (du,u,p,t) -> (du[:] = prob.f(u,p,t); nothing)
elseif !isinplace && typeof(prob.u0)<:AbstractArray
f! = (t,u,du,userdata) -> (du[:] = vec(prob.f(t,reshape(u,sizeu))); nothing)
f! = (du,u,p,t) -> (du[:] = vec(prob.f(reshape(u,sizeu),p,t)); nothing)
elseif typeof(prob.u0)<:Vector{Float64}
f! = (t,u,du,userdata) -> prob.f(t,u,du)
f! = prob.f
else # Then it's an in-place function on an abstract array
f! = (t,u,du,userdata) -> (prob.f(t,reshape(u,sizeu),reshape(du,sizeu));
u = vec(u); du=vec(du); nothing)
f! = (du,u,p,t) -> (prob.f(reshape(du,sizeu),reshape(u,sizeu),p,t); nothing)
end

ures = Vector{Vector{Float64}}()
Expand All @@ -104,7 +116,7 @@ function solve{uType,tType,isinplace}(
save_start ? ts = [t0] : ts = Vector{typeof(t0)}(0)

neq = Int32(length(u0))
userfun = UserFunctionAndData(f!, userdata,neq)
comfun = CommonFunction(f!,prob.p,neq)

atol = ones(Float64,neq)
rtol = ones(Float64,neq)
Expand Down Expand Up @@ -132,13 +144,13 @@ function solve{uType,tType,isinplace}(
end
opt.itask = itask_tmp

const fex_c = cfunction(lsodafun,Cint,(Cdouble,Ptr{Cdouble},Ptr{Cdouble},Ref{typeof(userfun)}))
const fex_c = cfunction(commonfun,Cint,(Cdouble,Ptr{Cdouble},Ptr{Cdouble},Ref{typeof(comfun)}))

ctx = lsoda_context_t()
ctx.function_ = fex_c
ctx.neq = neq
ctx.state = 1
ctx.data = pointer_from_objref(userfun)
ctx.data = pointer_from_objref(comfun)
ch = ContextHandle(ctx)

lsoda_prepare(ctx,opt)
Expand Down
4 changes: 2 additions & 2 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function lsoda(f::Function, y0::Vector{Float64}, tspan::Vector{Float64}; userdat
@assert (ctx_ptr.state >0) string("LSODA error istate = ", ctx_ptr.state, ", error = ",unsafe_string(ctx_ptr.error))
yres[k,:] = copy(y)
end

lsoda_free(ctx_ptr)
return yres
end
Expand All @@ -127,7 +127,7 @@ function lsoda_evolve!(ctx::lsoda_context_t,y::Vector{Float64},tspan::Vector{Flo
# # ctx.data.data = userdata
# # unsafe_pointer_to_objref(ctx.data).data = userdata
# end

t = Array{Float64}(1)
tout = Array{Float64}(1)
t[1] = tspan[1]
Expand Down

0 comments on commit d97a23d

Please sign in to comment.