refactor: use macros for semi-repeated functions

This commit is contained in:
Roman Godmaire 2024-05-10 21:27:00 -04:00
parent 5669a04db0
commit de4e9b4bd3
2 changed files with 84 additions and 199 deletions

View file

@ -5,79 +5,106 @@ use super::{NativeFunc, Value};
use crate::macros::arg_count; use crate::macros::arg_count;
use crate::Node; use crate::Node;
macro_rules! arithmetic {
($operator:tt) => {
NativeFunc(|args| {
args.into_iter()
.reduce(|lhs, rhs| match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs $operator rhs),
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs $operator rhs),
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 $operator rhs),
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs $operator rhs as f64),
(lhs, rhs) => Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
)),
})
.unwrap_or(Node::Int(0))
})
};
}
macro_rules! ordering {
($operator:tt) => {
NativeFunc(|args| {
arg_count!(2, args.len());
let mut args = args.into_iter();
let lhs = args.next().unwrap();
let rhs = args.next().unwrap();
let res = match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => lhs $operator rhs,
(Node::Float(lhs), Node::Float(rhs)) => lhs $operator rhs,
(Node::Float(lhs), Node::Int(rhs)) => lhs $operator rhs as f64,
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) $operator rhs,
(lhs, rhs) => {
return Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
))
}
};
Node::Boolean(res)
})
};
}
pub(super) fn core() -> HashMap<String, Value> { pub(super) fn core() -> HashMap<String, Value> {
[ [
// Arithmetic operations // Arithmetic operations
("+", arithmetic!(+)),
("-", arithmetic!(-)),
("*", arithmetic!(*)),
("/", arithmetic!(/)),
// Ordering
("<", ordering!(<)),
(">", ordering!(>)),
("<=", ordering!(<=)),
(">=", ordering!(>=)),
// Equalities
( (
"+", "eq?",
NativeFunc(|args| { NativeFunc(|args| {
if args.len() < 2 {
return Node::Error(format!("expected at least 2 args, got {}", args.len()));
}
args.into_iter() args.into_iter()
.reduce(|lhs, rhs| match (lhs, rhs) { .reduce(|lhs, rhs| match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs + rhs), (Node::Int(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs),
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs + rhs), (Node::Float(lhs), Node::Float(rhs)) => Node::Boolean(lhs == rhs),
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 + rhs), (Node::Int(lhs), Node::Float(rhs)) => Node::Boolean(lhs as f64 == rhs),
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs + rhs as f64), (Node::Float(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs as f64),
(Node::Boolean(lhs), Node::Boolean(rhs)) => Node::Boolean(lhs == rhs),
(Node::String(lhs), Node::String(rhs)) => Node::Boolean(lhs == rhs),
(Node::Nil, Node::Nil) => Node::Boolean(true),
err @ (Node::Error(_), _) => err.0,
(lhs, rhs) if lhs.get_type() == rhs.get_type() => Node::Error(format!(
"args of type {} cannot be compared",
lhs.get_type(),
)),
(lhs, rhs) => Node::Error(format!( (lhs, rhs) => Node::Error(format!(
"expected int or float, got {} and {}", "args of types {} and {} cannot be compared",
lhs.get_type(), lhs.get_type(),
rhs.get_type() rhs.get_type()
)), )),
}) })
.unwrap_or(Node::Int(0)) .expect("argument length checked above")
}), }),
), ),
( (
"-", "not",
NativeFunc(|args| { NativeFunc(|args| {
args.into_iter() arg_count!(1, args.len());
.reduce(|lhs, rhs| match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs - rhs), match args.into_iter().next().unwrap() {
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs - rhs), Node::Boolean(val) => Node::Boolean(!val),
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs - rhs as f64), expr => Node::Error(format!("expected boolean, got {}", expr.get_type())),
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 - rhs), }
(lhs, rhs) => Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
)),
})
.unwrap_or(Node::Int(0))
}),
),
(
"*",
NativeFunc(|args| {
args.into_iter()
.reduce(|lhs, rhs| match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs * rhs),
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs * rhs),
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs * rhs as f64),
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 * rhs),
(lhs, rhs) => Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
)),
})
.unwrap_or(Node::Int(0))
}),
),
(
"/",
NativeFunc(|args| {
args.into_iter()
.reduce(|lhs, rhs| match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs / rhs),
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs / rhs),
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs / rhs as f64),
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 / rhs),
(lhs, rhs) => Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
)),
})
.unwrap_or(Node::Int(0))
}), }),
), ),
// Errors! // Errors!
@ -199,146 +226,6 @@ pub(super) fn core() -> HashMap<String, Value> {
} }
}), }),
), ),
// Ordering
(
"eq?",
NativeFunc(|args| {
arg_count!(2, args.len());
let lhs = args[0].borrow();
let rhs = args[1].borrow();
match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs),
(Node::Boolean(lhs), Node::Boolean(rhs)) => Node::Boolean(lhs == rhs),
(Node::String(lhs), Node::String(rhs)) => Node::Boolean(lhs == rhs),
(Node::Nil, Node::Nil) => Node::Boolean(true),
(lhs, rhs) if lhs.get_type() == rhs.get_type() => {
Node::Error("TypeError: expected int, boolean, or string.".to_string())
}
(lhs, rhs) => Node::Error(format!(
"TypeError: expected pair of int, boolean, or string, got {} and {}.",
lhs.get_type(),
rhs.get_type()
)),
}
}),
),
(
"not",
NativeFunc(|args| {
arg_count!(1, args.len());
match args.into_iter().next().unwrap() {
Node::Boolean(val) => Node::Boolean(!val),
expr => Node::Error(format!("expected boolean, got {}", expr.get_type())),
}
}),
),
(
"<",
NativeFunc(|args| {
arg_count!(2, args.len());
let mut args = args.into_iter();
let lhs = args.next().unwrap();
let rhs = args.next().unwrap();
let less_than = match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => lhs < rhs,
(Node::Float(lhs), Node::Float(rhs)) => lhs < rhs,
(Node::Float(lhs), Node::Int(rhs)) => lhs < rhs as f64,
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) < rhs,
(lhs, rhs) => {
return Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
))
}
};
Node::Boolean(less_than)
}),
),
(
">",
NativeFunc(|args| {
arg_count!(2, args.len());
let mut args = args.into_iter();
let lhs = args.next().unwrap();
let rhs = args.next().unwrap();
let greater_than = match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => lhs > rhs,
(Node::Float(lhs), Node::Float(rhs)) => lhs > rhs,
(Node::Float(lhs), Node::Int(rhs)) => lhs > rhs as f64,
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) > rhs,
(lhs, rhs) => {
return Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
))
}
};
Node::Boolean(greater_than)
}),
),
(
"<=",
NativeFunc(|args| {
arg_count!(2, args.len());
let mut args = args.into_iter();
let lhs = args.next().unwrap();
let rhs = args.next().unwrap();
let less_than_equal = match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => lhs <= rhs,
(Node::Float(lhs), Node::Float(rhs)) => lhs <= rhs,
(Node::Float(lhs), Node::Int(rhs)) => lhs <= rhs as f64,
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) <= rhs,
(lhs, rhs) => {
return Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
))
}
};
Node::Boolean(less_than_equal)
}),
),
(
">=",
NativeFunc(|args| {
arg_count!(2, args.len());
let mut args = args.into_iter();
let lhs = args.next().unwrap();
let rhs = args.next().unwrap();
let greater_than_equal = match (lhs, rhs) {
(Node::Int(lhs), Node::Int(rhs)) => lhs >= rhs,
(Node::Float(lhs), Node::Float(rhs)) => lhs >= rhs,
(Node::Float(lhs), Node::Int(rhs)) => lhs >= rhs as f64,
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) >= rhs,
(lhs, rhs) => {
return Node::Error(format!(
"expected int or float, got {} and {}",
lhs.get_type(),
rhs.get_type()
))
}
};
Node::Boolean(greater_than_equal)
}),
),
// Strings // Strings
( (
"str", "str",

View file

@ -272,8 +272,6 @@ mod test {
#[case("(eq? nil nil)", "true")] #[case("(eq? nil nil)", "true")]
#[case("(not false)", "true")] #[case("(not false)", "true")]
#[case("(not true)", "false")] #[case("(not true)", "false")]
#[case("(not nil)", "false")]
#[case("(not 1)", "false")]
// Ordering // Ordering
#[case("(< 1 2)", "true")] #[case("(< 1 2)", "true")]
#[case("(< 2 1)", "false")] #[case("(< 2 1)", "false")]