TL;DR:使用 Java 的 Vector API 优化 16 位整数数组乘法而不溢出。
我正在尝试优化一个性能关键的循环,该循环应用激活函数并使用 Java 的(正在孵化的)Vector API 计算两个 int16 数组的点积。这是我当前的标量实现:
for (int i = 0; i < HIDDEN_SIZE; i++)
{
result += screlu(us.values[i]) * network.L1Weights[i]
+ screlu(them.values[i]) * network.L1Weights[i + HIDDEN_SIZE];
}
哪里
private static int screlu(short i)
{
int v = Math.max(0, Math.min(i, QA));
return v * v;
}
我尝试像这样矢量化它:
int[] usValues = new int[HIDDEN_SIZE];
int[] themValues = new int[HIDDEN_SIZE];
for (int i = 0; i < HIDDEN_SIZE; i++)
{
usValues[i] = (int) us.values[i];
themValues[i] = (int) them.values[i];
}
IntVector sum = IntVector.zero(INT_SPECIES);
for (; i < upperBound; i += INT_SPECIES.length())
{
IntVector va = IntVector.fromArray(INT_SPECIES, usValues, i);
IntVector vb = IntVector.fromArray(INT_SPECIES, themValues, i);
IntVector vc = IntVector.fromArray(INT_SPECIES, network.L1Weights, i);
IntVector vd = IntVector.fromArray(INT_SPECIES, network.L1Weights, i + HIDDEN_SIZE);
va = va.max(0).min(QA);
va = va.mul(va).mul(vc);
vb = vb.max(0).min(QA);
vb = vb.mul(vb).mul(vd);
sum = sum.add(va).add(vb);
}
int result = sum.reduceLanes(VectorOperators.ADD);
由于溢出,我不得不使用 32 位宽的通道,吞吐量减半。结果,性能仅稍好一些。经过一番研究,我发现像
_mm256_madd_epi16
这样的内在函数完全解决了我的问题,但我在文档中找不到任何有关它的信息。 Vector API 中是否存在等效操作,如果不存在,是否有其他解决方案来解决此问题?
在一些帮助下,我最终找到了使用 S2I 运算符的实现,它比标量实现更快,但可能比
vpmaddwd
可用时要慢。
import static jdk.incubator.vector.VectorOperators.S2I;
for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length())
{
ShortVector usInputs = ShortVector.fromArray(SHORT_SPECIES, us.values, i);
ShortVector themInputs = ShortVector.fromArray(SHORT_SPECIES, them.values, i);
ShortVector usWeights = ShortVector.fromArray(SHORT_SPECIES, network.L2Weights[chosenBucket], i);
ShortVector themWeights = ShortVector.fromArray(SHORT_SPECIES, network.L2Weights[chosenBucket],
i + HIDDEN_SIZE);
usInputs = usInputs.max(ShortVector.zero(SHORT_SPECIES)).min(ShortVector.broadcast(SHORT_SPECIES, QA));
themInputs = themInputs.max(ShortVector.zero(SHORT_SPECIES)).min(ShortVector.broadcast(SHORT_SPECIES, QA));
ShortVector usWeightedTerms = usInputs.mul(usWeights);
ShortVector themWeightedTerms = themInputs.mul(themWeights);
Vector<Integer> usInputsLo = usInputs.convert(S2I, 0);
Vector<Integer> usInputsHi = usInputs.convert(S2I, 1);
Vector<Integer> themInputsLo = themInputs.convert(S2I, 0);
Vector<Integer> themInputsHi = themInputs.convert(S2I, 1);
Vector<Integer> usWeightedTermsLo = usWeightedTerms.convert(S2I, 0);
Vector<Integer> usWeightedTermsHi = usWeightedTerms.convert(S2I, 1);
Vector<Integer> themWeightedTermsLo = themWeightedTerms.convert(S2I, 0);
Vector<Integer> themWeightedTermsHi = themWeightedTerms.convert(S2I, 1);
sum = sum.add(usInputsLo.mul(usWeightedTermsLo)).add(usInputsHi.mul(usWeightedTermsHi))
.add(themInputsLo.mul(themWeightedTermsLo)).add(themInputsHi.mul(themWeightedTermsHi));
}
int result = sum.reduceLanes(VectorOperators.ADD);