Mal step 2

This commit is contained in:
Roman Godmaire 2023-09-19 08:17:13 -04:00
parent 96f822ace8
commit b9b833bf5c
6 changed files with 347 additions and 22 deletions

21
Cargo.lock generated
View file

@ -136,6 +136,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"rstest", "rstest",
"thiserror",
] ]
[[package]] [[package]]
@ -273,6 +274,26 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "thiserror"
version = "1.0.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.11" version = "1.0.11"

View file

@ -7,6 +7,7 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.75" anyhow = "1.0.75"
thiserror = "1.0.48"
[dev-dependencies] [dev-dependencies]
rstest = "0.18.2" rstest = "0.18.2"

177
src/env.rs Normal file
View file

@ -0,0 +1,177 @@
use std::{borrow::Borrow, collections::HashMap, rc::Rc};
use anyhow::Result;
use crate::eval::Error;
#[derive(Debug, PartialEq)]
pub enum Expression {
// Values
Int(i64),
Boolean(bool),
Keyword(String),
String(String),
Nil,
// Collections
Vector(Vec<Rc<Expression>>),
HashMap(HashMap<String, Rc<Expression>>),
NativeFunc {
func: fn(args: Vec<Rc<Expression>>) -> Result<Rc<Expression>>,
},
}
impl std::fmt::Display for Expression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expression::Int(val) => write!(f, "{}", val),
Expression::Boolean(true) => write!(f, "true"),
Expression::Boolean(false) => write!(f, "false"),
Expression::Keyword(val) => write!(f, ":{}", val),
Expression::String(val) => write!(f, "{}", val),
Expression::Nil => write!(f, "nil"),
Expression::Vector(vec) => {
let s = vec
.iter()
.map(|elem| elem.to_string())
.reduce(|lhs, rhs| format!("{lhs} {rhs}"))
.unwrap_or_default();
write!(f, "[{s}]")
}
Expression::HashMap(map) => {
let res = map
.into_iter()
.map(|(k, v)| format!("{k}: {v}"))
.reduce(|lhs, rhs| format!("{lhs}, {rhs}"))
.unwrap_or_default();
write!(f, "{{{res}}}")
}
Expression::NativeFunc { .. } => write!(f, "function"),
}
}
}
pub type Environment = HashMap<String, Rc<Expression>>;
pub fn core_environment() -> Environment {
[
// Arithmetic operations
(
"+".to_string(),
Expression::NativeFunc {
func: |args| {
let res = args
.into_iter()
.reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) {
(Expression::Int(lhs), Expression::Int(rhs)) => {
Expression::Int(lhs + rhs).into()
}
_ => todo!(),
})
.unwrap_or(Rc::new(Expression::Int(0)));
Ok(res)
},
}
.into(),
),
(
"-".to_string(),
Expression::NativeFunc {
func: |args| {
let res = args
.into_iter()
.reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) {
(Expression::Int(lhs), Expression::Int(rhs)) => {
Expression::Int(lhs - rhs).into()
}
_ => todo!(),
})
.unwrap_or(Rc::new(Expression::Int(0)));
Ok(res)
},
}
.into(),
),
(
"*".to_string(),
Expression::NativeFunc {
func: |args| {
let res = args
.into_iter()
.reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) {
(Expression::Int(lhs), Expression::Int(rhs)) => {
Expression::Int(lhs * rhs).into()
}
_ => todo!(),
})
.unwrap_or(Rc::new(Expression::Int(0)));
Ok(res)
},
}
.into(),
),
(
"/".to_string(),
Expression::NativeFunc {
func: |args| {
let res = args
.into_iter()
.reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) {
(Expression::Int(lhs), Expression::Int(rhs)) => {
Expression::Int(lhs / rhs).into()
}
_ => todo!(),
})
.unwrap_or(Rc::new(Expression::Int(0)));
Ok(res)
},
}
.into(),
),
// Collections
(
"vector".to_string(),
Expression::NativeFunc {
func: |args| Ok(Rc::new(Expression::Vector(args))),
}
.into(),
),
(
"hashmap".to_string(),
Expression::NativeFunc {
func: |args| {
if args.len() % 2 != 0 {
Err(Error::MismatchedArgCount)?
}
let mut index = -1;
let (keys, values): (Vec<_>, Vec<_>) = args.into_iter().partition(|_| {
index += 1;
index % 2 == 0
});
let res = keys
.into_iter()
// We turn the keys into strings because they're hashable
// This feels so hacky, but ¯\_(ツ)_/¯
.map(|key| key.to_string())
.zip(values.into_iter())
.collect();
Ok(Rc::new(Expression::HashMap(res)))
},
}
.into(),
),
]
.into()
}

128
src/eval.rs Normal file
View file

@ -0,0 +1,128 @@
use std::{borrow::Borrow, rc::Rc};
use anyhow::Result;
use thiserror::Error;
use crate::{
env::{Environment, Expression},
parser::Node,
};
thread_local! {
static TRUE: Rc<Expression> = Rc::new(Expression::Boolean(true));
static FALSE: Rc<Expression> = Rc::new(Expression::Boolean(false));
static NIL: Rc<Expression> = Rc::new(Expression::Nil);
}
#[derive(Debug, Error)]
pub enum Error {
#[error("could not find symbol in environment")]
NotInEnv,
#[error("expression does not have a valid operator")]
InvalidOperator,
#[error("incorrect number of arguments passed to function")]
MismatchedArgCount,
}
pub fn eval(env: &Environment, ast: Vec<Node>) -> Result<Vec<Rc<Expression>>> {
let mut res = Vec::new();
for node in ast {
res.push(eval_ast_node(env, node)?);
}
Ok(res)
}
fn eval_ast_node(env: &Environment, ast_node: Node) -> Result<Rc<Expression>> {
let expr = match ast_node {
Node::List(list) => {
let mut res = Vec::new();
for node in list {
res.push(eval_ast_node(env, node)?);
}
let mut list = res.into_iter();
let operator = match list.next() {
Some(rc) => match rc.borrow() {
Expression::NativeFunc { func } => *func,
_ => Err(Error::InvalidOperator)?,
},
None => Err(Error::InvalidOperator)?,
};
operator(list.collect())?
}
Node::Symbol(sym) => env.get(&sym).ok_or(Error::NotInEnv)?.to_owned(),
Node::Int(i) => Expression::Int(i).into(),
Node::Keyword(val) => Expression::Keyword(val).into(),
Node::String(s) => Expression::String(s).into(),
Node::True => TRUE.with(|val| val.clone()),
Node::False => FALSE.with(|val| val.clone()),
Node::Nil => NIL.with(|val| val.clone()),
};
Ok(expr)
}
#[cfg(test)]
mod test {
use crate::core_environment;
use crate::lexer;
use crate::parser;
use super::*;
use rstest::rstest;
#[rstest]
// Basic test of raw values
#[case("1", "1")]
#[case("\"uwu\"", "uwu")]
#[case(":owo", ":owo")]
#[case("(+ 1 2)", "3")]
#[case("(- 5 1)", "4")]
#[case("(* 8 9)", "72")]
#[case("(/ 86 2)", "43")]
// Native functions
#[case("(+ 1 2 (- 3 4))", "2")]
#[case("(vector 1 2 3)", "[1 2 3]")]
// Native functions defaults
#[case("(+)", "0")]
#[case("(-)", "0")]
#[case("(*)", "0")]
#[case("(/)", "0")]
#[case("(vector)", "[]")]
// Collections
#[case("[]", "[]")]
#[case("[1 2]", "[1 2]")]
#[case("[1 (+ 1 2)]", "[1 3]")]
#[case("{}", "{}")]
#[case("{:a \"uwu\"}", "{:a: uwu}")]
fn test_evaluator(#[case] input: &str, #[case] expected: &str) {
let env = &core_environment();
let tokens = lexer::read(input).unwrap();
let ast = parser::parse(tokens).unwrap();
let res = eval(env, ast)
.unwrap()
.into_iter()
.map(|elem| elem.to_string())
.reduce(|lhs, rhs| format!("{lhs}\n{rhs}"))
.unwrap();
assert_eq!(res, expected);
}
#[rstest]
#[case("{:a}")]
#[case("(not-a-func :uwu)")]
fn test_evaluator_fail(#[case] input: &str) {
let env = &core_environment();
let tokens = lexer::read(input).unwrap();
let ast = parser::parse(tokens).unwrap();
let res = eval(env, ast);
assert!(res.is_err())
}
}

View file

@ -1,9 +1,14 @@
use std::io::{self, Write}; use std::io::{self, Write};
use crate::env::core_environment;
mod env;
mod eval;
mod lexer; mod lexer;
mod parser; mod parser;
fn main() { fn main() {
let env = core_environment();
let mut input = String::new(); let mut input = String::new();
println!("MAL -- REPL"); println!("MAL -- REPL");
@ -22,14 +27,12 @@ fn main() {
let tokens = lexer::read(&input).unwrap(); let tokens = lexer::read(&input).unwrap();
let ast = parser::parse(tokens).unwrap(); let ast = parser::parse(tokens).unwrap();
let res = eval(ast); let res = eval::eval(&env, ast).unwrap();
println!("{res}"); for expr in res {
println!("{expr}")
}
input.clear(); input.clear();
} }
} }
fn eval(input: Vec<parser::Node>) -> String {
format!("{input:?}")
}

View file

@ -7,8 +7,6 @@ use crate::lexer::Token;
#[derive(Debug, PartialEq, PartialOrd)] #[derive(Debug, PartialEq, PartialOrd)]
pub enum Node { pub enum Node {
List(Vec<Node>), List(Vec<Node>),
Vector(Vec<Node>),
HashMap(Vec<Node>),
Symbol(String), Symbol(String),
Keyword(String), Keyword(String),
@ -74,7 +72,12 @@ fn next_statement(tokens: &mut Peekable<IntoIter<Token>>) -> Result<Option<Node>
} }
fn read_list(tokens: &mut Peekable<IntoIter<Token>>, closer: Token) -> Result<Node> { fn read_list(tokens: &mut Peekable<IntoIter<Token>>, closer: Token) -> Result<Node> {
let mut list = Vec::new(); let mut list = match closer {
Token::RightParen => Vec::new(),
Token::RightBracket => vec![Node::Symbol("vector".into())],
Token::RightBrace => vec![Node::Symbol("hashmap".into())],
_ => bail!("unreachable"),
};
loop { loop {
if tokens.peek() == Some(&closer) { if tokens.peek() == Some(&closer) {
@ -98,17 +101,7 @@ fn read_list(tokens: &mut Peekable<IntoIter<Token>>, closer: Token) -> Result<No
} }
} }
match closer { Ok(Node::List(list))
Token::RightParen => Ok(Node::List(list)),
Token::RightBracket => Ok(Node::Vector(list)),
Token::RightBrace => Ok(Node::HashMap(list)),
// This should theoretically be unreachable
_ => bail!(
"invalid collection type using closer {:?}. This is a bug; please file a bug report",
closer
),
}
} }
fn read_quote(tokens: &mut Peekable<IntoIter<Token>>, quote_type: &str) -> Result<Node> { fn read_quote(tokens: &mut Peekable<IntoIter<Token>>, quote_type: &str) -> Result<Node> {
@ -142,11 +135,13 @@ mod test {
Node::Int(10), Node::Int(10),
Node::Int(2)])])] Node::Int(2)])])]
#[case("[10 2]", vec![ #[case("[10 2]", vec![
Node::Vector(vec![ Node::List(vec![
Node::Symbol("vector".into()),
Node::Int(10), Node::Int(10),
Node::Int(2)])])] Node::Int(2)])])]
#[case("{10 2}", vec![ #[case("{10 2}", vec![
Node::HashMap(vec![ Node::List(vec![
Node::Symbol("hashmap".into()),
Node::Int(10), Node::Int(10),
Node::Int(2)])])] Node::Int(2)])])]
#[case("(+ - * /)", vec![ #[case("(+ - * /)", vec![