是否可以加快我的设置实施速度?

问题描述 投票:0回答:1

我正在尝试为 64 位无符号整数创建一个快速且节省空间的集合实现。 我不想使用 set(),因为它会将所有内容转换为 Python int,而每个 int 使用的空间远多于 8 个字节。 这是我的努力:

import numpy as np
from numba import njit

class HashSet:
    def __init__(self, capacity=1024):
        self.capacity = capacity
        self.size = 0
        self.EMPTY = np.uint64(0xFFFFFFFFFFFFFFFF)  # 2^64 - 1
        self.DELETED = np.uint64(0xFFFFFFFFFFFFFFFE)  # 2^64 - 2
        self.table = np.full(capacity, self.EMPTY)  # Initialize with a special value indicating empty

    def insert(self, key):
        if self.size >= self.capacity:
            raise RuntimeError("Hash table is full")
        if not self._insert(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
            print(f"Key already exists: {key}")
        else:
            self.size += 1

    def contains(self, key):
        return self._contains(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash)

    def remove(self, key):
        if self._remove(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
            self.size -= 1

    def __len__(self):
        return self.size

    @staticmethod
    @njit
    def _hash(key, capacity):
        return key % capacity

    @staticmethod
    @njit
    def _insert(table, key, capacity, EMPTY, DELETED, hash_func):
        index = hash_func(key, capacity)
        while table[index] != EMPTY and table[index] != DELETED and table[index] != key:
            index = (index + 1) % capacity

        if table[index] == key:
            return False  # Key already exists

        table[index] = key
        return True

    @staticmethod
    @njit
    def _contains(table, key, capacity, EMPTY, DELETED, hash_func):
        index = hash_func(key, capacity)
        while table[index] != EMPTY:
            if table[index] == key:
                return True
            index = (index + 1) % capacity
        return False

    @staticmethod
    @njit
    def _remove(table, key, capacity, EMPTY, DELETED, hash_func):
        index = hash_func(key, capacity)
        while table[index] != EMPTY:
            if table[index] == key:
                table[index] = DELETED
                return True
            index = (index + 1) % capacity
        return False

我尽可能使用 numba 来加快速度,但它仍然不是很快。例如:

hash_set = HashSet(capacity=204800)
keys = np.random.randint(0, 2**64, size=100000, dtype=np.uint64)
def insert_and_remove(hash_set, key):
    hash_set.insert(np.uint64(key))
    hash_set.remove(key)
%timeit insert_and_remove(hash_set, keys[0])

这给出:

16.9 μs ± 407 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

我认为主要原因是我未能用 numba 包装代码。

如何加快速度?

编辑

@ken 建议将 _hash 定义为类外的全局函数。这会加快速度,所以现在它只比 set() 慢 50%。

python performance numba
1个回答
0
投票

根据要求,这是课程,但使用

jitclass
。我不确定所有类型注释会增加多少价值。我一直在尝试看看是否可以得到任何改进。总体而言,您的原始代码的峰值性能为 20 μs。然而,下面的代码的峰值性能为 2.3 μs(快了一个数量级。但是,使用 python
set
再次快了一个数量级,为 0.34 μs。这些时间仅适用于您提供的测试工具。否完成了其他性能测试。

为了让您的代码与

jitclass
一起工作,我必须做的主要事情是:

  • 向类的属性添加类型注释和规范。
  • 将文字
    1
    转换为
    numba.uint64
    。如果没有这个,numba 会将表达式的类型提升为
    numba.float64
    ,即使本地具有
    uint64
    的类型注释。尝试使用浮点索引数组会导致整个编译步骤失败。
  • 从类的方法中删除所有
    njit
    装饰器。
    jitclass
    自动将
    njit
    应用于所有方法。如果任何类方法已经被 jit-ed,
    jitclass
    会出错。

您真正需要的唯一部分是:

@jitclass([('table', numba.uint64[:])])
class HashSet:
    capacity: numba.uint64
    size: numba.uint64
    table: np.ndarray

index = (index + numba.uint64(1)) % self.capacity

此外我还制作了

EMPTY
DELETED
全局常量。如果您有很多小型设备,这可以节省一点空间,但性能不会降低。对于 numba,它们确实是常量,而不仅仅是全局变量。

代码

import numpy as np
import numba
from numba.experimental import jitclass

EMPTY = numba.uint64(0xFFFFFFFFFFFFFFFF)  # 2^64 - 1
DELETED = numba.uint64(0xFFFFFFFFFFFFFFFE)  # 2^64 - 2


@jitclass([('table', numba.uint64[:])])
class HashSet:
    capacity: numba.uint64
    size: numba.uint64
    table: np.ndarray

    def __init__(self, capacity: int = 1024) -> None:
        self.capacity = capacity
        self.size = 0
        self.table = np.full(self.capacity, EMPTY) # Initialize with a special value indicating empty

    def __len__(self) -> int:
        return self.size

    @staticmethod
    def _hash(key: numba.uint64, capacity: numba.uint64) -> numba.uint64:
        return key % capacity

    def insert(self, key: numba.uint64) -> bool:
        if self.size >= self.capacity:
            raise RuntimeError("Hash table is full")

        index = self._hash(key, self.capacity)
        while self.table[index] != EMPTY and self.table[index] != DELETED and self.table[index] != key:
            index = (index + numba.uint64(1)) % self.capacity

        if self.table[index] == key:
            return False  # Key already exists

        self.table[index] = key
        self.size += 1
        return True

    def contains(self, key: numba.uint64) -> bool:
        index = self._hash(key, self.capacity)
        while self.table[index] != EMPTY:
            if self.table[index] == key:
                return True
            index = (index + numba.uint64(1)) % self.capacity
        return False

    def remove(self, key: numba.uint64) -> bool:
        index = self._hash(key, self.capacity)
        while self.table[index] != EMPTY:
            if self.table[index] == key:
                self.table[index] = DELETED
                self.size -= 1
                return True
            index = (index + numba.uint64(1)) % self.capacity
        return False
© www.soinside.com 2019 - 2024. All rights reserved.