@@ -169,24 +169,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
169
169
case Neg (rhs) =>
170
170
mono(rhs, ! pol).!
171
171
case CompType (lhs, rhs, pol) =>
172
- val (l, r) = (typeMonoType(lhs), typeMonoType(rhs))
173
- if ! pol then
174
- val lfa = l.toDnf.conjs.flatMap(_.i.v).collect:
175
- case (f :: fs) =>
176
- val fd = Type .discriminant(f.args).flatten
177
- fs.foldLeft(fd : Type , fd.fields.keys.toSet): (x, y) =>
178
- val d = Type .discriminant(y.args).flatten
179
- (x._1 | d, x._2 & d.fields.keys.toSet)
180
- val rfa = r.toDnf.conjs.flatMap(_.i.v).collect:
181
- case (f :: fs) =>
182
- val fd = Type .discriminant(f.args).flatten
183
- fs.foldLeft(fd : Type , fd.fields.keys.toSet): (x, y) =>
184
- val d = Type .discriminant(y.args).flatten
185
- (x._1 | d, x._2 & d.fields.keys.toSet)
186
- val d = lfa.iterator.flatMap(x => rfa.iterator.map(y => (x, y))).exists:
187
- case ((x, u), (y, w)) => (u & w).isEmpty || Type .disjoint(x, y) =/= S (Set .empty)
188
- if d then error(msg " Ill-formed functions intersection " -> ty.toLoc :: Nil )
189
- Type .mkComposedType(l, r, pol)
172
+ Type .mkComposedType(typeMonoType(lhs), typeMonoType(rhs), pol)
190
173
case _ =>
191
174
ty.symbol.flatMap(_.asTpe) match
192
175
case S (cls : (ClassSymbol | TypeAliasSymbol )) => typeAndSubstType(Term .TyApp (ty, Nil )(N ), pol)
@@ -201,17 +184,39 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
201
184
tv -> qv
202
185
bds.foreach:
203
186
case (tv, QuantVar (_, ub, lb)) =>
204
- ub.foreach(ub => tv.state.upperBounds ::= typeMonoType( ub))
205
- lb.foreach(lb => tv.state.lowerBounds ::= typeMonoType( lb))
187
+ ub.foreach(ub => tv.state.upperBounds ::= monoOrErr(typeType(ub), ub))
188
+ lb.foreach(lb => tv.state.lowerBounds ::= monoOrErr(typeType(lb), lb))
206
189
val lbty = tv.state.lowerBounds.foldLeft[Type ](Bot )(_ | _)
207
190
val ubty = tv.state.upperBounds.foldLeft[Type ](Top )(_ & _)
208
191
solver.constrain(lbty, ubty)
209
192
PolyType (bds.map(_._1), S (outer), body)
210
193
211
- private def typeMonoType (ty : Term )(using ctx : BbCtx , cctx : CCtx ): Type = monoOrErr(typeType(ty), ty)
212
-
213
- private def typeType (ty : Term )(using ctx : BbCtx , cctx : CCtx ): GeneralType =
214
- typeAndSubstType(ty, pol = true )(using Map .empty)
194
+ private def typeMonoType (ty : Term )(using ctx : BbCtx , cctx : CCtx ): Type = monoOrErr(typeAndSubstType(ty, true )(using Map .empty), ty)
195
+
196
+ private def wffuns (fs : Ls [FunType ]) =
197
+ val wf = fs.forall(f => (f.ret :: f.eff :: f.args).forall(wftype))
198
+ wf && fs.combinations(2 ).forall: u =>
199
+ Type .disjoint(Type .discriminant(u.head.args), Type .discriminant(u.tail.head.args)).exists(_.isEmpty)
200
+ private def wfrcds (rs : Ls [RcdType ]) =
201
+ val wf = rs.forall(_.fields.forall(u => wftype(u._2)))
202
+ wf && rs.combinations(2 ).forall: u =>
203
+ Type .disjoint(u.head, u.tail.head).exists(_.isEmpty)
204
+ private def wfcls (cs : Ls [ClassLikeType ]) =
205
+ cs.forall(_.targs.forall(u => wftype(u.posPart) && wftype(u.negPart)))
206
+ private def wftype (ty : GeneralType ): Bool = ty match
207
+ case t : PolyType => wftype(t.body)
208
+ case t : PolyFunType => (t.ret :: t.eff :: t.args).forall(wftype)
209
+ case t : Type =>
210
+ val n = t.! .toDnf.conjs
211
+ val nf = n.iterator.map(_.i.v).forall:
212
+ case S (f : Ls [FunType ]) => false
213
+ case _ => true
214
+ nf && wffuns(n.flatMap(_.u.fun)) && n.forall(c => wfrcds(c.u.rcd) && wfcls(c.u.cls))
215
+
216
+ private def typeType (ty : Term , map : Map [Uid [Symbol ], TypeArg ] = Map .empty)(using ctx : BbCtx , cctx : CCtx ): GeneralType =
217
+ val t = typeAndSubstType(ty, pol = true )(using map)
218
+ if ! wftype(t) then error(msg " Ill-formed type " -> ty.toLoc :: Nil )
219
+ t
215
220
216
221
private def instantiate (ty : PolyType )(using ctx : BbCtx ): GeneralType =
217
222
ty.instantiate(infVarState.nextUid, freshEnv(new TempSymbol (N , " env" )), ctx.lvl)
@@ -486,13 +491,29 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
486
491
ascribe(term, typeType(ty))
487
492
ascribe(term, rhs)
488
493
case _ =>
489
- val (lhsTy, eff) = typeCheck(lhs)
490
494
rhs match
491
495
case pf : PolyFunType if pf.isPoly =>
492
496
(error(msg " Cannot type non-function term ${lhs.toString} as ${rhs.show}" -> lhs.toLoc :: Nil ), Bot )
493
497
case _ =>
494
- constrain(tryMkMono(lhsTy, lhs), monoOrErr(rhs, lhs))
495
- (rhs, eff)
498
+ val r = monoOrErr(rhs, lhs)
499
+ val n = r.! .toDnf.conjs
500
+ if n.flatMap(_.u.fun).length > 1 then
501
+ val eff = n.foldLeft(Bot : Type ): (e, c) =>
502
+ (c.i.v, c.u, c.vars) match
503
+ case (i, u, v) =>
504
+ val n = i.fold(Bot : Type ):
505
+ case x : (ClassLikeType | RcdType ) => x.!
506
+ case x : Ls [FunType ] => x.reduce[Type ](_ & _).!
507
+ val k = Ls (u.fun, u.cls, u.cls).flatten.foldLeft(Bot : Type )(_ | _)
508
+ val vs = v.foldLeft(Bot : Type ): (x, y) =>
509
+ x | (if y._2 then y._1.! else y._1)
510
+ val (_, eff) = ascribe(lhs, n | k | vs)
511
+ e | eff
512
+ (rhs, eff)
513
+ else
514
+ val (lhsTy, eff) = typeCheck(lhs)
515
+ constrain(tryMkMono(lhsTy, lhs), r)
516
+ (rhs, eff)
496
517
497
518
// TODO: t -> loc when toLoc is implemented
498
519
private def app (lhs : (GeneralType , Type ), rhs : Ls [Elem ], t : Term )
@@ -620,7 +641,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
620
641
case Param (_, sym, sign, _) =>
621
642
if sym.nme === field.name then sign else N
622
643
}.filter(_.isDefined)) match
623
- case S (res) :: Nil => (typeAndSubstType (res, pol = true )( using map.toMap), eff)
644
+ case S (res) :: Nil => (typeType (res, map.toMap), eff)
624
645
case _ => (error(msg " ${field.name} is not a valid member in class ${clsSym.nme}" -> t.toLoc :: Nil ), Bot )
625
646
case N =>
626
647
(error(msg " Not a valid class: ${cls.describe}" -> cls.toLoc :: Nil ), Bot )
@@ -650,7 +671,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
650
671
require(clsDfn.paramsOpt.forall(_.restParam.isEmpty))
651
672
args.iterator.zip(clsDfn.params.params).foreach {
652
673
case (arg, Param (sign = S (sign))) =>
653
- val (ty, eff) = ascribe(arg, typeAndSubstType (sign, pol = true )( using map.toMap))
674
+ val (ty, eff) = ascribe(arg, typeType (sign, map.toMap))
654
675
effBuff += eff
655
676
case _ => ???
656
677
}
0 commit comments