diff options
Diffstat (limited to 'native/thulani_calc/src/lib.rs')
| -rw-r--r-- | native/thulani_calc/src/lib.rs | 215 |
1 files changed, 215 insertions, 0 deletions
diff --git a/native/thulani_calc/src/lib.rs b/native/thulani_calc/src/lib.rs new file mode 100644 index 0000000..75a6b37 --- /dev/null +++ b/native/thulani_calc/src/lib.rs @@ -0,0 +1,215 @@ +use log::error; +use rand::prelude::*; +use statrs; +use thiserror::Error; +use lazy_static::lazy_static; +use rustler::Encoder; + +mod atoms { + rustler::atoms! { + ok, + error, + inf, + ninf, + nan, + } +} + +#[rustler::nif] +pub fn eval<'a>(env: rustler::Env<'a>, args: String) -> Result<rustler::Term<'a>, String> { + Calc::eval(&args) + .map(|x| match x { + f64::INFINITY => atoms::inf().encode(env), + f64::NEG_INFINITY => atoms::ninf().encode(env), + x if x.is_nan() => atoms::nan().encode(env), + x => x.encode(env), + }) + .map_err(|e| e.to_string()) +} + +rustler::init!("Elixir.Thulani.Calc", [eval]); + + +#[derive(pest_derive::Parser)] +#[grammar = "calc.pest"] +struct Calc; + +#[derive(Copy, Clone, Error, Debug, PartialEq, Eq, Hash)] +pub(crate) enum CalcError { + #[error("pest was unable to parse the input")] + Pest, + + #[error("invalid number format")] + NumberFormat, + + #[error("bad argument count")] + ArgCount, +} + +impl Calc { + fn eval<S: AsRef<str>>(s: S) -> Result<f64, CalcError> { + use pest::{ + Parser, + prec_climber::PrecClimber, + iterators::{Pair, Pairs}, + }; + + use self::Rule::*; + + lazy_static! { + static ref CLIMBER: PrecClimber<self::Rule> = { + use pest::prec_climber::{ + Operator, + Assoc::*, + }; + + PrecClimber::new(vec![ + Operator::new(add, Left) | Operator::new(sub, Left) | Operator::new(modulo, Left), + Operator::new(mul, Left) | Operator::new(div, Left), + Operator::new(dice, Left), + Operator::new(pow, Right), + ]) + }; + } + + let result = Calc::parse(calc, s.as_ref()).map_err(|_| CalcError::Pest)?; + + fn eval_single_pair(pair: Pair<self::Rule>) -> Result<f64, CalcError> { + let result = match pair.as_rule() { + oct | hex | binary => { + let base = match pair.as_rule() { + hex => 16, + oct => 8, + binary => 2, + _ => unreachable!(), + }; + + u64::from_str_radix(&pair.as_str()[2..], base).map_err(|_| CalcError::NumberFormat)? as f64 + }, + float => pair.as_str().parse::<f64>().map_err(|_| CalcError::NumberFormat)?, + expr | num => eval_expr(pair.into_inner())?, + unary_expr => { + let mut p = pair.into_inner(); + + let op = p.next().ok_or(CalcError::ArgCount)?; + let arg = eval_expr(p)?; + + match op.as_rule() { + log => arg.ln(), + sqrt => arg.sqrt(), + sgn => arg.signum(), + + sin => arg.sin(), + cos => arg.cos(), + tan => arg.tan(), + asin => arg.asin(), + acos => arg.acos(), + atan => arg.atan(), + + sinh => arg.sinh(), + cosh => arg.cosh(), + tanh => arg.tanh(), + asinh => arg.asinh(), + acosh => arg.acosh(), + atanh => arg.atanh(), + + exp => arg.exp(), + abs => arg.abs(), + ceil => arg.ceil(), + floor => arg.floor(), + round => arg.round(), + _ => unreachable!(), + } + }, + binary_expr => { + let mut p = pair.into_inner(); + + let op = p.next().ok_or(CalcError::ArgCount)?; + + let arg1 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?; + let arg2 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?; + + assert!(p.next().is_none()); + + match op.as_rule() { + min => arg1.min(arg2), + max => arg1.max(arg2), + atan2 => arg1.atan2(arg2), + _ => unreachable!(), + } + }, + suffix_expr => { + let mut p = pair.into_inner(); + + let arg = eval_expr(p.next().ok_or(CalcError::ArgCount)?.into_inner())?; + let op = p.next().ok_or(CalcError::ArgCount)?; + + assert!(p.next().is_none()); + + match op.as_rule() { + factorial => statrs::function::gamma::gamma(arg + 1.), + _ => unreachable!(), + } + }, + _ => unreachable!(), + }; + + Ok(result) + } + + fn eval_expr(p: Pairs<self::Rule>) -> Result<f64, CalcError> { + CLIMBER.climb( + p, + eval_single_pair, + |lhs, op, rhs| { + let lhs = lhs?; + let rhs = rhs?; + + let result = match op.as_rule() { + add => lhs + rhs, + sub => lhs - rhs, + mul => lhs * rhs, + div => lhs / rhs, + pow => lhs.powf(rhs), + dice => { + let dice_count = lhs as usize; + let dice_faces = rhs as usize; + + let mut rng = thread_rng(); + (0..dice_count).map(|_| rng.gen_range(1..=dice_faces)).sum::<usize>() as f64 + }, + _ => unreachable!(), + }; + + Ok(result) + } + ) + } + + eval_expr(result) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_calc_basics() { + assert_eq!(3., Calc::eval("1 + 2").unwrap()); + assert_eq!(3.0f64.ln(), Calc::eval("log 3").unwrap()); + assert!(6. - Calc::eval("3!").unwrap() < 0.0001); + assert_eq!(3., Calc::eval("max 3 2").unwrap()); + } + + #[test] + fn test_binary_unary() { + assert_eq!(3.0f64.ln(), Calc::eval("max log 3 log 2").unwrap()); + } + + #[test] + fn test_prefix_suffix() { + assert!(6. - Calc::eval("abs 3!").unwrap() < 0.0001); + } +} + |
