feat: functions

This commit is contained in:
Roman Godmaire 2024-02-16 10:03:22 -05:00
parent f5ac02aedb
commit 49e99b6fa4
3 changed files with 91 additions and 11 deletions

View file

@ -4,6 +4,7 @@ use super::{eval_ast_node, Error, Expression};
use crate::parser::Node;
pub type RawEnvironment = HashMap<String, Rc<Expression>>;
#[derive(Debug, PartialEq, Clone)]
pub struct Environment {
current: RefCell<RawEnvironment>,
outer: Option<Rc<Environment>>,
@ -18,6 +19,13 @@ impl Environment {
self.outer.clone()?.get(ident)
}
pub fn wrap(&self, records: Vec<(String, Rc<Expression>)>) -> Environment {
Environment {
current: RefCell::new(records.into_iter().collect()),
outer: Some(Rc::new(self.clone())),
}
}
fn set(&self, ident: String, val: Rc<Expression>) {
self.current.borrow_mut().insert(ident, val);
}
@ -27,7 +35,7 @@ pub fn core_environment() -> Rc<Environment> {
let env = [
// Arithmetic operations
(
"+".to_string(),
"+",
Expression::NativeFunc(|args| {
let res = args
.into_iter()
@ -43,7 +51,7 @@ pub fn core_environment() -> Rc<Environment> {
}),
),
(
"-".to_string(),
"-",
Expression::NativeFunc(|args| {
let res = args
.into_iter()
@ -59,7 +67,7 @@ pub fn core_environment() -> Rc<Environment> {
}),
),
(
"*".to_string(),
"*",
Expression::NativeFunc(|args| {
let res = args
.into_iter()
@ -75,7 +83,7 @@ pub fn core_environment() -> Rc<Environment> {
}),
),
(
"/".to_string(),
"/",
Expression::NativeFunc(|args| {
let res = args
.into_iter()
@ -92,11 +100,11 @@ pub fn core_environment() -> Rc<Environment> {
),
// Collections
(
"vector".to_string(),
"vector",
Expression::NativeFunc(|args| Ok(Rc::new(Expression::Vector(args)))),
),
(
"hashmap".to_string(),
"hashmap",
Expression::NativeFunc(|args| {
if args.len() % 2 != 0 {
Err(Error::MismatchedArgCount)?
@ -121,7 +129,7 @@ pub fn core_environment() -> Rc<Environment> {
),
// Branching
(
"if".to_string(),
"if",
Expression::NativeFunc(|args| {
if args.len() != 3 {
Err(Error::MismatchedArgCount)?
@ -140,7 +148,7 @@ pub fn core_environment() -> Rc<Environment> {
),
// Environment Manipulation
(
"define!".to_string(),
"define!",
Expression::Special(|env, args| {
let mut args = args.into_iter();
if args.len() != 2 {
@ -160,7 +168,7 @@ pub fn core_environment() -> Rc<Environment> {
}),
),
(
"let*".to_string(),
"let*",
Expression::Special(|env, args| {
if args.len() != 2 {
Err(Error::MismatchedArgCount)?
@ -188,9 +196,38 @@ pub fn core_environment() -> Rc<Environment> {
eval_ast_node(Rc::new(new_env), args.next().unwrap())
}),
),
(
"fn*",
Expression::Special(|env, args| {
let mut args = args;
if args.len() != 2 {
Err(Error::MismatchedArgCount)?
}
let arg_list = args.remove(0);
let body = args.remove(0);
let args: Vec<String> = match arg_list {
Node::List(args) => args
.into_iter()
.map(|v| match v {
Node::Symbol(sym) => Ok(sym),
_ => Err(Error::ExpectedSymbol),
})
.collect(),
_ => Err(Error::ExpectedList),
}?;
Ok(Rc::new(Expression::Function {
params: args,
env: (*env).clone(),
body,
}))
}),
),
]
.into_iter()
.map(|(k, v)| (k, Rc::new(v)));
.map(|(k, v)| (k.to_string(), Rc::new(v)));
Environment {
current: RefCell::new(HashMap::from_iter(env)),

View file

@ -27,6 +27,11 @@ pub enum Expression {
Vector(Vec<Rc<Expression>>),
HashMap(HashMap<String, Rc<Expression>>),
Function {
params: Vec<String>,
env: Environment,
body: Node,
},
NativeFunc(fn(args: Vec<Rc<Expression>>) -> Result<Rc<Expression>>),
Special(fn(env: Rc<Environment>, args: Vec<Node>) -> Result<Rc<Expression>>),
}
@ -60,6 +65,14 @@ impl std::fmt::Display for Expression {
write!(f, "{{{res}}}")
}
Expression::Function {
params: args,
env: _,
body: _,
} => {
write!(f, "(fn* ({}) /body/)", args.join(" "))
}
Expression::NativeFunc(func) => write!(f, "{func:?}"),
Expression::Special(func) => write!(f, "{func:?}"),
}
@ -111,10 +124,36 @@ fn eval_ast_node(env: Rc<Environment>, ast_node: Node) -> Result<Rc<Expression>>
func(args)?
}
Expression::Special(func) => {
let args = list.collect();
func(env, args)?
}
Expression::Function {
params,
env: inner_env,
body,
} => {
let args = list;
if args.len() != params.len() {
Err(Error::MismatchedArgCount)?
}
let args = args
.map(|node| eval_ast_node(env.clone(), node))
.collect::<Result<Vec<Rc<Expression>>>>()?;
let records = params
.into_iter()
.map(|k| k.to_owned())
.zip(args.into_iter())
.collect();
let env = inner_env.wrap(records).into();
eval_ast_node(env, (*body).clone())?
}
_ => Err(Error::InvalidOperator)?,
}
}
@ -176,6 +215,10 @@ mod test {
#[case("(if false true false)", "false")]
#[case("(if 4 true false)", "false")]
#[case("(if \"blue\" true false)", "false")]
// Functions
#[case("((fn* (a) a) 3)", "3")]
#[case("((fn* (a) (+ a 2)) 1)", "3")]
#[case("((fn* (a b) (+ a b)) 1 2)", "3")]
fn test_evaluator(#[case] input: &str, #[case] expected: &str) {
let env = core_environment();
let tokens = lexer::read(input).unwrap();

View file

@ -4,7 +4,7 @@ use anyhow::{bail, Result};
use crate::lexer::Token;
#[derive(Debug, PartialEq, PartialOrd)]
#[derive(Debug, PartialEq, PartialOrd, Clone)]
pub enum Node {
List(Vec<Node>),