Skip to content

Instantly share code, notes, and snippets.

@resilar
Last active February 6, 2026 23:46
Show Gist options
  • Select an option

  • Save resilar/2ebc21800b2c27d27a335c11659ce806 to your computer and use it in GitHub Desktop.

Select an option

Save resilar/2ebc21800b2c27d27a335c11659ce806 to your computer and use it in GitHub Desktop.
Zero-cost fixed-point decimals in Rust
use std::cmp::{Eq, Ordering, PartialEq, PartialOrd};
use std::convert::From;
use std::error::Error;
use std::fmt;
use std::num::ParseIntError;
use std::str::FromStr;
/// Fixed-point decimal for accurate monetary calculations.
/// N fractional base-10 digits.
#[derive(Clone, Copy)]
pub struct Decimal<const N: u32>(i64);
impl<const N: u32> fmt::Display for Decimal<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if N > 0 {
if self.0.is_negative() {
use fmt::Write;
f.write_char('-')?;
}
let (val, div) = (self.0.unsigned_abs(), 10_u64.saturating_pow(N));
let (lhs, rhs) = (val / div, val % div);
write!(f, "{lhs}.{rhs:0>pad$}", pad = N as usize)
} else {
write!(f, "{}", self.0)
}
}
}
impl<const N: u32> fmt::Debug for Decimal<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Decimal<{}>({})", N, self.0)
}
}
impl<const N: u32> FromStr for Decimal<N> {
type Err = DecimalError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.splitn(3, '.');
if let (Some(lhs), cdr, None) = (parts.next(), parts.next(), parts.next())
&& !lhs.is_empty()
&& let Some(rhs) = cdr.or(Some(""))
&& (cdr.is_none() || (!rhs.is_empty() && rhs.len() <= N as usize))
{
let mantissa = format!("{lhs}{rhs:0<pad$}", pad = N as usize);
Ok(Decimal(mantissa.parse()?))
} else {
Err(DecimalError::InvalidFormat)
}
}
}
impl<const N: u32> From<i64> for Decimal<N> {
fn from(value: i64) -> Decimal<N> {
Decimal(value)
}
}
impl<const N: u32> From<Decimal<N>> for i64 {
fn from(value: Decimal<N>) -> i64 {
value.0
}
}
impl<const N: u32, const M: u32> PartialEq<Decimal<M>> for Decimal<N> {
fn eq(&self, other: &Decimal<M>) -> bool {
if let (Some(lhs), Some(rhs)) = (
self.0.checked_mul(10_i64.pow(M.saturating_sub(N))),
other.0.checked_mul(10_i64.pow(N.saturating_sub(M))),
) {
lhs == rhs
} else {
false
}
}
}
impl<const N: u32> Eq for Decimal<N> {}
impl<const N: u32, const M: u32> PartialOrd<Decimal<M>> for Decimal<N> {
fn partial_cmp(&self, other: &Decimal<M>) -> Option<Ordering> {
let lhs = self.0.checked_mul(10_i64.pow(M.saturating_sub(N)))?;
let rhs = other.0.checked_mul(10_i64.pow(N.saturating_sub(M)))?;
lhs.partial_cmp(&rhs)
}
}
#[derive(Debug, PartialEq)]
pub enum DecimalError {
InvalidFormat,
ParseInt(ParseIntError),
Overflow,
}
impl fmt::Display for DecimalError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use DecimalError::*;
match *self {
InvalidFormat => write!(f, "Invalid fixed point format"),
ParseInt(ref err) => write!(f, "ParseIntError: {err}"),
Overflow => write!(f, "Fixed point arithmetic overflow"),
}
}
}
impl Error for DecimalError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
DecimalError::InvalidFormat => None,
DecimalError::ParseInt(ref err) => Some(err),
DecimalError::Overflow => None,
}
}
}
impl From<ParseIntError> for DecimalError {
fn from(err: ParseIntError) -> DecimalError {
DecimalError::ParseInt(err)
}
}
/// Helper macro for complie-time const asserts
macro_rules! max {
($L: ident, $R: ident) => {
if $L < $R { $R } else { $L }
};
}
/// Module for fixed-point decimal arithmetic functions.
mod dec {
use super::Decimal;
/// Add L-digit & R-digit decimals, return O-digit sum.
///
/// Requirement: `O = max(L, R)` (verified statically).
pub fn add<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Decimal<O> {
checked_add(lhs, rhs).expect("fixed::dec::add() overflow")
}
pub fn checked_add<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Option<Decimal<O>> {
const { assert!(O == max!(L, R)) };
let lhs = lhs.0.checked_mul(10_i64.pow(R.saturating_sub(L)))?;
let rhs = rhs.0.checked_mul(10_i64.pow(L.saturating_sub(R)))?;
lhs.checked_add(rhs).map(Decimal)
}
/// Subtract L-digit & R-digit decimals, return O-digit difference.
///
/// Requirement: `O = max(L, R)` (verified statically).
pub fn sub<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Decimal<O> {
checked_sub(lhs, rhs).expect("fixed::dec::sub() overflow")
}
pub fn checked_sub<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Option<Decimal<O>> {
const { assert!(O == max!(L, R)) };
let lhs = lhs.0.checked_mul(10_i64.pow(R.saturating_sub(L)))?;
let rhs = rhs.0.checked_mul(10_i64.pow(L.saturating_sub(R)))?;
lhs.checked_sub(rhs).map(Decimal)
}
/// Multiply L-digit & R-digit decimals, return O-digit product.
///
/// Requirement: `O = L + R` (verified statically).
pub fn mul<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Decimal<O> {
checked_mul(lhs, rhs).expect("fixed::dec::mul() overflow")
}
pub fn checked_mul<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Option<Decimal<O>> {
const { assert!(O == L + R) };
Some(Decimal(lhs.0.checked_mul(rhs.0)?))
}
/// Divide L-digit & R-digit decimals, return O-digit quotinent.
///
/// Default rounding mode is [`RoundingMode::Down`], i.e., towards zero.
///
/// # Example (non-default rounding)
///
/// ```
/// let lhs: Decimal<2> = "1.49".parse().unwrap();
/// let rhs: Decimal<0> = "2".parse().unwrap();
/// let res: Decimal<3> = div(lhs, rhs);
/// assert_eq!(res.to_string(), "0.745");
/// assert_eq!(res.round::<2>(RoundingMode::Ceil).to_string(), "0.75");
/// ```
pub fn div<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Decimal<O> {
checked_div(lhs, rhs).expect("fixed::dec::div() overflow")
}
pub fn checked_div<const O: u32, const L: u32, const R: u32>(
lhs: Decimal<L>,
rhs: Decimal<R>,
) -> Option<Decimal<O>> {
let lhs = (lhs.0 as i128).checked_mul(10_i128.pow(O.saturating_sub(L) + R))?;
let rhs = (rhs.0 as i128).checked_mul(10_i128.pow(L.saturating_sub(O)))?;
Some(Decimal(lhs.checked_div(rhs)?.try_into().ok()?))
}
}
#[derive(Clone, Copy, Debug)]
pub enum RoundingMode {
/// Towards zero
Down,
/// Away from zero
Up,
/// Half integer, ties towards zero.
HalfDown,
/// Half integer, ties away from zero
HalfUp,
/// Towards negative infinity
Floor,
/// Towards positive infinity
Ceil,
}
impl<const N: u32> Decimal<N> {
/// Truncate N-digit decimal to M digits, rounding according to [`RoundingMode`].
pub fn round<const M: u32>(self, mode: RoundingMode) -> Decimal<M> {
const { assert!(M < N) };
let div = 10_i64.pow(N.saturating_sub(M));
let (val, rem) = (self.0 / div, (self.0 % div).unsigned_abs() as i64);
match mode {
_ if rem == 0 => val * 10_i64.pow(M.saturating_sub(N)),
RoundingMode::Up => val + self.0.signum(),
RoundingMode::HalfDown if rem > div / 2 => val + self.0.signum(),
RoundingMode::HalfUp if rem >= div / 2 => val + self.0.signum(),
RoundingMode::Floor if self.0.is_negative() => val - 1,
RoundingMode::Ceil if self.0.is_positive() => val + 1,
_ => val,
}
.into()
}
/// Truncate N-digit decimal to M digits (`M < N`).
pub fn truncate<const M: u32>(self) -> Decimal<M> {
const { assert!(M < N) };
(self.0 / 10_i64.pow(N - M)).into()
}
/// Extend N-digit decimal to M digits (`M > N`).
pub fn extend<const M: u32>(self) -> Decimal<M> {
self.checked_extend()
.expect("fixed::Decimal::extend() overflow")
}
pub fn checked_extend<const M: u32>(self) -> Option<Decimal<M>> {
const { assert!(M > N) };
Some(self.0.checked_mul(10_i64.pow(M - N))?.into())
}
}
#[cfg(test)]
mod tests {
use super::dec::*;
use super::*;
#[test]
fn decimal_fmt() {
let x = Decimal::<5>(12345_67890_i64);
assert_eq!(format!("{x}"), "12345.67890");
assert_eq!("8765.4321".parse(), Ok(Decimal::<4>(8765_4321)));
assert_eq!("8765.43".parse(), Ok(Decimal::<4>(8765_4300)));
assert_eq!("8765".parse(), Ok(Decimal::<4>(8765_0000)));
assert!("8765.".parse::<Decimal<4>>().is_err());
assert!(".8765".parse::<Decimal<4>>().is_err());
assert!("8765.43210".parse::<Decimal<4>>().is_err());
let x: Result<Decimal<0>, _> = "-42".parse();
assert_eq!(x, Ok(Decimal(-42)));
assert_eq!(x.unwrap().to_string(), "-42");
let x: Result<Decimal<1>, _> = "+4.2".parse();
assert_eq!(x, Ok(Decimal(42)));
assert_eq!(x.unwrap().to_string(), "4.2");
let x = Decimal::<4>(-1234_5678);
assert_eq!(x, Decimal::<4>(-12345678));
assert_eq!(x.to_string(), "-1234.5678");
let x: Decimal<10> = "-0.0012345678".parse().unwrap();
assert_eq!(x, Decimal::<10>(-12345678));
assert_eq!(x.to_string(), "-0.0012345678");
}
#[test]
fn decimal_cmp() {
let x = Decimal::<2>(12_34);
let y = Decimal::<4>(1234);
assert_ne!(x, y);
assert!(x > y);
let z = Decimal::<4>(12_3400);
assert_eq!(x, z);
assert_ne!(Decimal::<1>(-12345678), Decimal::<3>(-12345678));
}
#[test]
fn decimal_add() {
let x = Decimal::<4>(1234_5678);
let y = Decimal::<4>(8765_4321);
let xaddy: Decimal<4> = add(x, y);
assert_eq!(
"1234.5678 + 8765.4321 = 9999.9999",
format!("{x} + {y} = {xaddy}")
);
assert_eq!(xaddy, add::<4, _, _>(y, x));
let x: Decimal<20> = (i64::MAX - 1).into();
let (one, two): (Decimal<20>, Decimal<20>) = (1.into(), 2.into());
assert_eq!(checked_add::<20, _, _>(x, one), Some(Decimal(i64::MAX)));
assert_eq!(checked_add::<20, _, _>(x, two), None);
}
#[test]
fn decimal_sub() {
let x = Decimal::<4>(9999_9999);
let y = Decimal::<4>(1234_5678);
let xsuby: Decimal<4> = sub(x, y);
assert_eq!(
"9999.9999 - 1234.5678 = 8765.4321",
format!("{x} - {y} = {xsuby}")
);
let x: Decimal<20> = (i64::MIN + 1).into();
let (one, two): (Decimal<20>, Decimal<20>) = (1.into(), 2.into());
assert_eq!(checked_sub::<20, _, _>(x, one), Some(Decimal(i64::MIN)));
assert_eq!(checked_sub::<20, _, _>(x, two), None);
}
#[test]
fn decimal_mul() {
let x = Decimal::<4>(12345678);
let y = Decimal::<4>(100000001);
let xmuly: Decimal<8> = mul(x, y);
assert_eq!(
"1234.5678 * 10000.0001 = 12345678.12345678",
format!("{x} * {y} = {xmuly}")
);
let (x, y) = (Decimal::<0>(1234), Decimal::<4>(1));
let xmuly: Decimal<4> = mul(x, y);
assert_eq!("1234 * 0.0001 = 0.1234", format!("{x} * {y} = {xmuly}"));
let (x, y) = (Decimal::<1>(1234), Decimal::<4>(1));
let xmuly: Decimal<5> = mul(x, y);
assert_eq!("123.4 * 0.0001 = 0.01234", format!("{x} * {y} = {xmuly}"));
let (x, y) = (Decimal::<2>(1234), Decimal::<2>(10000));
let xmuly: Decimal<4> = mul(x, y);
assert_eq!("12.34 * 100.00 = 1234.0000", format!("{x} * {y} = {xmuly}"));
}
#[test]
fn decimal_div() {
let x: Decimal<2> = "1.49".parse().unwrap();
let y: Decimal<0> = "2".parse().unwrap();
let xdivy: Decimal<3> = div(x, y);
assert_eq!(format!("{x} / {y} = {xdivy}"), "1.49 / 2 = 0.745");
let x = Decimal::<4>(1234_5678);
let y = Decimal::<2>(1);
let xdivy: Decimal<4> = div(x, y);
assert_eq!(x.to_string(), "1234.5678");
assert_eq!(y.to_string(), "0.01");
assert_eq!(xdivy.to_string(), "123456.7800");
assert_eq!(div::<7, _, 2>(x, Decimal(1)).to_string(), "123456.7800000");
assert_eq!(div::<4, _, 2>(x, Decimal(2)).to_string(), "61728.3900");
assert_eq!(div::<4, _, 3>(x, Decimal(1000)).to_string(), "1234.5678");
assert_eq!(div::<4, _, 3>(x, Decimal(100_000)).to_string(), "12.3456");
assert_eq!(div::<4, _, 5>(x, Decimal(1)).to_string(), "123456780.0000");
assert_eq!(div::<4, _, 5>(x, Decimal(12_34567)).to_string(), "100.0000");
let x: Decimal<12> = "1.003000000001".parse().unwrap();
let y: Decimal<4> = "1234.5678".parse().unwrap();
let z: Decimal<4> = "12.5678".parse().unwrap();
assert_eq!(div::<18, _, _>(x, y).to_string(), "0.000812430066620075");
assert_eq!(div::<16, _, _>(x, z).to_string(), "0.0798071261478540");
assert_eq!(div::<15, _, _>(x, z).to_string(), "0.079807126147854");
let x: Decimal<10> = "0.1234567890".parse().unwrap();
assert_eq!(div::<9, _, 0>(x, Decimal(1)).to_string(), "0.123456789");
assert_eq!(div::<9, _, 0>(x, Decimal(-1)).to_string(), "-0.123456789");
assert_eq!(
div::<3, 2, 0>(Decimal(149), Decimal(2)).to_string(),
"0.745"
);
}
#[test]
fn decimal_round() {
let res: Decimal<3> = "0.745".parse().unwrap();
for (mode, (a, b, c)) in [
(RoundingMode::Down, ("0.74", "0.7", "0")),
(RoundingMode::Up, ("0.75", "0.8", "1")),
(RoundingMode::HalfDown, ("0.74", "0.7", "1")),
(RoundingMode::HalfUp, ("0.75", "0.7", "1")),
(RoundingMode::Floor, ("0.74", "0.7", "0")),
(RoundingMode::Ceil, ("0.75", "0.8", "1")),
]
.iter()
{
assert_eq!(*a, res.round::<2>(*mode).to_string());
assert_eq!(*b, res.round::<1>(*mode).to_string());
assert_eq!(*c, res.round::<0>(*mode).to_string());
}
let res: Decimal<3> = "-0.745".parse().unwrap();
for (mode, (a, b, c)) in [
(RoundingMode::Down, ("-0.74", "-0.7", "0")),
(RoundingMode::Up, ("-0.75", "-0.8", "-1")),
(RoundingMode::HalfDown, ("-0.74", "-0.7", "-1")),
(RoundingMode::HalfUp, ("-0.75", "-0.7", "-1")),
(RoundingMode::Floor, ("-0.75", "-0.8", "-1")),
(RoundingMode::Ceil, ("-0.74", "-0.7", "0")),
]
.iter()
{
assert_eq!(*a, res.round::<2>(*mode).to_string());
assert_eq!(*b, res.round::<1>(*mode).to_string());
assert_eq!(*c, res.round::<0>(*mode).to_string());
}
}
#[test]
fn decimal_truncate() {
let x: Decimal<3> = "-0.745".parse().unwrap();
let y: Decimal<2> = x.truncate();
let z: Decimal<1> = x.truncate();
assert_eq!(format!("{x} -> {y} -> {z}"), "-0.745 -> -0.74 -> -0.7");
}
#[test]
fn decimal_extend() {
let x: Decimal<3> = "-0.745".parse().unwrap();
let y: Decimal<4> = x.extend();
assert_eq!(y.to_string(), "-0.7450");
let x: Decimal<5> = Decimal(i64::MAX / 10);
let y: Option<Decimal<6>> = x.checked_extend();
let z: Option<Decimal<7>> = x.checked_extend();
assert!(y.is_some());
assert!(z.is_none());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment