diff --git a/src/evaluator/env.rs b/src/evaluator/env.rs index 92f4ca7..e5139ba 100644 --- a/src/evaluator/env.rs +++ b/src/evaluator/env.rs @@ -4,6 +4,7 @@ use super::{eval_ast_node, Error, Expression}; use crate::parser::Node; pub type RawEnvironment = HashMap>; +#[derive(Debug, PartialEq, Clone)] pub struct Environment { current: RefCell, outer: Option>, @@ -18,6 +19,13 @@ impl Environment { self.outer.clone()?.get(ident) } + pub fn wrap(&self, records: Vec<(String, Rc)>) -> Environment { + Environment { + current: RefCell::new(records.into_iter().collect()), + outer: Some(Rc::new(self.clone())), + } + } + fn set(&self, ident: String, val: Rc) { self.current.borrow_mut().insert(ident, val); } @@ -27,7 +35,7 @@ pub fn core_environment() -> Rc { let env = [ // Arithmetic operations ( - "+".to_string(), + "+", Expression::NativeFunc(|args| { let res = args .into_iter() @@ -43,7 +51,7 @@ pub fn core_environment() -> Rc { }), ), ( - "-".to_string(), + "-", Expression::NativeFunc(|args| { let res = args .into_iter() @@ -59,7 +67,7 @@ pub fn core_environment() -> Rc { }), ), ( - "*".to_string(), + "*", Expression::NativeFunc(|args| { let res = args .into_iter() @@ -75,7 +83,7 @@ pub fn core_environment() -> Rc { }), ), ( - "/".to_string(), + "/", Expression::NativeFunc(|args| { let res = args .into_iter() @@ -92,11 +100,11 @@ pub fn core_environment() -> Rc { ), // Collections ( - "vector".to_string(), + "vector", Expression::NativeFunc(|args| Ok(Rc::new(Expression::Vector(args)))), ), ( - "hashmap".to_string(), + "hashmap", Expression::NativeFunc(|args| { if args.len() % 2 != 0 { Err(Error::MismatchedArgCount)? @@ -121,7 +129,7 @@ pub fn core_environment() -> Rc { ), // Branching ( - "if".to_string(), + "if", Expression::NativeFunc(|args| { if args.len() != 3 { Err(Error::MismatchedArgCount)? @@ -140,7 +148,7 @@ pub fn core_environment() -> Rc { ), // Environment Manipulation ( - "define!".to_string(), + "define!", Expression::Special(|env, args| { let mut args = args.into_iter(); if args.len() != 2 { @@ -160,7 +168,7 @@ pub fn core_environment() -> Rc { }), ), ( - "let*".to_string(), + "let*", Expression::Special(|env, args| { if args.len() != 2 { Err(Error::MismatchedArgCount)? @@ -188,9 +196,38 @@ pub fn core_environment() -> Rc { eval_ast_node(Rc::new(new_env), args.next().unwrap()) }), ), + ( + "fn*", + Expression::Special(|env, args| { + let mut args = args; + if args.len() != 2 { + Err(Error::MismatchedArgCount)? + } + + let arg_list = args.remove(0); + let body = args.remove(0); + + let args: Vec = match arg_list { + Node::List(args) => args + .into_iter() + .map(|v| match v { + Node::Symbol(sym) => Ok(sym), + _ => Err(Error::ExpectedSymbol), + }) + .collect(), + _ => Err(Error::ExpectedList), + }?; + + Ok(Rc::new(Expression::Function { + params: args, + env: (*env).clone(), + body, + })) + }), + ), ] .into_iter() - .map(|(k, v)| (k, Rc::new(v))); + .map(|(k, v)| (k.to_string(), Rc::new(v))); Environment { current: RefCell::new(HashMap::from_iter(env)), diff --git a/src/evaluator/mod.rs b/src/evaluator/mod.rs index 59f9364..ecd9421 100644 --- a/src/evaluator/mod.rs +++ b/src/evaluator/mod.rs @@ -27,6 +27,11 @@ pub enum Expression { Vector(Vec>), HashMap(HashMap>), + Function { + params: Vec, + env: Environment, + body: Node, + }, NativeFunc(fn(args: Vec>) -> Result>), Special(fn(env: Rc, args: Vec) -> Result>), } @@ -60,6 +65,14 @@ impl std::fmt::Display for Expression { write!(f, "{{{res}}}") } + Expression::Function { + params: args, + env: _, + body: _, + } => { + write!(f, "(fn* ({}) /body/)", args.join(" ")) + } + Expression::NativeFunc(func) => write!(f, "{func:?}"), Expression::Special(func) => write!(f, "{func:?}"), } @@ -111,10 +124,36 @@ fn eval_ast_node(env: Rc, ast_node: Node) -> Result> func(args)? } + Expression::Special(func) => { let args = list.collect(); func(env, args)? } + + Expression::Function { + params, + env: inner_env, + body, + } => { + let args = list; + if args.len() != params.len() { + Err(Error::MismatchedArgCount)? + } + + let args = args + .map(|node| eval_ast_node(env.clone(), node)) + .collect::>>>()?; + + let records = params + .into_iter() + .map(|k| k.to_owned()) + .zip(args.into_iter()) + .collect(); + + let env = inner_env.wrap(records).into(); + + eval_ast_node(env, (*body).clone())? + } _ => Err(Error::InvalidOperator)?, } } @@ -176,6 +215,10 @@ mod test { #[case("(if false true false)", "false")] #[case("(if 4 true false)", "false")] #[case("(if \"blue\" true false)", "false")] + // Functions + #[case("((fn* (a) a) 3)", "3")] + #[case("((fn* (a) (+ a 2)) 1)", "3")] + #[case("((fn* (a b) (+ a b)) 1 2)", "3")] fn test_evaluator(#[case] input: &str, #[case] expected: &str) { let env = core_environment(); let tokens = lexer::read(input).unwrap(); diff --git a/src/parser.rs b/src/parser.rs index e588f0b..2dbfb1a 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Result}; use crate::lexer::Token; -#[derive(Debug, PartialEq, PartialOrd)] +#[derive(Debug, PartialEq, PartialOrd, Clone)] pub enum Node { List(Vec),