Skip to content

Commit

Permalink
test: EvalFullInterp gives same results as EvalFullStep
Browse files Browse the repository at this point in the history
This adds code to check for alpha equality of terms

Signed-off-by: Ben Price <ben@hackworthltd.com>
  • Loading branch information
brprice committed Dec 5, 2023
1 parent ab1c547 commit c545482
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 3 deletions.
8 changes: 7 additions & 1 deletion primer/src/Primer/Core/Type/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module Primer.Core.Type.Utils (
freeVarsTy,
boundVarsTy,
alphaEqTy,
alphaEqTy',
concreteTy,
) where

Expand Down Expand Up @@ -130,7 +131,12 @@ boundVarsTy = foldMap' getBoundHereDnTy . universe
-- Note that we do not expand TLets, they must be structurally
-- the same (perhaps with a different named binding)
alphaEqTy :: Type' () () -> Type' () () -> Bool
alphaEqTy = go (0, mempty, mempty)
alphaEqTy = alphaEqTy' (0, mempty, mempty)

-- Check two types for alpha equality where each may be from a
-- different alpha-related context
alphaEqTy' :: (Int, Map TyVarName Int, Map TyVarName Int) -> Type' () () -> Type' () () -> Bool
alphaEqTy' = go
where
go _ (TEmptyHole _) (TEmptyHole _) = True
go bs (THole _ s) (THole _ t) = go bs s t
Expand Down
67 changes: 67 additions & 0 deletions primer/src/Primer/Core/Utils.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE ViewPatterns #-}

module Primer.Core.Utils (
freshLocalName,
freshLocalName',
Expand All @@ -23,6 +25,7 @@ module Primer.Core.Utils (
freeGlobalVars,
alphaEqTy,
concreteTy,
alphaEq,
freshen,
) where

Expand All @@ -31,8 +34,10 @@ import Foreword
import Control.Monad.Fresh (MonadFresh, fresh)
import Data.Data (Data)
import Data.Generics.Uniplate.Data (universe)
import Data.Map.Strict qualified as M
import Data.Set qualified as S
import Data.Set.Optics (setOf)
import Data.Tuple.Extra (firstM)
import Optics (
Fold,
Traversal,
Expand All @@ -52,6 +57,7 @@ import Optics (

import Primer.Core (
CaseBranch' (..),
CaseFallback' (CaseExhaustive, CaseFallback),
Expr,
Expr' (..),
GVarName,
Expand All @@ -73,6 +79,7 @@ import Primer.Core (
import Primer.Core.Fresh (freshLocalName, freshLocalName')
import Primer.Core.Type.Utils (
alphaEqTy,
alphaEqTy',
boundVarsTy,
concreteTy,
forgetKindMetadata,
Expand Down Expand Up @@ -196,6 +203,66 @@ freeGlobalVars e = S.fromList [v | Var _ (GlobalVarRef v) <- universe e]
exprIDs :: (HasID a, HasID b, HasID c) => Traversal' (Expr' a b c) ID
exprIDs = (_exprMeta % _id) `adjoin` (_exprTypeMeta % _id) `adjoin` (_exprKindMeta % _id)

-- Check two terms for alpha equality
--
-- it makes usage easier if this is pure
-- i.e. we don't want to need a fresh name supply
-- We assume both inputs are both from the same context
--
-- Note that we do not expand let bindings, they must be structurally
-- the same (perhaps with a different named binding)
alphaEq :: Expr' () () () -> Expr' () () () -> Bool
alphaEq = go (0, mempty, mempty)
where
go bs (Hole _ t1) (Hole _ t2) = go bs t1 t2
go _ (EmptyHole _) (EmptyHole _) = True
go bs (Ann _ t1 ty1) (Ann _ t2 ty2) = go bs t1 t2 && alphaEqTy' (extractTypeEnv bs) ty1 ty2
go bs (App _ f1 t1) (App _ f2 t2) = go bs f1 f2 && go bs t1 t2
go bs (APP _ e1 ty1) (APP _ e2 ty2) = go bs e1 e2 && alphaEqTy' (extractTypeEnv bs) ty1 ty2
go bs (Con _ c1 as1) (Con _ c2 as2) = c1 == c2 && length as1 == length as2 && and (zipWith (go bs) as1 as2)
go bs (Lam _ v1 t1) (Lam _ v2 t2) = go (newTm bs v1 v2) t1 t2
go bs (LAM _ v1 t1) (LAM _ v2 t2) = go (newTy bs v1 v2) t1 t2
go (_, bs1, bs2) (Var _ (LocalVarRef v1)) (Var _ (LocalVarRef v2)) = bs1 ! Left v1 == bs2 ! Left v2
go _ (Var _ (GlobalVarRef v1)) (Var _ (GlobalVarRef v2)) = v1 == v2
go bs (Let _ v1 s1 t1) (Let _ v2 s2 t2) = go bs s1 s2 && go (newTm bs v1 v2) t1 t2
go bs (LetType _ v1 ty1 t1) (LetType _ v2 ty2 t2) = alphaEqTy' (extractTypeEnv bs) ty1 ty2 && go (newTy bs v1 v2) t1 t2
go bs (Letrec _ v1 t1 ty1 e1) (Letrec _ v2 t2 ty2 e2) =
go (newTm bs v1 v2) t1 t2
&& alphaEqTy' (extractTypeEnv bs) ty1 ty2
&& go (newTm bs v1 v2) e1 e2
go bs (Case _ e1 brs1 fb1) (Case _ e2 brs2 fb2) =
go bs e1 e2
&& and
( zipWith
( \(CaseBranch c1 (fmap bindName -> vs1) t1)
(CaseBranch c2 (fmap bindName -> vs2) t2) ->
c1
== c2
&& length vs1
== length vs2
&& go (foldl' (uncurry . newTm) bs $ zip vs1 vs2) t1 t2
)
brs1
brs2
)
&& case (fb1, fb2) of
(CaseExhaustive, CaseExhaustive) -> True
(CaseFallback f1, CaseFallback f2) -> go bs f1 f2
_ -> False
go _ (PrimCon _ c1) (PrimCon _ c2) = c1 == c2
go _ _ _ = False
p ! n = case p M.!? n of
Nothing -> Left n -- free vars: compare by name
Just i -> Right i -- bound vars: up to alpha
-- Note that the maps 'p' and 'q' map names to "which forall
-- they came from", in some sense. The @c@ value is how many
-- binders we have gone under, and is thus the next value free
-- in the map.
new (c, bs1, bs2) n m = (c + 1 :: Int, M.insert n c bs1, M.insert m c bs2)
newTm bs v1 v2 = new bs (Left v1) (Left v2)
newTy bs v1 v2 = new bs (Right v1) (Right v2)
extractTypeEnv (c, bs1, bs2) = let f = M.fromList . mapMaybe (firstM rightToMaybe) . M.assocs in (c, f bs1, f bs2)

freshen :: Set Name -> LocalName k -> LocalName k
freshen fvs n = go (0 :: Int)
where
Expand Down
4 changes: 3 additions & 1 deletion primer/src/Primer/EvalFullInterp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ interp (MicroSec t) tydefs env dir e = do
-- inside holes, by convincing Haskell's runtime system to do the
-- evaluation for us, in a call-by-need fashion. We return an AST
-- of the evaluated term, which will be type-correct (assuming the
-- input was): see 'Tests.EvalFullInterp.tasty_type_preservation'.
-- input was): see 'Tests.EvalFullInterp.tasty_type_preservation';
-- and will agree with iterating the small-step interpreter: see
-- 'Tests.EvalFullInterp.tasty_two_interp_agree'.
--
-- Warnings:
-- - Trying to evaluate a divergent term will (unsurprisingly) not terminate,
Expand Down
51 changes: 50 additions & 1 deletion primer/test/Tests/AlphaEquality.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import Foreword
import Hedgehog hiding (Property, check, property)
import Primer.Builtins
import Primer.Core (
Expr,
Type',
)
import Primer.Core.DSL
import Primer.Core.Utils (alphaEqTy, forgetTypeMetadata)
import Primer.Core.Utils (alphaEq, alphaEqTy, forgetMetadata, forgetTypeMetadata)
import Primer.Gen.Core.Raw (
evalExprGen,
genTyVarName,
Expand Down Expand Up @@ -101,6 +102,48 @@ tasty_alpha = property $ do
where
f v = create_ $ tforall v ktype $ tvar v

unit_tm_1 :: Assertion
unit_tm_1 = alphaNotEqTm (con0 cTrue) (con0 cFalse)

unit_tm_2 :: Assertion
unit_tm_2 = alphaEqTm (con cCons [con0 cTrue, con0 cNil]) (con cCons [con0 cTrue, con0 cNil])

unit_tm_3 :: Assertion
unit_tm_3 = alphaNotEqTm (con cCons [con0 cFalse, con0 cNil]) (con cCons [con0 cTrue, con0 cNil])

unit_tm_4 :: Assertion
unit_tm_4 = alphaNotEqTm (con cCons [con0 cFalse, con0 cNil]) (con0 cTrue)

unit_tm_5 :: Assertion
unit_tm_5 = alphaNotEqTm (lam "x" $ con0 cTrue) (con0 cTrue)

unit_tm_6 :: Assertion
unit_tm_6 = alphaEqTm (lam "x" $ lvar "x") (lam "y" $ lvar "y")

unit_tm_7 :: Assertion
unit_tm_7 = alphaNotEqTm (lam "x" $ lvar "x") (lam "y" $ con0 cTrue)

unit_tm_8 :: Assertion
unit_tm_8 = alphaNotEqTm (lAM "x" emptyHole) (lam "y" emptyHole)

unit_tm_9 :: Assertion
unit_tm_9 = alphaNotEqTm (lam "x" $ lam "y" $ lvar "x") (lam "x" $ lam "y" $ lvar "y")

unit_tm_10 :: Assertion
unit_tm_10 = alphaNotEqTm (lam "x" $ con1 cJust $ lvar "x") (con1 cJust $ lam "x" $ lvar "x")

unit_tm_11 :: Assertion
unit_tm_11 = alphaNotEqTm (lam "x" $ lvar "x" `app` con0 cTrue) (lam "x" (lvar "x") `app` con0 cTrue)

unit_tm_repeated_names :: Assertion
unit_tm_repeated_names = alphaEqTm (lam "a" $ lam "b" $ lvar "x" `app` lvar "x") (lam "a" $ lam "a" $ lvar "x" `app` lvar "x")

unit_tm_tmp :: Assertion
unit_tm_tmp =
alphaEqTm
(lAM "x" $ lAM "y" $ lam "x" $ hole $ case_ emptyHole [branch cTrue [("x", Nothing)] emptyHole])
(lAM "x" $ lAM "y" $ lam "x0" $ hole $ case_ emptyHole [branch cTrue [("x1", Nothing)] emptyHole])

create_ :: S (Type' a b) -> Alpha
create_ = Alpha . forgetTypeMetadata . create'

Expand All @@ -113,3 +156,9 @@ instance Eq Alpha where

assertNotEqual :: Alpha -> Alpha -> Assertion
assertNotEqual s t = assertBool "types are equal" $ s /= t

alphaEqTm :: S Expr -> S Expr -> Assertion
alphaEqTm s t = assertBool "terms should be equal" $ alphaEq (forgetMetadata $ create' s) (forgetMetadata $ create' t)

alphaNotEqTm :: S Expr -> S Expr -> Assertion
alphaNotEqTm s t = assertBool "terms should not be equal" $ not $ alphaEq (forgetMetadata $ create' s) (forgetMetadata $ create' t)
19 changes: 19 additions & 0 deletions primer/test/Tests/EvalFullInterp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import Primer.Builtins.DSL (boolAnn, bool_, list_, nat)
import Primer.Core
import Primer.Core.DSL
import Primer.Core.Utils (
alphaEq,
forgetMetadata,
generateIDs,
)
import Primer.Def (DefMap)
import Primer.Eval
import Primer.EvalFullInterp (InterpError (..), Timeout (MicroSec), interp, mkGlobalEnv)
import Primer.EvalFullStep (evalFullStepCount)
import Primer.Examples qualified as Examples (
even,
map,
Expand Down Expand Up @@ -83,6 +85,7 @@ import Primer.Test.Expected (
mapEven,
)
import Primer.Test.Util (
failWhenSevereLogs,
primDefs,
)
import Primer.TypeDef (TypeDefMap)
Expand Down Expand Up @@ -469,6 +472,22 @@ tasty_type_preservation = withTests 1000
s'' <- checkTest ty =<< generateIDs s'
s' === forgetMetadata s'' -- check no smart holes happened

tasty_two_interp_agree :: Property
tasty_two_interp_agree = withTests 1000
$ withDiscards 2000
$ propertyWT testModules
$ do
let globs = foldMap' moduleDefsQualified $ create' $ sequence testModules
tds <- asks typeDefs
(dir, t, _ty) <- genDirTm
let optsV = ViewRedexOptions{groupedLets = True, aggressiveElision = True, avoidShadowing = False}
let optsR = RunRedexOptions{pushAndElide = True}
(_, ss) <- failWhenSevereLogs $ evalFullStepCount @EvalLog UnderBinders optsV optsR tds globs 100 dir t
si <- liftIO (evalFullTest' (MicroSec 10_000) tds globs dir $ forgetMetadata t)
case (ss, si) of
(Right ss', Right si') -> label "both terminated" >> Hedgehog.diff (forgetMetadata ss') alphaEq si'
_ -> label "one failed to terminate"

---- Unsaturated primitives are stuck terms
unit_prim_stuck :: Assertion
unit_prim_stuck =
Expand Down

0 comments on commit c545482

Please sign in to comment.