我有以下数组:
array([8.1837177e-05, 1.0788739e-03, 4.4837892e-03, 3.4919381e-04, 7.6085329e-05, 7.6562166e-05, 5.3864717e-04, 5.4001808e-04, 3.3849746e-02, 2.9903650e-04], dtype = float32)
我想将其转换为:
array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0], dtype = float32)
我需要找到该行的最大值,将其替换为1.然后,将该行的其他9个值替换为0。
我需要为2D数组(一系列看起来像示例中的数组)完成此操作
x = array([1, 2, 3, 4])
x = np.where(x == max(x), 1, 0) # will replace max with 1 and the others with 0
这将创建:
array([0, 0, 0, 1])
对于2D阵列,您可以执行以下操作:
x = array([[0, 3, 4, 5],
[1, 2, 3, 1],
[6, 9, 1, 2]])
x = np.array([np.where(l == max(l), 1, 0) for l in x])
这将创建:
array([[0, 0, 0, 1],
[0, 0, 1, 0],
[0, 1, 0, 0]])`
将np.where
与max
结合使用:
a = np.array([8.1837177e-05, 1.0788739e-03, 4.4837892e-03, 3.4919381e-04, 7.6085329e-05, 7.6562166e-05, 5.3864717e-04, 5.4001808e-04, 3.3849746e-02, 2.9903650e-04])
np.where(a == a.max(), 1, 0)
输出:
array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0])
在2D情况下,我们取每行的最大值:
np.where(a == a.max(axis=1)[:, np.newaxis], 1, 0)
那就是说,我觉得keras
应该有一些内置的东西为你做这个...
你可以像这样使用列表理解:
x = [5,6,7,8,9]
y = [1 if num == max(x) else 0 for num in x]
这种方法需要两行,但它避免了将每个数组元素与max进行比较,并且在2D中运行良好。我不知道它会真的更快(当然不是渐近),但我认为两行比在python中执行for循环更好,可读性可能比使用np.where
更好。
import numpy as np
# here's your example input
# note - the input must be 2D even if there's just one row
# it's easy to adapt this to the 1D case, but you'll be working with 2D arrays for this anyway
class_probs = np.array([[
8.1837177e-05, 1.0788739e-03, 4.4837892e-03, 3.4919381e-04, 7.6085329e-05,
7.6562166e-05, 5.3864717e-04, 5.4001808e-04, 3.3849746e-02, 2.9903650e-04,
]])
pred_classes = np.zeros_like(class_probs)
pred_classes[range(len(class_probs)), class_probs.argmax(-1)] = 1
print(pred_classes) # [[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]]
# and here's showing the same code works for multiple rows
class_probs = np.random.rand(100, 10)
pred_classes = np.zeros_like(class_probs)
pred_classes[range(len(class_probs)), class_probs.argmax(-1)] = 1
pred_classes
(这不是你的实际问题,但是你的意思是使用sigmoid激活函数吗?而不是softmax?你得到的输出不是10个可能类的单一分布(你可以看到它甚至不是相反,你有10个分布,每个类一个(因此,输入为0级的概率是8.1837177e-05
,而不是0级的概率是1 - 8.1837177e-05
)。这在进行多标签分类时是有意义的(其中)可以应用多个标签),但是你不希望找到具有最高概率的类,你可以预测所有具有高于阈值的概率的类(例如0.5)。)