diff options
| author | Nathan Perry <np@nathanperry.dev> | 2024-08-07 07:39:16 -0400 |
|---|---|---|
| committer | Nathan Perry <np@nathanperry.dev> | 2024-08-07 07:39:16 -0400 |
| commit | b58b03c22d637fe8f7200edb6953325bf359544d (patch) | |
| tree | 606865b7afc4498b02da14895723b2b0ed49760e /calc | |
| parent | 96c197bde0f83d8b99ec66238856c76b41bfd5e1 (diff) | |
split calc into separate subcrate
Diffstat (limited to 'calc')
| -rw-r--r-- | calc/Cargo.toml | 14 | ||||
| -rw-r--r-- | calc/src/calc.pest | 79 | ||||
| -rw-r--r-- | calc/src/lib.rs | 189 |
3 files changed, 282 insertions, 0 deletions
diff --git a/calc/Cargo.toml b/calc/Cargo.toml new file mode 100644 index 0000000..fe2ecc3 --- /dev/null +++ b/calc/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "thulani_calc" +version = "0.1.0" +authors.workspace = true +edition.workspace = true + +[dependencies] +pest = "2.7" +pest_derive = "2.7" +thiserror = "1.0" +statrs = "0.16" + +lazy_static.workspace = true +rand.workspace = true diff --git a/calc/src/calc.pest b/calc/src/calc.pest new file mode 100644 index 0000000..07eeddb --- /dev/null +++ b/calc/src/calc.pest @@ -0,0 +1,79 @@ +num = { + hex + | oct + | binary + | float +} + +float = @{ int ~ ( "." ~ ASCII_DIGIT*)? ~ (^"e" ~ int)? } + int = { "-"? ~ ASCII_DIGIT+ } + +hex = @{ "0x" ~ ASCII_HEX_DIGIT+ } +oct = @{ "0o" ~ ASCII_OCT_DIGIT+ } +binary = @{ "0b" ~ ASCII_BIN_DIGIT+ } + +infix = _{ add | sub | mul | div | modulo } + add = { "+" } + sub = { "-" } + modulo = { "%" | "mod" } + mul = { "*" } + div = { "/" } + +tight_infix = _{ dice | pow } + dice = { "d" } + pow = { "^" } + +trig = _{ sin | cos | tan | asin | acos | atan } + sin = { "sin" } + cos = { "cos" } + tan = { "tan" } + asin = { "asin" } + acos = { "acos" } + atan = { "atan" } + +htrig = _{ sinh | cosh | tanh | asinh | acosh | atanh } + sinh = { "sinh" } + cosh = { "cosh" } + tanh = { "tanh" } + asinh = { "asinh" } + acosh = { "acosh" } + atanh = { "atanh" } + +unary_prefix = _{ log | sqrt | sgn | htrig | trig | exp | abs | ceil | floor | round } + log = { "log" | "ln" } + sqrt = { "sqrt" } + sgn = { "sgn" } + exp = { "exp" } + abs = { "abs" } + ceil = { "ceil" } + floor = { "floor" } + round = { "round" } + +binary_prefix = _{ min | max | atan2 } + min = { "min" } + max = { "max" } + atan2 = { "atan2" } + +suffix = _{ factorial } + factorial = { "!" } + +term = _{ num | "(" ~ expr ~ ")" } + +suffix_expr = { term ~ suffix } +unary_expr = ${ unary_prefix ~ ws+ ~ outfix_expr } +binary_expr = ${ binary_prefix ~ ws+ ~ outfix_expr ~ ws+ ~ outfix_expr } + +tight = _{ (suffix_expr | term) ~ (tight_infix ~ tight)* } + +expr = { outfix_expr ~ (infix ~ outfix_expr)* } + +outfix_expr = _{ + tight | + binary_expr | + unary_expr +} + +calc = _{ SOI ~ expr ~ EOI } + +ws = _{ " " | "\t" | "\n" } +WHITESPACE = _{ ws } diff --git a/calc/src/lib.rs b/calc/src/lib.rs new file mode 100644 index 0000000..05b83a9 --- /dev/null +++ b/calc/src/lib.rs @@ -0,0 +1,189 @@ +use lazy_static::lazy_static; +use pest::iterators::Pair; +use rand::Rng; + +#[derive(pest_derive::Parser)] +#[grammar = "calc.pest"] +pub struct Calc; + +#[derive(Copy, Clone, thiserror::Error, Debug, PartialEq, Eq, Hash)] +pub enum CalcError { + #[error("pest was unable to parse the input")] + Pest, + + #[error("invalid number format")] + NumberFormat, + + #[error("bad argument count")] + ArgCount, +} + +impl Calc { + pub fn eval<S: AsRef<str>>(s: S) -> Result<f64, CalcError> { + use pest::{ + iterators::Pairs, + pratt_parser::PrattParser, + Parser, + }; + + use self::Rule::*; + + lazy_static! { + static ref CLIMBER: PrattParser<Rule> = { + use pest::pratt_parser::{ + Assoc::*, + Op, + }; + + PrattParser::new() + .op(Op::infix(add, Left) | Op::infix(sub, Left) | Op::infix(modulo, Left)) + .op(Op::infix(mul, Left) | Op::infix(div, Left)) + .op(Op::infix(dice, Left)) + .op(Op::infix(pow, Right)) + .op(Op::postfix(EOI)) // discarded below + }; + } + + let result = Calc::parse(calc, s.as_ref()).map_err(|_| CalcError::Pest)?; + + fn eval_single_pair(pair: Pair<Rule>) -> Result<f64, CalcError> { + let result = match pair.as_rule() { + oct | hex | binary => { + let base = match pair.as_rule() { + hex => 16, + oct => 8, + binary => 2, + _ => unreachable!(), + }; + + u64::from_str_radix(&pair.as_str()[2..], base) + .map_err(|_| CalcError::NumberFormat)? as f64 + }, + float => pair.as_str().parse::<f64>().map_err(|_| CalcError::NumberFormat)?, + expr | num => eval_expr(pair.into_inner())?, + unary_expr => { + let mut p = pair.into_inner(); + + let op = p.next().ok_or(CalcError::ArgCount)?; + let arg = eval_expr(p)?; + + match op.as_rule() { + log => arg.ln(), + sqrt => arg.sqrt(), + sgn => arg.signum(), + + sin => arg.sin(), + cos => arg.cos(), + tan => arg.tan(), + asin => arg.asin(), + acos => arg.acos(), + atan => arg.atan(), + + sinh => arg.sinh(), + cosh => arg.cosh(), + tanh => arg.tanh(), + asinh => arg.asinh(), + acosh => arg.acosh(), + atanh => arg.atanh(), + + exp => arg.exp(), + abs => arg.abs(), + ceil => arg.ceil(), + floor => arg.floor(), + round => arg.round(), + _ => unreachable!(), + } + }, + binary_expr => { + let mut p = pair.into_inner(); + + let op = p.next().ok_or(CalcError::ArgCount)?; + + let arg1 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?; + let arg2 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?; + + assert!(p.next().is_none()); + + match op.as_rule() { + min => arg1.min(arg2), + max => arg1.max(arg2), + atan2 => arg1.atan2(arg2), + _ => unreachable!(), + } + }, + suffix_expr => { + let mut p = pair.into_inner(); + + let arg = eval_expr(p.next().ok_or(CalcError::ArgCount)?.into_inner())?; + let op = p.next().ok_or(CalcError::ArgCount)?; + + assert!(p.next().is_none()); + + match op.as_rule() { + factorial => statrs::function::gamma::gamma(arg + 1.), + _ => unreachable!(), + } + }, + _ => unreachable!(), + }; + + Ok(result) + } + + pub fn eval_expr(p: Pairs<Rule>) -> Result<f64, CalcError> { + CLIMBER + .map_primary(eval_single_pair) + .map_infix(|lhs, op, rhs| { + let lhs = lhs?; + let rhs = rhs?; + + let result = match op.as_rule() { + add => lhs + rhs, + sub => lhs - rhs, + mul => lhs * rhs, + div => lhs / rhs, + pow => lhs.powf(rhs), + dice => { + let dice_count = lhs as usize; + let dice_faces = rhs as usize; + + let mut rng = rand::thread_rng(); + (0..dice_count) + .map(|_| rng.gen_range(1..(dice_faces + 1))) + .sum::<usize>() as f64 + }, + _ => unreachable!(), + }; + + Ok(result) + }) + .map_postfix(|arg, _post| arg) // discard EOI + .parse(p) + } + + eval_expr(result) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_calc_basics() { + assert_eq!(3., Calc::eval("1 + 2").unwrap()); + assert_eq!(3.0f64.ln(), Calc::eval("log 3").unwrap()); + assert!(6. - Calc::eval("3!").unwrap() < 0.0001); + assert_eq!(3., Calc::eval("max 3 2").unwrap()); + } + + #[test] + fn test_binary_unary() { + assert_eq!(3.0f64.ln(), Calc::eval("max log 3 log 2").unwrap()); + } + + #[test] + fn test_prefix_suffix() { + assert!(6. - Calc::eval("abs 3!").unwrap() < 0.0001); + } +} |
