use pest::{ iterators::{ Pair, Pairs, }, pratt_parser::PrattParser, Parser, }; use rand::Rng as _; #[derive(pest_derive::Parser)] #[grammar = "calc.pest"] pub struct Calc; #[derive(Copy, Clone, thiserror::Error, Debug, PartialEq, Eq, Hash)] pub enum Error { #[error("unable to parse input")] Pest, #[error("invalid number format")] NumberFormat, #[error("bad argument count")] ArgCount, } lazy_static::lazy_static! { static ref PRATT_PARSER: PrattParser = { use pest::pratt_parser::{ Assoc::*, Op, }; use Rule::*; PrattParser::new() .op(Op::infix(add, Left) | Op::infix(sub, Left) | Op::infix(modulo, Left)) .op(Op::infix(mul, Left) | Op::infix(div, Left)) .op(Op::infix(dice, Left)) .op(Op::infix(pow, Right)) .op(Op::postfix(EOI)) // discarded below }; } impl Calc { #[inline] pub fn eval>(s: S) -> Result { let result = Calc::parse(Rule::calc, s.as_ref()).map_err(|_| Error::Pest)?; eval_expr(result) } } fn eval_single_pair(pair: Pair) -> Result { use Rule::*; let result = match pair.as_rule() { oct | hex | binary => { let base = match pair.as_rule() { hex => 16, oct => 8, binary => 2, _ => unreachable!(), }; u64::from_str_radix(&pair.as_str()[2..], base).map_err(|_| Error::NumberFormat)? as f64 }, float => pair.as_str().parse::().map_err(|_| Error::NumberFormat)?, expr | num => eval_expr(pair.into_inner())?, unary_expr => { let mut p = pair.into_inner(); let op = p.next().ok_or(Error::ArgCount)?; let arg = eval_expr(p)?; match op.as_rule() { log => arg.ln(), sqrt => arg.sqrt(), sgn => arg.signum(), sin => arg.sin(), cos => arg.cos(), tan => arg.tan(), asin => arg.asin(), acos => arg.acos(), atan => arg.atan(), sinh => arg.sinh(), cosh => arg.cosh(), tanh => arg.tanh(), asinh => arg.asinh(), acosh => arg.acosh(), atanh => arg.atanh(), exp => arg.exp(), abs => arg.abs(), ceil => arg.ceil(), floor => arg.floor(), round => arg.round(), _ => unreachable!(), } }, binary_expr => { let mut p = pair.into_inner(); let op = p.next().ok_or(Error::ArgCount)?; let arg1 = eval_single_pair(p.next().ok_or(Error::ArgCount)?)?; let arg2 = eval_single_pair(p.next().ok_or(Error::ArgCount)?)?; assert!(p.next().is_none()); match op.as_rule() { min => arg1.min(arg2), max => arg1.max(arg2), atan2 => arg1.atan2(arg2), _ => unreachable!(), } }, suffix_expr => { let mut p = pair.into_inner(); let arg = eval_expr(p.next().ok_or(Error::ArgCount)?.into_inner())?; let op = p.next().ok_or(Error::ArgCount)?; assert!(p.next().is_none()); match op.as_rule() { factorial => statrs::function::gamma::gamma(arg + 1.), _ => unreachable!(), } }, _ => unreachable!(), }; Ok(result) } fn eval_expr(p: Pairs) -> Result { use Rule::*; PRATT_PARSER .map_primary(eval_single_pair) .map_infix(|lhs, op, rhs| { let lhs = lhs?; let rhs = rhs?; let result = match op.as_rule() { add => lhs + rhs, sub => lhs - rhs, mul => lhs * rhs, div => lhs / rhs, pow => lhs.powf(rhs), dice => { let dice_count = lhs as usize; let dice_faces = rhs as usize; let mut rng = rand::thread_rng(); (0..dice_count) .map(|_| rng.gen_range(1..(dice_faces + 1))) .sum::() as f64 } _ => unreachable!(), }; Ok(result) }) .map_postfix(|arg, _post| arg) // discard EOI .parse(p) } #[cfg(test)] mod test { use super::*; #[test] fn test_calc_basics() { assert_eq!(3., Calc::eval("1 + 2").unwrap()); assert_eq!(3.0f64.ln(), Calc::eval("log 3").unwrap()); assert!(6. - Calc::eval("3!").unwrap() < 0.0001); assert_eq!(3., Calc::eval("max 3 2").unwrap()); } #[test] fn test_binary_unary() { assert_eq!(3.0f64.ln(), Calc::eval("max log 3 log 2").unwrap()); } #[test] fn test_prefix_suffix() { assert!(6. - Calc::eval("abs 3!").unwrap() < 0.0001); } }