Last active
February 6, 2026 23:46
-
-
Save resilar/2ebc21800b2c27d27a335c11659ce806 to your computer and use it in GitHub Desktop.
Zero-cost fixed-point decimals in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::cmp::{Eq, Ordering, PartialEq, PartialOrd}; | |
| use std::convert::From; | |
| use std::error::Error; | |
| use std::fmt; | |
| use std::num::ParseIntError; | |
| use std::str::FromStr; | |
| /// Fixed-point decimal for accurate monetary calculations. | |
| /// N fractional base-10 digits. | |
| #[derive(Clone, Copy)] | |
| pub struct Decimal<const N: u32>(i64); | |
| impl<const N: u32> fmt::Display for Decimal<N> { | |
| fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
| if N > 0 { | |
| if self.0.is_negative() { | |
| use fmt::Write; | |
| f.write_char('-')?; | |
| } | |
| let (val, div) = (self.0.unsigned_abs(), 10_u64.saturating_pow(N)); | |
| let (lhs, rhs) = (val / div, val % div); | |
| write!(f, "{lhs}.{rhs:0>pad$}", pad = N as usize) | |
| } else { | |
| write!(f, "{}", self.0) | |
| } | |
| } | |
| } | |
| impl<const N: u32> fmt::Debug for Decimal<N> { | |
| fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
| write!(f, "Decimal<{}>({})", N, self.0) | |
| } | |
| } | |
| impl<const N: u32> FromStr for Decimal<N> { | |
| type Err = DecimalError; | |
| fn from_str(s: &str) -> Result<Self, Self::Err> { | |
| let mut parts = s.splitn(3, '.'); | |
| if let (Some(lhs), cdr, None) = (parts.next(), parts.next(), parts.next()) | |
| && !lhs.is_empty() | |
| && let Some(rhs) = cdr.or(Some("")) | |
| && (cdr.is_none() || (!rhs.is_empty() && rhs.len() <= N as usize)) | |
| { | |
| let mantissa = format!("{lhs}{rhs:0<pad$}", pad = N as usize); | |
| Ok(Decimal(mantissa.parse()?)) | |
| } else { | |
| Err(DecimalError::InvalidFormat) | |
| } | |
| } | |
| } | |
| impl<const N: u32> From<i64> for Decimal<N> { | |
| fn from(value: i64) -> Decimal<N> { | |
| Decimal(value) | |
| } | |
| } | |
| impl<const N: u32> From<Decimal<N>> for i64 { | |
| fn from(value: Decimal<N>) -> i64 { | |
| value.0 | |
| } | |
| } | |
| impl<const N: u32, const M: u32> PartialEq<Decimal<M>> for Decimal<N> { | |
| fn eq(&self, other: &Decimal<M>) -> bool { | |
| if let (Some(lhs), Some(rhs)) = ( | |
| self.0.checked_mul(10_i64.pow(M.saturating_sub(N))), | |
| other.0.checked_mul(10_i64.pow(N.saturating_sub(M))), | |
| ) { | |
| lhs == rhs | |
| } else { | |
| false | |
| } | |
| } | |
| } | |
| impl<const N: u32> Eq for Decimal<N> {} | |
| impl<const N: u32, const M: u32> PartialOrd<Decimal<M>> for Decimal<N> { | |
| fn partial_cmp(&self, other: &Decimal<M>) -> Option<Ordering> { | |
| let lhs = self.0.checked_mul(10_i64.pow(M.saturating_sub(N)))?; | |
| let rhs = other.0.checked_mul(10_i64.pow(N.saturating_sub(M)))?; | |
| lhs.partial_cmp(&rhs) | |
| } | |
| } | |
| #[derive(Debug, PartialEq)] | |
| pub enum DecimalError { | |
| InvalidFormat, | |
| ParseInt(ParseIntError), | |
| Overflow, | |
| } | |
| impl fmt::Display for DecimalError { | |
| fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
| use DecimalError::*; | |
| match *self { | |
| InvalidFormat => write!(f, "Invalid fixed point format"), | |
| ParseInt(ref err) => write!(f, "ParseIntError: {err}"), | |
| Overflow => write!(f, "Fixed point arithmetic overflow"), | |
| } | |
| } | |
| } | |
| impl Error for DecimalError { | |
| fn source(&self) -> Option<&(dyn Error + 'static)> { | |
| match *self { | |
| DecimalError::InvalidFormat => None, | |
| DecimalError::ParseInt(ref err) => Some(err), | |
| DecimalError::Overflow => None, | |
| } | |
| } | |
| } | |
| impl From<ParseIntError> for DecimalError { | |
| fn from(err: ParseIntError) -> DecimalError { | |
| DecimalError::ParseInt(err) | |
| } | |
| } | |
| /// Helper macro for complie-time const asserts | |
| macro_rules! max { | |
| ($L: ident, $R: ident) => { | |
| if $L < $R { $R } else { $L } | |
| }; | |
| } | |
| /// Module for fixed-point decimal arithmetic functions. | |
| mod dec { | |
| use super::Decimal; | |
| /// Add L-digit & R-digit decimals, return O-digit sum. | |
| /// | |
| /// Requirement: `O = max(L, R)` (verified statically). | |
| pub fn add<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Decimal<O> { | |
| checked_add(lhs, rhs).expect("fixed::dec::add() overflow") | |
| } | |
| pub fn checked_add<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Option<Decimal<O>> { | |
| const { assert!(O == max!(L, R)) }; | |
| let lhs = lhs.0.checked_mul(10_i64.pow(R.saturating_sub(L)))?; | |
| let rhs = rhs.0.checked_mul(10_i64.pow(L.saturating_sub(R)))?; | |
| lhs.checked_add(rhs).map(Decimal) | |
| } | |
| /// Subtract L-digit & R-digit decimals, return O-digit difference. | |
| /// | |
| /// Requirement: `O = max(L, R)` (verified statically). | |
| pub fn sub<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Decimal<O> { | |
| checked_sub(lhs, rhs).expect("fixed::dec::sub() overflow") | |
| } | |
| pub fn checked_sub<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Option<Decimal<O>> { | |
| const { assert!(O == max!(L, R)) }; | |
| let lhs = lhs.0.checked_mul(10_i64.pow(R.saturating_sub(L)))?; | |
| let rhs = rhs.0.checked_mul(10_i64.pow(L.saturating_sub(R)))?; | |
| lhs.checked_sub(rhs).map(Decimal) | |
| } | |
| /// Multiply L-digit & R-digit decimals, return O-digit product. | |
| /// | |
| /// Requirement: `O = L + R` (verified statically). | |
| pub fn mul<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Decimal<O> { | |
| checked_mul(lhs, rhs).expect("fixed::dec::mul() overflow") | |
| } | |
| pub fn checked_mul<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Option<Decimal<O>> { | |
| const { assert!(O == L + R) }; | |
| Some(Decimal(lhs.0.checked_mul(rhs.0)?)) | |
| } | |
| /// Divide L-digit & R-digit decimals, return O-digit quotinent. | |
| /// | |
| /// Default rounding mode is [`RoundingMode::Down`], i.e., towards zero. | |
| /// | |
| /// # Example (non-default rounding) | |
| /// | |
| /// ``` | |
| /// let lhs: Decimal<2> = "1.49".parse().unwrap(); | |
| /// let rhs: Decimal<0> = "2".parse().unwrap(); | |
| /// let res: Decimal<3> = div(lhs, rhs); | |
| /// assert_eq!(res.to_string(), "0.745"); | |
| /// assert_eq!(res.round::<2>(RoundingMode::Ceil).to_string(), "0.75"); | |
| /// ``` | |
| pub fn div<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Decimal<O> { | |
| checked_div(lhs, rhs).expect("fixed::dec::div() overflow") | |
| } | |
| pub fn checked_div<const O: u32, const L: u32, const R: u32>( | |
| lhs: Decimal<L>, | |
| rhs: Decimal<R>, | |
| ) -> Option<Decimal<O>> { | |
| let lhs = (lhs.0 as i128).checked_mul(10_i128.pow(O.saturating_sub(L) + R))?; | |
| let rhs = (rhs.0 as i128).checked_mul(10_i128.pow(L.saturating_sub(O)))?; | |
| Some(Decimal(lhs.checked_div(rhs)?.try_into().ok()?)) | |
| } | |
| } | |
| #[derive(Clone, Copy, Debug)] | |
| pub enum RoundingMode { | |
| /// Towards zero | |
| Down, | |
| /// Away from zero | |
| Up, | |
| /// Half integer, ties towards zero. | |
| HalfDown, | |
| /// Half integer, ties away from zero | |
| HalfUp, | |
| /// Towards negative infinity | |
| Floor, | |
| /// Towards positive infinity | |
| Ceil, | |
| } | |
| impl<const N: u32> Decimal<N> { | |
| /// Truncate N-digit decimal to M digits, rounding according to [`RoundingMode`]. | |
| pub fn round<const M: u32>(self, mode: RoundingMode) -> Decimal<M> { | |
| const { assert!(M < N) }; | |
| let div = 10_i64.pow(N.saturating_sub(M)); | |
| let (val, rem) = (self.0 / div, (self.0 % div).unsigned_abs() as i64); | |
| match mode { | |
| _ if rem == 0 => val * 10_i64.pow(M.saturating_sub(N)), | |
| RoundingMode::Up => val + self.0.signum(), | |
| RoundingMode::HalfDown if rem > div / 2 => val + self.0.signum(), | |
| RoundingMode::HalfUp if rem >= div / 2 => val + self.0.signum(), | |
| RoundingMode::Floor if self.0.is_negative() => val - 1, | |
| RoundingMode::Ceil if self.0.is_positive() => val + 1, | |
| _ => val, | |
| } | |
| .into() | |
| } | |
| /// Truncate N-digit decimal to M digits (`M < N`). | |
| pub fn truncate<const M: u32>(self) -> Decimal<M> { | |
| const { assert!(M < N) }; | |
| (self.0 / 10_i64.pow(N - M)).into() | |
| } | |
| /// Extend N-digit decimal to M digits (`M > N`). | |
| pub fn extend<const M: u32>(self) -> Decimal<M> { | |
| self.checked_extend() | |
| .expect("fixed::Decimal::extend() overflow") | |
| } | |
| pub fn checked_extend<const M: u32>(self) -> Option<Decimal<M>> { | |
| const { assert!(M > N) }; | |
| Some(self.0.checked_mul(10_i64.pow(M - N))?.into()) | |
| } | |
| } | |
| #[cfg(test)] | |
| mod tests { | |
| use super::dec::*; | |
| use super::*; | |
| #[test] | |
| fn decimal_fmt() { | |
| let x = Decimal::<5>(12345_67890_i64); | |
| assert_eq!(format!("{x}"), "12345.67890"); | |
| assert_eq!("8765.4321".parse(), Ok(Decimal::<4>(8765_4321))); | |
| assert_eq!("8765.43".parse(), Ok(Decimal::<4>(8765_4300))); | |
| assert_eq!("8765".parse(), Ok(Decimal::<4>(8765_0000))); | |
| assert!("8765.".parse::<Decimal<4>>().is_err()); | |
| assert!(".8765".parse::<Decimal<4>>().is_err()); | |
| assert!("8765.43210".parse::<Decimal<4>>().is_err()); | |
| let x: Result<Decimal<0>, _> = "-42".parse(); | |
| assert_eq!(x, Ok(Decimal(-42))); | |
| assert_eq!(x.unwrap().to_string(), "-42"); | |
| let x: Result<Decimal<1>, _> = "+4.2".parse(); | |
| assert_eq!(x, Ok(Decimal(42))); | |
| assert_eq!(x.unwrap().to_string(), "4.2"); | |
| let x = Decimal::<4>(-1234_5678); | |
| assert_eq!(x, Decimal::<4>(-12345678)); | |
| assert_eq!(x.to_string(), "-1234.5678"); | |
| let x: Decimal<10> = "-0.0012345678".parse().unwrap(); | |
| assert_eq!(x, Decimal::<10>(-12345678)); | |
| assert_eq!(x.to_string(), "-0.0012345678"); | |
| } | |
| #[test] | |
| fn decimal_cmp() { | |
| let x = Decimal::<2>(12_34); | |
| let y = Decimal::<4>(1234); | |
| assert_ne!(x, y); | |
| assert!(x > y); | |
| let z = Decimal::<4>(12_3400); | |
| assert_eq!(x, z); | |
| assert_ne!(Decimal::<1>(-12345678), Decimal::<3>(-12345678)); | |
| } | |
| #[test] | |
| fn decimal_add() { | |
| let x = Decimal::<4>(1234_5678); | |
| let y = Decimal::<4>(8765_4321); | |
| let xaddy: Decimal<4> = add(x, y); | |
| assert_eq!( | |
| "1234.5678 + 8765.4321 = 9999.9999", | |
| format!("{x} + {y} = {xaddy}") | |
| ); | |
| assert_eq!(xaddy, add::<4, _, _>(y, x)); | |
| let x: Decimal<20> = (i64::MAX - 1).into(); | |
| let (one, two): (Decimal<20>, Decimal<20>) = (1.into(), 2.into()); | |
| assert_eq!(checked_add::<20, _, _>(x, one), Some(Decimal(i64::MAX))); | |
| assert_eq!(checked_add::<20, _, _>(x, two), None); | |
| } | |
| #[test] | |
| fn decimal_sub() { | |
| let x = Decimal::<4>(9999_9999); | |
| let y = Decimal::<4>(1234_5678); | |
| let xsuby: Decimal<4> = sub(x, y); | |
| assert_eq!( | |
| "9999.9999 - 1234.5678 = 8765.4321", | |
| format!("{x} - {y} = {xsuby}") | |
| ); | |
| let x: Decimal<20> = (i64::MIN + 1).into(); | |
| let (one, two): (Decimal<20>, Decimal<20>) = (1.into(), 2.into()); | |
| assert_eq!(checked_sub::<20, _, _>(x, one), Some(Decimal(i64::MIN))); | |
| assert_eq!(checked_sub::<20, _, _>(x, two), None); | |
| } | |
| #[test] | |
| fn decimal_mul() { | |
| let x = Decimal::<4>(12345678); | |
| let y = Decimal::<4>(100000001); | |
| let xmuly: Decimal<8> = mul(x, y); | |
| assert_eq!( | |
| "1234.5678 * 10000.0001 = 12345678.12345678", | |
| format!("{x} * {y} = {xmuly}") | |
| ); | |
| let (x, y) = (Decimal::<0>(1234), Decimal::<4>(1)); | |
| let xmuly: Decimal<4> = mul(x, y); | |
| assert_eq!("1234 * 0.0001 = 0.1234", format!("{x} * {y} = {xmuly}")); | |
| let (x, y) = (Decimal::<1>(1234), Decimal::<4>(1)); | |
| let xmuly: Decimal<5> = mul(x, y); | |
| assert_eq!("123.4 * 0.0001 = 0.01234", format!("{x} * {y} = {xmuly}")); | |
| let (x, y) = (Decimal::<2>(1234), Decimal::<2>(10000)); | |
| let xmuly: Decimal<4> = mul(x, y); | |
| assert_eq!("12.34 * 100.00 = 1234.0000", format!("{x} * {y} = {xmuly}")); | |
| } | |
| #[test] | |
| fn decimal_div() { | |
| let x: Decimal<2> = "1.49".parse().unwrap(); | |
| let y: Decimal<0> = "2".parse().unwrap(); | |
| let xdivy: Decimal<3> = div(x, y); | |
| assert_eq!(format!("{x} / {y} = {xdivy}"), "1.49 / 2 = 0.745"); | |
| let x = Decimal::<4>(1234_5678); | |
| let y = Decimal::<2>(1); | |
| let xdivy: Decimal<4> = div(x, y); | |
| assert_eq!(x.to_string(), "1234.5678"); | |
| assert_eq!(y.to_string(), "0.01"); | |
| assert_eq!(xdivy.to_string(), "123456.7800"); | |
| assert_eq!(div::<7, _, 2>(x, Decimal(1)).to_string(), "123456.7800000"); | |
| assert_eq!(div::<4, _, 2>(x, Decimal(2)).to_string(), "61728.3900"); | |
| assert_eq!(div::<4, _, 3>(x, Decimal(1000)).to_string(), "1234.5678"); | |
| assert_eq!(div::<4, _, 3>(x, Decimal(100_000)).to_string(), "12.3456"); | |
| assert_eq!(div::<4, _, 5>(x, Decimal(1)).to_string(), "123456780.0000"); | |
| assert_eq!(div::<4, _, 5>(x, Decimal(12_34567)).to_string(), "100.0000"); | |
| let x: Decimal<12> = "1.003000000001".parse().unwrap(); | |
| let y: Decimal<4> = "1234.5678".parse().unwrap(); | |
| let z: Decimal<4> = "12.5678".parse().unwrap(); | |
| assert_eq!(div::<18, _, _>(x, y).to_string(), "0.000812430066620075"); | |
| assert_eq!(div::<16, _, _>(x, z).to_string(), "0.0798071261478540"); | |
| assert_eq!(div::<15, _, _>(x, z).to_string(), "0.079807126147854"); | |
| let x: Decimal<10> = "0.1234567890".parse().unwrap(); | |
| assert_eq!(div::<9, _, 0>(x, Decimal(1)).to_string(), "0.123456789"); | |
| assert_eq!(div::<9, _, 0>(x, Decimal(-1)).to_string(), "-0.123456789"); | |
| assert_eq!( | |
| div::<3, 2, 0>(Decimal(149), Decimal(2)).to_string(), | |
| "0.745" | |
| ); | |
| } | |
| #[test] | |
| fn decimal_round() { | |
| let res: Decimal<3> = "0.745".parse().unwrap(); | |
| for (mode, (a, b, c)) in [ | |
| (RoundingMode::Down, ("0.74", "0.7", "0")), | |
| (RoundingMode::Up, ("0.75", "0.8", "1")), | |
| (RoundingMode::HalfDown, ("0.74", "0.7", "1")), | |
| (RoundingMode::HalfUp, ("0.75", "0.7", "1")), | |
| (RoundingMode::Floor, ("0.74", "0.7", "0")), | |
| (RoundingMode::Ceil, ("0.75", "0.8", "1")), | |
| ] | |
| .iter() | |
| { | |
| assert_eq!(*a, res.round::<2>(*mode).to_string()); | |
| assert_eq!(*b, res.round::<1>(*mode).to_string()); | |
| assert_eq!(*c, res.round::<0>(*mode).to_string()); | |
| } | |
| let res: Decimal<3> = "-0.745".parse().unwrap(); | |
| for (mode, (a, b, c)) in [ | |
| (RoundingMode::Down, ("-0.74", "-0.7", "0")), | |
| (RoundingMode::Up, ("-0.75", "-0.8", "-1")), | |
| (RoundingMode::HalfDown, ("-0.74", "-0.7", "-1")), | |
| (RoundingMode::HalfUp, ("-0.75", "-0.7", "-1")), | |
| (RoundingMode::Floor, ("-0.75", "-0.8", "-1")), | |
| (RoundingMode::Ceil, ("-0.74", "-0.7", "0")), | |
| ] | |
| .iter() | |
| { | |
| assert_eq!(*a, res.round::<2>(*mode).to_string()); | |
| assert_eq!(*b, res.round::<1>(*mode).to_string()); | |
| assert_eq!(*c, res.round::<0>(*mode).to_string()); | |
| } | |
| } | |
| #[test] | |
| fn decimal_truncate() { | |
| let x: Decimal<3> = "-0.745".parse().unwrap(); | |
| let y: Decimal<2> = x.truncate(); | |
| let z: Decimal<1> = x.truncate(); | |
| assert_eq!(format!("{x} -> {y} -> {z}"), "-0.745 -> -0.74 -> -0.7"); | |
| } | |
| #[test] | |
| fn decimal_extend() { | |
| let x: Decimal<3> = "-0.745".parse().unwrap(); | |
| let y: Decimal<4> = x.extend(); | |
| assert_eq!(y.to_string(), "-0.7450"); | |
| let x: Decimal<5> = Decimal(i64::MAX / 10); | |
| let y: Option<Decimal<6>> = x.checked_extend(); | |
| let z: Option<Decimal<7>> = x.checked_extend(); | |
| assert!(y.is_some()); | |
| assert!(z.is_none()); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment