我有以下格式的字典:
{0: array([-1.37979662, 1.2237947 , 1.02293956, 2.84491658]),
1: array([-1.32019091, 1.17396212, 1.01325119, 2.89558077]),
2: array([-1.29436374, 0.93597102, 1.06104517, 2.92670774]),
3: array([-1.24879849, 1.04383302, 1.06866074, 2.41867709]),
4: array([-1.1299237 , 0.72485214, 1.04738796, 2.16609311]),
5: array([-1.08398485, 0.96394932, 1.03896677, 2.34082866]),
6: array([-1.24153984, 0.82464176, 1.08445227, 2.6564374 ]),
7: array([-1.04296362, 0.52683467, 1.10769773, 2.32662654]),
8: array([-1.34813309, 0.76031429, 1.01582122, 2.60977459]),
9: array([-1.20303226, 0.79573596, 1.03138351, 2.41515303])}
任何提示,谢谢您的努力。
data[9]
以符合此条件。dict comprehension
满足仅包含dict comprehension
对的条件,其中key: value
中的所有值均为np.array
。>1
是内置的python函数.all()
-> .all()
np.array([ True, True, True, True]).all()
-> True
np.array([ False, True, True, True]).all()
False
假设每个值是一个import numpy as np
data = {0: np.array([-1.37979662, 1.2237947 , 1.02293956, 2.84491658]),
1: np.array([-1.32019091, 1.17396212, 1.01325119, 2.89558077]),
2: np.array([-1.29436374, 0.93597102, 1.06104517, 2.92670774]),
3: np.array([-1.24879849, 1.04383302, 1.06866074, 2.41867709]),
4: np.array([-1.1299237 , 0.72485214, 1.04738796, 2.16609311]),
5: np.array([-1.08398485, 0.96394932, 1.03896677, 2.34082866]),
6: np.array([-1.24153984, 0.82464176, 1.08445227, 2.6564374 ]),
7: np.array([-1.04296362, 0.52683467, 1.10769773, 2.32662654]),
8: np.array([1.34813309, 0.76031429, 1.01582122, -2.60977459]),
9: np.array([1.20303226, 1.79573596, 1.03138351, 2.41515303])}
# dict comprehension
data1 = {k: v for k, v in data.items() if (v > 1.0).all()}
print(data1)
>>> {9: array([1.20303226, 1.79573596, 1.03138351, 2.41515303])}
数组,因为基本操作是矢量化的,所以可以立即检查整个数组的条件。
data2 = {k: v for k, v in data.items() if v[0] > 1 if v[1] > 0 if v[2] > 1 if v[3] < -2}
print(data2)
>>> {8: array([ 1.34813309, 0.76031429, 1.01582122, -2.60977459])}
当您执行numpy
时,它返回一个布尔数组,我们可以求和。如果此布尔数组的总和为数组的长度,则我们知道每个项目都满足所陈述的条件。我使用{k: v for k, v in data.items() if sum(v > -1.1) == len(v)}
{5: array([-1.08398485, 0.96394932, 1.03896677, 2.34082866]),
7: array([-1.04296362, 0.52683467, 1.10769773, 2.32662654])}
来表明它可以正常工作,因为该词典中的值均不符合array > value
标准。
您也可以使用-1.1
,因为您有一个numpy数组:
> 1