Skip to content

Instantly share code, notes, and snippets.

@7etsuo
Created January 9, 2026 15:12
Show Gist options
  • Select an option

  • Save 7etsuo/e4726d8f56dd03bfa38ddf450d3eda4f to your computer and use it in GitHub Desktop.

Select an option

Save 7etsuo/e4726d8f56dd03bfa38ddf450d3eda4f to your computer and use it in GitHub Desktop.
deflate.c
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAX_BITS 15
#define MAX_LIT_CODES 288 /* Literal/length alphabet */
#define MAX_DIST_CODES 32 /* Distance alphabet */
/* Length codes: symbol 257-285 map to lengths 3-258 */
static const uint16_t length_base[] = {
3, 4, 5, 6, 7, 8, 9, 10, /* 257-264 */
11, 13, 15, 17, /* 265-268 */
19, 23, 27, 31, /* 269-272 */
35, 43, 51, 59, /* 273-276 */
67, 83, 99, 115, /* 277-280 */
131, 163, 195, 227, /* 281-284 */
258 /* 285 */
};
static const uint8_t length_extra[] = {
0, 0, 0, 0, 0, 0, 0, 0, /* 257-264 */
1, 1, 1, 1, /* 265-268 */
2, 2, 2, 2, /* 269-272 */
3, 3, 3, 3, /* 273-276 */
4, 4, 4, 4, /* 277-280 */
5, 5, 5, 5, /* 281-284 */
0 /* 285 */
};
/* Distance codes: symbol 0-29 map to distances 1-32768 */
static const uint16_t dist_base[]
= { 1, 2, 3, 4, 5, 7, 9, 13, 17, 25,
33, 49, 65, 97, 129, 193, 257, 385, 513, 769,
1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 };
static const uint8_t dist_extra[]
= { 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6,
6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 };
typedef struct
{
uint16_t *symbols; /* Symbol for each code */
int *counts; /* Number of codes of each length */
int max_bits; /* Maximum code length */
} HuffmanTable;
typedef struct
{
const uint8_t *data; /* Input buffer */
size_t size; /* Total size */
size_t byte_pos; /* Current byte position */
int bit_pos; /* Current bit position (0-7) */
} BitReader;
#define WINDOW_SIZE 32768 /* 32KB sliding window */
#define HASH_BITS 15
#define HASH_SIZE (1 << HASH_BITS)
#define MIN_MATCH 3
#define MAX_MATCH 258
#define MAX_DIST WINDOW_SIZE
#define MAX_CHAIN 256 /* Limit chain search for speed */
typedef struct
{
uint8_t *output; /* Output buffer */
size_t output_size; /* Current output size */
size_t output_capacity; /* Allocated capacity */
} Inflater;
typedef struct
{
int head[HASH_SIZE]; /* Hash table: position of most recent match */
int prev[WINDOW_SIZE]; /* Chain of previous positions with same hash */
const uint8_t *data;
size_t size;
} MatchFinder;
typedef struct
{
uint32_t freq;
int16_t symbol; /* -1 for internal nodes */
int16_t left; /* child indices, -1 if none */
int16_t right;
} TreeNode;
typedef struct
{
int symbol;
uint32_t freq;
} SymFreq;
typedef struct
{
uint8_t *data;
size_t capacity;
size_t byte_pos;
int bit_pos;
uint32_t bit_buffer;
} BitWriter;
/* bitreader functions */
void bitreader_init (BitReader *, const uint8_t *, size_t);
uint32_t bitreader_read (BitReader *, int);
void bitreader_align (BitReader *);
void bitreader_read_bytes (BitReader *, uint8_t *, size_t);
/* bitwriter functions */
void bitwriter_init (BitWriter *, size_t);
void bitwriter_write (BitWriter *, uint32_t, int);
void bitwriter_write_huffman (BitWriter *, uint16_t, int);
void bitwriter_flush (BitWriter *);
void bitwriter_free (BitWriter *);
/* Huffman table functions */
int huffman_build (HuffmanTable *, const uint8_t *, int);
int initialize_huffman_table (HuffmanTable *ht, int max_bits, int *bl_count);
uint16_t huffman_decode (HuffmanTable *ht, BitReader *br);
void huffman_free (HuffmanTable *ht);
void build_fixed_literal_table (HuffmanTable *ht);
void build_fixed_distance_table (HuffmanTable *ht);
/* Length/distance decoding functions */
uint16_t decode_length (uint16_t symbol, BitReader *br);
uint16_t decode_distance (uint16_t symbol, BitReader *br);
/* Dynamic table decoding */
int decode_dynamic_tables (BitReader *br, HuffmanTable *lit_ht,
HuffmanTable *dist_ht);
/* Inflater functions */
int inflater_init (Inflater *inf, size_t initial_capacity);
int inflater_grow (Inflater *inf, size_t needed);
void inflater_emit (Inflater *inf, uint8_t byte);
void inflater_copy (Inflater *inf, uint16_t distance, uint16_t length);
void inflater_free (Inflater *inf);
/* Inflate functions */
int inflate_block (BitReader *br, Inflater *inf, int btype);
int inflate (const uint8_t *input, size_t input_size, uint8_t **output,
size_t *output_size);
/* MatchFinder functions */
void matchfinder_init (MatchFinder *mf, const uint8_t *data, size_t size);
int matchfinder_find (MatchFinder *mf, size_t pos, uint16_t *match_dist,
uint16_t *match_len);
void matchfinder_insert (MatchFinder *mf, size_t pos);
/* Huffman tree building (compression) */
void build_code_lengths (const uint32_t *freqs, uint8_t *lengths,
int num_symbols, int max_bits);
void generate_codes (const uint8_t *lengths, uint16_t *codes, int num_symbols);
/* Deflate compression */
int deflate_fixed (const uint8_t *input, size_t input_size, uint8_t **output,
size_t *output_size);
static int
count_code_lengths (const uint8_t *lengths, int num_symbols, int *bl_count,
int *max_bits)
{
for (int i = 0; i < num_symbols; i++)
{
if (lengths[i] > 0)
{
bl_count[lengths[i]]++;
if (lengths[i] > *max_bits)
{
*max_bits = lengths[i];
}
}
}
if (*max_bits == 0)
{
/* No codes - empty table */
return 0;
}
return 1;
}
static void
calculate_next_codes (int *bl_count, int *next_code, int max_bits)
{
int code = 0;
bl_count[0] = 0;
for (int bits = 1; bits <= max_bits; bits++)
{
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
}
static void
populate_huffman_table (const uint8_t *lengths, int *next_code,
int num_symbols, int table_size, HuffmanTable *ht)
{
for (int n = 0; n < num_symbols; n++)
{
int len = lengths[n];
if (len != 0)
{
int code_val = next_code[len];
next_code[len]++;
/*
* For fast decoding, we store the symbol at every table position
* that this code maps to when padded with all possible suffixes.
*
* For a code of length 'len' with value 'code_val', we fill
* all positions: (suffix << len) | reverse(code_val, len)
*/
int reversed = 0;
for (int i = 0; i < len; i++)
{
reversed = (reversed << 1) | ((code_val >> i) & 1);
}
int step = 1 << len;
for (int i = reversed; i < table_size; i += step)
{
ht->symbols[i] = n | (len << 12); /* Store symbol and length */
}
}
}
}
static int
clear_huffman_table (HuffmanTable *ht)
{
ht->max_bits = 0;
ht->symbols = NULL;
ht->counts = NULL;
return 0;
}
int
initialize_huffman_table (HuffmanTable *ht, int max_bits, int *bl_count)
{
int table_size = 1 << max_bits;
ht->symbols = (uint16_t *)malloc (table_size * sizeof (uint16_t));
ht->counts = (int *)malloc ((max_bits + 1) * sizeof (int));
if (!ht->symbols || !ht->counts)
{
return -1;
}
memset (ht->symbols, 0xFF, table_size * sizeof (uint16_t));
memcpy (ht->counts, bl_count, (max_bits + 1) * sizeof (int));
ht->max_bits = max_bits;
return 0;
}
int
huffman_build (HuffmanTable *ht, const uint8_t *lengths, int num_symbols)
{
int bl_count[MAX_BITS + 1] = { 0 };
int next_code[MAX_BITS + 1];
int max_bits = 0;
/* Step 1: Count number of codes for each length */
if (!count_code_lengths (lengths, num_symbols, bl_count, &max_bits))
{
/* No codes */
clear_huffman_table (ht);
return 0;
}
/* Step 2: Calculate the first code for each length */
calculate_next_codes (bl_count, next_code, max_bits);
/* Allocate symbol table (indexed by code value at each length) */
if (initialize_huffman_table (ht, max_bits, bl_count) != 0)
{
return -1;
}
/* Step 3: Assign codes to symbols and populate lookup table */
populate_huffman_table (lengths, next_code, num_symbols, 1 << max_bits, ht);
return 0;
}
uint16_t
huffman_decode (HuffmanTable *ht, BitReader *br)
{
/* Peek at max_bits bits */
uint32_t peek = 0;
/* Save position for potential restoration */
size_t saved_byte = br->byte_pos;
int saved_bit = br->bit_pos;
/* Read max_bits bits */
peek = bitreader_read (br, ht->max_bits);
/* Look up in table */
uint16_t entry = ht->symbols[peek & ((1 << ht->max_bits) - 1)];
uint16_t symbol = entry & 0xFFF;
int code_len = entry >> 12;
/* Put back the unused bits */
int extra_bits = ht->max_bits - code_len;
if (extra_bits > 0)
{
/* Restore position and re-read correct number of bits */
br->byte_pos = saved_byte;
br->bit_pos = saved_bit;
bitreader_read (br, code_len);
}
return symbol;
}
void
huffman_free (HuffmanTable *ht)
{
free (ht->symbols);
free (ht->counts);
ht->symbols = NULL;
ht->counts = NULL;
}
void
build_fixed_literal_table (HuffmanTable *ht)
{
uint8_t lengths[MAX_LIT_CODES];
/* Literal/length code lengths as per RFC 1951:
* 0-143: 8 bits
* 144-255: 9 bits
* 256-279: 7 bits
* 280-287: 8 bits
*/
for (int i = 0; i <= 143; i++)
lengths[i] = 8;
for (int i = 144; i <= 255; i++)
lengths[i] = 9;
for (int i = 256; i <= 279; i++)
lengths[i] = 7;
for (int i = 280; i <= 287; i++)
lengths[i] = 8;
huffman_build (ht, lengths, MAX_LIT_CODES);
}
void
build_fixed_distance_table (HuffmanTable *ht)
{
uint8_t lengths[MAX_DIST_CODES];
/* All distance codes are 5 bits */
for (int i = 0; i < MAX_DIST_CODES; i++)
{
lengths[i] = 5;
}
huffman_build (ht, lengths, MAX_DIST_CODES);
}
/* Decode length from symbol (257-285) and extra bits */
uint16_t
decode_length (uint16_t symbol, BitReader *br)
{
int index = symbol - 257;
uint16_t length = length_base[index];
int extra = length_extra[index];
if (extra > 0)
{
length += bitreader_read (br, extra);
}
return length;
}
/* Decode distance from symbol (0-29) and extra bits */
uint16_t
decode_distance (uint16_t symbol, BitReader *br)
{
uint16_t distance = dist_base[symbol];
int extra = dist_extra[symbol];
if (extra > 0)
{
distance += bitreader_read (br, extra);
}
return distance;
}
/* Order in which code length code lengths are stored */
static const int codelen_order[]
= { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 };
static int
build_codelen_table (BitReader *br, HuffmanTable *codelen_ht)
{
int hclen = bitreader_read (br, 4) + 4;
uint8_t codelen_lengths[19] = { 0 };
for (int i = 0; i < hclen; i++)
{
codelen_lengths[codelen_order[i]] = bitreader_read (br, 3);
}
return huffman_build (codelen_ht, codelen_lengths, 19);
}
static void
fill_repeat (uint8_t *lengths, int *pos, int total, int count, uint8_t value)
{
int i = *pos;
while (count-- > 0 && i < total)
lengths[i++] = value;
*pos = i;
}
static int
decode_codelen_symbol (HuffmanTable *codelen_ht, BitReader *br,
uint8_t *lengths, int *pos, int total)
{
uint16_t sym = huffman_decode (codelen_ht, br);
if (sym < 16)
{
lengths[(*pos)++] = sym;
}
else if (sym == 16)
{
uint8_t prev = (*pos > 0) ? lengths[*pos - 1] : 0;
fill_repeat (lengths, pos, total, bitreader_read (br, 2) + 3, prev);
}
else if (sym == 17)
{
fill_repeat (lengths, pos, total, bitreader_read (br, 3) + 3, 0);
}
else if (sym == 18)
{
fill_repeat (lengths, pos, total, bitreader_read (br, 7) + 11, 0);
}
else
{
return -1;
}
return 0;
}
static int
decode_all_lengths (HuffmanTable *codelen_ht, BitReader *br, uint8_t *lengths,
int total)
{
int i = 0;
while (i < total)
{
if (decode_codelen_symbol (codelen_ht, br, lengths, &i, total) < 0)
{
fprintf (stderr, "Invalid code length symbol\n");
return -1;
}
}
return 0;
}
static int
build_lit_dist_tables (uint8_t *lengths, int hlit, int hdist,
HuffmanTable *lit_ht, HuffmanTable *dist_ht)
{
if (huffman_build (lit_ht, lengths, hlit) < 0)
return -1;
if (huffman_build (dist_ht, lengths + hlit, hdist) < 0)
{
huffman_free (lit_ht);
return -1;
}
return 0;
}
int
decode_dynamic_tables (BitReader *br, HuffmanTable *lit_ht,
HuffmanTable *dist_ht)
{
int hlit = bitreader_read (br, 5) + 257;
int hdist = bitreader_read (br, 5) + 1;
HuffmanTable codelen_ht = { 0 };
if (build_codelen_table (br, &codelen_ht) < 0)
return -1;
uint8_t lengths[MAX_LIT_CODES + MAX_DIST_CODES];
int result = decode_all_lengths (&codelen_ht, br, lengths, hlit + hdist);
huffman_free (&codelen_ht);
if (result < 0)
return -1;
return build_lit_dist_tables (lengths, hlit, hdist, lit_ht, dist_ht);
}
void
bitreader_init (BitReader *br, const uint8_t *data, size_t size)
{
br->data = data;
br->size = size;
br->byte_pos = 0;
br->bit_pos = 0;
}
/* Read n bits (up to 32), LSB first */
uint32_t
bitreader_read (BitReader *br, int n)
{
uint32_t result = 0;
int bits_read = 0;
while (bits_read < n)
{
if (br->byte_pos >= br->size)
{
fprintf (stderr, "Error: Unexpected end of input\n");
return 0;
}
/* How many bits left in current byte? */
int bits_in_byte = 8 - br->bit_pos;
/* How many bits do we still need? */
int bits_needed = n - bits_read;
/* How many bits to take from this byte? */
int bits_to_read
= (bits_in_byte < bits_needed) ? bits_in_byte : bits_needed;
/* Extract bits from current byte */
uint8_t mask = (1 << bits_to_read) - 1;
uint8_t bits = (br->data[br->byte_pos] >> br->bit_pos) & mask;
/* Add to result */
result |= ((uint32_t)bits << bits_read);
bits_read += bits_to_read;
br->bit_pos += bits_to_read;
/* Move to next byte if needed */
if (br->bit_pos >= 8)
{
br->bit_pos = 0;
br->byte_pos++;
}
}
return result;
}
/* Align to next byte boundary */
void
bitreader_align (BitReader *br)
{
if (br->bit_pos > 0)
{
br->bit_pos = 0;
br->byte_pos++;
}
}
/* Read bytes directly (must be byte-aligned) */
void
bitreader_read_bytes (BitReader *br, uint8_t *dest, size_t count)
{
memcpy (dest, br->data + br->byte_pos, count);
br->byte_pos += count;
}
/* ============================================
* Bit Stream Writer
* ============================================ */
static void
bitwriter_grow (BitWriter *bw)
{
bw->capacity *= 2;
bw->data = (uint8_t *)realloc (bw->data, bw->capacity);
}
static void
bitwriter_ensure_capacity (BitWriter *bw)
{
if (bw->byte_pos >= bw->capacity)
{
bitwriter_grow (bw);
}
}
static void
bitwriter_store_byte (BitWriter *bw, uint8_t byte)
{
bitwriter_ensure_capacity (bw);
bw->data[bw->byte_pos++] = byte;
}
static void
bitwriter_emit_byte (BitWriter *bw)
{
bitwriter_store_byte (bw, bw->bit_buffer & 0xFF);
bw->bit_buffer >>= 8;
bw->bit_pos -= 8;
}
static void
bitwriter_accumulate_bits (BitWriter *bw, uint32_t value, int n)
{
bw->bit_buffer |= (value << bw->bit_pos);
bw->bit_pos += n;
}
static void
bitwriter_flush_full_bytes (BitWriter *bw)
{
while (bw->bit_pos >= 8)
{
bitwriter_emit_byte (bw);
}
}
static uint16_t
reverse_bits (uint16_t code, int length)
{
uint16_t reversed = 0;
for (int i = 0; i < length; i++)
{
reversed = (reversed << 1) | ((code >> i) & 1);
}
return reversed;
}
void
bitwriter_init (BitWriter *bw, size_t initial_capacity)
{
bw->data = (uint8_t *)malloc (initial_capacity);
bw->capacity = initial_capacity;
bw->byte_pos = 0;
bw->bit_pos = 0;
bw->bit_buffer = 0;
}
void
bitwriter_write (BitWriter *bw, uint32_t value, int n)
{
bitwriter_accumulate_bits (bw, value, n);
bitwriter_flush_full_bytes (bw);
}
void
bitwriter_write_huffman (BitWriter *bw, uint16_t code, int length)
{
uint16_t reversed = reverse_bits (code, length);
bitwriter_write (bw, reversed, length);
}
void
bitwriter_flush (BitWriter *bw)
{
if (bw->bit_pos > 0)
{
bitwriter_store_byte (bw, bw->bit_buffer & 0xFF);
bw->bit_buffer = 0;
bw->bit_pos = 0;
}
}
void
bitwriter_free (BitWriter *bw)
{
free (bw->data);
bw->data = NULL;
}
int
inflater_init (Inflater *inf, size_t initial_capacity)
{
inf->output = (uint8_t *)malloc (initial_capacity);
if (!inf->output)
return -1;
inf->output_size = 0;
inf->output_capacity = initial_capacity;
return 0;
}
int
inflater_grow (Inflater *inf, size_t needed)
{
if (inf->output_size + needed <= inf->output_capacity)
return 0;
size_t new_capacity = inf->output_capacity * 2;
while (new_capacity < inf->output_size + needed)
new_capacity *= 2;
uint8_t *new_output = (uint8_t *)realloc (inf->output, new_capacity);
if (!new_output)
return -1;
inf->output = new_output;
inf->output_capacity = new_capacity;
return 0;
}
void
inflater_emit (Inflater *inf, uint8_t byte)
{
if (inflater_grow (inf, 1) < 0)
return;
inf->output[inf->output_size++] = byte;
}
void
inflater_copy (Inflater *inf, uint16_t distance, uint16_t length)
{
if (inflater_grow (inf, length) < 0)
return;
/*
* Note: The copy can overlap with the destination!
* For example, distance=1, length=10 produces run-length encoding
* We must copy byte-by-byte to handle this correctly
*/
size_t src = inf->output_size - distance;
for (uint16_t i = 0; i < length; i++)
inf->output[inf->output_size++] = inf->output[src + i];
}
void
inflater_free (Inflater *inf)
{
free (inf->output);
inf->output = NULL;
}
static int
inflate_stored_block (BitReader *br, Inflater *inf)
{
bitreader_align (br);
uint16_t len = bitreader_read (br, 16);
uint16_t nlen = bitreader_read (br, 16);
if ((len ^ nlen) != 0xFFFF)
{
fprintf (stderr, "Invalid stored block length\n");
return -1;
}
if (inflater_grow (inf, len) < 0)
return -1;
bitreader_read_bytes (br, inf->output + inf->output_size, len);
inf->output_size += len;
return 0;
}
static int
setup_block_tables (BitReader *br, int btype, HuffmanTable *lit_ht,
HuffmanTable *dist_ht)
{
if (btype == 1)
{
build_fixed_literal_table (lit_ht);
build_fixed_distance_table (dist_ht);
return 0;
}
else if (btype == 2)
{
return decode_dynamic_tables (br, lit_ht, dist_ht);
}
fprintf (stderr, "Invalid block type: %d\n", btype);
return -1;
}
static int
decode_back_reference (BitReader *br, Inflater *inf, HuffmanTable *dist_ht,
uint16_t symbol)
{
uint16_t length = decode_length (symbol, br);
uint16_t dist_sym = huffman_decode (dist_ht, br);
uint16_t distance = decode_distance (dist_sym, br);
if (distance > inf->output_size)
{
fprintf (stderr, "Invalid distance: %d > %zu\n", distance,
inf->output_size);
return -1;
}
inflater_copy (inf, distance, length);
return 0;
}
static int
decode_symbol (BitReader *br, Inflater *inf, HuffmanTable *lit_ht,
HuffmanTable *dist_ht, int *done)
{
uint16_t symbol = huffman_decode (lit_ht, br);
if (symbol < 256)
{
inflater_emit (inf, symbol);
return 0;
}
else if (symbol == 256)
{
*done = 1;
return 0;
}
else if (symbol <= 285)
{
return decode_back_reference (br, inf, dist_ht, symbol);
}
fprintf (stderr, "Invalid literal/length symbol: %d\n", symbol);
return -1;
}
static int
decode_compressed_data (BitReader *br, Inflater *inf, HuffmanTable *lit_ht,
HuffmanTable *dist_ht)
{
int done = 0;
while (!done)
{
if (decode_symbol (br, inf, lit_ht, dist_ht, &done) < 0)
return -1;
}
return 0;
}
int
inflate_block (BitReader *br, Inflater *inf, int btype)
{
if (btype == 0)
return inflate_stored_block (br, inf);
HuffmanTable lit_ht = { 0 }, dist_ht = { 0 };
if (setup_block_tables (br, btype, &lit_ht, &dist_ht) < 0)
return -1;
int result = decode_compressed_data (br, inf, &lit_ht, &dist_ht);
huffman_free (&lit_ht);
huffman_free (&dist_ht);
return result;
}
int
inflate (const uint8_t *input, size_t input_size, uint8_t **output,
size_t *output_size)
{
BitReader br;
Inflater inf;
bitreader_init (&br, input, input_size);
if (inflater_init (&inf, 4096) < 0)
return -1;
int bfinal = 0;
while (!bfinal)
{
bfinal = bitreader_read (&br, 1);
int btype = bitreader_read (&br, 2);
if (inflate_block (&br, &inf, btype) < 0)
{
inflater_free (&inf);
return -1;
}
}
*output = inf.output;
*output_size = inf.output_size;
return 0;
}
static uint32_t
hash3 (const uint8_t *p)
{
return ((p[0] << 10) ^ (p[1] << 5) ^ p[2]) & (HASH_SIZE - 1);
}
void
matchfinder_init (MatchFinder *mf, const uint8_t *data, size_t size)
{
mf->data = data;
mf->size = size;
memset (mf->head, -1, sizeof (mf->head));
memset (mf->prev, -1, sizeof (mf->prev));
}
static size_t
compare_strings (const uint8_t *p1, const uint8_t *p2, size_t max_len)
{
size_t len = 0;
while (len < max_len && p1[len] == p2[len])
len++;
return len;
}
static size_t
get_max_match_len (MatchFinder *mf, size_t pos)
{
size_t max_len = mf->size - pos;
return (max_len > MAX_MATCH) ? MAX_MATCH : max_len;
}
static int
try_match (MatchFinder *mf, size_t pos, int match_pos, uint16_t *best_len,
uint16_t *best_dist)
{
size_t dist = pos - match_pos;
if (dist > MAX_DIST)
return -1; /* Signal to stop chain search */
size_t max_len = get_max_match_len (mf, pos);
size_t len = compare_strings (mf->data + pos, mf->data + match_pos, max_len);
if (len >= MIN_MATCH && len > *best_len)
{
*best_len = len;
*best_dist = dist;
if (len == MAX_MATCH)
return 1; /* Signal found max, stop searching */
}
return 0; /* Continue searching */
}
static void
search_hash_chain (MatchFinder *mf, size_t pos, int match_pos,
uint16_t *match_len, uint16_t *match_dist)
{
int chain_len = 0;
while (match_pos >= 0 && chain_len < MAX_CHAIN)
{
int result = try_match (mf, pos, match_pos, match_len, match_dist);
if (result != 0)
break;
match_pos = mf->prev[match_pos & (WINDOW_SIZE - 1)];
chain_len++;
}
}
int
matchfinder_find (MatchFinder *mf, size_t pos, uint16_t *match_dist,
uint16_t *match_len)
{
*match_len = 0;
*match_dist = 0;
if (pos + MIN_MATCH > mf->size)
return 0;
uint32_t h = hash3 (mf->data + pos);
int match_pos = mf->head[h];
search_hash_chain (mf, pos, match_pos, match_len, match_dist);
return *match_len >= MIN_MATCH;
}
void
matchfinder_insert (MatchFinder *mf, size_t pos)
{
if (pos + MIN_MATCH > mf->size)
return;
uint32_t h = hash3 (mf->data + pos);
mf->prev[pos & (WINDOW_SIZE - 1)] = mf->head[h];
mf->head[h] = pos;
}
static int
count_symbols (const uint32_t *freqs, int num_symbols)
{
int count = 0;
for (int i = 0; i < num_symbols; i++)
{
if (freqs[i] > 0)
count++;
}
return count;
}
static void
init_lengths (uint8_t *lengths, int num_symbols)
{
for (int i = 0; i < num_symbols; i++)
lengths[i] = 0;
}
static int
handle_single_symbol (const uint32_t *freqs, uint8_t *lengths, int num_symbols)
{
for (int i = 0; i < num_symbols; i++)
{
if (freqs[i] > 0)
{
lengths[i] = 1;
return 1;
}
}
return 0;
}
static int
collect_symbols (const uint32_t *freqs, int num_symbols, SymFreq *out)
{
int count = 0;
for (int i = 0; i < num_symbols; i++)
{
if (freqs[i] > 0)
{
out[count].symbol = i;
out[count].freq = freqs[i];
count++;
}
}
return count;
}
static void
sort_by_freq_asc (SymFreq *arr, int count)
{
for (int i = 1; i < count; i++)
{
SymFreq key = arr[i];
int j = i - 1;
while (j >= 0 && arr[j].freq > key.freq)
{
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = key;
}
}
static void
init_leaf_nodes (TreeNode *nodes, const SymFreq *syms, int count)
{
for (int i = 0; i < count; i++)
{
nodes[i].freq = syms[i].freq;
nodes[i].symbol = syms[i].symbol;
nodes[i].left = -1;
nodes[i].right = -1;
}
}
static int
find_minimum (const TreeNode *nodes, const int *active, int num_active)
{
int min_idx = 0;
for (int i = 1; i < num_active; i++)
{
if (nodes[active[i]].freq < nodes[active[min_idx]].freq)
min_idx = i;
}
return min_idx;
}
static void
remove_from_active (int *active, int *num_active, int idx)
{
(*num_active)--;
active[idx] = active[*num_active];
}
static int
create_internal_node (TreeNode *nodes, int *node_count, int left, int right)
{
int new_idx = *node_count;
nodes[new_idx].freq = nodes[left].freq + nodes[right].freq;
nodes[new_idx].symbol = -1;
nodes[new_idx].left = left;
nodes[new_idx].right = right;
(*node_count)++;
return new_idx;
}
static void
compute_depths_recursive (const TreeNode *nodes, int idx, int depth,
uint8_t *lengths)
{
if (nodes[idx].symbol >= 0)
{
lengths[nodes[idx].symbol] = depth;
return;
}
if (nodes[idx].left >= 0)
compute_depths_recursive (nodes, nodes[idx].left, depth + 1, lengths);
if (nodes[idx].right >= 0)
compute_depths_recursive (nodes, nodes[idx].right, depth + 1, lengths);
}
static void
build_huffman_tree (const SymFreq *syms, int count, uint8_t *lengths)
{
TreeNode nodes[2 * MAX_LIT_CODES];
int active[MAX_LIT_CODES];
int num_active = count;
int node_count = count;
init_leaf_nodes (nodes, syms, count);
for (int i = 0; i < count; i++)
active[i] = i;
while (num_active > 1)
{
int min1_idx = find_minimum (nodes, active, num_active);
int left = active[min1_idx];
remove_from_active (active, &num_active, min1_idx);
int min2_idx = find_minimum (nodes, active, num_active);
int right = active[min2_idx];
remove_from_active (active, &num_active, min2_idx);
int new_node = create_internal_node (nodes, &node_count, left, right);
active[num_active] = new_node;
num_active++;
}
int root = active[0];
compute_depths_recursive (nodes, root, 0, lengths);
}
static int
find_max_length (const uint8_t *lengths, int num_symbols)
{
int max_len = 0;
for (int i = 0; i < num_symbols; i++)
{
if (lengths[i] > max_len)
max_len = lengths[i];
}
return max_len;
}
static void
count_at_each_length (const uint8_t *lengths, int num_symbols, int *bl_count,
int max_len)
{
for (int i = 0; i <= max_len; i++)
bl_count[i] = 0;
for (int i = 0; i < num_symbols; i++)
{
if (lengths[i] > 0)
bl_count[lengths[i]]++;
}
}
static void
collect_symbols_with_lengths (const uint8_t *lengths, int num_symbols,
SymFreq *out, int *out_count)
{
int count = 0;
for (int i = 0; i < num_symbols; i++)
{
if (lengths[i] > 0)
{
out[count].symbol = i;
out[count].freq = lengths[i];
count++;
}
}
*out_count = count;
}
static void
sort_by_length_asc (SymFreq *arr, int count)
{
for (int i = 1; i < count; i++)
{
SymFreq key = arr[i];
int j = i - 1;
while (j >= 0 && arr[j].freq > key.freq)
{
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = key;
}
}
static void
reassign_lengths_from_counts (uint8_t *lengths, int num_symbols,
const int *bl_count, int max_bits)
{
SymFreq syms[MAX_LIT_CODES];
int count;
collect_symbols_with_lengths (lengths, num_symbols, syms, &count);
sort_by_length_asc (syms, count);
int sym_idx = 0;
for (int len = 1; len <= max_bits && sym_idx < count; len++)
{
for (int j = 0; j < bl_count[len] && sym_idx < count; j++)
{
lengths[syms[sym_idx].symbol] = len;
sym_idx++;
}
}
}
static void
limit_code_lengths (uint8_t *lengths, int num_symbols, int max_bits)
{
int bl_count[MAX_BITS + 2] = { 0 };
int max_len = find_max_length (lengths, num_symbols);
if (max_len <= max_bits)
return;
count_at_each_length (lengths, num_symbols, bl_count, max_len);
int overflow = 0;
for (int i = max_len; i > max_bits; i--)
{
overflow += bl_count[i];
bl_count[i] = 0;
}
bl_count[max_bits] += overflow;
while (1)
{
int kraft_sum = 0;
for (int i = 1; i <= max_bits; i++)
kraft_sum += bl_count[i] << (max_bits - i);
if (kraft_sum <= (1 << max_bits))
break;
bl_count[max_bits]--;
for (int i = max_bits - 1; i >= 1; i--)
{
if (bl_count[i] > 0)
{
bl_count[i]--;
bl_count[i + 1] += 2;
break;
}
}
}
reassign_lengths_from_counts (lengths, num_symbols, bl_count, max_bits);
}
void
build_code_lengths (const uint32_t *freqs, uint8_t *lengths, int num_symbols,
int max_bits)
{
init_lengths (lengths, num_symbols);
int num_codes = count_symbols (freqs, num_symbols);
if (num_codes == 0)
return;
if (num_codes == 1)
{
handle_single_symbol (freqs, lengths, num_symbols);
return;
}
SymFreq syms[MAX_LIT_CODES];
int count = collect_symbols (freqs, num_symbols, syms);
sort_by_freq_asc (syms, count);
build_huffman_tree (syms, count, lengths);
limit_code_lengths (lengths, num_symbols, max_bits);
}
static void
count_code_lengths_for_gen (const uint8_t *lengths, int num_symbols,
int *bl_count)
{
for (int i = 0; i <= MAX_BITS; i++)
bl_count[i] = 0;
for (int i = 0; i < num_symbols; i++)
{
if (lengths[i] > 0)
bl_count[lengths[i]]++;
}
}
static void
compute_first_codes (const int *bl_count, int *next_code)
{
int code = 0;
for (int bits = 1; bits <= MAX_BITS; bits++)
{
code = (code + bl_count[bits - 1]) << 1;
next_code[bits] = code;
}
}
static void
assign_symbol_codes (const uint8_t *lengths, int num_symbols, int *next_code,
uint16_t *codes)
{
for (int i = 0; i < num_symbols; i++)
{
if (lengths[i] > 0)
codes[i] = next_code[lengths[i]]++;
else
codes[i] = 0;
}
}
void
generate_codes (const uint8_t *lengths, uint16_t *codes, int num_symbols)
{
int bl_count[MAX_BITS + 1];
int next_code[MAX_BITS + 1];
count_code_lengths_for_gen (lengths, num_symbols, bl_count);
compute_first_codes (bl_count, next_code);
assign_symbol_codes (lengths, num_symbols, next_code, codes);
}
/* ============================================
* DEFLATE Compression (Fixed Huffman)
* ============================================ */
static int
find_length_code (uint16_t length)
{
for (int i = 0; i < 29; i++)
{
int base = length_base[i];
int extra = length_extra[i];
int max_val = base + (1 << extra) - 1;
if (length >= base && length <= max_val)
{
return 257 + i;
}
}
return 285;
}
static int
find_distance_code (uint16_t distance)
{
for (int i = 0; i < 30; i++)
{
int base = dist_base[i];
int extra = dist_extra[i];
int max_val = base + (1 << extra) - 1;
if (distance >= base && distance <= max_val)
{
return i;
}
}
return 29;
}
static void
init_fixed_lit_lengths (uint8_t *lengths)
{
for (int i = 0; i <= 143; i++)
lengths[i] = 8;
for (int i = 144; i <= 255; i++)
lengths[i] = 9;
for (int i = 256; i <= 279; i++)
lengths[i] = 7;
for (int i = 280; i <= 287; i++)
lengths[i] = 8;
}
static void
init_fixed_dist_lengths (uint8_t *lengths)
{
for (int i = 0; i < MAX_DIST_CODES; i++)
lengths[i] = 5;
}
static void
write_block_header_fixed (BitWriter *bw, int is_final)
{
bitwriter_write (bw, is_final, 1);
bitwriter_write (bw, 1, 2);
}
static void
write_length_extra (BitWriter *bw, int len_code, uint16_t match_len)
{
int len_index = len_code - 257;
if (length_extra[len_index] > 0)
{
int extra_val = match_len - length_base[len_index];
bitwriter_write (bw, extra_val, length_extra[len_index]);
}
}
static void
write_distance_extra (BitWriter *bw, int dist_code, uint16_t match_dist)
{
if (dist_extra[dist_code] > 0)
{
int extra_val = match_dist - dist_base[dist_code];
bitwriter_write (bw, extra_val, dist_extra[dist_code]);
}
}
static void
write_match (BitWriter *bw, const uint16_t *lit_codes,
const uint8_t *lit_lengths, const uint16_t *dist_codes,
const uint8_t *dist_lengths, uint16_t match_len,
uint16_t match_dist)
{
int len_code = find_length_code (match_len);
int dist_code = find_distance_code (match_dist);
bitwriter_write_huffman (bw, lit_codes[len_code], lit_lengths[len_code]);
write_length_extra (bw, len_code, match_len);
bitwriter_write_huffman (bw, dist_codes[dist_code], dist_lengths[dist_code]);
write_distance_extra (bw, dist_code, match_dist);
}
static void
write_literal (BitWriter *bw, const uint16_t *lit_codes,
const uint8_t *lit_lengths, uint8_t byte)
{
bitwriter_write_huffman (bw, lit_codes[byte], lit_lengths[byte]);
}
static void
write_end_of_block (BitWriter *bw, const uint16_t *lit_codes,
const uint8_t *lit_lengths)
{
bitwriter_write_huffman (bw, lit_codes[256], lit_lengths[256]);
}
static void
insert_match_positions (MatchFinder *mf, size_t pos, uint16_t match_len)
{
for (size_t i = 0; i < match_len; i++)
{
matchfinder_insert (mf, pos + i);
}
}
static size_t
process_position (BitWriter *bw, MatchFinder *mf, const uint8_t *input,
size_t pos, const uint16_t *lit_codes,
const uint8_t *lit_lengths, const uint16_t *dist_codes,
const uint8_t *dist_lengths)
{
uint16_t match_dist, match_len;
if (matchfinder_find (mf, pos, &match_dist, &match_len))
{
write_match (bw, lit_codes, lit_lengths, dist_codes, dist_lengths,
match_len, match_dist);
insert_match_positions (mf, pos, match_len);
return match_len;
}
else
{
write_literal (bw, lit_codes, lit_lengths, input[pos]);
matchfinder_insert (mf, pos);
return 1;
}
}
static void
compress_data (BitWriter *bw, MatchFinder *mf, const uint8_t *input,
size_t input_size, const uint16_t *lit_codes,
const uint8_t *lit_lengths, const uint16_t *dist_codes,
const uint8_t *dist_lengths)
{
size_t pos = 0;
while (pos < input_size)
{
pos += process_position (bw, mf, input, pos, lit_codes, lit_lengths,
dist_codes, dist_lengths);
}
}
int
deflate_fixed (const uint8_t *input, size_t input_size, uint8_t **output,
size_t *output_size)
{
BitWriter bw;
bitwriter_init (&bw, input_size);
MatchFinder mf;
matchfinder_init (&mf, input, input_size);
uint8_t lit_lengths[MAX_LIT_CODES];
uint16_t lit_codes[MAX_LIT_CODES];
uint8_t dist_lengths[MAX_DIST_CODES];
uint16_t dist_codes[MAX_DIST_CODES];
init_fixed_lit_lengths (lit_lengths);
generate_codes (lit_lengths, lit_codes, MAX_LIT_CODES);
init_fixed_dist_lengths (dist_lengths);
generate_codes (dist_lengths, dist_codes, MAX_DIST_CODES);
write_block_header_fixed (&bw, 1);
compress_data (&bw, &mf, input, input_size, lit_codes, lit_lengths,
dist_codes, dist_lengths);
write_end_of_block (&bw, lit_codes, lit_lengths);
bitwriter_flush (&bw);
*output = bw.data;
*output_size = bw.byte_pos;
return 0;
}
int
main (void)
{
const char *test_string = "The quick brown fox jumps over the lazy dog. "
"The quick brown fox jumps over the lazy dog. "
"The quick brown fox jumps over the lazy dog. "
"Pack my box with five dozen liquor jugs.";
const uint8_t *input = (const uint8_t *)test_string;
size_t input_size = strlen (test_string);
printf ("Original size: %zu bytes\n", input_size);
printf ("Original text: %s\n\n", test_string);
/* Compress */
uint8_t *compressed = NULL;
size_t compressed_size = 0;
if (deflate_fixed (input, input_size, &compressed, &compressed_size) < 0)
{
fprintf (stderr, "Compression failed\n");
return 1;
}
printf ("Compressed size: %zu bytes (%.1f%% of original)\n", compressed_size,
100.0 * compressed_size / input_size);
/* Decompress */
uint8_t *decompressed = NULL;
size_t decompressed_size = 0;
if (inflate (compressed, compressed_size, &decompressed, &decompressed_size)
< 0)
{
fprintf (stderr, "Decompression failed\n");
free (compressed);
return 1;
}
printf ("Decompressed size: %zu bytes\n", decompressed_size);
/* Verify */
if (decompressed_size == input_size
&& memcmp (decompressed, input, input_size) == 0)
{
printf ("\nSUCCESS: Decompressed data matches original!\n");
}
else
{
printf ("\nFAILURE: Data mismatch!\n");
}
/* Print hex dump of compressed data */
printf ("\nCompressed data (hex):\n");
for (size_t i = 0; i < compressed_size; i++)
{
printf ("%02X ", compressed[i]);
if ((i + 1) % 16 == 0)
printf ("\n");
}
printf ("\n");
free (compressed);
free (decompressed);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment