#!/usr/bin/env python3
# encoding: utf-8
# pysapcompress - Pure Python implementation of SAP LZH and LZC compression
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# This is a port of the TypeScript sapcomp library by Max Jäger (xje4@xje4.dev),
# which itself was derived from the MaxDB compression code by SAP AG.
# Reference: https://git.sr.ht/~xje4/sapcomp
#
# The original C extension used CS_LZC=0x0 and CS_LZH=0x2 as algorithm selector
# constants passed to compress(). Those values are preserved here for API
# compatibility.
import struct
from array import array
# ---------------------------------------------------------------------------
# Public exceptions
# ---------------------------------------------------------------------------
[docs]
class CompressError(Exception):
"""Raised when compression fails."""
[docs]
class DecompressError(Exception):
"""Raised when decompression fails."""
# ---------------------------------------------------------------------------
# Algorithm selector constants (CS_LZC / CS_LZH from hpa104CsObject.h)
# ---------------------------------------------------------------------------
ALG_LZC = 0x0 # CS_LZC
ALG_LZH = 0x2 # CS_LZH
# Internal algorithm identifiers stored in the 8-byte compression header
_HDR_ALG_LZC = 1
_HDR_ALG_LZH = 2
_HDR_VERSION = 1
_HDR_SIZE = 8
_MAGIC = b'\x1f\x9d'
# CS_END_OF_STREAM return code (success) from the original C library
_CS_END_OF_STREAM = 1
# ---------------------------------------------------------------------------
# Header helpers
# ---------------------------------------------------------------------------
def _parse_header(data):
"""Return (uncompressed_length, alg_id, alg_version, extra) from an 8-byte header."""
if len(data) < _HDR_SIZE:
raise DecompressError("invalid input length: header truncated")
length = struct.unpack_from('<I', data, 0)[0]
alg_byte = data[4]
alg_id = alg_byte & 0x0f
alg_version = (alg_byte >> 4) & 0x0f
magic = bytes(data[5:7])
extra = data[7]
if magic != _MAGIC:
raise DecompressError("input not compressed: magic bytes not found")
return length, alg_id, alg_version, extra
def _build_header(uncompressed_length, alg_id, alg_version, extra):
"""Return a packed 8-byte SAP compression header."""
hdr = bytearray(_HDR_SIZE)
struct.pack_into('<I', hdr, 0, uncompressed_length)
hdr[4] = alg_id | (alg_version << 4)
hdr[5] = _MAGIC[0]
hdr[6] = _MAGIC[1]
hdr[7] = extra & 0xff
return bytes(hdr)
# ---------------------------------------------------------------------------
# I/O primitives
# ---------------------------------------------------------------------------
class _Reader:
"""Byte/bit reader backed by an immutable bytes-like buffer."""
__slots__ = ('_data', '_pos', '_bits', '_bits_count')
def __init__(self, data):
self._data = bytes(data)
self._pos = 0
self._bits = 0
self._bits_count = 0
@property
def bytes_read(self):
return self._pos
@property
def bytes_left(self):
return len(self._data) - self._pos
@property
def end_reached(self):
return self._pos >= len(self._data)
@property
def bits_left(self):
return self.bytes_left * 8 + self._bits_count
@property
def total_length(self):
return len(self._data)
def read_byte(self):
if self._bits_count > 0:
raise RuntimeError("unfinished bit read pending")
return self._read_byte()
def _read_byte(self):
if self._pos >= len(self._data):
raise DecompressError("unexpected end of compressed data")
b = self._data[self._pos]
self._pos += 1
return b
def read(self, length):
if self._bits_count > 0:
raise RuntimeError("unfinished bit read pending")
if self._pos + length > len(self._data):
raise DecompressError("unexpected end of compressed data")
chunk = self._data[self._pos:self._pos + length]
self._pos += length
return chunk
def peek_bits(self, length):
while self._bits_count < length:
b = self._read_byte()
self._bits |= b << self._bits_count
self._bits_count += 8
return self._bits & ((1 << length) - 1)
def read_bits(self, length):
value = self.peek_bits(length)
self._bits >>= length
self._bits_count -= length
return value
def skip_bits(self, length):
self.read_bits(length)
class _Writer:
"""Byte/bit writer that accumulates output into a bytearray."""
__slots__ = ('_buf', '_bits', '_bits_count')
def __init__(self):
self._buf = bytearray()
self._bits = 0
self._bits_count = 0
@property
def data(self):
return bytes(self._buf)
@property
def bytes_written(self):
return len(self._buf)
def write(self, data):
if self._bits_count > 0:
raise RuntimeError("unfinished bit write pending")
self._buf.extend(data)
def write_byte(self, byte):
if self._bits_count > 0:
raise RuntimeError("unfinished bit write pending")
self._buf.append(byte & 0xff)
def _write_byte(self, byte):
self._buf.append(byte & 0xff)
def write_bits(self, value, bit_count):
self._bits |= value << self._bits_count
self._bits_count += bit_count
while self._bits_count >= 8:
self._buf.append(self._bits & 0xff)
self._bits >>= 8
self._bits_count -= 8
def flush_pending_bits(self):
while self._bits_count > 0:
self._buf.append(self._bits & 0xff)
self._bits >>= 8
self._bits_count = max(self._bits_count - 8, 0)
# ---------------------------------------------------------------------------
# LZC constants
# ---------------------------------------------------------------------------
_LZC_VERSION = 1
_LZC_MIN_CODE_LENGTH = 9
_LZC_MAX_CODE_LENGTH = 16
_LZC_LITERAL_CODE_COUNT = 256
_LZC_CODE_END_BLOCK = 256
_LZC_RATIO_CHECK_INTERVAL = 4096
_LZC_DEFAULT_CODE_LENGTH_LIMIT = 13
_LZC_DEFAULT_BLOCK_MODE = 1 # MULTI_BLOCK
_LZC_SINGLE_BLOCK = 0
_LZC_MULTI_BLOCK = 1
# ---------------------------------------------------------------------------
# LZC compress
# ---------------------------------------------------------------------------
class _LZCCompress:
def __init__(self, data, code_length_limit=_LZC_DEFAULT_CODE_LENGTH_LIMIT,
block_mode=_LZC_DEFAULT_BLOCK_MODE):
self._reader = _Reader(data)
self._writer = _Writer()
self._code_length_limit = code_length_limit
self._code_limit = 1 << code_length_limit
self._block_mode = block_mode
self._code_length = _LZC_MIN_CODE_LENGTH
self._max_code = (1 << _LZC_MIN_CODE_LENGTH) - 1
self._code_index = {} # sequence_id → code
self._next_free_code = -1
self._latest_ratio = 0
self._next_ratio_check = 0
# Chunk buffer (NOTE-6)
self._chunk_buf = bytearray(_LZC_MAX_CODE_LENGTH)
self._chunk_cursor = 0
self._chunk_pending = 0
self._chunk_pending_count = 0
@property
def _first_sequence_code(self):
if self._block_mode == _LZC_SINGLE_BLOCK:
return _LZC_LITERAL_CODE_COUNT
return _LZC_LITERAL_CODE_COUNT + 1 # +1 for END_BLOCK control code
def _set_code_length(self, value):
self._code_length = value
if value == self._code_length_limit:
self._max_code = self._code_limit
else:
self._max_code = (1 << value) - 1
def _current_ratio(self):
br = self._reader.bytes_read
bw = self._writer.bytes_written
if bw == 0:
return 0
if br <= 0x007fffff:
return (br << 8) // bw
if bw < 0x100:
return 0x7fffffff
return br // (bw >> 8)
def _write_code(self, code):
# NOTE-6: use separate chunk buffer to replicate trash padding bytes
self._chunk_pending |= code << self._chunk_pending_count
self._chunk_pending_count += self._code_length
while self._chunk_pending_count >= 8:
self._chunk_buf[self._chunk_cursor] = self._chunk_pending & 0xff
self._chunk_cursor += 1
self._chunk_pending >>= 8
self._chunk_pending_count -= 8
def _finish_chunk(self):
# NOTE-7: chunk is exactly code_length bytes
self._flush_chunk(self._code_length)
def _flush_chunk(self, chunk_size=None):
if self._chunk_pending_count > 0:
self._chunk_buf[self._chunk_cursor] = self._chunk_pending & 0xff
self._chunk_cursor += 1
end = chunk_size if chunk_size is not None else self._chunk_cursor
self._writer.write(self._chunk_buf[:end])
self._chunk_pending = 0
self._chunk_pending_count = 0
self._chunk_cursor = 0
def _start_new_block(self):
self._write_code(_LZC_CODE_END_BLOCK)
self._finish_chunk()
self._set_code_length(_LZC_MIN_CODE_LENGTH)
self._code_index.clear()
self._next_free_code = self._first_sequence_code
self._latest_ratio = 0
def compress(self):
extra = self._code_length_limit | (self._block_mode << 7)
hdr = _build_header(self._reader.total_length, _HDR_ALG_LZC, _LZC_VERSION, extra)
self._writer.write(hdr)
if self._reader.end_reached:
return self._writer.data
self._next_free_code = self._first_sequence_code
self._next_ratio_check = _LZC_RATIO_CHECK_INTERVAL
next_code = self._reader.read_byte()
while not self._reader.end_reached:
next_byte = self._reader.read_byte()
sequence_id = (next_byte << self._code_length_limit) | next_code
sequence_code = self._code_index.get(sequence_id)
if sequence_code:
next_code = sequence_code
continue
# Emit current code
self._write_code(next_code)
if self._chunk_cursor >= self._code_length: # chunk full
self._finish_chunk()
# Increase code length if needed
if self._next_free_code > self._max_code:
if self._chunk_cursor > 0 or self._chunk_pending_count > 0:
self._finish_chunk()
self._set_code_length(self._code_length + 1)
if self._next_free_code < self._code_limit:
self._code_index[sequence_id] = self._next_free_code
self._next_free_code += 1
elif self._block_mode == _LZC_MULTI_BLOCK and self._reader.bytes_read >= self._next_ratio_check:
ratio = self._current_ratio()
if ratio > self._latest_ratio:
self._latest_ratio = ratio
else:
self._start_new_block()
self._next_ratio_check = self._reader.bytes_read + _LZC_RATIO_CHECK_INTERVAL
next_code = next_byte
# Emit final code
self._write_code(next_code)
self._flush_chunk()
return self._writer.data
@staticmethod
def compress_data(data, code_length_limit=_LZC_DEFAULT_CODE_LENGTH_LIMIT,
block_mode=_LZC_DEFAULT_BLOCK_MODE):
return _LZCCompress(data, code_length_limit, block_mode).compress()
# ---------------------------------------------------------------------------
# LZC decompress
# ---------------------------------------------------------------------------
class _LZCDecompress:
def __init__(self, data, compat_mode=False):
self._reader = _Reader(data)
self._writer = _Writer()
self._compat_mode = compat_mode
# Parsed from header
self._block_mode = _LZC_DEFAULT_BLOCK_MODE
self._code_length_limit = _LZC_DEFAULT_CODE_LENGTH_LIMIT
self._code_limit = 1 << _LZC_DEFAULT_CODE_LENGTH_LIMIT
self._code_length = _LZC_MIN_CODE_LENGTH
self._max_code = (1 << _LZC_MIN_CODE_LENGTH) - 1
self._next_free_code = -1
self._chunk_reader = None # _Reader for current chunk
@property
def _first_sequence_code(self):
if self._block_mode == _LZC_SINGLE_BLOCK:
return _LZC_LITERAL_CODE_COUNT
return _LZC_LITERAL_CODE_COUNT + 1
def _set_code_length(self, value):
self._code_length = value
if value == self._code_length_limit:
self._max_code = self._code_limit
else:
self._max_code = (1 << value) - 1
def _read_header(self):
length, alg_id, alg_version, extra = _parse_header(self._reader.read(_HDR_SIZE))
if alg_id != _HDR_ALG_LZC:
raise DecompressError("unknown algorithm: expected LZC algorithm identifier")
block_mode = extra >> 7
code_length_limit = extra & 0x1f
if not (_LZC_MIN_CODE_LENGTH <= code_length_limit <= _LZC_MAX_CODE_LENGTH):
raise DecompressError("invalid header: code_length_limit out of range")
self._block_mode = block_mode
self._code_length_limit = code_length_limit
self._code_limit = 1 << code_length_limit
return length
def _start_new_chunk(self):
read_size = min(self._code_length, self._reader.bytes_left)
if read_size == 0:
self._chunk_reader = None
return
# Read whatever bytes are available; _read_code will stop when bits run out
self._chunk_reader = _Reader(self._reader.read(read_size))
def _start_new_block(self):
self._next_free_code = self._first_sequence_code
self._set_code_length(_LZC_MIN_CODE_LENGTH)
self._start_new_chunk()
def _read_code(self):
need_new_chunk = (
self._chunk_reader is None
or self._chunk_reader.bits_left < self._code_length
or self._next_free_code > self._max_code
)
if need_new_chunk:
if self._next_free_code > self._max_code:
self._set_code_length(self._code_length + 1)
self._start_new_chunk()
if self._chunk_reader is None or self._chunk_reader.bits_left < self._code_length:
return None
return self._chunk_reader.read_bits(self._code_length)
def decompress(self):
decomp_length = self._read_header()
self._next_free_code = self._first_sequence_code
decomp_left = decomp_length
codes = {} # code → {'base': int, 'next': int, 'chain_index': int}
chain_buf = bytearray(1 << self._code_length_limit)
prev_code = None
prev_code_def = None
while decomp_left > 0:
code = self._read_code()
if self._block_mode == _LZC_MULTI_BLOCK and code == _LZC_CODE_END_BLOCK:
self._start_new_block()
prev_code = None
prev_code_def = None
continue
if self._compat_mode and code is None:
break
if code is None:
raise DecompressError("unexpected end of compressed data")
if code >= self._code_limit:
raise DecompressError("unknown code %d encountered" % code)
chain_length = 0
if code == prev_code:
# Same code repeated - chain still in buffer
chain_length = (prev_code_def['chain_index'] if prev_code_def else 0) + 1
elif code < self._next_free_code:
resolve = code
while resolve > _LZC_LITERAL_CODE_COUNT - 1:
if resolve >= self._next_free_code:
raise DecompressError("unknown code %d encountered" % resolve)
cdef = codes.get(resolve)
if cdef is None:
raise DecompressError("unknown code %d encountered" % resolve)
chain_buf[cdef['chain_index']] = cdef['next']
chain_length += 1
resolve = cdef['base']
chain_buf[0] = resolve
chain_length += 1
elif code == self._next_free_code and prev_code is not None:
# NOTE-8: ababa case
prev_ci = prev_code_def['chain_index'] if prev_code_def else 0
chain_buf[prev_ci + 1] = chain_buf[0]
chain_length = prev_ci + 2
else:
raise DecompressError("unknown code %d encountered" % code)
self._writer.write(chain_buf[:chain_length])
decomp_left -= chain_length
if prev_code is not None and self._next_free_code < self._code_limit:
prev_ci = prev_code_def['chain_index'] if prev_code_def else 0
codes[self._next_free_code] = {
'base': prev_code,
'next': chain_buf[0],
'chain_index': prev_ci + 1,
}
self._next_free_code += 1
prev_code = code
prev_code_def = codes.get(code)
return self._writer.data
@staticmethod
def decompress_data(data, compat_mode=False):
return _LZCDecompress(data, compat_mode).decompress()
# ---------------------------------------------------------------------------
# LZH constants
# ---------------------------------------------------------------------------
_LZH_VERSION = 1
_LZH_HEAD_NOISE_LEN = 2
_LZH_DEFAULT_COMPRESSION_LEVEL = 2
_LZH_TREETYPE_STATIC = 1
_LZH_TREETYPE_DYNAMIC = 2
_LZH_LITLEN_HUFFTREE_MAX_BITS = 15
_LZH_DIST_HUFFTREE_MAX_BITS = 15
_LZH_BITLEN_HUFFTREE_MAX_BITS = 7
_LZH_WINDOW_SIZE = 0x4000
_LZH_MIN_MATCH = 3
_LZH_MAX_MATCH = 258
_LZH_HASH_SIZE = 14
_LZH_HASH_MASK = (1 << _LZH_HASH_SIZE) - 1
_LZH_HASH_SHIFT = (_LZH_HASH_SIZE + _LZH_MIN_MATCH - 1) // _LZH_MIN_MATCH
_LZH_MIN_LOOKAHEAD = _LZH_MAX_MATCH + _LZH_MIN_MATCH + 1
_LZH_MAX_DISTANCE = _LZH_WINDOW_SIZE - _LZH_MIN_LOOKAHEAD
_LZH_MAX_DISTANCE_3 = 4096
_LZH_MIN_DISTANCE = 1
_LZH_LITLEN_COUNT = 286
_LZH_DIST_COUNT = 30
_LZH_BITLEN_COUNT = 19
_LZH_END_BLOCK = 256
_LZH_LIT_LAST = 255
_LZH_LENGTH_FIRST = 257 # first length code in litlen alphabet
_LZH_LEN_EXTRA = [
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 99, 99
]
_LZH_DIST_EXTRA = [
0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13
]
_LZH_BITLEN_EXTRA = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 7
]
_LZH_BITLEN_RANKING = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]
_LZH_BITLEN_REPEAT_NZ_3_6 = 16
_LZH_BITLEN_REPEAT_NZ_3_6_ST = 3
_LZH_BITLEN_REPEAT_Z_3_10 = 17
_LZH_BITLEN_REPEAT_Z_3_10_ST = 3
_LZH_BITLEN_REPEAT_Z_11_138 = 18
_LZH_BITLEN_REPEAT_Z_11_138_ST = 11
_LZH_COMPRESSION_LEVELS = [
{'level': 0, 'good_length': 0, 'max_lazy': 0, 'max_chain': 0},
{'level': 1, 'good_length': 4, 'max_lazy': 4, 'max_chain': 16},
{'level': 2, 'good_length': 6, 'max_lazy': 8, 'max_chain': 16},
{'level': 3, 'good_length': 8, 'max_lazy': 16, 'max_chain': 32},
{'level': 4, 'good_length': 8, 'max_lazy': 32, 'max_chain': 64},
{'level': 5, 'good_length': 8, 'max_lazy': 64, 'max_chain': 128},
{'level': 6, 'good_length': 8, 'max_lazy': 128, 'max_chain': 256},
{'level': 7, 'good_length': 8, 'max_lazy': 128, 'max_chain': 512},
{'level': 8, 'good_length': 32, 'max_lazy': 258, 'max_chain': 1024},
{'level': 9, 'good_length': 32, 'max_lazy': 258, 'max_chain': 4096},
]
# ---------------------------------------------------------------------------
# LZH Huffman tree
# ---------------------------------------------------------------------------
def _reverse_int(value, length):
"""Reverse the low `length` bits of `value`."""
rev = 0
for _ in range(length):
rev = (rev << 1) | (value & 1)
value >>= 1
return rev
class _HuffTree:
"""Huffman tree node container.
Each node is a dict with keys: value, code, code_length.
For encode nodes additionally: occurrence, depth, synthetic, padding_node.
"""
def __init__(self, nodes):
self.nodes = nodes
self._lookup = None
self._max_bits = None
self._bitlen_seq = None
def get_max_code_length(self):
if self._max_bits is None:
self._max_bits = max((n['code_length'] for n in self.nodes), default=0)
return self._max_bits
def get_highest_assigned_node(self):
for n in reversed(self.nodes):
if n['code_length'] > 0:
return n
raise CompressError("empty Huffman tree: no assigned node")
# ------------------------------------------------------------------
# Encoding
# ------------------------------------------------------------------
def write_node(self, node_idx, writer):
node = self.nodes[node_idx]
writer.write_bits(node['code'], node['code_length'])
def calculate_encoded_size(self, occurrences, extra_bits_lookup, extra_bits_start):
total = 0
for idx, occ in enumerate(occurrences):
node = self.nodes[idx]
extra = extra_bits_lookup[idx - extra_bits_start] if idx >= extra_bits_start else 0
total += occ * (node['code_length'] + extra)
return total
def get_bitlen_sequence(self):
if self._bitlen_seq is None:
self._bitlen_seq = list(self._generate_bitlen_sequence())
return self._bitlen_seq
def write_bitlen_occurrences(self, occ):
for entry in self.get_bitlen_sequence():
occ[entry['code']] += 1
def _generate_bitlen_sequence(self):
nodes = self.nodes
last_node = self.get_highest_assigned_node()
last_idx = last_node['value']
def next_match_bounds(curlen, nextlen):
if nextlen == 0:
return 3, 138
elif curlen == nextlen:
return 3, 6
else:
return 4, 7
prevlen = -1
streak = 0
min_match, max_match = next_match_bounds(-1, nodes[0]['code_length'])
for i in range(last_idx + 1):
curlen = nodes[i]['code_length']
nextlen = nodes[i + 1]['code_length'] if i < last_idx else -1
streak += 1
if curlen == nextlen and streak < max_match:
continue
if streak < min_match:
for _ in range(streak):
yield {'type': 'single', 'code': curlen, 'encoded_value': -1}
elif curlen != 0:
if curlen != prevlen:
yield {'type': 'single', 'code': curlen, 'encoded_value': -1}
yield {'type': 'repeat-nz', 'code': _LZH_BITLEN_REPEAT_NZ_3_6, 'encoded_value': (streak - 1) - _LZH_BITLEN_REPEAT_NZ_3_6_ST}
else:
yield {'type': 'repeat-nz', 'code': _LZH_BITLEN_REPEAT_NZ_3_6, 'encoded_value': streak - _LZH_BITLEN_REPEAT_NZ_3_6_ST}
else:
if streak <= 10:
yield {'type': 'repeat-z3', 'code': _LZH_BITLEN_REPEAT_Z_3_10, 'encoded_value': streak - _LZH_BITLEN_REPEAT_Z_3_10_ST}
else:
yield {'type': 'repeat-z11', 'code': _LZH_BITLEN_REPEAT_Z_11_138, 'encoded_value': streak - _LZH_BITLEN_REPEAT_Z_11_138_ST}
streak = 0
prevlen = curlen
min_match, max_match = next_match_bounds(curlen, nextlen)
def encode_to(self, writer, bitlen_tree):
for entry in self.get_bitlen_sequence():
bitlen_tree.write_node(entry['code'], writer)
extra_bits = _LZH_BITLEN_EXTRA[entry['code']]
if extra_bits > 0:
writer.write_bits(entry['encoded_value'], extra_bits)
# ------------------------------------------------------------------
# Decoding
# ------------------------------------------------------------------
def _build_lookup(self):
max_bits = self.get_max_code_length()
length_lookup = [0] * (1 << max_bits)
node_lookup = [None] * (max_bits + 1)
for n in self.nodes:
cl = n['code_length']
if cl <= 0:
continue
if node_lookup[cl] is None:
node_lookup[cl] = {}
code = n['code']
node_lookup[cl][code] = n
# Fill length_lookup for all extensions of this code
ext_count = 1 << (max_bits - cl)
for i in range(ext_count):
length_lookup[code | (i << cl)] = cl
self._lookup = (length_lookup, node_lookup)
def lookup_code(self, code, code_length=None):
if self._lookup is None:
self._build_lookup()
length_lookup, node_lookup = self._lookup
cl = length_lookup[code] if code < len(length_lookup) else 0
if code_length is not None:
cl = min(cl, code_length)
if cl == 0:
return None
code = code & ((1 << cl) - 1)
m = node_lookup[cl]
return m.get(code) if m else None
def read_code(self, reader):
max_bits = self.get_max_code_length()
raw = reader.peek_bits(max_bits)
node = self.lookup_code(raw)
if node is None:
raise DecompressError("bad hufman tree: no node for code 0x%x" % raw)
reader.skip_bits(node['code_length'])
return node
def read_code_two_staged(self, reader, first_stage_len):
# NOTE-5
max_bits = self.get_max_code_length()
if first_stage_len >= max_bits:
return self.read_code(reader)
first = reader.peek_bits(first_stage_len)
node = self.lookup_code(first, first_stage_len)
if node is None:
second = reader.peek_bits(max_bits)
node = self.lookup_code(second)
if node is None:
raise DecompressError("bad hufman tree: no node for code")
reader.skip_bits(node['code_length'])
return node
# ------------------------------------------------------------------
# Static constructors
# ------------------------------------------------------------------
@staticmethod
def from_distribution(dist):
"""Build a decode tree from a list of code-length values (index=value)."""
nodes = [{'value': i, 'code': -1, 'code_length': cl} for i, cl in enumerate(dist)]
_HuffTree._generate_node_codes(nodes)
return _HuffTree(nodes)
@staticmethod
def _generate_node_codes(leaf_nodes, dist=None):
if dist is None:
dist = _HuffTree._make_distribution(leaf_nodes)
_HuffTree._validate_distribution(dist)
next_codes = [0] * (len(dist) + 1)
code = 0
for cl in range(1, len(dist) + 1):
code = (code + (dist[cl - 1] if cl - 1 < len(dist) else 0)) << 1
next_codes[cl] = code
for n in leaf_nodes:
cl = n['code_length']
if cl <= 0:
continue
n['code'] = _reverse_int(next_codes[cl], cl)
next_codes[cl] += 1
@staticmethod
def _make_distribution(nodes):
dist = []
for n in nodes:
cl = n['code_length']
if cl <= 0:
continue
while len(dist) <= cl:
dist.append(0)
dist[cl] += 1
return dist
@staticmethod
def _validate_distribution(dist):
available = 1
for count in dist:
available -= count
if available < 0:
raise DecompressError("bad hufman tree: invalid code-length distribution")
available *= 2
if available > 0:
raise DecompressError("bad hufman tree: code-length distribution has unassigned codes")
class _HuffHeapTree:
"""Min-heap of Huffman nodes ordered by (occurrence, depth)."""
def __init__(self, nodes):
# 1-based; index 0 is unused placeholder
self.heap = [None] + nodes
@property
def length(self):
return len(self.heap) - 1
def _cmp(self, a, b):
"""Return negative if a < b, 0 if equal, positive if a > b."""
if a['occurrence'] != b['occurrence']:
return a['occurrence'] - b['occurrence']
return a['depth'] - b['depth']
def _update(self, idx):
base = self.heap[idx]
while True:
c1 = idx << 1
if c1 >= len(self.heap):
break
c2 = c1 + 1
# When tied, prefer c2 (NOTE in TS: "When in a tie, child2 is used")
smaller = c2 if (c2 < len(self.heap) and self._cmp(self.heap[c1], self.heap[c2]) >= 0) else c1
if self._cmp(base, self.heap[smaller]) <= 0:
break
self.heap[idx] = self.heap[smaller]
idx = smaller
self.heap[idx] = base
def init(self):
for i in range(self.length // 2, 0, -1):
self._update(i)
def pop(self):
if self.length == 0:
raise CompressError("cannot pop from empty heap")
if self.length == 1:
return self.heap.pop()
top = self.heap[1]
self.heap[1] = self.heap.pop()
self._update(1)
return top
def create_huff_node(self):
if self.length < 2:
raise CompressError("heap has fewer than 2 elements")
top = self.pop()
top2 = self.heap[1]
syn = {
'synthetic': True,
'occurrence': top['occurrence'] + top2['occurrence'],
'depth': max(top['depth'], top2['depth']) + 1,
'code_length': -1,
'parent': None,
'children': (top, top2),
}
top['parent'] = syn
top2['parent'] = syn
self.heap[1] = syn
self._update(1)
return syn
class _HuffTreeEncoder:
@staticmethod
def build_tree(occurrences, max_code_length):
enc = _HuffTreeEncoder()
return enc._build(occurrences, max_code_length)
def _build(self, occurrences, max_code_length):
base_nodes = self._make_bare_nodes(occurrences)
populated = [n for n in base_nodes if n['occurrence'] > 0]
self._pad_nodes(populated, base_nodes)
tree_nodes = self._create_huffman_nodes(populated)
dist = self._arrange_nodes(tree_nodes, max_code_length)
_HuffTree._generate_node_codes(populated, dist)
return _HuffTree(base_nodes)
def _make_bare_nodes(self, occurrences):
return [
{
'value': i,
'occurrence': occ,
'synthetic': False,
'padding_node': False,
'depth': 0,
'parent': None,
'code': -1,
'code_length': 0,
}
for i, occ in enumerate(occurrences)
]
def _pad_nodes(self, nodes, base_nodes):
# NOTE-12
while len(nodes) < 2:
if len(nodes) == 0:
next_pad = 0
else:
next_pad = nodes[0]['value'] + 1 if nodes[0]['value'] < 2 else 0
pad = base_nodes[next_pad]
pad['occurrence'] = 1
pad['padding_node'] = True
nodes.append(pad)
nodes.sort(key=lambda n: n['value'])
def _create_huffman_nodes(self, base_nodes):
heap = _HuffHeapTree(list(base_nodes))
heap.init()
tree_nodes = []
while heap.length >= 2:
syn = heap.create_huff_node()
tree_nodes.extend(syn['children'])
tree_nodes.append(heap.heap[1])
return tree_nodes
def _arrange_nodes(self, tree_nodes, max_code_length):
# Root is the last element, with depth 0
tree_nodes[-1]['code_length'] = 0
overflow = 0
dist = [0] * (max_code_length + 1)
for i in range(len(tree_nodes) - 2, -1, -1):
node = tree_nodes[i]
cl = node['parent']['code_length'] + 1
if cl > max_code_length:
cl = max_code_length
overflow += 1
node['code_length'] = cl
if not node['synthetic']:
dist[cl] += 1
if overflow:
# NOTE-4: rearrange overflowed nodes
while overflow > 0:
next_free = -1
for i in range(max_code_length - 1, -1, -1):
if dist[i] > 0:
next_free = i
break
if next_free < 0:
raise CompressError("no space to fit overflow nodes into Huffman tree")
dist[next_free] -= 1
dist[next_free + 1] += 1
dist[max_code_length] -= 1
dist[next_free + 1] += 1
overflow -= 2
# Re-assign code lengths to natural nodes
natural = [n for n in tree_nodes if not n['synthetic']]
nat_iter = iter(natural)
for layer in range(max_code_length, 0, -1):
for _ in range(dist[layer]):
try:
next(nat_iter)['code_length'] = layer
except StopIteration:
raise CompressError("ran out of leaf nodes during Huffman rearrangement")
return dist
# ---------------------------------------------------------------------------
# LZH static trees (cached at class level)
# ---------------------------------------------------------------------------
class _LZHBase:
_len_code_map = None # (codeLookup: bytes, valueStartLookup: list)
_dist_code_map = None
_static_litlen = None
_static_dist = None
@classmethod
def get_length_code_mapping(cls):
if cls._len_code_map is None:
cls._len_code_map = cls._gen_length_code_mapping()
return cls._len_code_map
@staticmethod
def _gen_length_code_mapping():
lookup = bytearray(259) # index 0-258
starts = [0] * 29
length = _LZH_MIN_MATCH
for code in range(28):
starts[code] = length
count = 1 << _LZH_LEN_EXTRA[code]
for _ in range(count):
lookup[length] = code
length += 1
# NOTE-3: length 258 gets its own code
lookup[_LZH_MAX_MATCH] = 28
starts[28] = _LZH_MAX_MATCH
return lookup, starts
@classmethod
def get_distance_code_mapping(cls):
if cls._dist_code_map is None:
cls._dist_code_map = cls._gen_distance_code_mapping()
return cls._dist_code_map
@staticmethod
def _gen_distance_code_mapping():
lookup = bytearray(32769)
starts = [0] * 30
dist = 1
for code in range(30):
starts[code] = dist
count = 1 << _LZH_DIST_EXTRA[code]
for _ in range(count):
lookup[dist] = code
dist += 1
return lookup, starts
@classmethod
def get_static_litlen_tree(cls):
if cls._static_litlen is None:
cls._static_litlen = cls._gen_static_litlen_tree()
return cls._static_litlen
@staticmethod
def _gen_static_litlen_tree():
# 288 nodes needed for canonical tree generation (codes 286-287 are dummies)
nodes = []
for c in range(144):
nodes.append({'value': c, 'code': -1, 'code_length': 8})
for c in range(144, 256):
nodes.append({'value': c, 'code': -1, 'code_length': 9})
for c in range(256, 280):
nodes.append({'value': c, 'code': -1, 'code_length': 7})
for c in range(280, 288):
nodes.append({'value': c, 'code': -1, 'code_length': 8})
_HuffTree._generate_node_codes(nodes)
return _HuffTree(nodes)
@classmethod
def get_static_dist_tree(cls):
if cls._static_dist is None:
cls._static_dist = cls._gen_static_dist_tree()
return cls._static_dist
@staticmethod
def _gen_static_dist_tree():
nodes = [
{'value': c, 'code_length': 5, 'code': _reverse_int(c, 5)}
for c in range(30)
]
return _HuffTree(nodes)
# ---------------------------------------------------------------------------
# LZH LZSS matcher
# ---------------------------------------------------------------------------
class _LZSSMatcher:
"""Sliding-window LZSS byte-stream matcher."""
def __init__(self, reader, config):
self._reader = reader
self._config = config
self._window = bytearray(_LZH_WINDOW_SIZE * 2)
self._win_off = 0
self._win_cur = 0
self._win_end = 0
self._win_sealed = False
self._hash_index = {} # hash → local_cursor
self._hash_history = array('H', bytes(2 * _LZH_WINDOW_SIZE)) # 16-bit entries
self._cur_hash = 0
self._nearest = None # global position of nearest match
def _hash(self, base, next_byte):
return ((base << _LZH_HASH_SHIFT) ^ next_byte) & _LZH_HASH_MASK
def _populate_window(self):
if self._win_sealed:
raise RuntimeError("window is sealed")
win_used = self._win_end - self._win_off
win_free = len(self._window) - win_used
if win_free == 0:
self._shift_window()
win_free = _LZH_WINDOW_SIZE
if self._reader.end_reached:
self._win_sealed = True
return
count = min(win_free, self._reader.bytes_left)
local_end = self._win_end - self._win_off
self._window[local_end:local_end + count] = self._reader.read(count)
self._win_end += count
def _shift_window(self):
ws = _LZH_WINDOW_SIZE
self._window[:ws] = self._window[ws:ws * 2]
self._win_off += ws
# Adjust hash_index
self._hash_index = {
k: v - ws
for k, v in self._hash_index.items()
if v >= ws
}
# Adjust hash_history
hist = self._hash_history
for n in range(ws):
m = hist[n]
hist[n] = (m - ws) if m >= ws else 0
def _move_cursor_to(self, pos):
win_off = self._win_off
h = self._cur_hash
while self._win_cur < pos:
self._win_cur += 1
if h < 0:
continue
local = self._win_cur - win_off
h = self._hash(h, self._window[local + _LZH_MIN_MATCH - 1])
prev = self._hash_index.get(h)
self._nearest = (prev + win_off) if prev is not None else None
self._hash_index[h] = local
self._hash_history[local % _LZH_WINDOW_SIZE] = prev if prev is not None else 0
self._cur_hash = h
while (self._win_end - self._win_cur < _LZH_MIN_LOOKAHEAD
and not self._win_sealed):
self._populate_window()
def _next_match(self, min_length=None):
if min_length is None:
min_length = _LZH_MIN_MATCH
if self._win_cur >= self._win_end:
return None
near = self._nearest
if (near is None
or self._win_cur - near > _LZH_MAX_DISTANCE
or self._win_end - self._win_cur < min_length):
return {'cursor': self._win_cur, 'distance': 0, 'length': 1}
win_off = self._win_off
bound = max(self._win_cur - _LZH_MAX_DISTANCE, 0)
max_hops = (self._config['max_chain'] // 4
if min_length > self._config['good_length']
else self._config['max_chain'])
local_cur = self._win_cur - win_off
win = self._window
hist = self._hash_history
occ_start = near
local_occ_start = occ_start - win_off
hops = 0
best = None
while True:
# Quick reject (NOTE-10: use 0 for out-of-bounds reads)
wc0 = win[local_cur] if local_cur < len(win) else 0
wo0 = win[local_occ_start] if local_occ_start < len(win) else 0
wc1 = win[local_cur + min_length - 1] if (local_cur + min_length - 1) < len(win) else 0
wo1 = win[local_occ_start + min_length - 1] if (local_occ_start + min_length - 1) < len(win) else 0
wc2 = win[local_cur + min_length - 2] if (local_cur + min_length - 2) < len(win) else 0
wo2 = win[local_occ_start + min_length - 2] if (local_occ_start + min_length - 2) < len(win) else 0
if wc0 == wo0 and wc1 == wo1 and wc2 == wo2:
match_len = 1
while (match_len < _LZH_MAX_MATCH
and (win[local_cur + match_len] if (local_cur + match_len) < len(win) else 0)
== (win[local_occ_start + match_len] if (local_occ_start + match_len) < len(win) else 0)):
match_len += 1
if match_len >= min_length:
best = {'cursor': self._win_cur, 'distance': self._win_cur - occ_start, 'length': match_len}
if match_len >= _LZH_MAX_MATCH:
break
min_length = match_len + 1
local_occ_start = hist[local_occ_start % _LZH_WINDOW_SIZE]
occ_start = local_occ_start + win_off
hops += 1
if hops >= max_hops or occ_start <= bound:
break
if best is None:
return {'cursor': self._win_cur, 'distance': 0, 'length': 1}
# Cap to actual remaining bytes (NOTE-10)
best['length'] = min(best['length'], self._win_end - self._win_cur)
# Discard short distant matches
if best['length'] == _LZH_MIN_MATCH and best['distance'] > _LZH_MAX_DISTANCE_3:
return {'cursor': self._win_cur, 'distance': 0, 'length': 1}
return best
def _resolve(self, m):
local = m['cursor'] - self._win_off
return {
'distance': m['distance'],
'length': m['length'],
'first_byte': self._window[local],
}
def match(self):
self._populate_window()
# Prime the hash with MIN_MATCH - 1 bytes ahead
for i in range(_LZH_MIN_MATCH):
self._cur_hash = self._hash(self._cur_hash, self._window[i])
self._hash_index[self._cur_hash] = 0
while self._win_cur < self._win_end:
cur = self._next_match()
while cur and (cur['length'] == 1 or cur['length'] >= self._config['max_lazy']):
yield self._resolve(cur)
self._move_cursor_to(cur['cursor'] + cur['length'])
cur = self._next_match()
while cur:
self._move_cursor_to(self._win_cur + 1)
pending = cur
cur = self._next_match(pending['length'] + 1)
if not cur or pending['length'] >= cur['length']:
yield self._resolve(pending)
self._move_cursor_to(pending['cursor'] + pending['length'])
cur = None
else:
yield self._resolve({'cursor': pending['cursor'], 'distance': 0, 'length': 1})
if cur['length'] >= self._config['max_lazy']:
yield self._resolve(cur)
self._move_cursor_to(cur['cursor'] + cur['length'])
cur = None
class _LZSSBlockMatcher:
"""Groups LZSS matches into encoder blocks."""
def __init__(self, reader, config):
self._reader = reader
self._config = config
def _should_end_block(self, blk):
mc = blk['match_count']
dmc = blk['dist_match_count']
if mc == 0x3fff or dmc == 0x4000:
return True
if (self._config['level'] > 2
and (mc & 0xfff) == 0
and dmc < (mc >> 1)
and (blk['compressed_tracker'] >> 3) < (blk['uncompressed_size'] >> 1)):
return True
return False
def match(self):
_, dist_starts = _LZHBase.get_distance_code_mapping()
dist_lookup, _ = _LZHBase.get_distance_code_mapping()
def new_block():
return {'matches': [], 'match_count': 0, 'dist_match_count': 0,
'uncompressed_size': 1, # off-by-one replication from original
'compressed_tracker': 0}
blk = new_block()
for m in _LZSSMatcher(self._reader, self._config).match():
blk['matches'].append(m)
blk['match_count'] += 1
if m['distance'] == 0:
blk['compressed_tracker'] += 8
else:
blk['dist_match_count'] += 1
dc = dist_lookup[m['distance']]
blk['compressed_tracker'] += 8 + 5 + _LZH_DIST_EXTRA[dc]
if self._should_end_block(blk):
yield {'matches': blk['matches'], 'block_finished': True}
blk = new_block()
continue
blk['uncompressed_size'] += m['length']
if blk['matches']:
yield {'matches': blk['matches'], 'block_finished': False}
# ---------------------------------------------------------------------------
# LZH compress
# ---------------------------------------------------------------------------
class _LZHCompress(_LZHBase):
def __init__(self, data, level):
self._reader = _Reader(data)
self._writer = _Writer()
self._config = _LZH_COMPRESSION_LEVELS[level]
def compress(self):
self._write_head()
len_lookup, len_starts = self.get_length_code_mapping()
dist_lookup, dist_starts = self.get_distance_code_mapping()
open_block = None
for block in _LZSSBlockMatcher(self._reader, self._config).match():
if block['block_finished']:
self._write_match_block(block['matches'], final=False,
len_lookup=len_lookup, len_starts=len_starts,
dist_lookup=dist_lookup, dist_starts=dist_starts)
else:
open_block = block
self._write_match_block(
open_block['matches'] if open_block else [],
final=True,
len_lookup=len_lookup, len_starts=len_starts,
dist_lookup=dist_lookup, dist_starts=dist_starts,
)
self._writer.flush_pending_bits()
# NOTE-5: courtesy zero byte
self._writer.write_byte(0)
return self._writer.data
def _write_head(self):
hdr = _build_header(self._reader.total_length, _HDR_ALG_LZH, _LZH_VERSION,
self._config['level'])
self._writer.write(hdr)
# NOTE-16: fixed noise bits (3 bits of noise, value 7)
self._writer.write_bits(3, _LZH_HEAD_NOISE_LEN)
self._writer.write_bits(7, 3)
def _write_match_block(self, matches, final, len_lookup, len_starts,
dist_lookup, dist_starts):
# Count occurrences
litlen_occ = [0] * _LZH_LITLEN_COUNT
dist_occ = [0] * _LZH_DIST_COUNT
litlen_occ[_LZH_END_BLOCK] = 1
for m in matches:
if m['distance'] == 0:
litlen_occ[m['first_byte']] += 1
else:
lc = len_lookup[m['length']]
litlen_occ[lc + 257] += 1
dc = dist_lookup[m['distance']]
dist_occ[dc] += 1
litlen_tree = _HuffTreeEncoder.build_tree(litlen_occ, _LZH_LITLEN_HUFFTREE_MAX_BITS)
dist_tree = _HuffTreeEncoder.build_tree(dist_occ, _LZH_DIST_HUFFTREE_MAX_BITS)
bitlen_occ = [0] * _LZH_BITLEN_COUNT
litlen_tree.write_bitlen_occurrences(bitlen_occ)
dist_tree.write_bitlen_occurrences(bitlen_occ)
bitlen_tree = _HuffTreeEncoder.build_tree(bitlen_occ, _LZH_BITLEN_HUFFTREE_MAX_BITS)
ranked = self._rank_bitlen_nodes(bitlen_tree)
hi_ranked = max((i for i, n in enumerate(ranked) if n['code_length'] > 0), default=-1)
if hi_ranked < 3:
raise CompressError("max_bitlen_index cannot be below 3")
static_ll = self.get_static_litlen_tree()
static_d = self.get_static_dist_tree()
# NOTE in TS: static tree size calculation uses LEN extra bits for both
# litlen and dist (a bug kept for compatibility)
static_bits = (
static_ll.calculate_encoded_size(litlen_occ, _LZH_LEN_EXTRA, 257)
+ static_d.calculate_encoded_size(dist_occ, _LZH_LEN_EXTRA, 0)
)
static_bytes = (static_bits + 3 + 7) >> 3
dyn_bits = (
litlen_tree.calculate_encoded_size(litlen_occ, _LZH_LEN_EXTRA, 257)
+ dist_tree.calculate_encoded_size(dist_occ, _LZH_LEN_EXTRA, 0)
+ bitlen_tree.calculate_encoded_size(bitlen_occ, _LZH_BITLEN_EXTRA, 0)
+ 3 * (hi_ranked + 1) + 5 + 5 + 4
)
dyn_bytes = (dyn_bits + 3 + 7) >> 3
w = self._writer
w.write_bits(1 if final else 0, 1)
if static_bytes <= dyn_bytes:
w.write_bits(_LZH_TREETYPE_STATIC, 2)
self._write_matches(matches, static_ll, static_d,
len_lookup, len_starts, dist_lookup, dist_starts)
static_ll.write_node(_LZH_END_BLOCK, w)
else:
w.write_bits(_LZH_TREETYPE_DYNAMIC, 2)
self._write_dynamic_trees(litlen_tree, dist_tree, bitlen_tree)
self._write_matches(matches, litlen_tree, dist_tree,
len_lookup, len_starts, dist_lookup, dist_starts)
litlen_tree.write_node(_LZH_END_BLOCK, w)
def _rank_bitlen_nodes(self, bitlen_tree):
return [bitlen_tree.nodes[_LZH_BITLEN_RANKING[i]] for i in range(_LZH_BITLEN_COUNT)]
def _write_dynamic_trees(self, litlen_tree, dist_tree, bitlen_tree):
w = self._writer
hi_lit = litlen_tree.get_highest_assigned_node()['value']
hi_dist = dist_tree.get_highest_assigned_node()['value']
w.write_bits(hi_lit - 256, 5)
w.write_bits(hi_dist, 5)
ranked = self._rank_bitlen_nodes(bitlen_tree)
hi_ranked = max((i for i, n in enumerate(ranked) if n['code_length'] > 0), default=-1)
w.write_bits(hi_ranked - 3, 4)
for i in range(hi_ranked + 1):
w.write_bits(ranked[i]['code_length'], 3)
litlen_tree.encode_to(w, bitlen_tree)
dist_tree.encode_to(w, bitlen_tree)
def _write_matches(self, matches, litlen_tree, dist_tree,
len_lookup, len_starts, dist_lookup, dist_starts):
w = self._writer
for m in matches:
if m['distance'] == 0:
litlen_tree.write_node(m['first_byte'], w)
else:
lc = len_lookup[m['length']]
litlen_tree.write_node(lc + 257, w)
len_extra = _LZH_LEN_EXTRA[lc]
if len_extra:
w.write_bits(m['length'] - len_starts[lc], len_extra)
dc = dist_lookup[m['distance']]
dist_tree.write_node(dc, w)
dist_extra = _LZH_DIST_EXTRA[dc]
if dist_extra:
w.write_bits(m['distance'] - dist_starts[dc], dist_extra)
@staticmethod
def compress_data(data, level=_LZH_DEFAULT_COMPRESSION_LEVEL):
return _LZHCompress(data, level).compress()
# ---------------------------------------------------------------------------
# LZH decompress
# ---------------------------------------------------------------------------
class _LZHDecompress(_LZHBase):
def __init__(self, data):
self._reader = _Reader(data)
self._writer = _Writer()
self._dec_buf = bytearray(_LZH_WINDOW_SIZE)
self._dec_cursor = 0
def decompress(self):
self._read_head()
while True:
last_block = self._reader.read_bits(1)
block_type = self._reader.read_bits(2)
if block_type == _LZH_TREETYPE_STATIC:
self._read_static_block()
elif block_type == _LZH_TREETYPE_DYNAMIC:
self._read_dynamic_block()
else:
raise DecompressError("unknown block type 0x%x" % block_type)
if last_block:
break
return self._writer.data
def _read_head(self):
hdr_bytes = self._reader.read(_HDR_SIZE)
length, alg_id, _, _ = _parse_header(hdr_bytes)
if alg_id != _HDR_ALG_LZH:
raise DecompressError("unknown algorithm: expected LZH algorithm identifier")
noise_count = self._reader.read_bits(_LZH_HEAD_NOISE_LEN)
if noise_count:
self._reader.skip_bits(noise_count)
def _read_static_block(self):
self._read_block_content(
self.get_static_litlen_tree(),
self.get_static_dist_tree(),
)
def _read_dynamic_block(self):
r = self._reader
litlen_count = 257 + r.read_bits(5)
dist_count = 1 + r.read_bits(5)
bitlen_count = 4 + r.read_bits(4)
if litlen_count > _LZH_LITLEN_COUNT:
raise DecompressError("invalid litlen code count %d" % litlen_count)
if dist_count > _LZH_DIST_COUNT:
raise DecompressError("invalid dist code count %d" % dist_count)
bitlen_dist = [0] * _LZH_BITLEN_COUNT
for i in range(bitlen_count):
bitlen_dist[_LZH_BITLEN_RANKING[i]] = r.read_bits(3)
bitlen_tree = _HuffTree.from_distribution(bitlen_dist)
litlen_cl = self._read_encoded_lengths(r, bitlen_tree, litlen_count)
litlen_tree = _HuffTree.from_distribution(litlen_cl)
dist_cl = self._read_encoded_lengths(r, bitlen_tree, dist_count)
dist_tree = _HuffTree.from_distribution(dist_cl)
self._read_block_content(litlen_tree, dist_tree)
@staticmethod
def _read_encoded_lengths(reader, bitlen_tree, count):
codes = []
lastlen = -1
while len(codes) < count:
node = bitlen_tree.read_code(reader)
bl_code = node['value']
if bl_code < 16:
codes.append(bl_code)
lastlen = bl_code
elif bl_code == _LZH_BITLEN_REPEAT_NZ_3_6:
rep = _LZH_BITLEN_REPEAT_NZ_3_6_ST + reader.read_bits(_LZH_BITLEN_EXTRA[bl_code])
if rep > count - len(codes):
raise DecompressError("repeat code overflows expected code count")
codes.extend([lastlen] * rep)
elif bl_code == _LZH_BITLEN_REPEAT_Z_3_10:
rep = _LZH_BITLEN_REPEAT_Z_3_10_ST + reader.read_bits(_LZH_BITLEN_EXTRA[bl_code])
if rep > count - len(codes):
raise DecompressError("repeat code overflows expected code count")
codes.extend([0] * rep)
lastlen = 0
elif bl_code == _LZH_BITLEN_REPEAT_Z_11_138:
rep = _LZH_BITLEN_REPEAT_Z_11_138_ST + reader.read_bits(_LZH_BITLEN_EXTRA[bl_code])
if rep > count - len(codes):
raise DecompressError("repeat code overflows expected code count")
codes.extend([0] * rep)
lastlen = 0
else:
raise DecompressError("unknown bitlen code %d" % bl_code)
return codes
def _read_block_content(self, litlen_tree, dist_tree):
r = self._reader
w = self._writer
buf = self._dec_buf
eob = litlen_tree.nodes[_LZH_END_BLOCK]
if not eob or eob['code_length'] == 0:
raise DecompressError("litlen tree has no end-of-block code")
len_lookup, len_starts = self.get_length_code_mapping()
dist_lookup, dist_starts = self.get_distance_code_mapping()
while True:
node = litlen_tree.read_code_two_staged(r, eob['code_length'])
val = node['value']
if val <= _LZH_LIT_LAST:
buf[self._dec_cursor] = val
w.write_byte(val)
self._dec_cursor = (self._dec_cursor + 1) % _LZH_WINDOW_SIZE
continue
if val == _LZH_END_BLOCK:
break
# Length-distance back reference
lc = val - _LZH_LENGTH_FIRST
len_extra = _LZH_LEN_EXTRA[lc]
length = len_starts[lc] + r.read_bits(len_extra)
if length > _LZH_MAX_MATCH:
raise DecompressError("invalid match length %d" % length)
dist_node = dist_tree.read_code(r)
dc = dist_node['value']
dist_extra = _LZH_DIST_EXTRA[dc]
distance = dist_starts[dc] + r.read_bits(dist_extra)
if distance > _LZH_MAX_DISTANCE:
raise DecompressError("invalid match distance %d" % distance)
copy_left = length
copy_start = (self._dec_cursor - distance + _LZH_WINDOW_SIZE) % _LZH_WINDOW_SIZE
while copy_left > 0:
slide_left = _LZH_WINDOW_SIZE - max(copy_start, self._dec_cursor)
copy_length = min(slide_left, copy_left)
curs = self._dec_cursor
if curs > copy_start and curs > copy_start + copy_length or curs < copy_start:
buf[curs:curs + copy_length] = buf[copy_start:copy_start + copy_length]
else:
for i in range(copy_length):
buf[curs + i] = buf[copy_start + i]
w.write(buf[curs:curs + copy_length])
self._dec_cursor = (curs + copy_length) % _LZH_WINDOW_SIZE
copy_left -= copy_length
copy_start = (copy_start + copy_length) % _LZH_WINDOW_SIZE
@staticmethod
def decompress_data(data):
return _LZHDecompress(data).decompress()
# ---------------------------------------------------------------------------
# Public API (compatible with the old C pysapcompress extension)
# ---------------------------------------------------------------------------
[docs]
def compress(data, algorithm=ALG_LZC):
"""Compress *data* using the given SAP algorithm.
:param bytes data: data to compress
:param int algorithm: ALG_LZC (0) or ALG_LZH (2)
:returns: tuple (status, compressed_length, compressed_bytes)
:rtype: tuple[int, int, bytes]
:raises CompressError: on error
"""
if not data:
raise CompressError("Compression error (CS_E_IN_BUFFER_LEN: invalid input length)")
try:
if algorithm == ALG_LZC:
out = _LZCCompress.compress_data(bytes(data))
elif algorithm == ALG_LZH:
out = _LZHCompress.compress_data(bytes(data))
else:
raise CompressError("Compression error (CS_E_UNKNOWN_ALG: unknown algorithm)")
except CompressError:
raise
except Exception as exc:
raise CompressError("Compression error (%s)" % exc) from exc
return _CS_END_OF_STREAM, len(out), out
[docs]
def decompress(data, out_length):
"""Decompress *data*, expecting *out_length* uncompressed bytes.
:param bytes data: compressed data (including 8-byte SAP header)
:param int out_length: expected uncompressed length
:returns: tuple (status, decompressed_length, decompressed_bytes)
:rtype: tuple[int, int, bytes]
:raises DecompressError: on error
"""
if not data:
raise DecompressError("Decompression error (CS_E_IN_BUFFER_LEN: invalid input length)")
data = bytes(data)
if len(data) < _HDR_SIZE:
raise DecompressError("Decompression error (CS_E_IN_BUFFER_LEN: invalid input length)")
# Peek at the algorithm identifier without consuming
try:
hdr_length, alg_id, _, _ = _parse_header(data)
except DecompressError as exc:
raise DecompressError("Decompression error (%s)" % exc) from exc
try:
if alg_id == _HDR_ALG_LZC:
if hdr_length != out_length:
raise DecompressError(
"Decompression error (CS_E_OUT_BUFFER_LEN: invalid output length): "
"header says %d but caller expects %d" % (hdr_length, out_length)
)
out = _LZCDecompress.decompress_data(data, compat_mode=True)
elif alg_id == _HDR_ALG_LZH:
if hdr_length != out_length:
raise DecompressError(
"Decompression error (CS_E_OUT_BUFFER_LEN: invalid output length): "
"header says %d but caller expects %d" % (hdr_length, out_length)
)
out = _LZHDecompress.decompress_data(data)
else:
raise DecompressError(
"Decompression error (CS_E_UNKNOWN_ALG: unknown algorithm): "
"algorithm id 0x%x" % alg_id
)
except DecompressError:
raise
except Exception as exc:
raise DecompressError("Decompression error (%s)" % exc) from exc
if len(out) != out_length:
raise DecompressError(
"Decompression error (CS_E_OUT_BUFFER_LEN: invalid output length): "
"decoded %d but caller expects %d" % (len(out), out_length)
)
return _CS_END_OF_STREAM, len(out), out