feat: reduce + fold

This commit is contained in:
Roman Godmaire 2024-05-14 16:08:21 -04:00
parent a2210f2db3
commit 285b93aa89

View file

@ -285,8 +285,62 @@ pub fn eval_node(env: &Environment, ast_node: Node) -> Result<Node> {
Node::List(res) Node::List(res)
} }
Node::Reduce(mut body) => todo!(), Node::Reduce(mut body) => {
Node::Fold(mut body) => todo!(), if body.len() != 2 {
todo!();
}
let list = body.pop_front().expect("arg count verified above");
let mut list = match eval_node(env, list)? {
Node::List(list) => list,
_ => todo!(),
};
let func = match body.pop_front().expect("arg count verified above") {
node @ Node::Function(_) => node,
node @ Node::Symbol(_) => eval_node(env, node)?,
_ => todo!(),
};
let mut acc = match list.pop_front() {
Some(node) => eval_node(env, node)?,
None => Node::Nil,
};
while let Some(node) = list.pop_front() {
let node = eval_node(env, node)?;
acc = eval_node(env, Node::List(vec![func.clone(), acc, node].into()))?;
}
acc
}
Node::Fold(mut body) => {
if body.len() != 3 {
todo!();
}
let list = body.pop_front().expect("arg count verified above");
let mut list = match eval_node(env, list)? {
Node::List(list) => list,
_ => todo!(),
};
let mut acc = body.pop_front().expect("arg count verified above");
let func = match body.pop_front().expect("arg count verified above") {
node @ Node::Function(_) => node,
node @ Node::Symbol(_) => eval_node(env, node)?,
_ => todo!(),
};
while let Some(node) = list.pop_front() {
let node = eval_node(env, node)?;
acc = eval_node(env, Node::List(vec![func.clone(), acc, node].into()))?;
}
acc
}
Node::Symbol(sym) => env Node::Symbol(sym) => env
.get_node(&sym) .get_node(&sym)
@ -484,6 +538,10 @@ mod test {
// Iteration // Iteration
#[case("(map '(1 2 3) (fn* (x) (* x x)))", "(1 4 9)")] #[case("(map '(1 2 3) (fn* (x) (* x x)))", "(1 4 9)")]
#[case("(filter '(1 2 3) (fn* (x) (> x 1)))", "(2 3)")] #[case("(filter '(1 2 3) (fn* (x) (> x 1)))", "(2 3)")]
#[case("(reduce '(1 2 3 4) +)", "10")]
#[case("(reduce '(1 2 3 4) (fn* (lhs rhs) (+ lhs rhs)))", "10")]
#[case("(fold '(1 2 3 4) 5 +)", "15")]
#[case("(fold '(1 2 3 4) 5 (fn* (lhs rhs) (+ lhs rhs)))", "15")]
fn test_evaluator(#[case] input: &str, #[case] expected: &str) { fn test_evaluator(#[case] input: &str, #[case] expected: &str) {
dbg!(input); dbg!(input);