使用 Vector API 优化 Java 中 int16 向量点积的计算

问题描述 投票:0回答:1

TL;DR:使用 Java 的 Vector API 优化 16 位整数数组乘法而不溢出。

我正在尝试优化一个性能关键的循环,该循环应用激活函数并使用 Java 的(正在孵化的)Vector API 计算两个 int16 数组的点积。这是我当前的标量实现:

for (int i = 0; i < HIDDEN_SIZE; i++)
{
    result += screlu(us.values[i]) * network.L1Weights[i]
        + screlu(them.values[i]) * network.L1Weights[i + HIDDEN_SIZE];
}

哪里

private static int screlu(short i)
{
    int v = Math.max(0, Math.min(i, QA));
    return v * v;
}

我尝试像这样矢量化它:

int[] usValues = new int[HIDDEN_SIZE];
int[] themValues = new int[HIDDEN_SIZE];

for (int i = 0; i < HIDDEN_SIZE; i++)
{
    usValues[i] = (int) us.values[i];
    themValues[i] = (int) them.values[i];
}

IntVector sum = IntVector.zero(INT_SPECIES);

for (; i < upperBound; i += INT_SPECIES.length())
{
    IntVector va = IntVector.fromArray(INT_SPECIES, usValues, i);
    IntVector vb = IntVector.fromArray(INT_SPECIES, themValues, i);
    IntVector vc = IntVector.fromArray(INT_SPECIES, network.L1Weights, i);
    IntVector vd = IntVector.fromArray(INT_SPECIES, network.L1Weights, i + HIDDEN_SIZE);

    va = va.max(0).min(QA);
    va = va.mul(va).mul(vc);

    vb = vb.max(0).min(QA);
    vb = vb.mul(vb).mul(vd);

    sum = sum.add(va).add(vb);
}

int result = sum.reduceLanes(VectorOperators.ADD);

由于溢出,我不得不使用 32 位宽的通道,吞吐量减半。结果,性能仅稍好一些。经过一番研究,我发现像

_mm256_madd_epi16
这样的内在函数完全解决了我的问题,但我在文档中找不到任何有关它的信息。 Vector API 中是否存在等效操作,如果不存在,是否有其他解决方案来解决此问题?

java optimization vectorization simd
1个回答
0
投票

在一些帮助下,我最终找到了使用 S2I 运算符的实现,它比标量实现更快,但可能比

vpmaddwd
可用时要慢。

import static jdk.incubator.vector.VectorOperators.S2I;

for (int i = 0; i < UPPERBOUND; i += SHORT_SPECIES.length())
{
    ShortVector usInputs = ShortVector.fromArray(SHORT_SPECIES, us.values, i);
    ShortVector themInputs = ShortVector.fromArray(SHORT_SPECIES, them.values, i);
    ShortVector usWeights = ShortVector.fromArray(SHORT_SPECIES, network.L2Weights[chosenBucket], i);
    ShortVector themWeights = ShortVector.fromArray(SHORT_SPECIES, network.L2Weights[chosenBucket],
            i + HIDDEN_SIZE);

    usInputs = usInputs.max(ShortVector.zero(SHORT_SPECIES)).min(ShortVector.broadcast(SHORT_SPECIES, QA));
    themInputs = themInputs.max(ShortVector.zero(SHORT_SPECIES)).min(ShortVector.broadcast(SHORT_SPECIES, QA));

    ShortVector usWeightedTerms = usInputs.mul(usWeights);
    ShortVector themWeightedTerms = themInputs.mul(themWeights);

    Vector<Integer> usInputsLo = usInputs.convert(S2I, 0);
    Vector<Integer> usInputsHi = usInputs.convert(S2I, 1);
    Vector<Integer> themInputsLo = themInputs.convert(S2I, 0);
    Vector<Integer> themInputsHi = themInputs.convert(S2I, 1);

    Vector<Integer> usWeightedTermsLo = usWeightedTerms.convert(S2I, 0);
    Vector<Integer> usWeightedTermsHi = usWeightedTerms.convert(S2I, 1);
    Vector<Integer> themWeightedTermsLo = themWeightedTerms.convert(S2I, 0);
    Vector<Integer> themWeightedTermsHi = themWeightedTerms.convert(S2I, 1);

    sum = sum.add(usInputsLo.mul(usWeightedTermsLo)).add(usInputsHi.mul(usWeightedTermsHi))
    .add(themInputsLo.mul(themWeightedTermsLo)).add(themInputsHi.mul(themWeightedTermsHi));
}

int result = sum.reduceLanes(VectorOperators.ADD);
© www.soinside.com 2019 - 2024. All rights reserved.