Created
October 29, 2024 14:15
-
-
Save pgp/1877ea82d472424c2055d85b1e371f33 to your computer and use it in GitHub Desktop.
Python implementation of Java BigDecimal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import re | |
from numbers import Number | |
""" | |
Equivalent of Java's BigDecimal, since python's decimal.Decimal works differently - maximum precision is fixed and set globally, | |
and division and usage of (truncated) non-finite-representation numbers is allowed | |
(hence, arbitrary precision multiplication is not always lossless when using decimal.Decimal) | |
""" | |
class BigDecimal(Number): | |
def __hash__(self): | |
n1 = self.normalize() | |
return (n1.sign, n1.mantissa, n1.scale).__hash__() | |
# https://regex101.com/r/yP5nX5/1 | |
DECIMAL_PATTERN = re.compile(r"^[-+]?\d*[.]?\d*$") | |
@classmethod | |
def of(cls, s: str|int): | |
if isinstance(s, int): | |
i = s | |
return BigDecimal(abs(i), 0, 0 if i == 0 else 1 if i > 0 else -1) | |
assert isinstance(s, str), f'Unsupported input type. Expected: str|int, found: {type(s)}' | |
assert cls.DECIMAL_PATTERN.match(s), f'Not a decimal number: {s}' | |
sign = -1 if s[0] == '-' else +1 | |
s_abs = s[1:] if s[0] == '-' else s.lstrip('+') | |
idx = s_abs.find('.') | |
if idx < 0: # decimal point not found, just an exact integer | |
mantissa = int(s_abs) | |
scale = 0 | |
else: | |
s_mantissa = s_abs[:idx] + s_abs[idx+1:] | |
mantissa = int(s_mantissa) | |
scale = idx - len(s_mantissa) # <= 0 (for > 0, just store the integer itself with trailing zeros) | |
if mantissa == 0: | |
sign = 0 | |
return BigDecimal(mantissa, scale, sign) | |
# removes trailing zeros after the decimal separator | |
def normalize(self, inplace=False): | |
# E.g. 1.2300 (mantissa=12300, scale=-4) -> 1.23 (mantissa=123, scale=-2) | |
m = self.mantissa | |
sc = self.scale | |
while m > 0 and m % 10 == 0: | |
m //= 10 | |
sc += 1 | |
if m == 0: | |
sc = 0 # scale is not relevant if the number is 0 | |
if inplace: | |
self.mantissa = m | |
self.scale = sc | |
return self | |
else: | |
return BigDecimal(m, sc, self.sign) | |
def __init__(self, mantissa: int, scale: int, sign: int): | |
assert isinstance(mantissa, int) and mantissa >= 0, 'Mantissa must be a non-negative integer' | |
assert scale <= 0, 'Scale must be negative or 0' | |
assert sign in {-1,0,1}, 'Allowed values for sign are -1,0,+1' | |
assert (sign == 0 and mantissa == 0) or (sign != 0 and mantissa != 0), 'Incoherent sign w.r.t. mantissa' | |
self.mantissa = mantissa | |
self.scale = scale | |
self.sign = sign | |
def __repr__(self): | |
return f'(mantissa={self.mantissa}; scale={self.scale}; sign={self.sign}) {self.__str__()}' | |
def __str__(self): | |
sign_prefix = '' if self.sign == 0 else ('+' if self.sign >= 0 else '-') | |
sc = -self.scale | |
tmp = str(self.mantissa).zfill(sc) | |
if sc != 0: | |
tmp = tmp[:-sc] + '.' + tmp[-sc:] | |
return sign_prefix + (f'0{tmp}' if tmp[0] == '.' else tmp) | |
def __eq__(self, other): | |
if not isinstance(other, BigDecimal): | |
return False | |
n1, n2 = self.normalize(), other.normalize() | |
return (n1.mantissa, n1.scale, n1.sign) == (n2.mantissa, n2.scale, n2.sign) | |
def __cmp__(self, other): | |
if isinstance(other, int): | |
other = BigDecimal.ofInt(other) | |
else: | |
assert isinstance(other, BigDecimal), f'Only BigDecimal and int instances allowed, supplied: {type(other)}' | |
sc, osc = -self.scale, -other.scale | |
absmaxscale = max(sc, osc) | |
i1 = self.sign * self.mantissa * (10 ** (absmaxscale - sc)) | |
i2 = other.sign * other.mantissa * (10 ** (absmaxscale - osc)) | |
return i1 - i2 | |
def __ne__(self, other): | |
return self.__cmp__(other) != 0 | |
def __gt__(self, other): # > | |
return self.__cmp__(other) > 0 | |
def __ge__(self, other): # >= | |
return self.__cmp__(other) >= 0 | |
def __lt__(self, other): # < | |
return self.__cmp__(other) < 0 | |
def __le__(self, other): # <= | |
return self.__cmp__(other) <= 0 | |
def __neg__(self): | |
return BigDecimal(self.mantissa, self.scale, -1*self.sign) | |
def __pos__(self): | |
return BigDecimal(self.mantissa, self.scale, self.sign) | |
def __add__(self, other): | |
if isinstance(other, int): | |
other = BigDecimal.ofInt(other) | |
else: | |
assert isinstance(other, BigDecimal), f'Only BigDecimal and int instances allowed, supplied: {type(other)}' | |
sc, osc = -self.scale, -other.scale | |
absmaxscale = max(sc, osc) | |
# - denormalize both numbers w.r.t. the higher scale | |
i1 = self.sign * self.mantissa * (10**(absmaxscale-sc)) | |
i2 = other.sign * other.mantissa * (10**(absmaxscale-osc)) | |
isum = i1+i2 | |
isum_sign = 0 if isum == 0 else (+1 if isum > 0 else -1) | |
return BigDecimal(abs(isum), -absmaxscale, isum_sign) | |
def __iadd__(self, other): | |
return self.__add__(other) | |
def __sub__(self, other): | |
return self.__add__(-other) | |
def __isub__(self, other): | |
return self.__sub__(other) | |
def __mul__(self, other): | |
if isinstance(other, int): | |
other = BigDecimal.ofInt(other) | |
else: | |
assert isinstance(other, BigDecimal), f'Only BigDecimal and int instances allowed, supplied: {type(other)}' | |
msign = self.sign * other.sign | |
mmantissa = self.mantissa * other.mantissa | |
mscale = self.scale + other.scale | |
return BigDecimal(mmantissa, mscale, msign) | |
def __imul__(self, other): | |
return self.__mul__(other) | |
def __truediv__(self, other): | |
if isinstance(other, int): | |
other = BigDecimal.ofInt(other) | |
else: | |
assert isinstance(other, BigDecimal), f'Only BigDecimal and int instances allowed, supplied: {type(other)}' | |
i1, i2 = self.mantissa, other.mantissa | |
g = math.gcd(i1,i2) | |
i1_g, i2_g = i1 // g, i2 // g | |
acc_2,acc_5 = 0,0 | |
x = i2_g | |
while x > 1: | |
if x % 2 == 0: | |
acc_2 += 1 | |
x //= 2 | |
elif x % 5 == 0: | |
acc_5 += 1 | |
x //= 5 | |
else: | |
break | |
if x != 1: | |
raise RuntimeError('Non terminating decimal expansion detected') | |
shift_10s = max(acc_2, acc_5) | |
qq,rr = divmod(i1_g * (10**shift_10s) , i2_g) | |
assert rr == 0, 'Guard block' | |
qsign = self.sign // other.sign | |
qmantissa = qq | |
qscale = self.scale - other.scale - shift_10s | |
if qscale > 0: | |
qmantissa *= 10**qscale | |
qscale = 0 | |
return BigDecimal(qmantissa, qscale, qsign) | |
def __int__(self): | |
return self.sign * (self.mantissa // (10**(-self.scale))) | |
def __float__(self): | |
return self.sign * (self.mantissa / (10**(-self.scale))) | |
BigDecimal.ZERO = BigDecimal(0,0,0) | |
BigDecimal.ONE = BigDecimal(1,0,1) | |
BigDecimal.TWO = BigDecimal(2,0,1) | |
BigDecimal.TEN = BigDecimal(10,0,1) | |
bd = BigDecimal.of |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment