feat: functions

This commit is contained in:
Roman Godmaire 2024-02-16 10:03:22 -05:00
parent f5ac02aedb
commit 49e99b6fa4
3 changed files with 91 additions and 11 deletions

View file

@ -4,6 +4,7 @@ use super::{eval_ast_node, Error, Expression};
use crate::parser::Node; use crate::parser::Node;
pub type RawEnvironment = HashMap<String, Rc<Expression>>; pub type RawEnvironment = HashMap<String, Rc<Expression>>;
#[derive(Debug, PartialEq, Clone)]
pub struct Environment { pub struct Environment {
current: RefCell<RawEnvironment>, current: RefCell<RawEnvironment>,
outer: Option<Rc<Environment>>, outer: Option<Rc<Environment>>,
@ -18,6 +19,13 @@ impl Environment {
self.outer.clone()?.get(ident) self.outer.clone()?.get(ident)
} }
pub fn wrap(&self, records: Vec<(String, Rc<Expression>)>) -> Environment {
Environment {
current: RefCell::new(records.into_iter().collect()),
outer: Some(Rc::new(self.clone())),
}
}
fn set(&self, ident: String, val: Rc<Expression>) { fn set(&self, ident: String, val: Rc<Expression>) {
self.current.borrow_mut().insert(ident, val); self.current.borrow_mut().insert(ident, val);
} }
@ -27,7 +35,7 @@ pub fn core_environment() -> Rc<Environment> {
let env = [ let env = [
// Arithmetic operations // Arithmetic operations
( (
"+".to_string(), "+",
Expression::NativeFunc(|args| { Expression::NativeFunc(|args| {
let res = args let res = args
.into_iter() .into_iter()
@ -43,7 +51,7 @@ pub fn core_environment() -> Rc<Environment> {
}), }),
), ),
( (
"-".to_string(), "-",
Expression::NativeFunc(|args| { Expression::NativeFunc(|args| {
let res = args let res = args
.into_iter() .into_iter()
@ -59,7 +67,7 @@ pub fn core_environment() -> Rc<Environment> {
}), }),
), ),
( (
"*".to_string(), "*",
Expression::NativeFunc(|args| { Expression::NativeFunc(|args| {
let res = args let res = args
.into_iter() .into_iter()
@ -75,7 +83,7 @@ pub fn core_environment() -> Rc<Environment> {
}), }),
), ),
( (
"/".to_string(), "/",
Expression::NativeFunc(|args| { Expression::NativeFunc(|args| {
let res = args let res = args
.into_iter() .into_iter()
@ -92,11 +100,11 @@ pub fn core_environment() -> Rc<Environment> {
), ),
// Collections // Collections
( (
"vector".to_string(), "vector",
Expression::NativeFunc(|args| Ok(Rc::new(Expression::Vector(args)))), Expression::NativeFunc(|args| Ok(Rc::new(Expression::Vector(args)))),
), ),
( (
"hashmap".to_string(), "hashmap",
Expression::NativeFunc(|args| { Expression::NativeFunc(|args| {
if args.len() % 2 != 0 { if args.len() % 2 != 0 {
Err(Error::MismatchedArgCount)? Err(Error::MismatchedArgCount)?
@ -121,7 +129,7 @@ pub fn core_environment() -> Rc<Environment> {
), ),
// Branching // Branching
( (
"if".to_string(), "if",
Expression::NativeFunc(|args| { Expression::NativeFunc(|args| {
if args.len() != 3 { if args.len() != 3 {
Err(Error::MismatchedArgCount)? Err(Error::MismatchedArgCount)?
@ -140,7 +148,7 @@ pub fn core_environment() -> Rc<Environment> {
), ),
// Environment Manipulation // Environment Manipulation
( (
"define!".to_string(), "define!",
Expression::Special(|env, args| { Expression::Special(|env, args| {
let mut args = args.into_iter(); let mut args = args.into_iter();
if args.len() != 2 { if args.len() != 2 {
@ -160,7 +168,7 @@ pub fn core_environment() -> Rc<Environment> {
}), }),
), ),
( (
"let*".to_string(), "let*",
Expression::Special(|env, args| { Expression::Special(|env, args| {
if args.len() != 2 { if args.len() != 2 {
Err(Error::MismatchedArgCount)? Err(Error::MismatchedArgCount)?
@ -188,9 +196,38 @@ pub fn core_environment() -> Rc<Environment> {
eval_ast_node(Rc::new(new_env), args.next().unwrap()) eval_ast_node(Rc::new(new_env), args.next().unwrap())
}), }),
), ),
(
"fn*",
Expression::Special(|env, args| {
let mut args = args;
if args.len() != 2 {
Err(Error::MismatchedArgCount)?
}
let arg_list = args.remove(0);
let body = args.remove(0);
let args: Vec<String> = match arg_list {
Node::List(args) => args
.into_iter()
.map(|v| match v {
Node::Symbol(sym) => Ok(sym),
_ => Err(Error::ExpectedSymbol),
})
.collect(),
_ => Err(Error::ExpectedList),
}?;
Ok(Rc::new(Expression::Function {
params: args,
env: (*env).clone(),
body,
}))
}),
),
] ]
.into_iter() .into_iter()
.map(|(k, v)| (k, Rc::new(v))); .map(|(k, v)| (k.to_string(), Rc::new(v)));
Environment { Environment {
current: RefCell::new(HashMap::from_iter(env)), current: RefCell::new(HashMap::from_iter(env)),

View file

@ -27,6 +27,11 @@ pub enum Expression {
Vector(Vec<Rc<Expression>>), Vector(Vec<Rc<Expression>>),
HashMap(HashMap<String, Rc<Expression>>), HashMap(HashMap<String, Rc<Expression>>),
Function {
params: Vec<String>,
env: Environment,
body: Node,
},
NativeFunc(fn(args: Vec<Rc<Expression>>) -> Result<Rc<Expression>>), NativeFunc(fn(args: Vec<Rc<Expression>>) -> Result<Rc<Expression>>),
Special(fn(env: Rc<Environment>, args: Vec<Node>) -> Result<Rc<Expression>>), Special(fn(env: Rc<Environment>, args: Vec<Node>) -> Result<Rc<Expression>>),
} }
@ -60,6 +65,14 @@ impl std::fmt::Display for Expression {
write!(f, "{{{res}}}") write!(f, "{{{res}}}")
} }
Expression::Function {
params: args,
env: _,
body: _,
} => {
write!(f, "(fn* ({}) /body/)", args.join(" "))
}
Expression::NativeFunc(func) => write!(f, "{func:?}"), Expression::NativeFunc(func) => write!(f, "{func:?}"),
Expression::Special(func) => write!(f, "{func:?}"), Expression::Special(func) => write!(f, "{func:?}"),
} }
@ -111,10 +124,36 @@ fn eval_ast_node(env: Rc<Environment>, ast_node: Node) -> Result<Rc<Expression>>
func(args)? func(args)?
} }
Expression::Special(func) => { Expression::Special(func) => {
let args = list.collect(); let args = list.collect();
func(env, args)? func(env, args)?
} }
Expression::Function {
params,
env: inner_env,
body,
} => {
let args = list;
if args.len() != params.len() {
Err(Error::MismatchedArgCount)?
}
let args = args
.map(|node| eval_ast_node(env.clone(), node))
.collect::<Result<Vec<Rc<Expression>>>>()?;
let records = params
.into_iter()
.map(|k| k.to_owned())
.zip(args.into_iter())
.collect();
let env = inner_env.wrap(records).into();
eval_ast_node(env, (*body).clone())?
}
_ => Err(Error::InvalidOperator)?, _ => Err(Error::InvalidOperator)?,
} }
} }
@ -176,6 +215,10 @@ mod test {
#[case("(if false true false)", "false")] #[case("(if false true false)", "false")]
#[case("(if 4 true false)", "false")] #[case("(if 4 true false)", "false")]
#[case("(if \"blue\" true false)", "false")] #[case("(if \"blue\" true false)", "false")]
// Functions
#[case("((fn* (a) a) 3)", "3")]
#[case("((fn* (a) (+ a 2)) 1)", "3")]
#[case("((fn* (a b) (+ a b)) 1 2)", "3")]
fn test_evaluator(#[case] input: &str, #[case] expected: &str) { fn test_evaluator(#[case] input: &str, #[case] expected: &str) {
let env = core_environment(); let env = core_environment();
let tokens = lexer::read(input).unwrap(); let tokens = lexer::read(input).unwrap();

View file

@ -4,7 +4,7 @@ use anyhow::{bail, Result};
use crate::lexer::Token; use crate::lexer::Token;
#[derive(Debug, PartialEq, PartialOrd)] #[derive(Debug, PartialEq, PartialOrd, Clone)]
pub enum Node { pub enum Node {
List(Vec<Node>), List(Vec<Node>),