aboutsummaryrefslogtreecommitdiff
path: root/calc
diff options
context:
space:
mode:
authorNathan Perry <np@nathanperry.dev>2024-08-07 07:39:16 -0400
committerNathan Perry <np@nathanperry.dev>2024-08-07 07:39:16 -0400
commitb58b03c22d637fe8f7200edb6953325bf359544d (patch)
tree606865b7afc4498b02da14895723b2b0ed49760e /calc
parent96c197bde0f83d8b99ec66238856c76b41bfd5e1 (diff)
split calc into separate subcrate
Diffstat (limited to 'calc')
-rw-r--r--calc/Cargo.toml14
-rw-r--r--calc/src/calc.pest79
-rw-r--r--calc/src/lib.rs189
3 files changed, 282 insertions, 0 deletions
diff --git a/calc/Cargo.toml b/calc/Cargo.toml
new file mode 100644
index 0000000..fe2ecc3
--- /dev/null
+++ b/calc/Cargo.toml
@@ -0,0 +1,14 @@
+[package]
+name = "thulani_calc"
+version = "0.1.0"
+authors.workspace = true
+edition.workspace = true
+
+[dependencies]
+pest = "2.7"
+pest_derive = "2.7"
+thiserror = "1.0"
+statrs = "0.16"
+
+lazy_static.workspace = true
+rand.workspace = true
diff --git a/calc/src/calc.pest b/calc/src/calc.pest
new file mode 100644
index 0000000..07eeddb
--- /dev/null
+++ b/calc/src/calc.pest
@@ -0,0 +1,79 @@
+num = {
+ hex
+ | oct
+ | binary
+ | float
+}
+
+float = @{ int ~ ( "." ~ ASCII_DIGIT*)? ~ (^"e" ~ int)? }
+ int = { "-"? ~ ASCII_DIGIT+ }
+
+hex = @{ "0x" ~ ASCII_HEX_DIGIT+ }
+oct = @{ "0o" ~ ASCII_OCT_DIGIT+ }
+binary = @{ "0b" ~ ASCII_BIN_DIGIT+ }
+
+infix = _{ add | sub | mul | div | modulo }
+ add = { "+" }
+ sub = { "-" }
+ modulo = { "%" | "mod" }
+ mul = { "*" }
+ div = { "/" }
+
+tight_infix = _{ dice | pow }
+ dice = { "d" }
+ pow = { "^" }
+
+trig = _{ sin | cos | tan | asin | acos | atan }
+ sin = { "sin" }
+ cos = { "cos" }
+ tan = { "tan" }
+ asin = { "asin" }
+ acos = { "acos" }
+ atan = { "atan" }
+
+htrig = _{ sinh | cosh | tanh | asinh | acosh | atanh }
+ sinh = { "sinh" }
+ cosh = { "cosh" }
+ tanh = { "tanh" }
+ asinh = { "asinh" }
+ acosh = { "acosh" }
+ atanh = { "atanh" }
+
+unary_prefix = _{ log | sqrt | sgn | htrig | trig | exp | abs | ceil | floor | round }
+ log = { "log" | "ln" }
+ sqrt = { "sqrt" }
+ sgn = { "sgn" }
+ exp = { "exp" }
+ abs = { "abs" }
+ ceil = { "ceil" }
+ floor = { "floor" }
+ round = { "round" }
+
+binary_prefix = _{ min | max | atan2 }
+ min = { "min" }
+ max = { "max" }
+ atan2 = { "atan2" }
+
+suffix = _{ factorial }
+ factorial = { "!" }
+
+term = _{ num | "(" ~ expr ~ ")" }
+
+suffix_expr = { term ~ suffix }
+unary_expr = ${ unary_prefix ~ ws+ ~ outfix_expr }
+binary_expr = ${ binary_prefix ~ ws+ ~ outfix_expr ~ ws+ ~ outfix_expr }
+
+tight = _{ (suffix_expr | term) ~ (tight_infix ~ tight)* }
+
+expr = { outfix_expr ~ (infix ~ outfix_expr)* }
+
+outfix_expr = _{
+ tight |
+ binary_expr |
+ unary_expr
+}
+
+calc = _{ SOI ~ expr ~ EOI }
+
+ws = _{ " " | "\t" | "\n" }
+WHITESPACE = _{ ws }
diff --git a/calc/src/lib.rs b/calc/src/lib.rs
new file mode 100644
index 0000000..05b83a9
--- /dev/null
+++ b/calc/src/lib.rs
@@ -0,0 +1,189 @@
+use lazy_static::lazy_static;
+use pest::iterators::Pair;
+use rand::Rng;
+
+#[derive(pest_derive::Parser)]
+#[grammar = "calc.pest"]
+pub struct Calc;
+
+#[derive(Copy, Clone, thiserror::Error, Debug, PartialEq, Eq, Hash)]
+pub enum CalcError {
+ #[error("pest was unable to parse the input")]
+ Pest,
+
+ #[error("invalid number format")]
+ NumberFormat,
+
+ #[error("bad argument count")]
+ ArgCount,
+}
+
+impl Calc {
+ pub fn eval<S: AsRef<str>>(s: S) -> Result<f64, CalcError> {
+ use pest::{
+ iterators::Pairs,
+ pratt_parser::PrattParser,
+ Parser,
+ };
+
+ use self::Rule::*;
+
+ lazy_static! {
+ static ref CLIMBER: PrattParser<Rule> = {
+ use pest::pratt_parser::{
+ Assoc::*,
+ Op,
+ };
+
+ PrattParser::new()
+ .op(Op::infix(add, Left) | Op::infix(sub, Left) | Op::infix(modulo, Left))
+ .op(Op::infix(mul, Left) | Op::infix(div, Left))
+ .op(Op::infix(dice, Left))
+ .op(Op::infix(pow, Right))
+ .op(Op::postfix(EOI)) // discarded below
+ };
+ }
+
+ let result = Calc::parse(calc, s.as_ref()).map_err(|_| CalcError::Pest)?;
+
+ fn eval_single_pair(pair: Pair<Rule>) -> Result<f64, CalcError> {
+ let result = match pair.as_rule() {
+ oct | hex | binary => {
+ let base = match pair.as_rule() {
+ hex => 16,
+ oct => 8,
+ binary => 2,
+ _ => unreachable!(),
+ };
+
+ u64::from_str_radix(&pair.as_str()[2..], base)
+ .map_err(|_| CalcError::NumberFormat)? as f64
+ },
+ float => pair.as_str().parse::<f64>().map_err(|_| CalcError::NumberFormat)?,
+ expr | num => eval_expr(pair.into_inner())?,
+ unary_expr => {
+ let mut p = pair.into_inner();
+
+ let op = p.next().ok_or(CalcError::ArgCount)?;
+ let arg = eval_expr(p)?;
+
+ match op.as_rule() {
+ log => arg.ln(),
+ sqrt => arg.sqrt(),
+ sgn => arg.signum(),
+
+ sin => arg.sin(),
+ cos => arg.cos(),
+ tan => arg.tan(),
+ asin => arg.asin(),
+ acos => arg.acos(),
+ atan => arg.atan(),
+
+ sinh => arg.sinh(),
+ cosh => arg.cosh(),
+ tanh => arg.tanh(),
+ asinh => arg.asinh(),
+ acosh => arg.acosh(),
+ atanh => arg.atanh(),
+
+ exp => arg.exp(),
+ abs => arg.abs(),
+ ceil => arg.ceil(),
+ floor => arg.floor(),
+ round => arg.round(),
+ _ => unreachable!(),
+ }
+ },
+ binary_expr => {
+ let mut p = pair.into_inner();
+
+ let op = p.next().ok_or(CalcError::ArgCount)?;
+
+ let arg1 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?;
+ let arg2 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?;
+
+ assert!(p.next().is_none());
+
+ match op.as_rule() {
+ min => arg1.min(arg2),
+ max => arg1.max(arg2),
+ atan2 => arg1.atan2(arg2),
+ _ => unreachable!(),
+ }
+ },
+ suffix_expr => {
+ let mut p = pair.into_inner();
+
+ let arg = eval_expr(p.next().ok_or(CalcError::ArgCount)?.into_inner())?;
+ let op = p.next().ok_or(CalcError::ArgCount)?;
+
+ assert!(p.next().is_none());
+
+ match op.as_rule() {
+ factorial => statrs::function::gamma::gamma(arg + 1.),
+ _ => unreachable!(),
+ }
+ },
+ _ => unreachable!(),
+ };
+
+ Ok(result)
+ }
+
+ pub fn eval_expr(p: Pairs<Rule>) -> Result<f64, CalcError> {
+ CLIMBER
+ .map_primary(eval_single_pair)
+ .map_infix(|lhs, op, rhs| {
+ let lhs = lhs?;
+ let rhs = rhs?;
+
+ let result = match op.as_rule() {
+ add => lhs + rhs,
+ sub => lhs - rhs,
+ mul => lhs * rhs,
+ div => lhs / rhs,
+ pow => lhs.powf(rhs),
+ dice => {
+ let dice_count = lhs as usize;
+ let dice_faces = rhs as usize;
+
+ let mut rng = rand::thread_rng();
+ (0..dice_count)
+ .map(|_| rng.gen_range(1..(dice_faces + 1)))
+ .sum::<usize>() as f64
+ },
+ _ => unreachable!(),
+ };
+
+ Ok(result)
+ })
+ .map_postfix(|arg, _post| arg) // discard EOI
+ .parse(p)
+ }
+
+ eval_expr(result)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_calc_basics() {
+ assert_eq!(3., Calc::eval("1 + 2").unwrap());
+ assert_eq!(3.0f64.ln(), Calc::eval("log 3").unwrap());
+ assert!(6. - Calc::eval("3!").unwrap() < 0.0001);
+ assert_eq!(3., Calc::eval("max 3 2").unwrap());
+ }
+
+ #[test]
+ fn test_binary_unary() {
+ assert_eq!(3.0f64.ln(), Calc::eval("max log 3 log 2").unwrap());
+ }
+
+ #[test]
+ fn test_prefix_suffix() {
+ assert!(6. - Calc::eval("abs 3!").unwrap() < 0.0001);
+ }
+}