Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement type checking for sgir #3

Merged
merged 8 commits into from
Mar 6, 2024
Merged
28 changes: 20 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,30 @@ fn main() {
Binding { id: "j".to_owned(), typ: Type::Number },
],
body: Box::new(Variable("alex!".to_owned())),
}),
}), // (number, boolean, number, number) -> number
arguments: vec![
Number(420),
Boolean(true),
Function {
parameters: vec![Binding { id: "x".to_owned(), typ: Type::Number }],
body: Box::new(Variable("x".to_owned())),
},
Number(694208008135),
Number(420), // : number
Boolean(true), // : boolean
Application {
function: Box::new(Function {
parameters: vec![Binding { id: "x".to_owned(), typ: Type::Number }],
body: Box::new(Variable("x".to_owned())),
}), // : (number) -> number
arguments: vec![Number(42)], // : number
}, // : number
Number(694208008135), // : number
]
};

let typ = match sgir::check(prog.clone()) {
Ok(typ) => typ,
Err(type_error) => {
eprintln!("[ERROR] {:?}", type_error);
return
},
};
println!("TYPE: {:?}", typ);

let result = sgir::run(prog);
println!("{:?}", result);
}
202 changes: 200 additions & 2 deletions src/sgir/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
//! This module implements `sgir`, an intermediate representation for sanguinello.
//!
//! It is based on System Fω with explicit typing.

use std::collections::HashMap;
use thiserror::Error;

Expand All @@ -8,7 +12,10 @@ type Identifier = String;

#[derive(Clone, Debug, PartialEq)]
pub enum Kind {
/// The type of types.
Star,

/// A type constructor, or type function.
Arrow {
from: Vec<Kind>,
to: Box<Kind>,
Expand All @@ -31,6 +38,7 @@ pub enum Type {
parameters: Vec<TypeBinding>,
typ: Box<Type>,
},

/// type instantiation, e.g. T<U...>
Instantiate {
typ: Box<Type>,
Expand All @@ -44,20 +52,84 @@ pub enum Type {
arguments: Vec<Type>,
result: Box<Type>,
},

/// a boolean
Boolean,

/// a number
Number,
}

type TypeSubstitution = HashMap<Identifier, Type>;

impl Type {
fn apply(self, subst: &TypeSubstitution) -> Type {
match self {
Type::Variable(id) => match subst.get(&id) {
Some(replacement) => replacement.clone(),
None => Type::Variable(id),
},

Type::ForAll { parameters, typ } => {
// to handle shadowing properly, we have to remove any occurrences
// of any of the `parameters` found in the substitution.
let mut extended_subst = subst.clone();
for TypeBinding { id, .. } in parameters.iter() {
extended_subst.remove(id);
}

Type::ForAll {
parameters,
typ: Box::new(typ.apply(&extended_subst)),
}
},

Type::Instantiate { typ, arguments } => Type::Instantiate {
typ: Box::new(typ.apply(&subst)),
arguments: arguments.into_iter().map(|typ| typ.apply(&subst)).collect(),
},

Type::Function { arguments, result } => Type::Function {
arguments: arguments.into_iter().map(|typ| typ.apply(&subst)).collect(),
result: Box::new(result.apply(&subst)),
},

Type::Boolean => Type::Boolean,
Type::Number => Type::Number,
}
}
}

#[derive(Debug, Error, Clone, PartialEq)]
enum TypeError {
pub enum TypeError {
#[error("kind mismatch: expected {expected:?}, found {found:?}")]
KindMismatch {
expected: Kind,
found: Kind,
},

#[error("type mismatch: expected {expected:?}, found {found:?}")]
TypeMismatch {
expected: Type,
found: Type,
},

#[error("arity mismatch: expected {expected} arguments, found {found}")]
ArityMismatch {
expected: usize,
found: usize,
},

#[error("cannot call a non-function: {found:?}")]
CannotCallNonFunction {
found: Type,
},

#[error("cannot instantiate a non-quantification: {found:?}")]
CannotCallNonQuantification {
found: Type,
},

#[error("kind mismatch: expected a quantifier in type {found:?}")]
ExpectedQuantifier {
found: Type,
Expand All @@ -67,7 +139,7 @@ enum TypeError {
UnboundIdentifier(Identifier),
}

type TC<T> = Result<T, TypeError>;
pub type TC<T> = Result<T, TypeError>;

type KindEnv = HashMap<Identifier, Kind>;

Expand Down Expand Up @@ -122,6 +194,18 @@ pub enum Expression {
Boolean(bool),
Number(i64), // haha, this should be a bignum

/// type quantification (i.e. \Lambda)
Quantify {
parameters: Vec<TypeBinding>,
body: Box<Expression>,
},

/// type instantiation (i.e. application of a \Lambda)
Instantiate {
function: Box<Expression>,
arguments: Vec<Type>,
},

Function {
parameters: Vec<Binding>,
body: Box<Expression>,
Expand All @@ -133,6 +217,116 @@ pub enum Expression {
},
}

type TypeEnv = HashMap<Identifier, Type>;

fn check_type(tenv: &TypeEnv, kenv: &KindEnv, expr: Expression) -> TC<Type> {
match expr {
Expression::Variable(id) => match tenv.get(&id) {
Some(ty) => Ok(ty.clone()),
None => Err(TypeError::UnboundIdentifier(id.clone())),
},

Expression::Boolean(_) => Ok(Type::Boolean),

Expression::Number(_) => Ok(Type::Number),

Expression::Quantify { parameters, body } => {
let mut extended_kenv = kenv.clone();
extended_kenv.extend(parameters.clone().into_iter()
.map(|TypeBinding { id, kind }| (id, kind)));

let typ = check_type(tenv, &extended_kenv, *body)?;
match check_kinds(kenv, typ.clone())? {
// the resulting type is a type...
Kind::Star => Ok(Type::ForAll { parameters , typ: Box::new(typ) }),

// the resulting type is a type function...
kind => Err(TypeError::KindMismatch { expected: Kind::Star, found: kind }),
}
}

Expression::Instantiate { function, arguments } => {
match check_type(tenv, kenv, *function)? {
Type::ForAll { parameters, typ } => {
if arguments.len() != parameters.len() {
return Err(TypeError::ArityMismatch { expected: parameters.len(), found: arguments.len() })
}

let argument_kind_pairs = arguments.clone().into_iter()
.zip(parameters.iter().map(|TypeBinding { kind, .. }| kind.clone()));

for (argument, expected_kind) in argument_kind_pairs {
let computed_kind = check_kinds(kenv, argument)?;

if computed_kind != expected_kind {
return Err(TypeError::KindMismatch { expected: expected_kind, found: computed_kind })
}
}

// we have to substitute the types in `arguments` for the type parameters in `parameters`
let subst: TypeSubstitution = parameters.into_iter()
.map(|TypeBinding { id, .. }| id)
.zip(arguments)
.collect();

Ok(typ.apply(&subst))
},

// Unexpected type here, it must be a forall!
found => Err(TypeError::CannotCallNonQuantification { found })
}
}

Expression::Function { parameters, body } => {
for Binding { typ, .. } in parameters.iter() {
if let kind@Kind::Arrow { .. } = check_kinds(kenv, typ.clone())? {
return Err(TypeError::KindMismatch { expected: Kind::Star, found: kind })
}
}

let arguments = parameters.iter()
.map(|Binding { typ, .. }| typ.clone())
.collect();

let mut extended_tenv = tenv.clone();
extended_tenv.extend(parameters.into_iter()
.map(|Binding { id, typ }| (id, typ)));
let result = Box::new(check_type(&extended_tenv, &kenv, *body)?);

Ok(Type::Function { arguments, result })
},

Expression::Application { function, arguments } => {
match check_type(tenv, kenv, *function)? {
Type::Function { arguments: expected_types, result: result_type } => {
if arguments.len() != expected_types.len() {
return Err(TypeError::ArityMismatch { expected: expected_types.len(), found: arguments.len() })
}

for (argument, expected_type) in arguments.into_iter().zip(expected_types) {
let computed_type = check_type(tenv, kenv, argument)?;

if computed_type != expected_type {
return Err(TypeError::TypeMismatch { expected: expected_type, found: computed_type })
}
}

Ok(*result_type)
},

// Unexpected type here, it must be a function!
found => Err(TypeError::CannotCallNonFunction { found })
}
},
}
}

pub fn check(expr: Expression) -> TC<Type> {
let type_env = HashMap::new();
let kind_env = HashMap::new();
check_type(&type_env, &kind_env, expr)
}

#[derive(Clone, Debug)]
pub enum Value {
// Primitives
Expand All @@ -152,6 +346,10 @@ fn eval(subst: &Substitution, expr: Expression) -> Value {
Expression::Variable(identifier) => subst[&identifier].clone(),
Expression::Boolean(value) => Value::Boolean(value),
Expression::Number(value) => Value::Number(value),
// quantification has no runtime semantics
Expression::Quantify { body, .. } => eval(subst, *body),
// instantiation has no runtime semantics
Expression::Instantiate { function, .. } => eval(subst, *function),
Expression::Function { parameters, body } => Value::Function { parameters: parameters.clone(), body: body.clone() },
Expression::Application { function, arguments } => match eval(subst, *function) {
Value::Function { parameters, body } => {
Expand Down
Loading
Loading