aboutsummaryrefslogtreecommitdiff
path: root/calc/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'calc/src/lib.rs')
-rw-r--r--calc/src/lib.rs189
1 files changed, 189 insertions, 0 deletions
diff --git a/calc/src/lib.rs b/calc/src/lib.rs
new file mode 100644
index 0000000..05b83a9
--- /dev/null
+++ b/calc/src/lib.rs
@@ -0,0 +1,189 @@
+use lazy_static::lazy_static;
+use pest::iterators::Pair;
+use rand::Rng;
+
+#[derive(pest_derive::Parser)]
+#[grammar = "calc.pest"]
+pub struct Calc;
+
+#[derive(Copy, Clone, thiserror::Error, Debug, PartialEq, Eq, Hash)]
+pub enum CalcError {
+ #[error("pest was unable to parse the input")]
+ Pest,
+
+ #[error("invalid number format")]
+ NumberFormat,
+
+ #[error("bad argument count")]
+ ArgCount,
+}
+
+impl Calc {
+ pub fn eval<S: AsRef<str>>(s: S) -> Result<f64, CalcError> {
+ use pest::{
+ iterators::Pairs,
+ pratt_parser::PrattParser,
+ Parser,
+ };
+
+ use self::Rule::*;
+
+ lazy_static! {
+ static ref CLIMBER: PrattParser<Rule> = {
+ use pest::pratt_parser::{
+ Assoc::*,
+ Op,
+ };
+
+ PrattParser::new()
+ .op(Op::infix(add, Left) | Op::infix(sub, Left) | Op::infix(modulo, Left))
+ .op(Op::infix(mul, Left) | Op::infix(div, Left))
+ .op(Op::infix(dice, Left))
+ .op(Op::infix(pow, Right))
+ .op(Op::postfix(EOI)) // discarded below
+ };
+ }
+
+ let result = Calc::parse(calc, s.as_ref()).map_err(|_| CalcError::Pest)?;
+
+ fn eval_single_pair(pair: Pair<Rule>) -> Result<f64, CalcError> {
+ let result = match pair.as_rule() {
+ oct | hex | binary => {
+ let base = match pair.as_rule() {
+ hex => 16,
+ oct => 8,
+ binary => 2,
+ _ => unreachable!(),
+ };
+
+ u64::from_str_radix(&pair.as_str()[2..], base)
+ .map_err(|_| CalcError::NumberFormat)? as f64
+ },
+ float => pair.as_str().parse::<f64>().map_err(|_| CalcError::NumberFormat)?,
+ expr | num => eval_expr(pair.into_inner())?,
+ unary_expr => {
+ let mut p = pair.into_inner();
+
+ let op = p.next().ok_or(CalcError::ArgCount)?;
+ let arg = eval_expr(p)?;
+
+ match op.as_rule() {
+ log => arg.ln(),
+ sqrt => arg.sqrt(),
+ sgn => arg.signum(),
+
+ sin => arg.sin(),
+ cos => arg.cos(),
+ tan => arg.tan(),
+ asin => arg.asin(),
+ acos => arg.acos(),
+ atan => arg.atan(),
+
+ sinh => arg.sinh(),
+ cosh => arg.cosh(),
+ tanh => arg.tanh(),
+ asinh => arg.asinh(),
+ acosh => arg.acosh(),
+ atanh => arg.atanh(),
+
+ exp => arg.exp(),
+ abs => arg.abs(),
+ ceil => arg.ceil(),
+ floor => arg.floor(),
+ round => arg.round(),
+ _ => unreachable!(),
+ }
+ },
+ binary_expr => {
+ let mut p = pair.into_inner();
+
+ let op = p.next().ok_or(CalcError::ArgCount)?;
+
+ let arg1 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?;
+ let arg2 = eval_single_pair(p.next().ok_or(CalcError::ArgCount)?)?;
+
+ assert!(p.next().is_none());
+
+ match op.as_rule() {
+ min => arg1.min(arg2),
+ max => arg1.max(arg2),
+ atan2 => arg1.atan2(arg2),
+ _ => unreachable!(),
+ }
+ },
+ suffix_expr => {
+ let mut p = pair.into_inner();
+
+ let arg = eval_expr(p.next().ok_or(CalcError::ArgCount)?.into_inner())?;
+ let op = p.next().ok_or(CalcError::ArgCount)?;
+
+ assert!(p.next().is_none());
+
+ match op.as_rule() {
+ factorial => statrs::function::gamma::gamma(arg + 1.),
+ _ => unreachable!(),
+ }
+ },
+ _ => unreachable!(),
+ };
+
+ Ok(result)
+ }
+
+ pub fn eval_expr(p: Pairs<Rule>) -> Result<f64, CalcError> {
+ CLIMBER
+ .map_primary(eval_single_pair)
+ .map_infix(|lhs, op, rhs| {
+ let lhs = lhs?;
+ let rhs = rhs?;
+
+ let result = match op.as_rule() {
+ add => lhs + rhs,
+ sub => lhs - rhs,
+ mul => lhs * rhs,
+ div => lhs / rhs,
+ pow => lhs.powf(rhs),
+ dice => {
+ let dice_count = lhs as usize;
+ let dice_faces = rhs as usize;
+
+ let mut rng = rand::thread_rng();
+ (0..dice_count)
+ .map(|_| rng.gen_range(1..(dice_faces + 1)))
+ .sum::<usize>() as f64
+ },
+ _ => unreachable!(),
+ };
+
+ Ok(result)
+ })
+ .map_postfix(|arg, _post| arg) // discard EOI
+ .parse(p)
+ }
+
+ eval_expr(result)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_calc_basics() {
+ assert_eq!(3., Calc::eval("1 + 2").unwrap());
+ assert_eq!(3.0f64.ln(), Calc::eval("log 3").unwrap());
+ assert!(6. - Calc::eval("3!").unwrap() < 0.0001);
+ assert_eq!(3., Calc::eval("max 3 2").unwrap());
+ }
+
+ #[test]
+ fn test_binary_unary() {
+ assert_eq!(3.0f64.ln(), Calc::eval("max log 3 log 2").unwrap());
+ }
+
+ #[test]
+ fn test_prefix_suffix() {
+ assert!(6. - Calc::eval("abs 3!").unwrap() < 0.0001);
+ }
+}