Skip to content

Commit

Permalink
Merge pull request #6 from MarcosAndradeV/chsvm
Browse files Browse the repository at this point in the history
[ast] Fix bugs
  • Loading branch information
MarcosAndradeV authored Dec 23, 2024
2 parents c819d19 + fe9d0b5 commit afda26c
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 127 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["chs", "chs_ast", "chs_lexer", "chs_util"]
members = ["chs", "chs_ast", "chs_lexer", "chs_util", "chs_vm"]
resolver = "2"

[profile.release]
Expand Down
27 changes: 26 additions & 1 deletion chs/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use chs_ast::Parser;
use chs_ast::{types::{infer, InferEnv}, Parser};
use chs_lexer::Lexer;
use std::{
env::{self, Args},
Expand Down Expand Up @@ -38,6 +38,29 @@ fn main() {
}
}
}
"eval" => {
let (fpath, bytes) = get_file(&mut args);
let lex = Lexer::new(fpath, bytes);
let parser = Parser::new(lex);
match parser.parse() {
Ok(ok) => {
let mut env = InferEnv::default();
for expr in &ok.top_level {
match infer(&mut env, expr, 0) {
Err(err) => {
eprintln!("{err}");
exit(1)
}
_ => ()
}
}
}
Err(err) => {
eprintln!("{err}");
exit(1)
}
}
}
"version" => {
println!("Version: 0.0.1");
}
Expand All @@ -60,6 +83,8 @@ fn usage(program_path: &String) {
println!("Command:");
println!(" version Display compiler version information.");
println!(" lex Dump tokens from a file.");
println!(" parse Dump AST from a file.");
println!(" eval Evaluate a file.");
println!(" help Show this message.");
}

Expand Down
27 changes: 6 additions & 21 deletions chs_ast/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use chs_lexer::{Lexer, Token, TokenKind};
use chs_util::{chs_error, CHSError};
use nodes::{Call, Expression, FnDecl, Module, Var, VarDecl};
use types::{generalize, infer, unify, CHSType, CHSBOOL, CHSCHAR, CHSINT, CHSSTRING, CHSVOID};
use types::{CHSType, CHSBOOL, CHSCHAR, CHSINT, CHSSTRING, CHSVOID};

pub mod nodes;
pub mod types;
Expand Down Expand Up @@ -54,12 +54,6 @@ impl Parser {

pub fn parse(mut self) -> Result<Module, CHSError> {
use chs_lexer::TokenKind::*;
let a = CHSType::Const("int".to_string());
self.module.env.insert(
"add".to_string(),
CHSType::Arrow(vec![a.clone(), a.clone()], a.into()),
);

loop {
let token = self.next();
if token.kind.is_eof() {
Expand All @@ -77,28 +71,21 @@ impl Parser {
// TODO: MAKE THE TYPE INFER AFTER PARSING EVERYTHING
fn parse_top_expression(&mut self, token: Token) -> Result<(), CHSError> {
use chs_lexer::TokenKind::*;
self.module.id.reset_id();
match token.kind {
Word if self.peek().kind == Colon => {
self.next();
let (value, ttype) = if let Some(ttype) = self.parse_type()? {
let ttype = self.parse_type()?;
if ttype.is_some() {
self.expect_kind(Assign)?;
let value = self.parse_expression()?;
let ty = infer(&mut self.module, &value, 1)?;
(value, unify(ttype, generalize(ty, 1))?)
} else {
let value = self.parse_expression()?;
let ty = infer(&mut self.module, &value, 1)?;
(value, generalize(ty, 1))
};
}
let value = self.parse_expression()?;
let name = token.value;
let expr = Expression::VarDecl(Box::new(VarDecl {
loc: token.loc,
name,
ttype,
value,
}));
infer(&mut self.module, &expr, 1)?;
self.module.push(expr);
Ok(())
}
Expand All @@ -108,7 +95,6 @@ impl Parser {
self.expect_kind(ParenOpen)?;
let (args, ret_type) = self.parse_fn_type_list()?;
self.expect_kind(Assign)?;
self.module.env.insert(name.clone(), ret_type.clone());
let body = self.parse_expression()?;
let expr = Expression::FnDecl(Box::new(FnDecl {
loc: token.loc,
Expand All @@ -117,7 +103,6 @@ impl Parser {
ret_type,
body,
}));
infer(&mut self.module, &expr, 1)?;
self.module.push(expr);
Ok(())
}
Expand Down Expand Up @@ -215,7 +200,7 @@ impl Parser {
fn parse_fn_type_list(&mut self) -> Result<(Vec<(String, CHSType)>, CHSType), CHSError> {
use chs_lexer::TokenKind::*;
let mut list = vec![];
let mut ret_type = CHSType::Const("()".to_string());
let mut ret_type = CHSVOID.clone();
loop {
let ptoken = self.peek();
match ptoken.kind {
Expand Down
17 changes: 2 additions & 15 deletions chs_ast/src/nodes.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
use std::collections::HashMap;

use chs_lexer::Token;
use chs_util::{chs_error, CHSError, Loc};

use crate::types::{CHSType, CHSTypeId};
use crate::types::CHSType;

#[derive(Debug, Default)]
pub struct Module {
pub top_level: Vec<Expression>,
pub env: HashMap<String, CHSType>,
pub id: CHSTypeId,
}

impl Module {
pub fn with_env(env: HashMap<String, CHSType>) -> Self {
Self {
env,
..Default::default()
}
}

pub fn push(&mut self, expr: Expression) {
self.top_level.push(expr);
}

pub fn set_env(&mut self, env: HashMap<String, CHSType>) {
self.env = env;
}
}

pub type VarId = usize;
Expand Down Expand Up @@ -103,7 +90,7 @@ pub struct VarDecl {
pub loc: Loc,
pub name: String,
pub value: Expression,
pub ttype: CHSType,
pub ttype: Option<CHSType>,
}

#[derive(Debug)]
Expand Down
115 changes: 27 additions & 88 deletions chs_ast/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{collections::HashMap, sync::LazyLock};

use chs_util::{chs_error, CHSError};

use crate::nodes::{Expression, Literal, Module, Var};
use crate::nodes::{Expression, Literal, Var};

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct CHSTypeId(usize);
Expand Down Expand Up @@ -73,7 +73,7 @@ impl CHSType {
}
}

pub fn generalize(ty: CHSType, level: CHSTypeLevel) -> CHSType {
fn generalize(ty: CHSType, level: CHSTypeLevel) -> CHSType {
match ty {
CHSType::Var(CHSTypeVar::Unbound(mut id, other_level)) if other_level > level => {
CHSType::new_gen_var(&mut id)
Expand Down Expand Up @@ -112,8 +112,8 @@ let rec generalize level = function
| TVar {contents = Generic _} | TVar {contents = Unbound _} | TConst _ as ty -> ty
*/

pub fn instantiate(
m: &mut Module,
fn instantiate(
m: &mut InferEnv,
id_var: &mut HashMap<usize, CHSType>,
ty: CHSType,
level: CHSTypeLevel,
Expand Down Expand Up @@ -172,8 +172,8 @@ let instantiate level ty =
f ty
*/

pub fn match_fun_ty(
m: &mut Module,
fn match_fun_ty(
m: &mut InferEnv,
num_params: usize,
ty: &mut CHSType,
) -> Result<(Vec<CHSType>, CHSType), CHSError> {
Expand Down Expand Up @@ -223,13 +223,12 @@ let rec match_fun_ty num_params = function
| _ -> error "expected a function"
*/

pub fn unify(ty1: CHSType, ty2: CHSType) -> Result<CHSType, CHSError> {
fn unify(ty1: CHSType, ty2: CHSType) -> Result<CHSType, CHSError> {
use CHSType::*;
dbg!(&ty1, &ty2);
if ty1 == ty2 {
return Ok(ty1);
}
let res = match (ty1, ty2) {
match (ty1, ty2) {
(Const(n1), Const(n2)) if n1 == n2 => Ok(CHSType::Const(n1)),
(App(ty1, ty_arg_list1), App(ty2, ty_arg_list2)) => {
let ty1 = unify(*ty1, *ty2)?;
Expand Down Expand Up @@ -265,9 +264,7 @@ pub fn unify(ty1: CHSType, ty2: CHSType) -> Result<CHSType, CHSError> {
Ok(Var(CHSTypeVar::Link(ty.clone().into())))
}
(ty1, ty2) => chs_error!("cannot unify types {:?} and {:?}", ty1, ty2),
};
dbg!(&res);
res
}
}

/*
Expand All @@ -291,7 +288,7 @@ let rec unify ty1 ty2 =
| _, _ -> error ("cannot unify types " ^ string_of_ty ty1 ^ " and " ^ string_of_ty ty2)
*/

pub fn occurs_check_adjust_levels(
fn occurs_check_adjust_levels(
id: &mut CHSTypeId,
level: &mut CHSTypeLevel,
ty: &mut CHSType,
Expand Down Expand Up @@ -349,7 +346,13 @@ let occurs_check_adjust_levels tvar_id tvar_level ty =
f ty
*/

pub fn infer(m: &mut Module, expr: &Expression, level: CHSTypeLevel) -> Result<CHSType, CHSError> {
#[derive(Debug, Default, Clone)]
pub struct InferEnv {
pub env: HashMap<String, CHSType>,
pub id: CHSTypeId,
}

pub fn infer(m: &mut InferEnv, expr: &Expression, level: CHSTypeLevel) -> Result<CHSType, CHSError> {
match expr {
Expression::Literal(literal) => match literal {
Literal::IntegerLiteral { .. } => return Ok(CHSINT.clone()),
Expand All @@ -363,21 +366,18 @@ pub fn infer(m: &mut Module, expr: &Expression, level: CHSTypeLevel) -> Result<C
let var_ty = infer(m, &v.value, level + 1)?;
let generalized_ty = generalize(var_ty, level);
m.env.insert(v.name.clone(), generalized_ty);
return Ok(CHSType::Const("()".into()));
return Ok(CHSVOID.clone());
}
Expression::FnDecl(fd) => {
let prev_env = m.env.clone();
m.env.extend(fd.args.clone());
let ret_type = unify(fd.ret_type.clone(), infer(m, &fd.body, level)?)?;
m.env = prev_env;
m.env.insert(
fd.name.clone(),
CHSType::Arrow(
fd.args.clone().into_iter().map(|(_, t)| t).collect(),
ret_type.into(),
),
);
return Ok(CHSType::Const("()".into()));
let mut fn_env = m.clone();
fn_env.env.extend(fd.args.clone());
fn_env.env.insert(fd.name.clone(), CHSType::Arrow(
fd.args.clone().into_iter().map(|(_, t)| t).collect(),
fd.ret_type.clone().into(),
));
unify(fd.ret_type.clone(), infer(&mut fn_env, &fd.body, level)?)?;
m.env.insert(fd.name.clone(), fn_env.env.get(&fd.name).cloned().unwrap());
return Ok(CHSVOID.clone());
}
Expression::Var(Var { name, loc: _ }) => {
if let Some(ty) = m.env.get(name) {
Expand Down Expand Up @@ -421,64 +421,3 @@ pub fn infer(m: &mut Module, expr: &Expression, level: CHSTypeLevel) -> Result<C
}
}
}

/*
| Var name -> begin
try
instantiate level (Env.lookup env name)
with Not_found -> error ("variable " ^ name ^ " not found")
end
| Fun(param_list, body_expr) ->
let param_ty_list = List.map (fun _ -> new_var level) param_list in
let fn_env = List.fold_left2
(fun env param_name param_ty -> Env.extend env param_name param_ty)
env param_list param_ty_list
in
let return_ty = infer fn_env level body_expr in
TArrow(param_ty_list, return_ty)
| Let(var_name, value_expr, body_expr) ->
let var_ty = infer env (level + 1) value_expr in
let generalized_ty = generalize level var_ty in
infer (Env.extend env var_name generalized_ty) level body_expr
| Call(fn_expr, arg_list) ->
let param_ty_list, return_ty =
match_fun_ty (List.length arg_list) (infer env level fn_expr)
in
List.iter2
(fun param_ty arg_expr -> unify param_ty (infer env level arg_expr))
param_ty_list arg_list
;
return_ty
*/

#[cfg(test)]
mod tests {
use chs_util::Loc;

use crate::nodes::VarDecl;

use super::*;

#[test]
fn test_name() {
let mut m = Module::default();
m.push(Expression::VarDecl(Box::new(VarDecl {
name: "x".into(),
value: Expression::Literal(Literal::IntegerLiteral {
loc: Loc::default(),
value: 10,
}),
loc: Loc::default(),
ttype: CHSType::Const("int".into()),
})));
let res = infer(
&mut m,
&Expression::Literal(Literal::IntegerLiteral {
loc: Loc::default(),
value: 10,
}),
1,
);
assert!(res.is_ok())
}
}
Loading

0 comments on commit afda26c

Please sign in to comment.