From 62feb5869443a554a830e6e76a72a5cc9da164eb Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Mon, 18 Nov 2024 13:15:07 -0700 Subject: [PATCH] WIP moshi --- Project.toml | 4 +- src/SymbolicUtils.jl | 3 +- src/code.jl | 4 +- src/matchers.jl | 18 ++- src/methods.jl | 2 +- src/ordering.jl | 7 +- src/polyform.jl | 1 + src/substitute.jl | 2 + src/types.jl | 350 +++++++++++++++++++++++++++---------------- src/utils.jl | 19 ++- test/basics.jl | 6 +- test/rewrite.jl | 12 +- test/runtests.jl | 1 + test/types.jl | 113 ++++++++++++++ 14 files changed, 390 insertions(+), 152 deletions(-) create mode 100644 test/types.jl diff --git a/Project.toml b/Project.toml index ec3162512..850126888 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -25,7 +26,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1" [weakdeps] @@ -48,6 +48,7 @@ DocStringExtensions = "0.8, 0.9" DynamicPolynomials = "0.5, 0.6" IfElse = "0.1" LabelledArrays = "1.5" +Moshi = "0.3.5" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1" ReverseDiff = "1" @@ -57,7 +58,6 @@ StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" TermInterface = "2.0" TimerOutputs = "0.5" -Unityper = "0.1.2" WeakValueDicts = "0.1.0" julia = "1.3" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index fb13f50b4..20d54df7b 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -7,7 +7,8 @@ using DocStringExtensions export @syms, term, showraw, hasmetadata, getmetadata, setmetadata -using Unityper +using Moshi.Data: @data, data_type_name, variant_name +using Moshi.Match: @match using TermInterface using DataStructures using Setfield diff --git a/src/code.jl b/src/code.jl index 4128a39fd..6ddfbd4b2 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, sorted_arguments, metadata, isterm, term, maketerm + isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm import SymbolicIndexingInterface: symbolic_type, NotSymbolic ##== state management ==## @@ -182,6 +182,8 @@ function toexpr(O, st) if issym(O) O = substitute_name(O, st) return issym(O) ? nameof(O) : toexpr(O, st) + elseif isconst(O) + return toexpr(O.val, st) end O = substitute_name(O, st) diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea537..edf7a0484 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,9 +6,23 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # function matcher(val::Any) - iscall(val) && return term_matcher(val) + if isconst(val) + slot = val.val + return matcher(slot) + elseif iscall(val) + return term_matcher(val) + end function literal_matcher(next, data, bindings) - islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing + if islist(data) + cd = car(data) + if isconst(cd) + cd = cd.val + end + if isequal(cd, val) + return next(bindings, 1) + end + end + nothing end end diff --git a/src/methods.jl b/src/methods.jl index 2baef6424..42480ac4a 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -188,7 +188,7 @@ end for f in [!, ~] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool - (::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s]) + (::$(typeof(f)))(s::Symbolic{Bool}) = isconst(s) ? !s.val : Term{Bool}(!, [s]) end end diff --git a/src/ordering.jl b/src/ordering.jl index 332f11cf8..2279ce7a7 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -27,7 +27,7 @@ function get_degrees(expr) elseif iscall(expr) op = operation(expr) args = sorted_arguments(expr) - if op == (^) && args[2] isa Number + if op == (^) && (args[2] isa Number || (isconst(args[2]) && args[2].val isa Number)) return map(get_degrees(args[1])) do (base, pow) (base => pow * args[2]) end @@ -79,12 +79,15 @@ function <ₑ(a::Tuple, b::Tuple) end function <ₑ(a::BasicSymbolic, b::BasicSymbolic) + isconst(a) && isconst(b) && return a.val <ₑ b.val + isconst(a) && return a.val <ₑ b + isconst(b) && return a <ₑ b.val da, db = get_degrees(a), get_degrees(b) fw = monomial_lt(da, db) bw = monomial_lt(db, da) if fw === bw && !isequal(a, b) if _arglen(a) == _arglen(b) - return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,) + return (operation(a), arguments(a)...) <ₑ (operation(b), arguments(b)...) else return _arglen(a) < _arglen(b) end diff --git a/src/polyform.jl b/src/polyform.jl index 7d6bc906e..7c741dc68 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -95,6 +95,7 @@ end _isone(p::PolyForm) = isone(p.p) function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) + x = isconst(x) ? x.val : x if x isa Number return x elseif iscall(x) diff --git a/src/substitute.jl b/src/substitute.jl index 828f88b14..4548f7c29 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -22,6 +22,7 @@ function substitute(expr, dict; fold=true) canfold = !(op isa Symbolic) args = map(arguments(expr)) do x x′ = substitute(x, dict; fold=fold) + x′ = isconst(x) ? x′.val : x′ canfold = canfold && !(x′ isa Symbolic) x′ end @@ -54,6 +55,7 @@ function _occursin(needle, haystack) if iscall(haystack) args = arguments(haystack) for arg in args + arg = isconst(arg) ? arg.val : arg if needle isa Integer || needle isa AbstractFloat isequal(needle, arg) && return true else diff --git a/src/types.jl b/src/types.jl index 898259f44..2377766a1 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,79 +1,82 @@ -#------------------- -#-------------------- -#### Symbolic -#-------------------- abstract type Symbolic{T} end -### -### Uni-type design -### - -@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV +@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV CONST const Metadata = Union{Nothing,Base.ImmutableDict{DataType,Any}} const NO_METADATA = nothing -sdict(kv...) = Dict{Any, Any}(kv...) - using Base: RefValue -const EMPTY_ARGS = [] const EMPTY_HASH = RefValue(UInt(0)) -const NOT_SORTED = RefValue(false) -const EMPTY_DICT = sdict() -const EMPTY_DICT_T = typeof(EMPTY_DICT) -@compactify show_methods=false begin - @abstract mutable struct BasicSymbolic{T} <: Symbolic{T} - metadata::Metadata = NO_METADATA - end - mutable struct Sym{T} <: BasicSymbolic{T} - name::Symbol = :OOF - end - mutable struct Term{T} <: BasicSymbolic{T} - f::Any = identity # base/num if Pow; issorted if Add/Dict - arguments::Vector{Any} = EMPTY_ARGS - hash::RefValue{UInt} = EMPTY_HASH - end - mutable struct Mul{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED - end - mutable struct Add{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED - end - mutable struct Div{T} <: BasicSymbolic{T} - num::Any = 1 - den::Any = 1 - simplified::Bool = false - arguments::Vector{Any} = EMPTY_ARGS - end - mutable struct Pow{T} <: BasicSymbolic{T} - base::Any = 1 - exp::Any = 1 - arguments::Vector{Any} = EMPTY_ARGS - end -end +# TODO: Actually close the type system by making everything hold only BasicSymbolicExpr except Const +@data BasicSymbolicExpr{T} <: Symbolic{T} begin + struct Sym + metadata::Metadata = NO_METADATA + name::Symbol = :OOF + end + struct Term + metadata::Metadata = NO_METADATA + f::Any = identity + arguments::Vector{Symbolic} = Symbolic[] + hash::RefValue{UInt} = EMPTY_HASH + end + struct Mul + metadata::Metadata = NO_METADATA + coeff::Any = 0 + dict::Dict{BasicSymbolicExpr.Type, Any} = Dict{BasicSymbolicExpr.Type, Any}() + hash::RefValue{UInt} = EMPTY_HASH + arguments::Vector{Any} = [] + issorted::RefValue{Bool} = RefValue(false) + end + struct Add + metadata::Metadata = NO_METADATA + coeff::Any = 0 + dict::Dict{BasicSymbolicExpr.Type, Any} = Dict{BasicSymbolicExpr.Type, Any}() + hash::RefValue{UInt} = EMPTY_HASH + arguments::Vector{Any} = [] + issorted::RefValue{Bool} = RefValue(false) + end + struct Div + metadata::Metadata = NO_METADATA + num::Any = 1 + den::Any = 1 + simplified::Bool = false + arguments::Vector{Any} = [] + end + struct Pow + metadata::Metadata = NO_METADATA + base::Any = 1 + exp::Any = 1 + arguments::Vector{Any} = [] + end + struct Const + metadata::Metadata = NO_METADATA + val::Any = 1 + end +end + +const BasicSymbolic = BasicSymbolicExpr.Type +const Term = BasicSymbolicExpr.Term +const Sym = BasicSymbolicExpr.Sym +const Add = BasicSymbolicExpr.Add +const Mul = BasicSymbolicExpr.Mul +const Div = BasicSymbolicExpr.Div +const Pow = BasicSymbolicExpr.Pow +const Const = BasicSymbolicExpr.Const function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end function exprtype(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => TERM - Add => ADD - Mul => MUL - Div => DIV - Pow => POW - Sym => SYM - _ => error_on_type() + @match x begin + Term(_) => TERM + Add(_) => ADD + Mul(_) => MUL + Div(_) => DIV + Pow(_) => POW + Sym(_) => SYM + Const(_) => CONST end end @@ -82,6 +85,7 @@ const wvd = WeakValueDict{UInt, BasicSymbolic}() # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") +@noinline error_const() = error("Const doesn't have a operation or arguments!") @noinline error_property(E, s) = error("$E doesn't have field $s") # We can think about bits later @@ -94,10 +98,16 @@ const SIMPLIFIED = 0x01 << 0 function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T nt = getproperties(obj) nt_new = merge(nt, patch) - # Call outer constructor because hash consing cannot be applied in inner constructor - @compactified obj::BasicSymbolic begin - Sym => Sym{T}(nt_new.name; nt_new...) - _ => Unityper.rt_constructor(obj){T}(;nt_new...) + #data_type_name(obj){T}(;nt_new...) + # TODO which to use? + @match obj begin + Sym(_) => Sym{T}(;nt_new...) + Term(_) => Term{T}(;nt_new...) + Add(_) => Add{T}(;nt_new...) + Mul(_) => Mul{T}(;nt_new...) + Div(_) => Div{T}(;nt_new...) + Pow(_) => Pow{T}(;nt_new...) + Const(_) => Const{T}(;nt_new...) end end @@ -120,14 +130,14 @@ symtype(x) = typeof(x) # We're returning a function pointer @inline function operation(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => x.f - Add => (+) - Mul => (*) - Div => (/) - Pow => (^) - Sym => error_sym() - _ => error_on_type() + @match x begin + Term(_) => x.f + Add(_) => (+) + Mul(_) => (*) + Div(_) => (/) + Pow(_) => (^) + Sym(_) => error_sym() + Const(_) => error_const() end end @@ -135,9 +145,9 @@ end function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) - @compactified x::BasicSymbolic begin - Add => @goto ADD - Mul => @goto MUL + @match x begin + Add(_) => @goto ADD + Mul(_) => @goto MUL _ => return args end @label MUL @@ -160,14 +170,14 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) function TermInterface.arguments(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => return x.arguments - Add => @goto ADDMUL - Mul => @goto ADDMUL - Div => @goto DIV - Pow => @goto POW - Sym => error_sym() - _ => error_on_type() + @match x begin + Term(_) => return x.arguments + Add(_) => @goto ADDMUL + Mul(_) => @goto ADDMUL + Div(_) => @goto DIV + Pow(_) => @goto POW + Sym(_) => error_sym() + Const(_) => error_const() end @label ADDMUL @@ -175,7 +185,7 @@ function TermInterface.arguments(x::BasicSymbolic) args = x.arguments isempty(args) || return args siz = length(x.dict) - idcoeff = E === ADD ? iszero(x.coeff) : isone(x.coeff) + idcoeff = E === ADD ? _iszero(x.coeff) : _isone(x.coeff) sizehint!(args, idcoeff ? siz : siz + 1) idcoeff || push!(args, x.coeff) if isadd(x) @@ -207,10 +217,17 @@ function TermInterface.arguments(x::BasicSymbolic) return args end -isexpr(s::BasicSymbolic) = !issym(s) +function isexpr(x::BasicSymbolic) + @match x begin + Sym(_) => false + Const(_) => false + _ => true + end +end + iscall(s::BasicSymbolic) = isexpr(s) -@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false +@inline isa_SymType(S, x) = x isa BasicSymbolic ? variant_name(x) == S : false """ issym(x) @@ -218,12 +235,13 @@ iscall(s::BasicSymbolic) = isexpr(s) Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined on `x` and must return a `Symbol`. """ -issym(x) = isa_SymType(Val(:Sym), x) -isterm(x) = isa_SymType(Val(:Term), x) -ismul(x) = isa_SymType(Val(:Mul), x) -isadd(x) = isa_SymType(Val(:Add), x) -ispow(x) = isa_SymType(Val(:Pow), x) -isdiv(x) = isa_SymType(Val(:Div), x) +issym(x) = isa_SymType(:Sym, x) +isterm(x) = isa_SymType(:Term, x) +ismul(x) = isa_SymType(:Mul, x) +isadd(x) = isa_SymType(:Add, x) +ispow(x) = isa_SymType(:Pow, x) +isdiv(x) = isa_SymType(:Div, x) +isconst(x) = isa_SymType(:Const, x) ### ### Base interface @@ -266,6 +284,8 @@ function _isequal(a, b, E) a1 = arguments(a) a2 = arguments(b) isequal(operation(a), operation(b)) && _allarequal(a1, a2) + elseif E === CONST + isequal(a.val, b.val) else error_on_type() end @@ -303,6 +323,7 @@ const ADD_SALT = 0xaddaddaddaddadda % UInt const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt const DIV_SALT = 0x334b218e73bbba53 % UInt const POW_SALT = 0x2b55b97a6efb080c % UInt +const COS_SALT = 0xdc3d6b8f18b75e3c % UInt function Base.hash(s::BasicSymbolic, salt::UInt)::UInt E = exprtype(s) if E === SYM @@ -328,6 +349,8 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h′ = hashvec(arguments(s), hash(oph, salt)) s.hash[] = h′ return h′ + elseif E === CONST + return hash(s.val, salt ⊻ COS_SALT) else error_on_type() end @@ -375,7 +398,7 @@ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.h `Base.isequal` to accommodate metadata without disrupting existing tests reliant on the original behavior of those functions. """ -function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic +function BasicSymbolicEquivalent(s::BasicSymbolic)::BasicSymbolic h = hash2(s) t = get!(wvd, h, s) if t === s || isequal_with_metadata(t, s) @@ -385,24 +408,42 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic end end +# TODO: figure out how to implement BasicSymbolicEquivalent function Sym{T}(name::Symbol; kw...) where {T} - s = Sym{T}(; name, kw...) - BasicSymbolic(s) + #s = Sym{T}(; name=name, kw...) + #BasicSymbolicEquivalent(s) + Sym{T}(; name=name, kw...) end function Term{T}(f, args; kw...) where T - if eltype(args) !== Any - args = convert(Vector{Any}, args) - end + #if eltype(args) !== Symbolic + # args = convert(Vector{Symbolic}, args) + #end - Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...) + # TODO: revisit convert after https://github.com/Roger-luo/Moshi.jl/issues/32 is resolved + Term{T}(;f=f, arguments=convert(Vector{Any}, args), hash=Ref(UInt(0)), kw...) end function Term(f, args; metadata=NO_METADATA) Term{_promote_symtype(f, args)}(f, args, metadata=metadata) end -function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T +function Const(val::T; kwargs...) where {T} + Const{T}(; val=val, kwargs...) +end + +function Base.convert(::Type{Symbolic}, x) + Const(x) +end + +function Base.convert(::Type{BasicSymbolic}, x) + Const(x) +end +function Base.convert(::Type{BasicSymbolic}, x::BasicSymbolic) + x +end + +function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where {T} if isempty(dict) return coeff elseif _iszero(coeff) && length(dict) == 1 @@ -415,10 +456,12 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T end end - Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + # TODO: revisit convert after https://github.com/Roger-luo/Moshi.jl/issues/32 is resolved + Add{T}(; coeff=coeff, dict=convert(Dict{BasicSymbolic, Any}, dict), + hash=Ref(UInt(0)), metadata=metadata, arguments=[], issorted=RefValue(false), kw...) end -function Mul(T, a, b; metadata=NO_METADATA, kw...) +function Mul(::Type{T}, a, b; metadata=NO_METADATA, kw...) where {T} isempty(b) && return a if _isone(a) && length(b) == 1 pair = first(b) @@ -430,7 +473,23 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...) else coeff = a dict = b - Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + # TODO: revisit convert after https://github.com/Roger-luo/Moshi.jl/issues/32 is resolved + Mul{T}(; coeff=coeff, dict=convert(Dict{BasicSymbolic, Any}, dict), + hash=Ref(UInt(0)), metadata=metadata, arguments=[], issorted=RefValue(false), kw...) + end +end + +function _iszero(x::BasicSymbolic) + @match x begin + Const(_) => iszero(x.val) + _ => false + end +end + +function _isone(x::BasicSymbolic) + @match x begin + Const(_) => isone(x.val) + _ => false end end @@ -495,7 +554,7 @@ function Div{T}(n, d, simplified=false; metadata=nothing) where {T} end end - Div{T}(; num=n, den=d, simplified, arguments=[], metadata) + Div{T}(; num=n, den=d, simplified=simplified, arguments=[], metadata=metadata) end function Div(n,d, simplified=false; kw...) @@ -512,7 +571,7 @@ end function Pow{T}(a, b; metadata=NO_METADATA) where {T} _iszero(b) && return 1 _isone(b) && return a - Pow{T}(; base=a, exp=b, arguments=[], metadata) + Pow{T}(; base=a, exp=b, arguments=[], metadata=metadata) end function Pow(a, b; metadata=NO_METADATA) @@ -524,14 +583,14 @@ function toterm(t::BasicSymbolic{T}) where T if E === SYM || E === TERM return t elseif E === ADD || E === MUL - args = Any[] + args = BasicSymbolic[] push!(args, t.coeff) for (k, coeff) in t.dict - push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), Any[coeff, k])) + push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), [Const(coeff), k])) end Term{T}(operation(t), args) elseif E === DIV - Term{T}(/, Any[t.num, t.den]) + Term{T}(/, [t.num, t.den]) elseif E === POW Term{T}(^, [t.base, t.exp]) else @@ -546,7 +605,7 @@ Any Muls inside an Add should always have a coeff of 1 and the key (in Add) should instead be used to store the actual coefficient """ function makeadd(sign, coeff, xs...) - d = sdict() + d = Dict{BasicSymbolic, Any}() for x in xs if isadd(x) coeff += x.coeff @@ -573,7 +632,7 @@ function makeadd(sign, coeff, xs...) coeff, d end -function makemul(coeff, xs...; d=sdict()) +function makemul(coeff, xs...; d=Dict{BasicSymbolic, Any}()) for x in xs if ispow(x) && x.exp isa Number d[x.base] = x.exp + get(d, x.base, 0) @@ -612,7 +671,7 @@ function term(f, args...; type = nothing) else T = type end - Term{T}(f, Any[args...]) + Term{T}(f, [args...]) end """ @@ -624,7 +683,7 @@ function unflatten(t::Symbolic{T}) where{T} f = operation(t) if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops a = arguments(t) - return foldl((x,y) -> Term{T}(f, Any[x, y]), a) + return foldl((x,y) -> Term{T}(f, [x, y]), a) end end return t @@ -709,10 +768,10 @@ end issafecanon(f, s) = true function issafecanon(f, s::Symbolic) - if isnothing(metadata(s)) || issym(s) - return true - else - _issafecanon(f, s) + isnothing(metadata(s)) || @match s begin + Sym(_) => true + Const(_) => true + _ => _issafecanon(f, s) end end _issafecanon(::typeof(*), s) = !iscall(s) || !(operation(s) in (+,*,^)) @@ -778,6 +837,10 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) + if isconst(t) + val = t.val + return isnegative(val) + end if iscall(t) && operation(t) === (*) coeff = first(arguments(t)) return isnegative(coeff) @@ -812,8 +875,12 @@ function remove_minus(t) !iscall(t) && return -t @assert operation(t) == (*) args = arguments(t) - @assert args[1] < 0 - Any[-args[1], args[2:end]...] + arg1 = args[1] + if isconst(arg1) + arg1 = arg1.val + end + @assert arg1 < 0 + Any[-arg1, args[2:end]...] end @@ -848,17 +915,27 @@ function show_pow(io, args) end function show_mul(io, args) + if isconst(args) + print(io, args.val) + return + end length(args) == 1 && return print_arg(io, *, args[1]) - minus = args[1] isa Number && args[1] == -1 - unit = args[1] isa Number && args[1] == 1 + arg1 = args[1] + if isconst(arg1) + arg1 = arg1.val + end + + minus = arg1 isa Number && arg1 == -1 + unit = arg1 isa Number && arg1 == 1 - paren_scalar = (args[1] isa Complex && !_iszero(imag(args[1]))) || - args[1] isa Rational || - (args[1] isa Number && !isfinite(args[1])) + paren_scalar = (arg1 isa Complex && !_iszero(imag(arg1))) || + arg1 isa Rational || + (arg1 isa Number && !isfinite(arg1)) nostar = minus || unit || - (!paren_scalar && args[1] isa Number && !(args[2] isa Number)) + (!paren_scalar && arg1 isa Number && + !(isconst(args[2]) && args[2].val isa Number)) for (i, t) in enumerate(args) if i != 1 @@ -947,10 +1024,10 @@ showraw(io, t) = Base.show(IOContext(io, :simplify=>false), t) showraw(t) = showraw(stdout, t) function Base.show(io::IO, v::BasicSymbolic) - if issym(v) - Base.show_unquoted(io, v.name) - else - show_term(io, v) + @match v begin + Sym(_) => Base.show_unquoted(io, v.name) + Const(_) => print(io, v.val) + _ => show_term(io, v) end end @@ -1164,6 +1241,12 @@ sub_t(a) = promote_symtype(-, symtype(a)) import Base: (+), (-), (*), (//), (/), (\), (^) function +(a::SN, b::SN) + if isconst(a) + return a.val + b + end + if isconst(b) + return b.val + a + end !issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return Add(add_t(a,b), @@ -1180,6 +1263,9 @@ function +(a::SN, b::SN) end function +(a::Number, b::SN) + if isconst(b) + return a + b.val + end !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) @@ -1194,6 +1280,7 @@ end +(a::SN) = a function -(a::SN) + isconst(a) && return Const(-a.val) !issafecanon(*, a) && return term(-, a) isadd(a) ? Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict)) : Add(sub_t(a), makeadd(-1, 0, a)...) @@ -1218,6 +1305,8 @@ mul_t(a) = promote_symtype(*, symtype(a)) *(a::SN) = a function *(a::SN, b::SN) + isconst(a) && return a.val * b + isconst(b) && return b.val * a # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) @@ -1246,6 +1335,7 @@ function *(a::SN, b::SN) end function *(a::Number, b::SN) + isconst(b) && return a * b.val !issafecanon(*, b) && return term(*, a, b) if iszero(a) a @@ -1256,7 +1346,7 @@ function *(a::Number, b::SN) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) - Add(T, b.coeff * a, Dict{Any,Any}(k=>v*a for (k, v) in b.dict)) + Add(T, b.coeff * a, Dict{BasicSymbolic,Any}(k=>v*a for (k, v) in b.dict)) else Mul(mul_t(a, b), makemul(a, b)...) end diff --git a/src/utils.jl b/src/utils.jl index 812e229fb..78c2c17d8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -64,8 +64,10 @@ end sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T -isliteral(::Type{T}) where {T} = x -> x isa T -is_literal_number(x) = isliteral(Number)(x) +function is_literal_number(x) + x = isconst(x) ? x.val : x + x isa Number +end # checking the type directly is faster than dynamic dispatch in type unstable code _iszero(x) = x isa Number && iszero(x) @@ -179,10 +181,15 @@ Base.length(l::LL) = length(l.v)-l.i+1 @inline car(l::LL) = l.v[l.i] @inline cdr(l::LL) = isempty(l) ? empty(l) : LL(l.v, l.i+1) -Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY -Base.isempty(t::Term) = false -@inline car(t::Term) = operation(t) -@inline cdr(t::Term) = arguments(t) +function Base.length(t::BasicSymbolic) + @match t begin + Term(_) => length(arguments(t)) + 1 + _ => 1 + end +end +Base.isempty(t::BasicSymbolic) = false +@inline car(t::BasicSymbolic) = operation(t) +@inline cdr(t::BasicSymbolic) = arguments(t) @inline car(v) = iscall(v) ? operation(v) : first(v) @inline function cdr(v) diff --git a/test/basics.jl b/test/basics.jl index a5f0b5149..b309ba90e 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -82,10 +82,10 @@ struct Ctx1 end struct Ctx2 end @testset "metadata" begin - @syms a b c - for a = [a, sin(a), a+b, a*b, a^3] + @syms a b + for x = [a, sin(a), a+b, a*b, a^3] - a′ = setmetadata(a, Ctx1, "meta_1") + a′ = setmetadata(x, Ctx1, "meta_1") @test hasmetadata(a′, Ctx1) @test !hasmetadata(a′, Ctx2) diff --git a/test/rewrite.jl b/test/rewrite.jl index c2e920f9b..8dc27ba5c 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,3 +1,4 @@ +using SymbolicUtils: Symbolic, Const using SymbolicUtils include("utils.jl") @@ -42,9 +43,12 @@ end @eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b] @eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c] - @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y))(term(+,9,8,9,type=Any)) == ([9,],8) - @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y, ~~x))(term(+,9,8,9,9,8,type=Any)) == ([9,8], 9, [9,8]) - @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y))(term(+, 9, 8, 9; type = Any)) == + (Symbolic[9], Const(8)) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 9, 8, 9, 9, 8; type = Any)) == + (Symbolic[9, 8], Const(9), Symbolic[9, 8]) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 6; type = Any)) == + (Symbolic[], Const(6), Symbolic[]) end using SymbolicUtils: @capture @@ -108,4 +112,4 @@ end ex1 = ex * b @test getmetadata(arguments(ex1)[1], MetaData) == :metadata -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index 9ea8354a8..8e4a5b836 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Pkg, Test, SafeTestsets @safetestset "Benchmark" begin include("benchmark.jl") end else @safetestset "Doc" begin include("doctest.jl") end + @safetestset "Types" begin include("types.jl") end @safetestset "Basics" begin include("basics.jl") end @safetestset "Order" begin include("order.jl") end @safetestset "PolyForm" begin include("polyform.jl") end diff --git a/test/types.jl b/test/types.jl new file mode 100644 index 000000000..292c25a28 --- /dev/null +++ b/test/types.jl @@ -0,0 +1,113 @@ +using SymbolicUtils: Symbolic, BasicSymbolic, Sym, Term, Add, Mul, Div, Pow, Const +using SymbolicUtils + +s1 = Sym{Float64}(:abc) +s2 = Sym{Int64}(; name = :def) +@testset "Sym" begin + @test typeof(s1) <: BasicSymbolic + @test typeof(s1) == BasicSymbolic{Float64} + @test s1 isa BasicSymbolic + @test s1 isa SymbolicUtils.Symbolic + @test s1.metadata isa SymbolicUtils.Metadata + @test s1.metadata == SymbolicUtils.NO_METADATA + @test s1.name == :abc + @test typeof(s2) <: BasicSymbolic + @test typeof(s2) == BasicSymbolic{Int64} + @test typeof(s2.name) == Symbol + @test s2.name == :def +end + +@testset "Term" begin + t1 = Term(sin, [s1]) + @test typeof(t1) <: BasicSymbolic + @test typeof(t1) == BasicSymbolic{Real} + @test t1.f == sin + @test isequal(t1.arguments, [s1]) + @test typeof(t1.arguments) == Vector{Symbolic} +end + +c1 = Const(1) +c2 = Const(3.14) +@testset "Const" begin + @test typeof(c1) <: BasicSymbolic + @test typeof(c1.val) == Int + @test c1.val == 1 + @test typeof(c2.val) == Float64 + @test c2.val == 3.14 + c3 = Const(big"123456789012345678901234567890") + @test typeof(c3.val) == BigInt + @test c3.val == big"123456789012345678901234567890" + c4 = Const(big"1.23456789012345678901") + @test typeof(c4.val) == BigFloat + @test c4.val == big"1.23456789012345678901" +end + +coeff = c1 +dict = Dict{BasicSymbolic, Any}(s1 => 3, s2 => 5) +@testset "Add" begin + a1 = Add{Real}(; coeff=coeff, dict=dict) + @test typeof(a1) <: BasicSymbolic + @test a1.coeff isa BasicSymbolic + @test isequal(a1.coeff, c1) + @test typeof(a1.dict) == Dict{BasicSymbolic, Any} + @test a1.dict == dict + @test typeof(a1.arguments) == Vector{Any} + @test isempty(a1.arguments) + @test typeof(a1.issorted) == Base.RefValue{Bool} + @test !a1.issorted[] +end + +@testset "Mul" begin + m1 = Mul{Real}(; coeff=coeff, dict=dict) + @test typeof(m1) <: BasicSymbolic + @test m1.coeff isa BasicSymbolic + @test isequal(m1.coeff, c1) + @test typeof(m1.dict) == Dict{BasicSymbolic, Any} + @test m1.dict == dict + @test typeof(m1.arguments) == Vector{Any} + @test isempty(m1.arguments) + @test typeof(m1.issorted) == Base.RefValue{Bool} + @test !m1.issorted[] +end + +@testset "Div" begin + d1 = Div(s1, s2) + @test typeof(d1) <: BasicSymbolic + @test typeof(d1) == BasicSymbolic{Float64} + @test isequal(d1.num, s1) + @test isequal(d1.den, s2) + @test typeof(d1.simplified) == Bool + @test !d1.simplified + @test isequal(arguments(d1), [s1, s2]) + d2 = Div{Real}(; num=s1, den=s2) + @test isequal(d2.num, s1) + @test isequal(d2.den, s2) +end + +@testset "Pow" begin + p1 = Pow(s1, s2) + @test typeof(p1) <: BasicSymbolic + @test isequal(p1.base, s1) + @test isequal(p1.exp, s2) + @test isequal(arguments(p1), [s1, s2]) + p2 = Pow{Real}(; base=s1, exp=s2) + @test isequal(p2.base, s1) + @test isequal(p2.exp, s2) +end + +@testset "BasicSymbolic iszero" begin + c1 = Const(0) + @test SymbolicUtils._iszero(c1) + c2 = Const(1) + @test !SymbolicUtils._iszero(c2) + c3 = Const(0.0) + @test SymbolicUtils._iszero(c3) + c4 = Const(0.00000000000000000000000001) + @test !SymbolicUtils._iszero(c4) + c5 = Const(big"326264532521352634435352152") + @test !SymbolicUtils._iszero(c5) + c6 = Const(big"0.314654523452") + @test !SymbolicUtils._iszero(c6) + s = Sym{Real}(:y) + @test !SymbolicUtils._iszero(s) +end