Created
June 11, 2021 16:28
-
-
Save jaburns/4c4a597021ac884dab9f7ec65d06a270 to your computer and use it in GitHub Desktop.
Rust implementation of an arithmetic coder with some basic models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Reference: https://github.com/rygorous/gaffer_net/blob/master/main.cpp | |
const PROBABILITY_BITS: u32 = 15; | |
pub const PROBABILITY_MAX: u32 = 1 << PROBABILITY_BITS; | |
struct BinaryArithCoder { | |
lo: u32, | |
hi: u32, | |
bytes: Vec<u8>, | |
} | |
impl BinaryArithCoder { | |
pub fn new() -> Self { | |
Self { | |
lo: 0, | |
hi: !0, | |
bytes: Vec::new(), | |
} | |
} | |
pub fn encode(&mut self, bit: bool, prob: u32) { | |
let lo64 = self.lo as u64; | |
let hi64 = self.hi as u64; | |
let prob64 = prob as u64; | |
let prob_bits64 = PROBABILITY_BITS as u64; | |
let x = self.lo + (((hi64 - lo64) * prob64) >> prob_bits64) as u32; | |
if bit { | |
self.hi = x; | |
} else { | |
self.lo = x + 1; | |
} | |
while (self.lo ^ self.hi) < (1u32 << 24) { | |
self.bytes.push((self.lo >> 24) as u8); | |
self.lo <<= 8; | |
self.hi = (self.hi << 8) | 0xff; | |
} | |
} | |
pub fn finalize(mut self) -> Vec<u8> { | |
let mut round_up: u32 = 0xffffffu32; | |
while round_up > 0 { | |
if (self.lo | round_up) != !0u32 { | |
let rounded: u32 = (self.lo + round_up) & !round_up; | |
if rounded <= self.hi { | |
self.lo = rounded; | |
break; | |
} | |
} | |
round_up >>= 8; | |
} | |
while self.lo > 0 { | |
self.bytes.push((self.lo >> 24) as u8); | |
self.lo <<= 8; | |
} | |
self.bytes | |
} | |
} | |
struct BinaryArithDecoder<'a> { | |
code: u32, | |
lo: u32, | |
hi: u32, | |
bytes: &'a [u8], | |
read_pos: usize, | |
} | |
impl<'a> BinaryArithDecoder<'a> { | |
pub fn new(bytes: &'a [u8]) -> Self { | |
let mut ret = Self { | |
lo: 0, | |
hi: !0, | |
code: 0, | |
bytes, | |
read_pos: 0, | |
}; | |
for _ in 0..4 { | |
ret.code = (ret.code << 8) | ret.get_byte() as u32; | |
} | |
ret | |
} | |
fn get_byte(&mut self) -> u8 { | |
if self.read_pos < self.bytes.len() { | |
let i = self.read_pos; | |
self.read_pos += 1; | |
self.bytes[i] | |
} else { | |
0 | |
} | |
} | |
pub fn decode(&mut self, prob: u32) -> bool { | |
let lo64 = self.lo as u64; | |
let hi64 = self.hi as u64; | |
let prob64 = prob as u64; | |
let prob_bits64 = PROBABILITY_BITS as u64; | |
let x = self.lo + (((hi64 - lo64) * prob64) >> prob_bits64) as u32; | |
let bit; | |
if self.code <= x { | |
self.hi = x; | |
bit = true; | |
} else { | |
self.lo = x + 1; | |
bit = false; | |
} | |
while (self.lo ^ self.hi) < (1u32 << 24) { | |
self.code = (self.code << 8) | self.get_byte() as u32; | |
self.lo <<= 8; | |
self.hi = (self.hi << 8) | 0xff; | |
} | |
bit | |
} | |
} | |
trait EncDec { | |
fn encode(&mut self, coder: &mut BinaryArithCoder, bit: bool); | |
fn decode(&mut self, coder: &mut BinaryArithDecoder) -> bool; | |
} | |
#[derive(Clone)] | |
struct TwoBinShiftModel<const INERTIA_0: u32, const INERTIA_1: u32> { | |
prob0: u32, | |
prob1: u32, | |
} | |
impl<const INERTIA_0: u32, const INERTIA_1: u32> Default | |
for TwoBinShiftModel<INERTIA_0, INERTIA_1> | |
{ | |
fn default() -> Self { | |
Self { | |
prob0: PROBABILITY_MAX / 4, | |
prob1: PROBABILITY_MAX / 4, | |
} | |
} | |
} | |
impl<const INERTIA_0: u32, const INERTIA_1: u32> TwoBinShiftModel<INERTIA_0, INERTIA_1> { | |
pub fn new() -> Self { | |
Self { | |
prob0: PROBABILITY_MAX / 4, | |
prob1: PROBABILITY_MAX / 4, | |
} | |
} | |
fn adapt(&mut self, bit: bool) { | |
if bit { | |
self.prob0 += (PROBABILITY_MAX / 2 - self.prob0) >> INERTIA_0; | |
self.prob1 += (PROBABILITY_MAX / 2 - self.prob1) >> INERTIA_1; | |
} else { | |
self.prob0 -= self.prob0 >> INERTIA_0; | |
self.prob1 -= self.prob1 >> INERTIA_1; | |
} | |
} | |
} | |
impl<const INERTIA_0: u32, const INERTIA_1: u32> EncDec for TwoBinShiftModel<INERTIA_0, INERTIA_1> { | |
fn encode(&mut self, coder: &mut BinaryArithCoder, bit: bool) { | |
coder.encode(bit, self.prob0 + self.prob1); | |
self.adapt(bit); | |
} | |
fn decode(&mut self, coder: &mut BinaryArithDecoder) -> bool { | |
let bit = coder.decode(self.prob0 + self.prob1); | |
self.adapt(bit); | |
bit | |
} | |
} | |
struct BitTreeModel<MODEL, const NUM_BITS: usize> { | |
model: Vec<MODEL>, | |
} | |
impl<MODEL: Clone + Default + EncDec, const NUM_BITS: usize> BitTreeModel<MODEL, NUM_BITS> { | |
const NUM_SYMS: usize = 1 << NUM_BITS; | |
const MSB: usize = 1 << (NUM_BITS - 1); | |
pub fn new() -> Self { | |
Self { | |
model: vec![MODEL::default(); Self::NUM_SYMS - 1], | |
} | |
} | |
pub fn encode(&mut self, coder: &mut BinaryArithCoder, mut value: usize) { | |
std::assert!(value < Self::NUM_SYMS); | |
let mut ctx: usize = 1; | |
while ctx < Self::NUM_SYMS { | |
let bit = (value & Self::MSB) != 0; | |
value += value; | |
self.model[ctx - 1].encode(coder, bit); | |
ctx = ctx + ctx + bit as usize; | |
} | |
} | |
pub fn decode(&mut self, coder: &mut BinaryArithDecoder) -> usize { | |
let mut ctx: usize = 1; | |
while ctx < Self::NUM_SYMS { | |
ctx = ctx + ctx + self.model[ctx - 1].decode(coder) as usize; | |
} | |
ctx - Self::NUM_SYMS | |
} | |
} | |
fn test_encode() { | |
let bytes = std::fs::read("test.txt").unwrap(); | |
let mut coder = BinaryArithCoder::new(); | |
let mut model = BitTreeModel::<TwoBinShiftModel<3, 7>, 8>::new(); | |
for byte in &bytes { | |
model.encode(&mut coder, *byte as usize); | |
} | |
let mut out = coder.finalize(); | |
let size_bytes = unsafe { std::mem::transmute::<u32, [u8; 4]>(bytes.len() as u32) }; | |
out.splice(0..0, size_bytes.iter().cloned()); | |
std::fs::write("test.out", out).unwrap(); | |
} | |
fn test_decode() { | |
let bytes = std::fs::read("test.out").unwrap(); | |
let mut bit_count = 8 * unsafe { | |
std::mem::transmute::<[u8; 4], u32>([bytes[0], bytes[1], bytes[2], bytes[3]]) | |
}; | |
let mut coder = BinaryArithDecoder::new(&bytes[4..]); | |
let mut model = BitTreeModel::<TwoBinShiftModel<3, 7>, 8>::new(); | |
let mut out_bytes = Vec::<u8>::new(); | |
while bit_count > 0 { | |
out_bytes.push(model.decode(&mut coder) as u8); | |
bit_count -= 8; | |
} | |
std::fs::write("test.out.txt", out_bytes).unwrap(); | |
} | |
fn main() { | |
test_encode(); | |
test_decode(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment