霍夫曼实施减少减压时间

问题描述 投票:0回答:1

我正在努力提高这个霍夫曼的计算速度。对于小的输入十六进制字符串,它很好,但输入字符串越大,时间增量就会显着增加,字符串速度(下面的示例)会达到 x50 1ms vs 55ms+

import time
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple, Optional
import numpy as np
from array import array
import ctypes
from line_profiler._line_profiler import byteorder

class Node:
    __slots__ = ['char', 'freq', 'left', 'right']

    def __init__(self, char: str, freq: int, left=None, right=None):
        self.char = char
        self.freq = freq
        self.left = left
        self.right = right




class HybridLookupTable:
    
"""Hybrid approach combining direct lookup for short codes and binary search for long codes"""
    
__slots__ = ['short_table', 'long_codes', 'max_short_bits']

    def __init__(self, max_short_bits: int = 8):
        self.max_short_bits = max_short_bits
        self.short_table = [(None, 0)] * (1 << max_short_bits)  # Changed to tuple list for safety
        self.long_codes = {}

    def add_code(self, code: str, char: str) -> None:
        code_int = int(code, 2)
        code_len = len(code)

        if code_len <= self.max_short_bits:
            # For short codes, use lookup table with limited prefix expansion
            prefix_mask = (1 << (self.max_short_bits - code_len)) - 1
            base_index = code_int << (self.max_short_bits - code_len)
            for i in range(prefix_mask + 1):
                self.short_table[base_index | i] = (char, code_len)
        else:
            # For long codes, store in dictionary
            self.long_codes[code_int] = (char, code_len)

    def lookup(self, bits: int, length: int) -> Optional[Tuple[str, int]]:
        
"""Look up a bit pattern and return (character, code length) if found"""
        
if length <= self.max_short_bits:
            return self.short_table[bits & ((1 << self.max_short_bits) - 1)]

        # Try matching long codes
        for code_bits, (char, code_len) in self.long_codes.items():
            if code_len <= length:
                mask = (1 << code_len) - 1
                if (bits >> (length - code_len)) == (code_bits & mask):
                    return (char, code_len)
        return None
class BitBuffer:
    
"""Fast bit buffer implementation using ctypes"""
    
__slots__ = ['buffer', 'bits_in_buffer']

    def __init__(self):
        self.buffer = ctypes.c_uint64(0)
        self.bits_in_buffer = 0
    def add_byte(self, byte: int) -> None:
        self.buffer.value = (self.buffer.value << 8) | byte
        self.bits_in_buffer += 8
    def peek_bits(self, num_bits: int) -> int:
        return (self.buffer.value >> (self.bits_in_buffer - num_bits)) & ((1 << num_bits) - 1)

    def consume_bits(self, num_bits: int) -> None:
        self.buffer.value &= (1 << (self.bits_in_buffer - num_bits)) - 1
        self.bits_in_buffer -= num_bits


class ChunkDecoder:
    
"""Decoder for a chunk of compressed data"""
    
__slots__ = ['lookup_table', 'tree', 'chunk_size']

    def __init__(self, lookup_table, tree, chunk_size=1024):
        self.lookup_table = lookup_table
        self.tree = tree
        self.chunk_size = chunk_size

    def decode_chunk(self, data: memoryview, start_bit: int, end_bit: int) -> Tuple[List[str], int]:
        
"""Decode a chunk of bits and return (decoded_chars, bits_consumed)"""
        
result = []
        pos = start_bit
        buffer = BitBuffer()
        bytes_processed = start_bit >> 3
        bit_offset = start_bit & 7
        # Pre-fill buffer
        for _ in range(8):
            if bytes_processed < len(data):
                buffer.add_byte(data[bytes_processed])
                bytes_processed += 1
        # Skip initial bit offset
        if bit_offset:
            buffer.consume_bits(bit_offset)

        while pos < end_bit and buffer.bits_in_buffer >= 8:
            # Try lookup table first (optimized for 8-bit codes)
            lookup_bits = buffer.peek_bits(8)
            char_info = self.lookup_table.lookup(lookup_bits, 8)

            if char_info:
                char, code_len = char_info
                buffer.consume_bits(code_len)
                result.append(char)
                pos += code_len
            else:
                # Fall back to tree traversal
                node = self.tree
                while node.left and node.right and buffer.bits_in_buffer > 0:
                    bit = buffer.peek_bits(1)
                    buffer.consume_bits(1)
                    node = node.right if bit else node.left
                    pos += 1
                if not (node.left or node.right):
                    result.append(node.char)

            # Refill buffer if needed
            while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
                buffer.add_byte(data[bytes_processed])
                bytes_processed += 1
        return result, pos - start_bit


class OptimizedHuffmanDecoder:
    def __init__(self, num_threads=4, chunk_size=1024):
        self.tree = None
        self.freqs = {}
        self.lookup_table = HybridLookupTable()
        self.num_threads = num_threads
        self.chunk_size = chunk_size
        self._setup_lookup_tables()

    def _setup_lookup_tables(self):
        # Pre-calculate bit manipulation tables
        self.bit_masks = array('Q', [(1 << i) - 1 for i in range(65)])
        self.bit_shifts = array('B', [x & 7 for x in range(8)])

    def _build_efficient_tree(self) -> None:
        # Use list-based heap instead of sorting
        nodes = [(freq, i, Node(char, freq)) for i, (char, freq) in enumerate(self.freqs.items())]

        # Convert to min-heap
        nodes.sort(reverse=True)  # Sort once at the beginning
        while len(nodes) > 1:
            freq1, _, node1 = nodes.pop()
            freq2, _, node2 = nodes.pop()

            # Create parent node
            parent = Node(node1.char + node2.char, freq1 + freq2, node1, node2)
            nodes.append((freq1 + freq2, len(nodes), parent))
            nodes.sort(reverse=True)

        self.tree = nodes[0][2] if nodes else None
        self._build_codes(self.tree)

    def _build_codes(self, node: Node, code: str = '') -> None:
        
"""Build lookup table using depth-first traversal"""
        
if not node:
            return
        if not node.left and not node.right:
            if code:  # Never store empty codes
                self.lookup_table.add_code(code, node.char)
            return
        if node.left:
            self._build_codes(node.left, code + '0')
        if node.right:
            self._build_codes(node.right, code + '1')

    def _parse_header_fast(self, data: memoryview) -> int:
        
"""Optimized header parsing"""
        
pos = 12  # Skip first 12 bytes (file_len, always0, chars_count)
        chars_count = int.from_bytes(data[8:12], byteorder)

        # Pre-allocate dictionary space
        self.freqs = {}
        self.freqs.clear()

        # Process all characters in a single loop
        for _ in range(chars_count):
            count = int.from_bytes(data[pos:pos + 4], byteorder)
            char = chr(data[pos + 4])  # Faster than decode
            self.freqs[char] = count
            pos += 8
        return pos

    def _decode_bits_parallel(self, data: memoryview, total_bits: int) -> str:
        
"""Parallel decoding using multiple threads"""
        
chunk_bits = (total_bits + self.num_threads - 1) // self.num_threads
        chunks = []

        # Create chunks ensuring they align with byte boundaries when possible
        for i in range(0, total_bits, chunk_bits):
            end_bit = min(i + chunk_bits, total_bits)
            if i > 0:
                # Align to byte boundary when possible
                while (i & 7) != 0 and i > 0:
                    i -= 1
            chunks.append((i, end_bit))

        # Create decoders for each thread
        decoders = [
            ChunkDecoder(self.lookup_table, self.tree, self.chunk_size)
            for _ in range(len(chunks))
        ]

        # Process chunks in parallel
        with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
            futures = [
                executor.submit(decoder.decode_chunk, data, start, end)
                for decoder, (start, end) in zip(decoders, chunks)
            ]

            # Collect results
            results = []
            for future in futures:
                chunk_result, _ = future.result()
                results.extend(chunk_result)

        return ''.join(results)

    def _decode_bits_optimized(self, data: memoryview, total_bits: int) -> str:
        
"""Optimized single-threaded decoding for small inputs"""
        
if total_bits > self.chunk_size:
            return self._decode_bits_parallel(data, total_bits)

        result = []
        buffer = BitBuffer()
        pos = 0
        bytes_processed = 0
        # Pre-fill buffer
        while bytes_processed < min(8, len(data)):
            buffer.add_byte(data[bytes_processed])
            bytes_processed += 1
        while pos < total_bits:
            # Use lookup table for common patterns
            if buffer.bits_in_buffer >= 8:
                lookup_bits = buffer.peek_bits(8)
                char_info = self.lookup_table.lookup(lookup_bits, 8)

                if char_info:
                    char, code_len = char_info
                    buffer.consume_bits(code_len)
                    result.append(char)
                    pos += code_len
                else:
                    # Tree traversal for uncommon patterns
                    node = self.tree
                    while node.left and node.right and buffer.bits_in_buffer > 0:
                        bit = buffer.peek_bits(1)
                        buffer.consume_bits(1)
                        node = node.right if bit else node.left
                        pos += 1
                    if not (node.left or node.right):
                        result.append(node.char)

            # Refill buffer
            while buffer.bits_in_buffer <= 56 and bytes_processed < len(data):
                buffer.add_byte(data[bytes_processed])
                bytes_processed += 1
            if buffer.bits_in_buffer == 0:
                break
        return ''.join(result)

    def decode_hex(self, hex_string: str) -> str:
        # Use numpy for faster hex decoding
        clean_hex = hex_string.replace(' ', '')
        data = np.frombuffer(bytes.fromhex(clean_hex), dtype=np.uint8)
        return self.decode_bytes(data.tobytes())

    def decode_bytes(self, data: bytes) -> str:
        view = memoryview(data)
        pos = self._parse_header_fast(view)

        self._build_efficient_tree()

        # Get packed data info using numpy for faster parsing
        header = np.frombuffer(data[pos:pos + 12], dtype=np.uint32)
        packed_bits = int(header[0])
        packed_bytes = int(header[1])
        pos += 12
        # Choose decoding method based on size
        if packed_bits > self.chunk_size:
            return self._decode_bits_parallel(view[pos:pos + packed_bytes], packed_bits)
        else:
            return self._decode_bits_optimized(view[pos:pos + packed_bytes], packed_bits)

    def encode(self, text: str) -> bytes:
        
"""Encode text using Huffman coding - for testing purposes"""
        
# Count frequencies
        self.freqs = {}
        for char in text:
            self.freqs[char] = self.freqs.get(char, 0) + 1
        # Build tree and codes
        self._build_efficient_tree()

        # Convert text to bits
        bits = []
        for char in text:
            code = self.lookup_table.get_code(char)
            bits.extend(code)

        # Pack bits into bytes
        packed_bytes = []
        for i in range(0, len(bits), 8):
            byte = 0
            for j in range(min(8, len(bits) - i)):
                if bits[i + j]:
                    byte |= 1 << (7 - j)
            packed_bytes.append(byte)

        # Create header
        header = bytearray()
        header.extend(len(text).to_bytes(4, byteorder))
        header.extend(b'\x00' * 4)  # always0
        header.extend(len(self.freqs).to_bytes(4, byteorder))

        # Add frequency table
        for char, freq in self.freqs.items():
            header.extend(freq.to_bytes(4, byteorder))
            header.extend(char.encode('ascii'))
            header.extend(b'\x00' * 3)  # padding
        # Add packed data info
        header.extend(len(bits).to_bytes(4, byteorder))
        header.extend(len(packed_bytes).to_bytes(4, byteorder))
        header.extend(b'\x00' * 4)  # unpacked_bytes
        # Combine header and packed data
        return bytes(header + bytes(packed_bytes))

if __name__ == '__main__':
    # Create decoder with custom settings
    decoder = OptimizedHuffmanDecoder(
        num_threads=4,  # Number of threads for parallel processing
        chunk_size=1024  # Minimum size for parallel processing
    )

    test_hex = 'A7 64 00 00 00 00 00 00 0C 00 00 00 38 25 00 00 2D 00 00 00 08 69 00 00 30 00 00 00 2E 13 00 00 31 00 00 00 D4 13 00 00 32 00 00 00 0F 0D 00 00 33 00 00 00 78 08 00 00 34 00 00 00 A4 0A 00 00 35 00 00 00 63 0E 00 00 36 00 00 00 AC 09 00 00 37 00 00 00 D0 07 00 00 38 00 00 00 4D 09 00 00 39 00 00 00 68 0C 00 00 7C 00 00 00 73 21 03 00 2F 64 00 00 01 0B 01 00 C9 63 2A C7 21 77 40 77 25 8D AB E9 E5 E7 80 77'
    start_time = time.perf_counter()
    # Decode data
    result = decoder.decode_hex(test_hex)
    execution_time_ms = (time.perf_counter() - start_time) * 1000  # Convert to milliseconds
    print(f"\nTotal execution time: {execution_time_ms:.2f} milliseconds")
    print(result)

预期输出: 总执行时间:1.04毫秒 19101-0-418-220000000|19102-0-371-530000000

但是如果你尝试使用更大的字符串,它会变得非常慢,我想提高性能,我尝试使用 cythoning 但没有以任何方式改进它,如果有人知道我可能做错了什么 使用第二个十六进制输入需要 55ms

更大的十六进制输入示例 文字

我想知道我是否做错了什么,有什么方法可以加快这个过程,我尝试了好几个小时想到的一切,但我不知道如何进一步改进。

python numpy huffman-code huffman-tree
1个回答
0
投票

我想提高性能

对于几乎每个性能问题,答案是:

介绍您的程序

如果您还没有进行分析,您就无法确定速度缓慢的原因。如果你不知道什么慢,你只能猜测如何让它更快。

Python 内置了 分析工具。尝试一下,和/或使用

timeit
进行微基准测试。

还要考虑将您不想分析的内容(例如十六进制字符串转换)移到您正在计时的部分之外。

最后,您很可能能够在 C、C++、Rust 或其他编译语言中获得更好的性能 - 但您也需要学习如何分析这些语言,以充分利用它们。

© www.soinside.com 2019 - 2024. All rights reserved.