我正在努力提高这个霍夫曼的计算速度。对于小的输入十六进制字符串,它很好,但输入字符串越大,时间增量就会显着增加,字符串速度(下面的示例)会达到 x50 1ms vs 55ms+
import time
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple, Optional
import numpy as np
from array import array
import ctypes
from line_profiler._line_profiler import byteorder
class Node:
__slots__ = ['char', 'freq', 'left', 'right']
def __init__(self, char: str, freq: int, left=None, right=None):
self.char = char
self.freq = freq
self.left = left
self.right = right
class HybridLookupTable:
"""Hybrid approach combining direct lookup for short codes and binary search for long codes"""
__slots__ = ['short_table', 'long_codes', 'max_short_bits']
def __init__(self, max_short_bits: int = 8):
self.max_short_bits = max_short_bits
self.short_table = [(None, 0)] * (1 << max_short_bits) # Changed to tuple list for safety
self.long_codes = {}
def add_code(self, code: str, char: str) -> None:
code_int = int(code, 2)
code_len = len(code)
if code_len <= self.max_short_bits:
# For short codes, use lookup table with limited prefix expansion
prefix_mask = (1 << (self.max_short_bits - code_len)) - 1
base_index = code_int << (self.max_short_bits - code_len)
for i in range(prefix_mask + 1):
self.short_table[base_index | i] = (char, code_len)
else:
# For long codes, store in dictionary
self.long_codes[code_int] = (char, code_len)
def lookup(self, bits: int, length: int) -> Optional[Tuple[str, int]]:
"""Look up a bit pattern and return (character, code length) if found"""
if length <= self.max_short_bits:
return self.short_table[bits & ((1 << self.max_short_bits) - 1)]
# Try matching long codes
for code_bits, (char, code_len) in self.long_codes.items():
if code_len <= length:
mask = (1 << code_len) - 1
if (bits >> (length - code_len)) == (code_bits & mask):
return (char, code_len)
return None
class BitBuffer:
"""Fast bit buffer implementation using ctypes"""
__slots__ = ['buffer', 'bits_in_buffer']
def __init__(self):
self.buffer = ctypes.c_uint64(0)
self.bits_in_buffer = 0
def add_byte(self, byte: int) -> None:
self.buffer.value = (self.buffer.value << 8) | byte
self.bits_in_buffer += 8
def peek_bits(self, num_bits: int) -> int:
return (self.buffer.value >> (self.bits_in_buffer - num_bits)) & ((1 << num_bits) - 1)
def consume_bits(self, num_bits: int) -> None:
self.buffer.value &= (1 << (self.bits_in_buffer - num_bits)) - 1
self.bits_in_buffer -= num_bits
class ChunkDecoder:
"""Decoder for a chunk of compressed data"""
__slots__ = ['lookup_table', 'tree', 'chunk_size']
def __init__(self, lookup_table, tree, chunk_size=1024):
self.lookup_table = lookup_table
self.tree = tree
self.chunk_size = chunk_size
def decode_chunk(self, data: memoryview, start_bit: int, end_bit: int) -> Tuple[List[str], int]:
"""Decode a chunk of bits and return (decoded_chars, bits_consumed)"""
result = []
pos = start_bit
buffer = BitBuffer()
bytes_processed = start_bit >> 3
bit_offset = start_bit & 7
# Pre-fill buffer
for _ in range(8):
if bytes_processed < len(data):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
# Skip initial bit offset
if bit_offset:
buffer.consume_bits(bit_offset)
while pos < end_bit and buffer.bits_in_buffer >= 8:
# Try lookup table first (optimized for 8-bit codes)
lookup_bits = buffer.peek_bits(8)
char_info = self.lookup_table.lookup(lookup_bits, 8)
if char_info:
char, code_len = char_info
buffer.consume_bits(code_len)
result.append(char)
pos += code_len
else:
# Fall back to tree traversal
node = self.tree
while node.left and node.right and buffer.bits_in_buffer > 0:
bit = buffer.peek_bits(1)
buffer.consume_bits(1)
node = node.right if bit else node.left
pos += 1
if not (node.left or node.right):
result.append(node.char)
# Refill buffer if needed
while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
return result, pos - start_bit
class OptimizedHuffmanDecoder:
def __init__(self, num_threads=4, chunk_size=1024):
self.tree = None
self.freqs = {}
self.lookup_table = HybridLookupTable()
self.num_threads = num_threads
self.chunk_size = chunk_size
self._setup_lookup_tables()
def _setup_lookup_tables(self):
# Pre-calculate bit manipulation tables
self.bit_masks = array('Q', [(1 << i) - 1 for i in range(65)])
self.bit_shifts = array('B', [x & 7 for x in range(8)])
def _build_efficient_tree(self) -> None:
# Use list-based heap instead of sorting
nodes = [(freq, i, Node(char, freq)) for i, (char, freq) in enumerate(self.freqs.items())]
# Convert to min-heap
nodes.sort(reverse=True) # Sort once at the beginning
while len(nodes) > 1:
freq1, _, node1 = nodes.pop()
freq2, _, node2 = nodes.pop()
# Create parent node
parent = Node(node1.char + node2.char, freq1 + freq2, node1, node2)
nodes.append((freq1 + freq2, len(nodes), parent))
nodes.sort(reverse=True)
self.tree = nodes[0][2] if nodes else None
self._build_codes(self.tree)
def _build_codes(self, node: Node, code: str = '') -> None:
"""Build lookup table using depth-first traversal"""
if not node:
return
if not node.left and not node.right:
if code: # Never store empty codes
self.lookup_table.add_code(code, node.char)
return
if node.left:
self._build_codes(node.left, code + '0')
if node.right:
self._build_codes(node.right, code + '1')
def _parse_header_fast(self, data: memoryview) -> int:
"""Optimized header parsing"""
pos = 12 # Skip first 12 bytes (file_len, always0, chars_count)
chars_count = int.from_bytes(data[8:12], byteorder)
# Pre-allocate dictionary space
self.freqs = {}
self.freqs.clear()
# Process all characters in a single loop
for _ in range(chars_count):
count = int.from_bytes(data[pos:pos + 4], byteorder)
char = chr(data[pos + 4]) # Faster than decode
self.freqs[char] = count
pos += 8
return pos
def _decode_bits_parallel(self, data: memoryview, total_bits: int) -> str:
"""Parallel decoding using multiple threads"""
chunk_bits = (total_bits + self.num_threads - 1) // self.num_threads
chunks = []
# Create chunks ensuring they align with byte boundaries when possible
for i in range(0, total_bits, chunk_bits):
end_bit = min(i + chunk_bits, total_bits)
if i > 0:
# Align to byte boundary when possible
while (i & 7) != 0 and i > 0:
i -= 1
chunks.append((i, end_bit))
# Create decoders for each thread
decoders = [
ChunkDecoder(self.lookup_table, self.tree, self.chunk_size)
for _ in range(len(chunks))
]
# Process chunks in parallel
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = [
executor.submit(decoder.decode_chunk, data, start, end)
for decoder, (start, end) in zip(decoders, chunks)
]
# Collect results
results = []
for future in futures:
chunk_result, _ = future.result()
results.extend(chunk_result)
return ''.join(results)
def _decode_bits_optimized(self, data: memoryview, total_bits: int) -> str:
"""Optimized single-threaded decoding for small inputs"""
if total_bits > self.chunk_size:
return self._decode_bits_parallel(data, total_bits)
result = []
buffer = BitBuffer()
pos = 0
bytes_processed = 0
# Pre-fill buffer
while bytes_processed < min(8, len(data)):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
while pos < total_bits:
# Use lookup table for common patterns
if buffer.bits_in_buffer >= 8:
lookup_bits = buffer.peek_bits(8)
char_info = self.lookup_table.lookup(lookup_bits, 8)
if char_info:
char, code_len = char_info
buffer.consume_bits(code_len)
result.append(char)
pos += code_len
else:
# Tree traversal for uncommon patterns
node = self.tree
while node.left and node.right and buffer.bits_in_buffer > 0:
bit = buffer.peek_bits(1)
buffer.consume_bits(1)
node = node.right if bit else node.left
pos += 1
if not (node.left or node.right):
result.append(node.char)
# Refill buffer
while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
buffer.add_byte(data[bytes_processed])
bytes_processed += 1
if buffer.bits_in_buffer == 0:
break
return ''.join(result)
def decode_hex(self, hex_string: str) -> str:
# Use numpy for faster hex decoding
clean_hex = hex_string.replace(' ', '')
data = np.frombuffer(bytes.fromhex(clean_hex), dtype=np.uint8)
return self.decode_bytes(data.tobytes())
def decode_bytes(self, data: bytes) -> str:
view = memoryview(data)
pos = self._parse_header_fast(view)
self._build_efficient_tree()
# Get packed data info using numpy for faster parsing
header = np.frombuffer(data[pos:pos + 12], dtype=np.uint32)
packed_bits = int(header[0])
packed_bytes = int(header[1])
pos += 12
# Choose decoding method based on size
if packed_bits > self.chunk_size:
return self._decode_bits_parallel(view[pos:pos + packed_bytes], packed_bits)
else:
return self._decode_bits_optimized(view[pos:pos + packed_bytes], packed_bits)
def encode(self, text: str) -> bytes:
"""Encode text using Huffman coding - for testing purposes"""
# Count frequencies
self.freqs = {}
for char in text:
self.freqs[char] = self.freqs.get(char, 0) + 1
# Build tree and codes
self._build_efficient_tree()
# Convert text to bits
bits = []
for char in text:
code = self.lookup_table.get_code(char)
bits.extend(code)
# Pack bits into bytes
packed_bytes = []
for i in range(0, len(bits), 8):
byte = 0
for j in range(min(8, len(bits) - i)):
if bits[i + j]:
byte |= 1 << (7 - j)
packed_bytes.append(byte)
# Create header
header = bytearray()
header.extend(len(text).to_bytes(4, byteorder))
header.extend(b'\x00' * 4) # always0
header.extend(len(self.freqs).to_bytes(4, byteorder))
# Add frequency table
for char, freq in self.freqs.items():
header.extend(freq.to_bytes(4, byteorder))
header.extend(char.encode('ascii'))
header.extend(b'\x00' * 3) # padding
# Add packed data info
header.extend(len(bits).to_bytes(4, byteorder))
header.extend(len(packed_bytes).to_bytes(4, byteorder))
header.extend(b'\x00' * 4) # unpacked_bytes
# Combine header and packed data
return bytes(header + bytes(packed_bytes))
if __name__ == '__main__':
# Create decoder with custom settings
decoder = OptimizedHuffmanDecoder(
num_threads=4, # Number of threads for parallel processing
chunk_size=1024 # Minimum size for parallel processing
)
test_hex = 'A7 64 00 00 00 00 00 00 0C 00 00 00 38 25 00 00 2D 00 00 00 08 69 00 00 30 00 00 00 2E 13 00 00 31 00 00 00 D4 13 00 00 32 00 00 00 0F 0D 00 00 33 00 00 00 78 08 00 00 34 00 00 00 A4 0A 00 00 35 00 00 00 63 0E 00 00 36 00 00 00 AC 09 00 00 37 00 00 00 D0 07 00 00 38 00 00 00 4D 09 00 00 39 00 00 00 68 0C 00 00 7C 00 00 00 73 21 03 00 2F 64 00 00 01 0B 01 00 C9 63 2A C7 21 77 40 77 25 8D AB E9 E5 E7 80 77'
start_time = time.perf_counter()
# Decode data
result = decoder.decode_hex(test_hex)
execution_time_ms = (time.perf_counter() - start_time) * 1000 # Convert to milliseconds
print(f"\nTotal execution time: {execution_time_ms:.2f} milliseconds")
print(result)
预期输出: 总执行时间:1.04毫秒 19101-0-418-220000000|19102-0-371-530000000
但是如果你尝试使用更大的字符串,它会变得非常慢,我想提高性能,我尝试使用 cythoning 但没有以任何方式改进它,如果有人知道我可能做错了什么 使用第二个十六进制输入需要 55ms
更大的十六进制输入示例 文字
我想知道我是否做错了什么,有什么方法可以加快这个过程,我尝试了好几个小时想到的一切,但我不知道如何进一步改进。
我想提高性能
对于几乎每个性能问题,答案是:
如果您还没有进行分析,您就无法确定速度缓慢的原因。如果你不知道什么慢,你只能猜测如何让它更快。
Python 内置了 分析工具。尝试一下,和/或使用
timeit
进行微基准测试。
还要考虑将您不想分析的内容(例如十六进制字符串转换)移到您正在计时的部分之外。
最后,您很可能能够在 C、C++、Rust 或其他编译语言中获得更好的性能 - 但您也需要学习如何分析这些语言,以充分利用它们。