我有一个120 GB的文件保存(通过pickle
二进制),包含大约50,000(600x600)2d numpy数组。我需要使用中位数来堆叠所有这些数组。最简单的方法是将整个文件作为数组列表读取并使用np.median(arrays, axis=0)
。但是,我没有太多的RAM可以使用,所以这不是一个好的选择。
因此,我尝试逐个像素地堆叠它们,就像我一次关注一个像素位置(i, j)
,然后逐个读取每个数组,将给定位置的值附加到列表中。一旦保存了所有数组中某个位置的所有值,我就使用np.median
,然后只需要将该值保存在列表中 - 最后将具有每个像素位置的中位数。最后我可以将其重塑为600x600,我会完成的。这个代码如下。
import pickle
import time
import numpy as np
filename = 'images.dat' #contains my 50,000 2D numpy arrays
def stack_by_pixel(i, j):
pixels_at_position = []
with open(filename, 'rb') as f:
while True:
try:
# Gather pixels at a given position
array = pickle.load(f)
pixels_at_position.append(array[i][j])
except EOFError:
break
# Stacking at position (median)
stacked_at_position = np.median(np.array(pixels_at_position))
return stacked_at_position
# Form whole stacked image
stacked = []
for i in range(600):
for j in range(600):
t1 = time.time()
stacked.append(stack_by_pixel(i, j))
t2 = time.time()
print('Done with element %d, %d: %f seconds' % (i, j, (t2-t1)))
stacked_image = np.reshape(stacked, (600,600))
看到一些打印输出后,我意识到这是非常低效的。每次完成一个位置(i, j)
大约需要150秒左右,这并不奇怪,因为它一个接一个地读取大约50,000个阵列。鉴于我的大阵列中有360,000个(i, j)
位置,预计需要22个月才能完成!显然这是不可行的。但我有点不知所措,因为没有足够的RAM可供读取整个文件。或者也许我可以一次性保存所有像素位置(每个位置的单独列表),因为它逐个打开它们,但不会在Python中保存360,000个列表(大约50,000个元素长)使用了很多RAM也是?
欢迎提出任何建议,如果不使用大量RAM,我可以大大加快运行速度。谢谢!
注意:我使用Python 2.x,将其移植到3.x应该不难。
我的想法很简单 - 磁盘空间很丰富,所以让我们做一些预处理并将那个大的pickle文件转换成更容易以小块处理的东西。
为了测试这个,我写了一个小脚本,生成一个类似于你的pickle文件。我假设您的输入图像是灰度级的,具有8位深度,并使用numpy.random.randint
生成10000个随机图像。
该脚本将作为基准,我们可以比较预处理和处理阶段。
import numpy as np
import pickle
import time
IMAGE_WIDTH = 600
IMAGE_HEIGHT = 600
FILE_COUNT = 10000
t1 = time.time()
with open('data/raw_data.pickle', 'wb') as f:
for i in range(FILE_COUNT):
data = np.random.randint(256, size=IMAGE_WIDTH*IMAGE_HEIGHT, dtype=np.uint8)
data = data.reshape(IMAGE_HEIGHT, IMAGE_WIDTH)
pickle.dump(data, f)
print i,
t2 = time.time()
print '\nDone in %0.3f seconds' % (t2 - t1)
在测试运行中,此脚本在372秒内完成,生成~10 GB的文件。
让我们逐行分割输入图像 - 我们将有600个文件,其中文件N
包含来自每个输入图像的行N
。我们可以使用numpy.ndarray.tofile
以二进制形式存储行数据(稍后使用numpy.fromfile
加载这些文件)。
import numpy as np
import pickle
import time
# Increase open file limit
# See https://stackoverflow.com/questions/6774724/why-python-has-limit-for-count-of-file-handles
import win32file
win32file._setmaxstdio(1024)
IMAGE_WIDTH = 600
IMAGE_HEIGHT = 600
FILE_COUNT = 10000
t1 = time.time()
outfiles = []
for i in range(IMAGE_HEIGHT):
outfilename = 'data/row_%03d.dat' % i
outfiles.append(open(outfilename, 'wb'))
with open('data/raw_data.pickle', 'rb') as f:
for i in range(FILE_COUNT):
data = pickle.load(f)
for j in range(IMAGE_HEIGHT):
data[j].tofile(outfiles[j])
print i,
for i in range(IMAGE_HEIGHT):
outfiles[i].close()
t2 = time.time()
print '\nDone in %0.3f seconds' % (t2 - t1)
在测试运行中,此脚本在134秒内完成,生成600个文件,每个文件600万个字节。它使用~30MB或RAM。
很简单,只需使用numpy.fromfile
加载每个数组,然后使用numpy.median
获取每列中值,将其减少回单行,并在列表中累积这些行。
最后,使用numpy.vstack
重新组合中值图像。
import numpy as np
import time
IMAGE_WIDTH = 600
IMAGE_HEIGHT = 600
t1 = time.time()
result_rows = []
for i in range(IMAGE_HEIGHT):
outfilename = 'data/row_%03d.dat' % i
data = np.fromfile(outfilename, dtype=np.uint8).reshape(-1, IMAGE_WIDTH)
median_row = np.median(data, axis=0)
result_rows.append(median_row)
print i,
result = np.vstack(result_rows)
print result
t2 = time.time()
print '\nDone in %0.3f seconds' % (t2 - t1)
在测试运行中,此脚本在74秒内完成。你甚至可以很容易地将它并行化,但它似乎并不值得。该脚本使用~40MB的RAM。
鉴于这两个脚本都是线性的,所用的时间也应该线性扩展。对于50000张图像,预处理大约需要11分钟,最终处理大约需要6分钟。这是在i7-4930K @ 3.4GHz上,故意使用32位Python。
这是numpy的memory mapped arrays的完美用例。内存映射数组允许您将磁盘上的.npy
文件视为一个numpy数组,而不是实际加载它。这很简单
arr = np.load('filename', mmap_mode='r')
在大多数情况下,您可以将其视为任何其他阵列。数组元素仅根据需要加载到内存中。不幸的是,一些快速实验表明median
不能很好地处理memmory映射数组*,它似乎仍然会将大部分数据同时加载到内存中。所以median(arr, 0)
可能无法正常工作。
但是,您仍然可以遍历每个索引并计算中位数而不会遇到内存问题。
[[np.median([arr[k][i][j] for k in range(50000)]) for i in range(600)] for j in range(600)]
其中50,000反映了阵列的总数。
如果没有解开每个文件的开销只是为了提取单个像素,那么运行时间应该快得多(大约360000次)。
当然,这留下了创建包含所有数据的.npy
文件的问题。可以按如下方式创建文件,
arr = np.lib.format.open_memmap(
'filename', # File to store in
mode='w+', # Specify to create the file and write to it
dtype=float32, # Change this to your data's type
shape=(50000, 600, 600) # Shape of resulting array
)
然后,像以前一样加载数据并将其存储到数组中(它只是在后台将其写入磁盘)。
idx = 0
with open(filename, 'rb') as f:
while True:
try:
arr[idx] = pickle.load(f)
idx += 1
except EOFError:
break
给它几个小时的运行时间,然后回到这个答案的开头,看看如何加载它并取中位数。不能再简单了**。
*我刚刚在一个7GB文件上测试它,取5,000个元素的1500个样本的中位数,内存使用量约为7GB,这表明整个阵列可能已加载到内存中。首先尝试这种方式并没有什么坏处。如果其他人有关于memmapped数组的中位数的经验,请随意发表评论。
**如果你相信互联网上的陌生人。