我正在尝试为 64 位无符号整数创建一个快速且节省空间的集合实现。 我不想使用 set(),因为它会将所有内容转换为 Python int,而每个 int 使用的空间远多于 8 个字节。 这是我的努力:
import numpy as np
from numba import njit
class HashSet:
def __init__(self, capacity=1024):
self.capacity = capacity
self.size = 0
self.EMPTY = np.uint64(0xFFFFFFFFFFFFFFFF) # 2^64 - 1
self.DELETED = np.uint64(0xFFFFFFFFFFFFFFFE) # 2^64 - 2
self.table = np.full(capacity, self.EMPTY) # Initialize with a special value indicating empty
def insert(self, key):
if self.size >= self.capacity:
raise RuntimeError("Hash table is full")
if not self._insert(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
print(f"Key already exists: {key}")
else:
self.size += 1
def contains(self, key):
return self._contains(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash)
def remove(self, key):
if self._remove(self.table, key, self.capacity, self.EMPTY, self.DELETED, self._hash):
self.size -= 1
def __len__(self):
return self.size
@staticmethod
@njit
def _hash(key, capacity):
return key % capacity
@staticmethod
@njit
def _insert(table, key, capacity, EMPTY, DELETED, hash_func):
index = hash_func(key, capacity)
while table[index] != EMPTY and table[index] != DELETED and table[index] != key:
index = (index + 1) % capacity
if table[index] == key:
return False # Key already exists
table[index] = key
return True
@staticmethod
@njit
def _contains(table, key, capacity, EMPTY, DELETED, hash_func):
index = hash_func(key, capacity)
while table[index] != EMPTY:
if table[index] == key:
return True
index = (index + 1) % capacity
return False
@staticmethod
@njit
def _remove(table, key, capacity, EMPTY, DELETED, hash_func):
index = hash_func(key, capacity)
while table[index] != EMPTY:
if table[index] == key:
table[index] = DELETED
return True
index = (index + 1) % capacity
return False
我尽可能使用 numba 来加快速度,但它仍然不是很快。例如:
hash_set = HashSet(capacity=204800)
keys = np.random.randint(0, 2**64, size=100000, dtype=np.uint64)
def insert_and_remove(hash_set, key):
hash_set.insert(np.uint64(key))
hash_set.remove(key)
%timeit insert_and_remove(hash_set, keys[0])
这给出:
16.9 μs ± 407 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
我认为主要原因是我未能用 numba 包装代码。
如何加快速度?
编辑
@ken 建议将 _hash 定义为类外的全局函数。这会加快速度,所以现在它只比 set() 慢 50%。
根据要求,这是课程,但使用
jitclass
。我不确定所有类型注释会增加多少价值。我一直在尝试看看是否可以得到任何改进。总体而言,您的原始代码的峰值性能为 20 μs。然而,下面的代码的峰值性能为 2.3 μs(快了一个数量级。但是,使用 python set
再次快了一个数量级,为 0.34 μs。这些时间仅适用于您提供的测试工具。否完成了其他性能测试。
为了让您的代码与
jitclass
一起工作,我必须做的主要事情是:
1
转换为 numba.uint64
。如果没有这个,numba 会将表达式的类型提升为 numba.float64
,即使本地具有 uint64
的类型注释。尝试使用浮点索引数组会导致整个编译步骤失败。njit
装饰器。 jitclass
自动将 njit
应用于所有方法。如果任何类方法已经被 jit-ed,jitclass
会出错。您真正需要的唯一部分是:
@jitclass([('table', numba.uint64[:])])
class HashSet:
capacity: numba.uint64
size: numba.uint64
table: np.ndarray
和
index = (index + numba.uint64(1)) % self.capacity
此外我还制作了
EMPTY
和 DELETED
全局常量。如果您有很多小型设备,这可以节省一点空间,但性能不会降低。对于 numba,它们确实是常量,而不仅仅是全局变量。
import numpy as np
import numba
from numba.experimental import jitclass
EMPTY = numba.uint64(0xFFFFFFFFFFFFFFFF) # 2^64 - 1
DELETED = numba.uint64(0xFFFFFFFFFFFFFFFE) # 2^64 - 2
@jitclass([('table', numba.uint64[:])])
class HashSet:
capacity: numba.uint64
size: numba.uint64
table: np.ndarray
def __init__(self, capacity: int = 1024) -> None:
self.capacity = capacity
self.size = 0
self.table = np.full(self.capacity, EMPTY) # Initialize with a special value indicating empty
def __len__(self) -> int:
return self.size
@staticmethod
def _hash(key: numba.uint64, capacity: numba.uint64) -> numba.uint64:
return key % capacity
def insert(self, key: numba.uint64) -> bool:
if self.size >= self.capacity:
raise RuntimeError("Hash table is full")
index = self._hash(key, self.capacity)
while self.table[index] != EMPTY and self.table[index] != DELETED and self.table[index] != key:
index = (index + numba.uint64(1)) % self.capacity
if self.table[index] == key:
return False # Key already exists
self.table[index] = key
self.size += 1
return True
def contains(self, key: numba.uint64) -> bool:
index = self._hash(key, self.capacity)
while self.table[index] != EMPTY:
if self.table[index] == key:
return True
index = (index + numba.uint64(1)) % self.capacity
return False
def remove(self, key: numba.uint64) -> bool:
index = self._hash(key, self.capacity)
while self.table[index] != EMPTY:
if self.table[index] == key:
self.table[index] = DELETED
self.size -= 1
return True
index = (index + numba.uint64(1)) % self.capacity
return False