diff --git a/REQUIRE b/REQUIRE index 98002ef..1558cb1 100644 --- a/REQUIRE +++ b/REQUIRE @@ -2,4 +2,4 @@ julia 0.6 BinDeps 0.4.3 Compat 0.17.0 Parameters 0.5.0 -DiffEqBase 1.5.1 +DiffEqBase 3.0.0 diff --git a/src/common.jl b/src/common.jl index f211e24..9c2b027 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,5 +1,18 @@ ## Common Interface Solve Functions +type CommonFunction{F,P} + func::F + p::P + neq::Cint +end + +function commonfun(t::T1,y::T2,yp::T3,comfun::CommonFunction) where {T1,T2,T3} + y_ = unsafe_wrap(Array,y,comfun.neq) + ydot_ = unsafe_wrap(Array,yp,comfun.neq) + comfun.func(ydot_,y_,comfun.p,t) + return Int32(0) +end + function solve{uType,tType,isinplace}( prob::AbstractODEProblem{uType,tType,isinplace}, alg::LSODAAlgorithm, @@ -84,14 +97,13 @@ function solve{uType,tType,isinplace}( ### Fix the more general function to Sundials allowed style if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number) - f! = (t,u,du,userdata) -> (du[:] = prob.f(t,u); nothing) + f! = (du,u,p,t) -> (du[:] = prob.f(u,p,t); nothing) elseif !isinplace && typeof(prob.u0)<:AbstractArray - f! = (t,u,du,userdata) -> (du[:] = vec(prob.f(t,reshape(u,sizeu))); nothing) + f! = (du,u,p,t) -> (du[:] = vec(prob.f(reshape(u,sizeu),p,t)); nothing) elseif typeof(prob.u0)<:Vector{Float64} - f! = (t,u,du,userdata) -> prob.f(t,u,du) + f! = prob.f else # Then it's an in-place function on an abstract array - f! = (t,u,du,userdata) -> (prob.f(t,reshape(u,sizeu),reshape(du,sizeu)); - u = vec(u); du=vec(du); nothing) + f! = (du,u,p,t) -> (prob.f(reshape(du,sizeu),reshape(u,sizeu),p,t); nothing) end ures = Vector{Vector{Float64}}() @@ -104,7 +116,7 @@ function solve{uType,tType,isinplace}( save_start ? ts = [t0] : ts = Vector{typeof(t0)}(0) neq = Int32(length(u0)) - userfun = UserFunctionAndData(f!, userdata,neq) + comfun = CommonFunction(f!,prob.p,neq) atol = ones(Float64,neq) rtol = ones(Float64,neq) @@ -132,13 +144,13 @@ function solve{uType,tType,isinplace}( end opt.itask = itask_tmp - const fex_c = cfunction(lsodafun,Cint,(Cdouble,Ptr{Cdouble},Ptr{Cdouble},Ref{typeof(userfun)})) + const fex_c = cfunction(commonfun,Cint,(Cdouble,Ptr{Cdouble},Ptr{Cdouble},Ref{typeof(comfun)})) ctx = lsoda_context_t() ctx.function_ = fex_c ctx.neq = neq ctx.state = 1 - ctx.data = pointer_from_objref(userfun) + ctx.data = pointer_from_objref(comfun) ch = ContextHandle(ctx) lsoda_prepare(ctx,opt) diff --git a/src/solver.jl b/src/solver.jl index fbd04dd..3356b9b 100644 --- a/src/solver.jl +++ b/src/solver.jl @@ -110,7 +110,7 @@ function lsoda(f::Function, y0::Vector{Float64}, tspan::Vector{Float64}; userdat @assert (ctx_ptr.state >0) string("LSODA error istate = ", ctx_ptr.state, ", error = ",unsafe_string(ctx_ptr.error)) yres[k,:] = copy(y) end - + lsoda_free(ctx_ptr) return yres end @@ -127,7 +127,7 @@ function lsoda_evolve!(ctx::lsoda_context_t,y::Vector{Float64},tspan::Vector{Flo # # ctx.data.data = userdata # # unsafe_pointer_to_objref(ctx.data).data = userdata # end - + t = Array{Float64}(1) tout = Array{Float64}(1) t[1] = tspan[1]