From de4e9b4bd3b20c35cd42320220249c66bcd30149 Mon Sep 17 00:00:00 2001 From: Roman Godmaire Date: Fri, 10 May 2024 21:27:00 -0400 Subject: [PATCH] refactor: use macros for semi-repeated functions --- mute-interpreter/src/env/core.rs | 281 +++++++++--------------------- mute-interpreter/src/evaluator.rs | 2 - 2 files changed, 84 insertions(+), 199 deletions(-) diff --git a/mute-interpreter/src/env/core.rs b/mute-interpreter/src/env/core.rs index e0dc07c..6d4d548 100644 --- a/mute-interpreter/src/env/core.rs +++ b/mute-interpreter/src/env/core.rs @@ -5,79 +5,106 @@ use super::{NativeFunc, Value}; use crate::macros::arg_count; use crate::Node; +macro_rules! arithmetic { + ($operator:tt) => { + NativeFunc(|args| { + args.into_iter() + .reduce(|lhs, rhs| match (lhs, rhs) { + (Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs $operator rhs), + (Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs $operator rhs), + (Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 $operator rhs), + (Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs $operator rhs as f64), + (lhs, rhs) => Node::Error(format!( + "expected int or float, got {} and {}", + lhs.get_type(), + rhs.get_type() + )), + }) + .unwrap_or(Node::Int(0)) + }) + }; +} + +macro_rules! ordering { + ($operator:tt) => { + NativeFunc(|args| { + arg_count!(2, args.len()); + + let mut args = args.into_iter(); + let lhs = args.next().unwrap(); + let rhs = args.next().unwrap(); + + let res = match (lhs, rhs) { + (Node::Int(lhs), Node::Int(rhs)) => lhs $operator rhs, + (Node::Float(lhs), Node::Float(rhs)) => lhs $operator rhs, + (Node::Float(lhs), Node::Int(rhs)) => lhs $operator rhs as f64, + (Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) $operator rhs, + (lhs, rhs) => { + return Node::Error(format!( + "expected int or float, got {} and {}", + lhs.get_type(), + rhs.get_type() + )) + } + }; + + Node::Boolean(res) + }) + }; +} + pub(super) fn core() -> HashMap { [ // Arithmetic operations + ("+", arithmetic!(+)), + ("-", arithmetic!(-)), + ("*", arithmetic!(*)), + ("/", arithmetic!(/)), + // Ordering + ("<", ordering!(<)), + (">", ordering!(>)), + ("<=", ordering!(<=)), + (">=", ordering!(>=)), + // Equalities ( - "+", + "eq?", NativeFunc(|args| { + if args.len() < 2 { + return Node::Error(format!("expected at least 2 args, got {}", args.len())); + } + args.into_iter() .reduce(|lhs, rhs| match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs + rhs), - (Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs + rhs), - (Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 + rhs), - (Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs + rhs as f64), + (Node::Int(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs), + (Node::Float(lhs), Node::Float(rhs)) => Node::Boolean(lhs == rhs), + (Node::Int(lhs), Node::Float(rhs)) => Node::Boolean(lhs as f64 == rhs), + (Node::Float(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs as f64), + (Node::Boolean(lhs), Node::Boolean(rhs)) => Node::Boolean(lhs == rhs), + (Node::String(lhs), Node::String(rhs)) => Node::Boolean(lhs == rhs), + (Node::Nil, Node::Nil) => Node::Boolean(true), + err @ (Node::Error(_), _) => err.0, + (lhs, rhs) if lhs.get_type() == rhs.get_type() => Node::Error(format!( + "args of type {} cannot be compared", + lhs.get_type(), + )), (lhs, rhs) => Node::Error(format!( - "expected int or float, got {} and {}", + "args of types {} and {} cannot be compared", lhs.get_type(), rhs.get_type() )), }) - .unwrap_or(Node::Int(0)) + .expect("argument length checked above") }), ), ( - "-", + "not", NativeFunc(|args| { - args.into_iter() - .reduce(|lhs, rhs| match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs - rhs), - (Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs - rhs), - (Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs - rhs as f64), - (Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 - rhs), - (lhs, rhs) => Node::Error(format!( - "expected int or float, got {} and {}", - lhs.get_type(), - rhs.get_type() - )), - }) - .unwrap_or(Node::Int(0)) - }), - ), - ( - "*", - NativeFunc(|args| { - args.into_iter() - .reduce(|lhs, rhs| match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs * rhs), - (Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs * rhs), - (Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs * rhs as f64), - (Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 * rhs), - (lhs, rhs) => Node::Error(format!( - "expected int or float, got {} and {}", - lhs.get_type(), - rhs.get_type() - )), - }) - .unwrap_or(Node::Int(0)) - }), - ), - ( - "/", - NativeFunc(|args| { - args.into_iter() - .reduce(|lhs, rhs| match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs / rhs), - (Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs / rhs), - (Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs / rhs as f64), - (Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 / rhs), - (lhs, rhs) => Node::Error(format!( - "expected int or float, got {} and {}", - lhs.get_type(), - rhs.get_type() - )), - }) - .unwrap_or(Node::Int(0)) + arg_count!(1, args.len()); + + match args.into_iter().next().unwrap() { + Node::Boolean(val) => Node::Boolean(!val), + expr => Node::Error(format!("expected boolean, got {}", expr.get_type())), + } }), ), // Errors! @@ -199,146 +226,6 @@ pub(super) fn core() -> HashMap { } }), ), - // Ordering - ( - "eq?", - NativeFunc(|args| { - arg_count!(2, args.len()); - - let lhs = args[0].borrow(); - let rhs = args[1].borrow(); - - match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs), - (Node::Boolean(lhs), Node::Boolean(rhs)) => Node::Boolean(lhs == rhs), - (Node::String(lhs), Node::String(rhs)) => Node::Boolean(lhs == rhs), - (Node::Nil, Node::Nil) => Node::Boolean(true), - (lhs, rhs) if lhs.get_type() == rhs.get_type() => { - Node::Error("TypeError: expected int, boolean, or string.".to_string()) - } - (lhs, rhs) => Node::Error(format!( - "TypeError: expected pair of int, boolean, or string, got {} and {}.", - lhs.get_type(), - rhs.get_type() - )), - } - }), - ), - ( - "not", - NativeFunc(|args| { - arg_count!(1, args.len()); - - match args.into_iter().next().unwrap() { - Node::Boolean(val) => Node::Boolean(!val), - expr => Node::Error(format!("expected boolean, got {}", expr.get_type())), - } - }), - ), - ( - "<", - NativeFunc(|args| { - arg_count!(2, args.len()); - - let mut args = args.into_iter(); - let lhs = args.next().unwrap(); - let rhs = args.next().unwrap(); - - let less_than = match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => lhs < rhs, - (Node::Float(lhs), Node::Float(rhs)) => lhs < rhs, - (Node::Float(lhs), Node::Int(rhs)) => lhs < rhs as f64, - (Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) < rhs, - (lhs, rhs) => { - return Node::Error(format!( - "expected int or float, got {} and {}", - lhs.get_type(), - rhs.get_type() - )) - } - }; - - Node::Boolean(less_than) - }), - ), - ( - ">", - NativeFunc(|args| { - arg_count!(2, args.len()); - - let mut args = args.into_iter(); - let lhs = args.next().unwrap(); - let rhs = args.next().unwrap(); - - let greater_than = match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => lhs > rhs, - (Node::Float(lhs), Node::Float(rhs)) => lhs > rhs, - (Node::Float(lhs), Node::Int(rhs)) => lhs > rhs as f64, - (Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) > rhs, - (lhs, rhs) => { - return Node::Error(format!( - "expected int or float, got {} and {}", - lhs.get_type(), - rhs.get_type() - )) - } - }; - - Node::Boolean(greater_than) - }), - ), - ( - "<=", - NativeFunc(|args| { - arg_count!(2, args.len()); - - let mut args = args.into_iter(); - let lhs = args.next().unwrap(); - let rhs = args.next().unwrap(); - - let less_than_equal = match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => lhs <= rhs, - (Node::Float(lhs), Node::Float(rhs)) => lhs <= rhs, - (Node::Float(lhs), Node::Int(rhs)) => lhs <= rhs as f64, - (Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) <= rhs, - (lhs, rhs) => { - return Node::Error(format!( - "expected int or float, got {} and {}", - lhs.get_type(), - rhs.get_type() - )) - } - }; - - Node::Boolean(less_than_equal) - }), - ), - ( - ">=", - NativeFunc(|args| { - arg_count!(2, args.len()); - - let mut args = args.into_iter(); - let lhs = args.next().unwrap(); - let rhs = args.next().unwrap(); - - let greater_than_equal = match (lhs, rhs) { - (Node::Int(lhs), Node::Int(rhs)) => lhs >= rhs, - (Node::Float(lhs), Node::Float(rhs)) => lhs >= rhs, - (Node::Float(lhs), Node::Int(rhs)) => lhs >= rhs as f64, - (Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) >= rhs, - (lhs, rhs) => { - return Node::Error(format!( - "expected int or float, got {} and {}", - lhs.get_type(), - rhs.get_type() - )) - } - }; - - Node::Boolean(greater_than_equal) - }), - ), // Strings ( "str", diff --git a/mute-interpreter/src/evaluator.rs b/mute-interpreter/src/evaluator.rs index 873607d..cfb7761 100644 --- a/mute-interpreter/src/evaluator.rs +++ b/mute-interpreter/src/evaluator.rs @@ -272,8 +272,6 @@ mod test { #[case("(eq? nil nil)", "true")] #[case("(not false)", "true")] #[case("(not true)", "false")] - #[case("(not nil)", "false")] - #[case("(not 1)", "false")] // Ordering #[case("(< 1 2)", "true")] #[case("(< 2 1)", "false")]