From b9b833bf5c28ac4fbd6c9f5b0619748a03fbe87c Mon Sep 17 00:00:00 2001 From: Roman Godmaire Date: Tue, 19 Sep 2023 08:17:13 -0400 Subject: [PATCH] Mal step 2 --- Cargo.lock | 21 ++++++ Cargo.toml | 1 + src/env.rs | 177 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/eval.rs | 128 ++++++++++++++++++++++++++++++++++++ src/main.rs | 15 +++-- src/parser.rs | 27 ++++---- 6 files changed, 347 insertions(+), 22 deletions(-) create mode 100644 src/env.rs create mode 100644 src/eval.rs diff --git a/Cargo.lock b/Cargo.lock index c840342..b1819c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,6 +136,7 @@ version = "0.1.0" dependencies = [ "anyhow", "rstest", + "thiserror", ] [[package]] @@ -273,6 +274,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.11" diff --git a/Cargo.toml b/Cargo.toml index 04df988..820afc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] anyhow = "1.0.75" +thiserror = "1.0.48" [dev-dependencies] rstest = "0.18.2" diff --git a/src/env.rs b/src/env.rs new file mode 100644 index 0000000..10ac386 --- /dev/null +++ b/src/env.rs @@ -0,0 +1,177 @@ +use std::{borrow::Borrow, collections::HashMap, rc::Rc}; + +use anyhow::Result; + +use crate::eval::Error; + +#[derive(Debug, PartialEq)] +pub enum Expression { + // Values + Int(i64), + Boolean(bool), + Keyword(String), + String(String), + Nil, + + // Collections + Vector(Vec>), + HashMap(HashMap>), + + NativeFunc { + func: fn(args: Vec>) -> Result>, + }, +} + +impl std::fmt::Display for Expression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Expression::Int(val) => write!(f, "{}", val), + Expression::Boolean(true) => write!(f, "true"), + Expression::Boolean(false) => write!(f, "false"), + Expression::Keyword(val) => write!(f, ":{}", val), + Expression::String(val) => write!(f, "{}", val), + Expression::Nil => write!(f, "nil"), + + Expression::Vector(vec) => { + let s = vec + .iter() + .map(|elem| elem.to_string()) + .reduce(|lhs, rhs| format!("{lhs} {rhs}")) + .unwrap_or_default(); + + write!(f, "[{s}]") + } + Expression::HashMap(map) => { + let res = map + .into_iter() + .map(|(k, v)| format!("{k}: {v}")) + .reduce(|lhs, rhs| format!("{lhs}, {rhs}")) + .unwrap_or_default(); + + write!(f, "{{{res}}}") + } + + Expression::NativeFunc { .. } => write!(f, "function"), + } + } +} + +pub type Environment = HashMap>; + +pub fn core_environment() -> Environment { + [ + // Arithmetic operations + ( + "+".to_string(), + Expression::NativeFunc { + func: |args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs + rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }, + } + .into(), + ), + ( + "-".to_string(), + Expression::NativeFunc { + func: |args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs - rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }, + } + .into(), + ), + ( + "*".to_string(), + Expression::NativeFunc { + func: |args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs * rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }, + } + .into(), + ), + ( + "/".to_string(), + Expression::NativeFunc { + func: |args| { + let res = args + .into_iter() + .reduce(|lhs, rhs| match (lhs.borrow(), rhs.borrow()) { + (Expression::Int(lhs), Expression::Int(rhs)) => { + Expression::Int(lhs / rhs).into() + } + _ => todo!(), + }) + .unwrap_or(Rc::new(Expression::Int(0))); + + Ok(res) + }, + } + .into(), + ), + // Collections + ( + "vector".to_string(), + Expression::NativeFunc { + func: |args| Ok(Rc::new(Expression::Vector(args))), + } + .into(), + ), + ( + "hashmap".to_string(), + Expression::NativeFunc { + func: |args| { + if args.len() % 2 != 0 { + Err(Error::MismatchedArgCount)? + } + + let mut index = -1; + let (keys, values): (Vec<_>, Vec<_>) = args.into_iter().partition(|_| { + index += 1; + index % 2 == 0 + }); + + let res = keys + .into_iter() + // We turn the keys into strings because they're hashable + // This feels so hacky, but ¯\_(ツ)_/¯ + .map(|key| key.to_string()) + .zip(values.into_iter()) + .collect(); + + Ok(Rc::new(Expression::HashMap(res))) + }, + } + .into(), + ), + ] + .into() +} diff --git a/src/eval.rs b/src/eval.rs new file mode 100644 index 0000000..2256572 --- /dev/null +++ b/src/eval.rs @@ -0,0 +1,128 @@ +use std::{borrow::Borrow, rc::Rc}; + +use anyhow::Result; +use thiserror::Error; + +use crate::{ + env::{Environment, Expression}, + parser::Node, +}; + +thread_local! { + static TRUE: Rc = Rc::new(Expression::Boolean(true)); + static FALSE: Rc = Rc::new(Expression::Boolean(false)); + static NIL: Rc = Rc::new(Expression::Nil); +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("could not find symbol in environment")] + NotInEnv, + #[error("expression does not have a valid operator")] + InvalidOperator, + #[error("incorrect number of arguments passed to function")] + MismatchedArgCount, +} + +pub fn eval(env: &Environment, ast: Vec) -> Result>> { + let mut res = Vec::new(); + + for node in ast { + res.push(eval_ast_node(env, node)?); + } + + Ok(res) +} + +fn eval_ast_node(env: &Environment, ast_node: Node) -> Result> { + let expr = match ast_node { + Node::List(list) => { + let mut res = Vec::new(); + for node in list { + res.push(eval_ast_node(env, node)?); + } + + let mut list = res.into_iter(); + let operator = match list.next() { + Some(rc) => match rc.borrow() { + Expression::NativeFunc { func } => *func, + _ => Err(Error::InvalidOperator)?, + }, + None => Err(Error::InvalidOperator)?, + }; + + operator(list.collect())? + } + Node::Symbol(sym) => env.get(&sym).ok_or(Error::NotInEnv)?.to_owned(), + + Node::Int(i) => Expression::Int(i).into(), + Node::Keyword(val) => Expression::Keyword(val).into(), + Node::String(s) => Expression::String(s).into(), + + Node::True => TRUE.with(|val| val.clone()), + Node::False => FALSE.with(|val| val.clone()), + Node::Nil => NIL.with(|val| val.clone()), + }; + + Ok(expr) +} + +#[cfg(test)] +mod test { + use crate::core_environment; + use crate::lexer; + use crate::parser; + + use super::*; + use rstest::rstest; + + #[rstest] + // Basic test of raw values + #[case("1", "1")] + #[case("\"uwu\"", "uwu")] + #[case(":owo", ":owo")] + #[case("(+ 1 2)", "3")] + #[case("(- 5 1)", "4")] + #[case("(* 8 9)", "72")] + #[case("(/ 86 2)", "43")] + // Native functions + #[case("(+ 1 2 (- 3 4))", "2")] + #[case("(vector 1 2 3)", "[1 2 3]")] + // Native functions defaults + #[case("(+)", "0")] + #[case("(-)", "0")] + #[case("(*)", "0")] + #[case("(/)", "0")] + #[case("(vector)", "[]")] + // Collections + #[case("[]", "[]")] + #[case("[1 2]", "[1 2]")] + #[case("[1 (+ 1 2)]", "[1 3]")] + #[case("{}", "{}")] + #[case("{:a \"uwu\"}", "{:a: uwu}")] + fn test_evaluator(#[case] input: &str, #[case] expected: &str) { + let env = &core_environment(); + let tokens = lexer::read(input).unwrap(); + let ast = parser::parse(tokens).unwrap(); + let res = eval(env, ast) + .unwrap() + .into_iter() + .map(|elem| elem.to_string()) + .reduce(|lhs, rhs| format!("{lhs}\n{rhs}")) + .unwrap(); + + assert_eq!(res, expected); + } + + #[rstest] + #[case("{:a}")] + #[case("(not-a-func :uwu)")] + fn test_evaluator_fail(#[case] input: &str) { + let env = &core_environment(); + let tokens = lexer::read(input).unwrap(); + let ast = parser::parse(tokens).unwrap(); + let res = eval(env, ast); + + assert!(res.is_err()) + } +} diff --git a/src/main.rs b/src/main.rs index fa028d6..e45e2df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,14 @@ use std::io::{self, Write}; +use crate::env::core_environment; + +mod env; +mod eval; mod lexer; mod parser; fn main() { + let env = core_environment(); let mut input = String::new(); println!("MAL -- REPL"); @@ -22,14 +27,12 @@ fn main() { let tokens = lexer::read(&input).unwrap(); let ast = parser::parse(tokens).unwrap(); - let res = eval(ast); + let res = eval::eval(&env, ast).unwrap(); - println!("{res}"); + for expr in res { + println!("{expr}") + } input.clear(); } } - -fn eval(input: Vec) -> String { - format!("{input:?}") -} diff --git a/src/parser.rs b/src/parser.rs index 9461282..e588f0b 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -7,8 +7,6 @@ use crate::lexer::Token; #[derive(Debug, PartialEq, PartialOrd)] pub enum Node { List(Vec), - Vector(Vec), - HashMap(Vec), Symbol(String), Keyword(String), @@ -74,7 +72,12 @@ fn next_statement(tokens: &mut Peekable>) -> Result } fn read_list(tokens: &mut Peekable>, closer: Token) -> Result { - let mut list = Vec::new(); + let mut list = match closer { + Token::RightParen => Vec::new(), + Token::RightBracket => vec![Node::Symbol("vector".into())], + Token::RightBrace => vec![Node::Symbol("hashmap".into())], + _ => bail!("unreachable"), + }; loop { if tokens.peek() == Some(&closer) { @@ -98,17 +101,7 @@ fn read_list(tokens: &mut Peekable>, closer: Token) -> Result Ok(Node::List(list)), - Token::RightBracket => Ok(Node::Vector(list)), - Token::RightBrace => Ok(Node::HashMap(list)), - - // This should theoretically be unreachable - _ => bail!( - "invalid collection type using closer {:?}. This is a bug; please file a bug report", - closer - ), - } + Ok(Node::List(list)) } fn read_quote(tokens: &mut Peekable>, quote_type: &str) -> Result { @@ -142,11 +135,13 @@ mod test { Node::Int(10), Node::Int(2)])])] #[case("[10 2]", vec![ - Node::Vector(vec![ + Node::List(vec![ + Node::Symbol("vector".into()), Node::Int(10), Node::Int(2)])])] #[case("{10 2}", vec![ - Node::HashMap(vec![ + Node::List(vec![ + Node::Symbol("hashmap".into()), Node::Int(10), Node::Int(2)])])] #[case("(+ - * /)", vec![