Skip to content

Commit

Permalink
Dev (#309)
Browse files Browse the repository at this point in the history
* update Project.toml

* s/logpdf/logdensity/

* update deps

* scratchpad

* version bump

* representative => rootmeasure

* update dependency versions

* reduce dependencies

* update deps

* scratchpad

* representative => rootmeasure

* update dependency versions

* reduce dependencies

* `as` methods for `xform`

* cleanup

* require latest MeasureTheory

* dorp old distributions code

* drop old iid code

* drop extra space

* limit deps to three newest releases

* update dynamichmc

* add Aqua

* bump version

* Better `predict` method

* withmeasures(::ConditionalModel)

* update dependencies

* updating symbolics

* some updates to symbolics

* hmm example

* update MeasureBase bound

* better dispatch for `predict`

* drop redundant method

* remove whitespace

* bump version
  • Loading branch information
cscherrer authored Nov 3, 2021
1 parent cd93be7 commit 36bc868
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 12 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Soss"
uuid = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
author = ["Chad Scherrer <chad.scherrer@gmail.com>"]
version = "0.20.8"
version = "0.20.9"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -49,7 +49,7 @@ JuliaVariables = "0.2"
MLStyle = "0.3,0.4"
MacroTools = "0.5"
MappedArrays = "0.3, 0.4"
MeasureBase = "0.4"
MeasureBase = "0.5"
MeasureTheory = "0.13"
NamedTupleTools = "0.12, 0.13"
NestedTuples = "0.3"
Expand All @@ -65,7 +65,7 @@ SpecialFunctions = "0.9, 0.10, 1"
StatsBase = "0.33"
StatsFuns = "0.9"
SymbolicCodegen = "0.2"
SymbolicUtils = "0.14, 0.15, 0.16"
SymbolicUtils = "0.15, 0.16, 0.17"
TransformVariables = "0.4"
TupleVectors = "0.1"
julia = "1.5"
Expand Down
48 changes: 48 additions & 0 deletions scratchpad/hmm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using MeasureTheory
using Base.Iterators
using Statistics

using Random
rng = Random.Xoshiro(3)

x = Chain(Normal()) do xj Normal=xj) end
xobs = rand(rng, x)
y = For(xobs) do xj Poisson(logλ=xj) end
yobs = rand(rng, y)
xv = take(xobs, 10) |> collect
yv = take(yobs, 10) |> collect

take(xobs.parent, 10) |> collect
take(yobs.parent, 10) |> collect


exp.(xv)
yv

# using Plots

# plt = scatter(normcdf.(xv, 1, yv), label=false)

# for j in 1:10
# xobs = rand(rng, x)
# yobs = rand(rng, y)
# xv = take(xobs, 100) |> collect;
# yv = take(yobs, 100) |> collect;
# plt = scatter!(plt, normcdf.(xv, 1, yv), label=false)
# end
# plt

using Soss

m = @model begin
x ~ Chain(Normal()) do xj Normal=xj) end
y ~ For(xobs) do xj Poisson(logλ=xj) end
end

truth = rand(rng, m())

xobs = take(truth.x, 10) |> collect
yobs = take(truth.y, 10) |> collect

logdensity(m(), (x=xobs, y=yobs))

1 change: 1 addition & 0 deletions src/symbolic/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function SymbolicCodegen.codegen(cm :: ConditionalModel; kwargs...)
pushfirst!(code.args, :($v = getproperty(_pars, $vname)))
end

code = MacroTools.flatten(code)

return mk_function(getmodule(cm), (:_args, :_data, :_pars), (), code)

Expand Down
8 changes: 4 additions & 4 deletions src/symbolic/symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ export schema

export symlogdensity

symlogdensity(d, x::Symbolic) = logdensity(d,x)
symlogdensity(d, x::Symbolic) = logpdf(d,x)

function symlogdensity(d::ProductMeasure{<:AbstractArray}, x::Symbolic{A}) where {A <: AbstractArray}
function symlogdensity(d::ProductMeasure{F,S,<:AbstractArray}, x::Symbolic{A}) where {F,S,A <: AbstractArray}
dims = size(d)

iters = Sym{Int}.(gensym.(Symbol.(:i, 1:length(dims))))

marginals = d.data
mar = marginals(d)

# To begin, the result is just the summand
result = getsummand(marginals, x, iters)
result = getsummand(mar, x, iters)

# Then we wrap in a summation index for each dimension
for i in 1:length(dims)
Expand Down
10 changes: 5 additions & 5 deletions src/transforms/predict.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
export predict
using TupleVectors
using SampleChains

predict(m::AbstractModel, args...) = predict(Random.GLOBAL_RNG, m, args...)
predict(d::AbstractMeasure, x) = x
predict(args...; kwargs...) = predict(Random.GLOBAL_RNG, args...; kwargs...)

# TODO: Fix this hack
predict(d::AbstractMeasure, x) = x
predict(d::Dists.Distribution, x) = x
predict(d::AbstractModel, args...; kwargs...) = predict(Random.GLOBAL_RNG, d, args...; kwargs...)

@inline function predict(rng::AbstractRNG, m::AbstractModel, nt::NamedTuple{N}) where {N}
pred = predictive(Model(m), N...)
Expand All @@ -13,7 +16,6 @@ end

predict(rng::AbstractRNG, m::AbstractModel; kwargs...) = predict(rng, m, (;kwargs...))


@inline function predict(rng::AbstractRNG, d::AbstractModel, nt::LazyMerge)
predict(rng, d, convert(NamedTuple, nt))
end
Expand All @@ -33,8 +35,6 @@ function predict(rng::AbstractRNG, d::AbstractModel, post::AbstractVector{<:Name
v
end

using SampleChains

function predict(rng::AbstractRNG, d::ConditionalModel, post::MultiChain)
[predict(rng, d, c) for c in getchains(post)]
end

2 comments on commit 36bc868

@cscherrer
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/48084

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.20.9 -m "<description of version>" 36bc8681988d2e6a7938dc89597039dd621a04b6
git push origin v0.20.9

Please sign in to comment.