refactor: use macros for semi-repeated functions
This commit is contained in:
parent
5669a04db0
commit
de4e9b4bd3
2 changed files with 84 additions and 199 deletions
281
mute-interpreter/src/env/core.rs
vendored
281
mute-interpreter/src/env/core.rs
vendored
|
@ -5,79 +5,106 @@ use super::{NativeFunc, Value};
|
|||
use crate::macros::arg_count;
|
||||
use crate::Node;
|
||||
|
||||
macro_rules! arithmetic {
|
||||
($operator:tt) => {
|
||||
NativeFunc(|args| {
|
||||
args.into_iter()
|
||||
.reduce(|lhs, rhs| match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs $operator rhs),
|
||||
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs $operator rhs),
|
||||
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 $operator rhs),
|
||||
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs $operator rhs as f64),
|
||||
(lhs, rhs) => Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
)),
|
||||
})
|
||||
.unwrap_or(Node::Int(0))
|
||||
})
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! ordering {
|
||||
($operator:tt) => {
|
||||
NativeFunc(|args| {
|
||||
arg_count!(2, args.len());
|
||||
|
||||
let mut args = args.into_iter();
|
||||
let lhs = args.next().unwrap();
|
||||
let rhs = args.next().unwrap();
|
||||
|
||||
let res = match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => lhs $operator rhs,
|
||||
(Node::Float(lhs), Node::Float(rhs)) => lhs $operator rhs,
|
||||
(Node::Float(lhs), Node::Int(rhs)) => lhs $operator rhs as f64,
|
||||
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) $operator rhs,
|
||||
(lhs, rhs) => {
|
||||
return Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Node::Boolean(res)
|
||||
})
|
||||
};
|
||||
}
|
||||
|
||||
pub(super) fn core() -> HashMap<String, Value> {
|
||||
[
|
||||
// Arithmetic operations
|
||||
("+", arithmetic!(+)),
|
||||
("-", arithmetic!(-)),
|
||||
("*", arithmetic!(*)),
|
||||
("/", arithmetic!(/)),
|
||||
// Ordering
|
||||
("<", ordering!(<)),
|
||||
(">", ordering!(>)),
|
||||
("<=", ordering!(<=)),
|
||||
(">=", ordering!(>=)),
|
||||
// Equalities
|
||||
(
|
||||
"+",
|
||||
"eq?",
|
||||
NativeFunc(|args| {
|
||||
if args.len() < 2 {
|
||||
return Node::Error(format!("expected at least 2 args, got {}", args.len()));
|
||||
}
|
||||
|
||||
args.into_iter()
|
||||
.reduce(|lhs, rhs| match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs + rhs),
|
||||
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs + rhs),
|
||||
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 + rhs),
|
||||
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs + rhs as f64),
|
||||
(Node::Int(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs),
|
||||
(Node::Float(lhs), Node::Float(rhs)) => Node::Boolean(lhs == rhs),
|
||||
(Node::Int(lhs), Node::Float(rhs)) => Node::Boolean(lhs as f64 == rhs),
|
||||
(Node::Float(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs as f64),
|
||||
(Node::Boolean(lhs), Node::Boolean(rhs)) => Node::Boolean(lhs == rhs),
|
||||
(Node::String(lhs), Node::String(rhs)) => Node::Boolean(lhs == rhs),
|
||||
(Node::Nil, Node::Nil) => Node::Boolean(true),
|
||||
err @ (Node::Error(_), _) => err.0,
|
||||
(lhs, rhs) if lhs.get_type() == rhs.get_type() => Node::Error(format!(
|
||||
"args of type {} cannot be compared",
|
||||
lhs.get_type(),
|
||||
)),
|
||||
(lhs, rhs) => Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
"args of types {} and {} cannot be compared",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
)),
|
||||
})
|
||||
.unwrap_or(Node::Int(0))
|
||||
.expect("argument length checked above")
|
||||
}),
|
||||
),
|
||||
(
|
||||
"-",
|
||||
"not",
|
||||
NativeFunc(|args| {
|
||||
args.into_iter()
|
||||
.reduce(|lhs, rhs| match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs - rhs),
|
||||
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs - rhs),
|
||||
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs - rhs as f64),
|
||||
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 - rhs),
|
||||
(lhs, rhs) => Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
)),
|
||||
})
|
||||
.unwrap_or(Node::Int(0))
|
||||
}),
|
||||
),
|
||||
(
|
||||
"*",
|
||||
NativeFunc(|args| {
|
||||
args.into_iter()
|
||||
.reduce(|lhs, rhs| match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs * rhs),
|
||||
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs * rhs),
|
||||
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs * rhs as f64),
|
||||
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 * rhs),
|
||||
(lhs, rhs) => Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
)),
|
||||
})
|
||||
.unwrap_or(Node::Int(0))
|
||||
}),
|
||||
),
|
||||
(
|
||||
"/",
|
||||
NativeFunc(|args| {
|
||||
args.into_iter()
|
||||
.reduce(|lhs, rhs| match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => Node::Int(lhs / rhs),
|
||||
(Node::Float(lhs), Node::Float(rhs)) => Node::Float(lhs / rhs),
|
||||
(Node::Float(lhs), Node::Int(rhs)) => Node::Float(lhs / rhs as f64),
|
||||
(Node::Int(lhs), Node::Float(rhs)) => Node::Float(lhs as f64 / rhs),
|
||||
(lhs, rhs) => Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
)),
|
||||
})
|
||||
.unwrap_or(Node::Int(0))
|
||||
arg_count!(1, args.len());
|
||||
|
||||
match args.into_iter().next().unwrap() {
|
||||
Node::Boolean(val) => Node::Boolean(!val),
|
||||
expr => Node::Error(format!("expected boolean, got {}", expr.get_type())),
|
||||
}
|
||||
}),
|
||||
),
|
||||
// Errors!
|
||||
|
@ -199,146 +226,6 @@ pub(super) fn core() -> HashMap<String, Value> {
|
|||
}
|
||||
}),
|
||||
),
|
||||
// Ordering
|
||||
(
|
||||
"eq?",
|
||||
NativeFunc(|args| {
|
||||
arg_count!(2, args.len());
|
||||
|
||||
let lhs = args[0].borrow();
|
||||
let rhs = args[1].borrow();
|
||||
|
||||
match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => Node::Boolean(lhs == rhs),
|
||||
(Node::Boolean(lhs), Node::Boolean(rhs)) => Node::Boolean(lhs == rhs),
|
||||
(Node::String(lhs), Node::String(rhs)) => Node::Boolean(lhs == rhs),
|
||||
(Node::Nil, Node::Nil) => Node::Boolean(true),
|
||||
(lhs, rhs) if lhs.get_type() == rhs.get_type() => {
|
||||
Node::Error("TypeError: expected int, boolean, or string.".to_string())
|
||||
}
|
||||
(lhs, rhs) => Node::Error(format!(
|
||||
"TypeError: expected pair of int, boolean, or string, got {} and {}.",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
)),
|
||||
}
|
||||
}),
|
||||
),
|
||||
(
|
||||
"not",
|
||||
NativeFunc(|args| {
|
||||
arg_count!(1, args.len());
|
||||
|
||||
match args.into_iter().next().unwrap() {
|
||||
Node::Boolean(val) => Node::Boolean(!val),
|
||||
expr => Node::Error(format!("expected boolean, got {}", expr.get_type())),
|
||||
}
|
||||
}),
|
||||
),
|
||||
(
|
||||
"<",
|
||||
NativeFunc(|args| {
|
||||
arg_count!(2, args.len());
|
||||
|
||||
let mut args = args.into_iter();
|
||||
let lhs = args.next().unwrap();
|
||||
let rhs = args.next().unwrap();
|
||||
|
||||
let less_than = match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => lhs < rhs,
|
||||
(Node::Float(lhs), Node::Float(rhs)) => lhs < rhs,
|
||||
(Node::Float(lhs), Node::Int(rhs)) => lhs < rhs as f64,
|
||||
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) < rhs,
|
||||
(lhs, rhs) => {
|
||||
return Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Node::Boolean(less_than)
|
||||
}),
|
||||
),
|
||||
(
|
||||
">",
|
||||
NativeFunc(|args| {
|
||||
arg_count!(2, args.len());
|
||||
|
||||
let mut args = args.into_iter();
|
||||
let lhs = args.next().unwrap();
|
||||
let rhs = args.next().unwrap();
|
||||
|
||||
let greater_than = match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => lhs > rhs,
|
||||
(Node::Float(lhs), Node::Float(rhs)) => lhs > rhs,
|
||||
(Node::Float(lhs), Node::Int(rhs)) => lhs > rhs as f64,
|
||||
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) > rhs,
|
||||
(lhs, rhs) => {
|
||||
return Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Node::Boolean(greater_than)
|
||||
}),
|
||||
),
|
||||
(
|
||||
"<=",
|
||||
NativeFunc(|args| {
|
||||
arg_count!(2, args.len());
|
||||
|
||||
let mut args = args.into_iter();
|
||||
let lhs = args.next().unwrap();
|
||||
let rhs = args.next().unwrap();
|
||||
|
||||
let less_than_equal = match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => lhs <= rhs,
|
||||
(Node::Float(lhs), Node::Float(rhs)) => lhs <= rhs,
|
||||
(Node::Float(lhs), Node::Int(rhs)) => lhs <= rhs as f64,
|
||||
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) <= rhs,
|
||||
(lhs, rhs) => {
|
||||
return Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Node::Boolean(less_than_equal)
|
||||
}),
|
||||
),
|
||||
(
|
||||
">=",
|
||||
NativeFunc(|args| {
|
||||
arg_count!(2, args.len());
|
||||
|
||||
let mut args = args.into_iter();
|
||||
let lhs = args.next().unwrap();
|
||||
let rhs = args.next().unwrap();
|
||||
|
||||
let greater_than_equal = match (lhs, rhs) {
|
||||
(Node::Int(lhs), Node::Int(rhs)) => lhs >= rhs,
|
||||
(Node::Float(lhs), Node::Float(rhs)) => lhs >= rhs,
|
||||
(Node::Float(lhs), Node::Int(rhs)) => lhs >= rhs as f64,
|
||||
(Node::Int(lhs), Node::Float(rhs)) => (lhs as f64) >= rhs,
|
||||
(lhs, rhs) => {
|
||||
return Node::Error(format!(
|
||||
"expected int or float, got {} and {}",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Node::Boolean(greater_than_equal)
|
||||
}),
|
||||
),
|
||||
// Strings
|
||||
(
|
||||
"str",
|
||||
|
|
|
@ -272,8 +272,6 @@ mod test {
|
|||
#[case("(eq? nil nil)", "true")]
|
||||
#[case("(not false)", "true")]
|
||||
#[case("(not true)", "false")]
|
||||
#[case("(not nil)", "false")]
|
||||
#[case("(not 1)", "false")]
|
||||
// Ordering
|
||||
#[case("(< 1 2)", "true")]
|
||||
#[case("(< 2 1)", "false")]
|
||||
|
|
Loading…
Reference in a new issue