diff --git a/src/env.rs b/src/env.rs deleted file mode 100644 index 53ae9d1..0000000 --- a/src/env.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::{borrow::Borrow, collections::HashMap, rc::Rc}; - -use crate::eval::{Error, Expression}; - -pub type Environment = HashMap>; - -pub fn core_environment() -> Environment { - [ - // Arithmetic operations - ( - "+".to_string(), - Expression::NativeFunc { - func: |args| { - let res = args - .into_iter() - .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { - (Expression::Int(lhs), Expression::Int(rhs)) => { - Expression::Int(lhs + rhs).into() - } - _ => todo!(), - }) - .unwrap_or(Rc::new(Expression::Int(0))); - - Ok(res) - }, - } - .into(), - ), - ( - "-".to_string(), - Expression::NativeFunc { - func: |args| { - let res = args - .into_iter() - .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { - (Expression::Int(lhs), Expression::Int(rhs)) => { - Expression::Int(lhs - rhs).into() - } - _ => todo!(), - }) - .unwrap_or(Rc::new(Expression::Int(0))); - - Ok(res) - }, - } - .into(), - ), - ( - "*".to_string(), - Expression::NativeFunc { - func: |args| { - let res = args - .into_iter() - .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { - (Expression::Int(lhs), Expression::Int(rhs)) => { - Expression::Int(lhs * rhs).into() - } - _ => todo!(), - }) - .unwrap_or(Rc::new(Expression::Int(0))); - - Ok(res) - }, - } - .into(), - ), - ( - "/".to_string(), - Expression::NativeFunc { - func: |args| { - let res = args - .into_iter() - .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { - (Expression::Int(lhs), Expression::Int(rhs)) => { - Expression::Int(lhs / rhs).into() - } - _ => todo!(), - }) - .unwrap_or(Rc::new(Expression::Int(0))); - - Ok(res) - }, - } - .into(), - ), - // Collections - ( - "vector".to_string(), - Expression::NativeFunc { - func: |args| Ok(Rc::new(Expression::Vector(args))), - } - .into(), - ), - ( - "hashmap".to_string(), - Expression::NativeFunc { - func: |args| { - if args.len() % 2 != 0 { - Err(Error::MismatchedArgCount)? - } - - let mut index = -1; - let (keys, values): (Vec<_>, Vec<_>) = args.into_iter().partition(|_| { - index += 1; - index % 2 == 0 - }); - - let res = keys - .into_iter() - // We turn the keys into strings because they're hashable - // This feels so hacky, but ¯\_(ツ)_/¯ - .map(|key| key.to_string()) - .zip(values.into_iter()) - .collect(); - - Ok(Rc::new(Expression::HashMap(res))) - }, - } - .into(), - ), - ] - .into() -} diff --git a/src/evaluator/env.rs b/src/evaluator/env.rs new file mode 100644 index 0000000..8e09a63 --- /dev/null +++ b/src/evaluator/env.rs @@ -0,0 +1,186 @@ +use std::{borrow::Borrow, cell::RefCell, collections::HashMap, rc::Rc}; + +use super::{eval_ast_node, Error, Expression}; +use crate::parser::Node; + +pub type RawEnvironment = HashMap>; +pub struct Environment { + current: RefCell, + outer: Option>, +} + +impl Environment { + pub fn get(&self, ident: &str) -> Option> { + if let Some(val) = self.current.borrow().get(ident) { + return Some(val.clone()); + } + + self.outer.clone()?.get(ident) + } + + fn set(&self, ident: String, val: Rc) { + self.current.borrow_mut().insert(ident, val); + } +} + +pub fn core_environment() -> Rc { + let env = [ + // Arithmetic operations + ( + "+".to_string(), + Expression::NativeFunc(|args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs + rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }) + .into(), + ), + ( + "-".to_string(), + Expression::NativeFunc(|args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs - rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }) + .into(), + ), + ( + "*".to_string(), + Expression::NativeFunc(|args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs * rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }) + .into(), + ), + ( + "/".to_string(), + Expression::NativeFunc(|args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs / rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }) + .into(), + ), + // Collections + ( + "vector".to_string(), + Expression::NativeFunc(|args| Ok(Rc::new(Expression::Vector(args)))).into(), + ), + ( + "hashmap".to_string(), + Expression::NativeFunc(|args| { + if args.len() % 2 != 0 { + Err(Error::MismatchedArgCount)? + } + + let mut index = -1; + let (keys, values): (Vec<_>, Vec<_>) = args.into_iter().partition(|_| { + index += 1; + index % 2 == 0 + }); + + let res = keys + .into_iter() + // We turn the keys into strings because they're hashable + // This feels so hacky, but ¯\_(ツ)_/¯ + .map(|key| key.to_string()) + .zip(values.into_iter()) + .collect(); + + Ok(Rc::new(Expression::HashMap(res))) + }) + .into(), + ), + // Environment Manipulation + ( + "define!".to_string(), + Expression::Special(|env, args| { + let mut args = args.into_iter(); + if args.len() != 2 { + Err(Error::MismatchedArgCount)? + } + + let key = match args.next().unwrap() { + Node::Symbol(name) => name, + _ => Err(Error::ExpectedSymbol)?, + }; + + let val = eval_ast_node(env.clone(), args.next().unwrap())?; + + env.set(key, val.clone()); + + Ok(val) + }) + .into(), + ), + ( + "let*".to_string(), + Expression::Special(|env, args| { + if args.len() != 2 { + Err(Error::MismatchedArgCount)? + } + + let mut args = args.into_iter(); + let new_env = match args.next().unwrap() { + Node::List(list) if list.len() == 2 => { + let mut list = list; + let sym = match list.remove(0) { + Node::Symbol(name) => name, + _ => Err(Error::ExpectedSymbol)?, + }; + + let val = eval_ast_node(env.clone(), list.remove(0))?; + Environment { + current: RefCell::new([(sym.clone(), val)].into()), + outer: Some(env.clone()), + } + } + Node::List(_) => Err(Error::MismatchedArgCount)?, + _ => Err(Error::ExpectedList)?, + }; + + eval_ast_node(Rc::new(new_env), args.next().unwrap()) + }) + .into(), + ), + ]; + + Environment { + current: RefCell::new(env.into()), + outer: None, + } + .into() +} diff --git a/src/eval.rs b/src/evaluator/mod.rs similarity index 72% rename from src/eval.rs rename to src/evaluator/mod.rs index 04a53b7..603b724 100644 --- a/src/eval.rs +++ b/src/evaluator/mod.rs @@ -3,7 +3,10 @@ use std::{borrow::Borrow, collections::HashMap, rc::Rc}; use anyhow::Result; use thiserror::Error; -use crate::{env::Environment, parser::Node}; +use crate::parser::Node; + +mod env; +pub use env::{core_environment, Environment}; thread_local! { static TRUE: Rc = Rc::new(Expression::Boolean(true)); @@ -24,9 +27,8 @@ pub enum Expression { Vector(Vec>), HashMap(HashMap>), - NativeFunc { - func: fn(args: Vec>) -> Result>, - }, + NativeFunc(fn(args: Vec>) -> Result>), + Special(fn(env: Rc, args: Vec) -> Result>), } impl std::fmt::Display for Expression { @@ -58,7 +60,8 @@ impl std::fmt::Display for Expression { write!(f, "{{{res}}}") } - Expression::NativeFunc { .. } => write!(f, "function"), + Expression::NativeFunc(func) => write!(f, "{func:?}"), + Expression::Special(func) => write!(f, "{func:?}"), } } } @@ -69,38 +72,45 @@ pub enum Error { NotInEnv, #[error("expression does not have a valid operator")] InvalidOperator, + #[error("expected symbol")] + ExpectedSymbol, + #[error("expected list")] + ExpectedList, #[error("incorrect number of arguments passed to function")] MismatchedArgCount, } -pub fn eval(env: &Environment, ast: Vec) -> Result>> { +pub fn eval(env: Rc, ast: Vec) -> Result>> { let mut res = Vec::new(); for node in ast { - res.push(eval_ast_node(env, node)?); + res.push(eval_ast_node(env.clone(), node)?); } Ok(res) } -fn eval_ast_node(env: &Environment, ast_node: Node) -> Result> { +fn eval_ast_node(env: Rc, ast_node: Node) -> Result> { let expr = match ast_node { Node::List(list) => { - let mut res = Vec::new(); - for node in list { - res.push(eval_ast_node(env, node)?); + let mut list = list.into_iter(); + let operator = eval_ast_node(env.clone(), list.next().ok_or(Error::InvalidOperator)?)?; + + match operator.borrow() { + Expression::NativeFunc(func) => { + let mut args = Vec::new(); + for node in list { + args.push(eval_ast_node(env.clone(), node)?); + } + + func(args)? + } + Expression::Special(func) => { + let args = list.collect(); + func(env, args)? + } + _ => Err(Error::InvalidOperator)?, } - - let mut list = res.into_iter(); - let operator = match list.next() { - Some(rc) => match rc.borrow() { - Expression::NativeFunc { func } => *func, - _ => Err(Error::InvalidOperator)?, - }, - None => Err(Error::InvalidOperator)?, - }; - - operator(list.collect())? } Node::Symbol(sym) => env.get(&sym).ok_or(Error::NotInEnv)?.to_owned(), @@ -118,7 +128,6 @@ fn eval_ast_node(env: &Environment, ast_node: Node) -> Result> { #[cfg(test)] mod test { - use crate::core_environment; use crate::lexer; use crate::parser; @@ -149,8 +158,11 @@ mod test { #[case("[1 (+ 1 2)]", "[1 3]")] #[case("{}", "{}")] #[case("{:a \"uwu\"}", "{:a: uwu}")] + // Environment manipulation + #[case("(define! asdf (+ 2 2)) (+ asdf 2)", "4\n6")] + #[case("(let* (a 2) (+ a 2))", "4")] fn test_evaluator(#[case] input: &str, #[case] expected: &str) { - let env = &core_environment(); + let env = core_environment(); let tokens = lexer::read(input).unwrap(); let ast = parser::parse(tokens).unwrap(); let res = eval(env, ast) @@ -167,7 +179,7 @@ mod test { #[case("{:a}")] #[case("(not-a-func :uwu)")] fn test_evaluator_fail(#[case] input: &str) { - let env = &core_environment(); + let env = core_environment(); let tokens = lexer::read(input).unwrap(); let ast = parser::parse(tokens).unwrap(); let res = eval(env, ast); diff --git a/src/lexer.rs b/src/lexer.rs index 34cf0e8..01634cf 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -108,7 +108,7 @@ fn next_token(input: &mut Peekable) -> Result> { Some(tok) => tok, None => return Ok(None), }, - _ => bail!("ilegal token"), + _ => bail!("illegal token"), }; Ok(Some(tok)) @@ -164,10 +164,11 @@ fn read_int(input: &mut Peekable, first: char) -> Token { } fn read_ident(input: &mut Peekable, first: char) -> Token { + let special_characters = ['!', '?', '*', '-', '_']; let mut raw_ident = vec![first]; while let Some(c) = input.peek() { - if !c.is_ascii_alphanumeric() { + if !c.is_ascii_alphanumeric() && !special_characters.contains(c) { break; } @@ -198,6 +199,9 @@ mod test { #[case("(/ 1 2)", vec![Token::LeftParen, Token::Slash, Token::Int(1), Token::Int(2), Token::RightParen])] #[case("(- -2 1)", vec![Token::LeftParen, Token::Minus, Token::Int(-2), Token::Int(1), Token::RightParen])] #[case("(\"string and stuff\")", vec![Token::LeftParen, Token::String("string and stuff".into()), Token::RightParen])] + #[case("define!", vec![Token::Ident("define!".into())])] + #[case("let*", vec![Token::Ident("let*".into())])] + #[case("is-int?", vec![Token::Ident("is-int?".into())])] #[case( "(func a b)", vec![ diff --git a/src/main.rs b/src/main.rs index e45e2df..8ad44a7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,11 @@ use std::io::{self, Write}; -use crate::env::core_environment; - -mod env; -mod eval; +mod evaluator; mod lexer; mod parser; fn main() { - let env = core_environment(); + let env = evaluator::core_environment(); let mut input = String::new(); println!("MAL -- REPL"); @@ -27,7 +24,7 @@ fn main() { let tokens = lexer::read(&input).unwrap(); let ast = parser::parse(tokens).unwrap(); - let res = eval::eval(&env, ast).unwrap(); + let res = evaluator::eval(env.clone(), ast).unwrap(); for expr in res { println!("{expr}")