我正在尝试实现一个 softmax 函数,它接受有符号 int8 输入并返回有符号 int8 输出数组。
我目前正在进行的实施是这样的,
import numpy as np
def softmax_int8(inputs):
inputs = np.array(inputs, dtype=np.int8)
x = inputs.astype(np.int32)
x_max = np.max(x)
x_shifted = x - x_max
scale_factor = 2 ** 14
exp_limit = 16
exp_x = np.clip(x_shifted + exp_limit, 0, None)
exp_x = (1 << exp_x)
sum_exp_x = np.sum(exp_x)
if sum_exp_x == 0:
sum_exp_x = 1
softmax_probs = (exp_x * scale_factor) // sum_exp_x
max_prob = np.max(softmax_probs)
min_prob = np.min(softmax_probs)
range_prob = max_prob - min_prob if max_prob != min_prob else 1
scaled_probs = ((softmax_probs - min_prob) * 255) // range_prob - 128
outputs = scaled_probs.astype(np.int8)
return outputs
我使用这个输入来测试它,
Input = [101, 49, 6, -34, -75, -79, -38, 120, -55, 115]
但是我得到了这个输出
array([-128, -128, -128, -128, -128, -128, -128, 127, -128, -121],dtype=int8)
。
我的预期产出是
array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=int8)
。
我在这里做错了什么以及如何解决它?
我认为在不同的上下文中,softmax 有不同的数学定义。
exp(z) / sum(exp(z))
(1<<(z-z_max + 16)) / sum((1 << (z-z_max + 16)))
或类似的东西。 1<<
=== 2**
显然。主要区别在于指数的底数。如果基数太高,您很可能会出现下溢并得到很多
-128
。此外,还有一个偏差将结果映射到 [-128, 127] 范围,这是琐碎且不太重要的
您从中获取测试用例的库很可能使用与上述两者不同的定义。
我使用 matplotlib 对您的测试用例和 softmax 的浮点定义进行了一些测试,以下表达式给出了很好的拟合:
softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
您可以想象,在执行
>>7
基于 2 的指数之前,您可能需要执行 1<<
来输入字节。为了给出完全相同的结果,您当然应该深入研究该库代码,但我没有时间这样做。
以下是验证码:
import numpy as np
import matplotlib.pyplot as plt
inarr = np.array([101, 49, 6, -34, -75, -79, -38, 120, -55, 115], dtype=np.int8).astype(np.double)
expected_arr = np.array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=np.int8).astype(np.double)
print(expected_arr)
softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
print(softmax_naive - expected_arr)
plt.plot(inarr)
plt.plot(expected_arr)
plt.plot(softmax_naive)
plt.show()