如何在Python中实现Softmax,其中输入是有符号的8个整数

问题描述 投票:0回答:1

我正在尝试实现一个 softmax 函数,它接受有符号 int8 输入并返回有符号 int8 输出数组。

我目前正在进行的实施是这样的,

 import numpy as np

def softmax_int8(inputs):
    inputs = np.array(inputs, dtype=np.int8)
    
    x = inputs.astype(np.int32)
    x_max = np.max(x)
    x_shifted = x - x_max
    scale_factor = 2 ** 14 
    exp_limit = 16
    exp_x = np.clip(x_shifted + exp_limit, 0, None)
    exp_x = (1 << exp_x)
    sum_exp_x = np.sum(exp_x)

    if sum_exp_x == 0:
        sum_exp_x = 1

    softmax_probs = (exp_x * scale_factor) // sum_exp_x
    max_prob = np.max(softmax_probs)
    min_prob = np.min(softmax_probs)
    range_prob = max_prob - min_prob if max_prob != min_prob else 1

    scaled_probs = ((softmax_probs - min_prob) * 255) // range_prob - 128
    outputs = scaled_probs.astype(np.int8)

    return outputs

我使用这个输入来测试它,

Input = [101, 49, 6, -34, -75, -79, -38, 120, -55, 115]

但是我得到了这个输出

array([-128, -128, -128, -128, -128, -128, -128,  127, -128, -121],dtype=int8)

我的预期产出是

array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=int8)

我在这里做错了什么以及如何解决它?

python machine-learning deep-learning
1个回答
0
投票

我认为在不同的上下文中,softmax 有不同的数学定义。

  • 维基百科定义(关于实数):
    exp(z) / sum(exp(z))
  • 我从你的代码中推断出:
    (1<<(z-z_max + 16)) / sum((1 << (z-z_max + 16)))
    或类似的东西。
    1<<
    ===
    2**
    显然。

主要区别在于指数的底数。如果基数太高,您很可能会出现下溢并得到很多

-128
。此外,还有一个偏差将结果映射到 [-128, 127] 范围,这是琐碎且不太重要的

您从中获取测试用例的库很可能使用与上述两者不同的定义。

我使用 matplotlib 对您的测试用例和 softmax 的浮点定义进行了一些测试,以下表达式给出了很好的拟合:

softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100

您可以想象,在执行

>>7
基于 2 的指数之前,您可能需要执行
1<<
来输入字节。为了给出完全相同的结果,您当然应该深入研究该库代码,但我没有时间这样做。

以下是验证码:

import numpy as np
import matplotlib.pyplot as plt

inarr = np.array([101, 49, 6, -34, -75, -79, -38, 120, -55, 115], dtype=np.int8).astype(np.double)
expected_arr = np.array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=np.int8).astype(np.double)
print(expected_arr)

softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
print(softmax_naive - expected_arr)
plt.plot(inarr)
plt.plot(expected_arr)
plt.plot(softmax_naive)
plt.show()

validation of softmax

最新问题
© www.soinside.com 2019 - 2025. All rights reserved.