对于 2024 年代码出现的第 4 天,存在一个问题,您需要查找字符网格中包含多少个“XMAS”字符串,例如
MMMSXXMASM
MSAMXMSMSA
AMXSXMAAMM
MSAMASMSMX
XMASAMXAMM
XXAMMXXAMA
SMSMSASXSS
SAXAMASAAA
MAMMMXMMMM
MXMXAXMASX
单词可以出现在 8 个方向中的任意一个方向:水平(上/下)/垂直(左/右),或 4 个对角线中的任意一个。从问题陈述中不清楚网格大小是否是固定的。
他们给你的实际测试用例是 140x140。
这有效并且相对较快(当我增加输入文件大小来测试其限制时,它达到了 460 MB/s),但是当在包含很多开发人员以及一些超过 15 年的开发人员的不和谐服务器上谈论我们的解决方案时具有专业经验的其中一位告诉我,对于这个特定问题,使用 SIMD 指令并不能让代码运行得更快,底层算法本身更重要。
所以我想知道,在什么情况下SIMD值得使用?我如何识别使用 SIMD 指令会产生显着差异的情况?我认为比较 1 条指令中的所有 32 个字符会产生很大的差异,而不是嵌套 2 个 for 循环,但我可能错了?
typedef struct {
std::vector<char> chars;
uint32_t line_length;
} Data;
Data readFile() {
std::ifstream file("./input/day4_input.txt", std::ios::binary | std::ios::ate);
std::streamsize size = file.tellg();
file.seekg(0, std::ios::beg);
std::vector<char> buffer(size);
file.read(buffer.data(), size);
file.close();
uint32_t line_length = std::ranges::find(buffer, '\n') - buffer.begin();
buffer.erase(std::remove(buffer.begin(), buffer.end(), '\n'), buffer.end());
Data input_data { buffer, line_length };
return input_data;
}
int main(int argc, char** argv) {
Data idata = readFile();
auto coordsToIdx = [&idata] (uint32_t i, uint32_t j) -> uint32_t { return i*idata.line_length + j; };
auto idxToCoords = [&idata] (uint32_t idx) -> std::pair<uint32_t, uint32_t> {
return std::make_pair(static_cast<uint32_t>(idx / idata.line_length), static_cast<uint32_t>(idx % idata.line_length));
};
unsigned int res = 0;
const __m256i _mask = _mm256_setr_epi8(
'X', 'M', 'A', 'S',
'X', 'M', 'A', 'S',
'X', 'M', 'A', 'S',
'X', 'M', 'A', 'S',
'X', 'M', 'A', 'S',
'X', 'M', 'A', 'S',
'X', 'M', 'A', 'S',
'X', 'M', 'A', 'S'
);
std::array<int, 8> out_buff;
for(size_t i = 0; i < idata.chars.size(); i++) {
if(idata.chars[i] != 'X') continue;
char buff[32] = {0};
std::pair<uint32_t, uint32_t> coords = idxToCoords(i);
bool u = coords.first >= 3;
bool d = coords.first <= (idata.chars.size() / idata.line_length) - 4; // 0 indexed
bool l = coords.second >= 3;
bool r = coords.second <= idata.line_length - 4;
if(u) for(size_t c = 0; c < 4; c++) buff[c] = idata.chars[coordsToIdx(coords.first - c, coords.second)];
if(u && r) for(size_t c = 0; c < 4; c++) buff[c + 4] = idata.chars[coordsToIdx(coords.first - c, coords.second + c)];
if(r) for(size_t c = 0; c < 4; c++) buff[c + 8] = idata.chars[coordsToIdx(coords.first, coords.second + c)];
if(d && r) for(size_t c = 0; c < 4; c++) buff[c + 12] = idata.chars[coordsToIdx(coords.first + c, coords.second + c)];
if(d) for(size_t c = 0; c < 4; c++) buff[c + 16] = idata.chars[coordsToIdx(coords.first + c, coords.second)];
if(d && l) for(size_t c = 0; c < 4; c++) buff[c + 20] = idata.chars[coordsToIdx(coords.first + c, coords.second - c)];
if(l) for(size_t c = 0; c < 4; c++) buff[c + 24] = idata.chars[coordsToIdx(coords.first, coords.second - c)];
if(u && l) for(size_t c = 0; c < 4; c++) buff[c + 28] = idata.chars[coordsToIdx(coords.first - c, coords.second - c)];
__m256i _block = _mm256_loadu_epi8(buff);
__m256i _cmpeq = _mm256_cmpeq_epi32(_block, _mask);
_mm256_storeu_epi32(out_buff.data(), _cmpeq);
// Contains signed 32 bits integers for each direction
// If a direction match, all bits are set to 1
// Which is -1. So we substract the total to add it instead
res -= std::accumulate(out_buff.begin(), out_buff.end(), 0);
}
std::cout << "Result is : << res << std::endl;
}
然后从矢量化或并行化开始。后两者只能给你带来线性的性能提升。选择更好的算法可以获得更多;但修改矢量化或并行化代码要困难得多。
但是,这里我们有一个非常具体的任务:我们在大海捞针中搜索的针正好是 4 个字节,并且它在输入数据中出现非常频繁。复杂的算法和大 O 表示法对于小数字来说毫无意义。这是简单的暴力算法的领域。矢量化算法
的矢量化形式
int count = 0;
for(int col = 0; col <= cols - 4; ++col)
count += (haystack[col] == 'X') & (haystack[col+1] == 'M')
& (haystack[col+1] == 'A') & (haystack[col+3] == 'S');
使用 AVX512,我们可以一次比较 32 个条目,如下所示:
const __m512i xvec = _mm512_set1_epi8('X');
const __m512i mvec = _mm512_set1_epi8('M');
const __m512i avec = _mm512_set1_epi8('A');
const __m512i svec = _mm512_set1_epi8('S');
for(int col = 0; col <= cols - 64 - 4; ++col) {
const char* colptr = haystack + col;
const __m512i c1 = _mm512_loadu_si512(colptr);
const __m512i c2 = _mm512_loadu_si512(colptr + 1);
const __m512i c3 = _mm512_loadu_si512(colptr + 2);
const __m512i c4 = _mm512_loadu_si512(colptr + 3);
__mmask64 fe = _mm512_cmpeq_epi8_mask(c1, xvec);
fe = _mm512_mask_cmpeq_epi8_mask(fe, c2, mvec);
fe = _mm512_mask_cmpeq_epi8_mask(fe, c3, avec);
fe = _mm512_mask_cmpeq_epi8_mask(fe, c4, svec);
count += _popcnt64(fe);
}
我们对每个值进行重叠加载,比较它们并对掩码中设置的位进行计数。您会在许多应用(例如卷积)中发现这种重叠负载模式。在 AVX512 中处理尾端非常简单:只需加载遮罩即可。由于零字节无法匹配我们的针,因此算法的其余部分可以保持不变。
int col;
for(col = 0; col <= cols - 64 - 4; ++col) {
... // (main loop)
}
const std::uint64_t tailcount = static_cast<std::uint64_t>(
std::min(64, cols - col));
if(tailcount) {
const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount);
const char* colptr = haystack + col;
const __m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr);
const __m512i c2 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + 1);
const __m512i c3 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2);
const __m512i c4 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3);
... // same as main loop from here
}
现在我们必须将其扩展到 2D 网格和其他 7 个方向。
检查反向方向可以简单地通过第二组比较来完成,只需将 xvec
反转为
svec
顺序即可。垂直和对角线检查很简单,因为我们不是前进一列,而是前进行加负 1 列。进行范围检查很容易出错,否则都是一样的。
这是完整的算法。我已经验证了结果,但我不能排除测试数据未涵盖的边缘情况。
#include <chrono>
#include <cstdint>
#include <cstring>
#include <fstream>
#include <iostream>
#include <string>
#include <immintrin.h>
struct Grid
{
std::string buf;
int rows, cols;
};
static Grid read_grid(const char* filepath)
{
Grid grid {};
std::ifstream infile { filepath };
std::string tmp;
while(std::getline(infile, tmp)) {
grid.rows += 1;
grid.cols = tmp.size();
grid.buf += tmp;
}
return grid;
}
static int count_xmas_dense(const Grid& grid)
{
const char* gridptr = grid.buf.c_str();
const std::ptrdiff_t rows = grid.rows;
const std::ptrdiff_t cols = grid.cols;
const __m512i xvec = _mm512_set1_epi8('X');
const __m512i mvec = _mm512_set1_epi8('M');
const __m512i avec = _mm512_set1_epi8('A');
const __m512i svec = _mm512_set1_epi8('S');
auto check_forward_backward =
[&](__m512i c1, __m512i c2, __m512i c3, __m512i c4) {
// check forward
__mmask64 fe = _mm512_cmpeq_epi8_mask(c1, xvec);
fe = _mm512_mask_cmpeq_epi8_mask(fe, c2, mvec);
fe = _mm512_mask_cmpeq_epi8_mask(fe, c3, avec);
fe = _mm512_mask_cmpeq_epi8_mask(fe, c4, svec);
int rtrn = _popcnt64(fe);
// check backward
__mmask64 be = _mm512_cmpeq_epi8_mask(c1, svec);
be = _mm512_mask_cmpeq_epi8_mask(be, c2, avec);
be = _mm512_mask_cmpeq_epi8_mask(be, c3, mvec);
be = _mm512_mask_cmpeq_epi8_mask(be, c4, xvec);
rtrn += _popcnt64(be);
return rtrn;
};
int rtrn = 0;
if(cols > 3) {
// check left to right and right to left
for(std::ptrdiff_t row = 0; row < rows; ++row) {
const char* rowptr = gridptr + row * cols;
std::ptrdiff_t col;
for(col = 0; col <= cols - 64 - 3; col += 64) {
const char* colptr = rowptr + col;
const __m512i c1 = _mm512_loadu_si512(colptr);
const __m512i c2 = _mm512_loadu_si512(colptr + 1);
const __m512i c3 = _mm512_loadu_si512(colptr + 2);
const __m512i c4 = _mm512_loadu_si512(colptr + 3);
rtrn += check_forward_backward(c1, c2, c3, c4);
}
const std::uint64_t tailcount = static_cast<std::uint64_t>(
std::min<std::ptrdiff_t>(64, cols - col));
if(!tailcount)
continue;
const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount);
const char* colptr = rowptr + col;
const __m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr);
const __m512i c2 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + 1);
const __m512i c3 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2);
const __m512i c4 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3);
rtrn += check_forward_backward(c1, c2, c3, c4);
}
}
if(rows > 3) {
// check top to bottom and bottom up
for(std::ptrdiff_t row = 0; row <= rows - 4; ++row) {
const char* rowptr = gridptr + row * cols;
std::ptrdiff_t col;
for(col = 0; col <= cols - 64; col += 64) {
const char* colptr = rowptr + col;
const __m512i c1 = _mm512_loadu_si512(colptr);
const __m512i c2 = _mm512_loadu_si512(colptr + cols);
const __m512i c3 = _mm512_loadu_si512(colptr + 2 * cols);
const __m512i c4 = _mm512_loadu_si512(colptr + 3 * cols);
rtrn += check_forward_backward(c1, c2, c3, c4);
}
const std::uint64_t tailcount = static_cast<std::uint64_t>(
std::min<std::ptrdiff_t>(64, cols - col));
if(! tailcount)
continue;
const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount);
const char* colptr = rowptr + col;
const __m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr);
const __m512i c2 = _mm512_maskz_loadu_epi8(tailmask, colptr + cols);
const __m512i c3 = _mm512_maskz_loadu_epi8(tailmask, colptr + 2 * cols);
const __m512i c4 = _mm512_maskz_loadu_epi8(tailmask, colptr + 3 * cols);
rtrn += check_forward_backward(c1, c2, c3, c4);
}
}
if(rows > 3 && cols > 3) {
for(std::ptrdiff_t row = 0; row <= rows - 4; ++row) {
const char* rowptr = gridptr + row * cols;
std::ptrdiff_t col;
for(col = 0; col <= cols - 64 - 3; col += 64) {
const char* colptr = rowptr + col;
// top left to bottom right
__m512i c1 = _mm512_loadu_si512(colptr);
__m512i c2 = _mm512_loadu_si512(colptr + cols + 1);
__m512i c3 = _mm512_loadu_si512(colptr + 2 * (cols + 1));
__m512i c4 = _mm512_loadu_si512(colptr + 3 * (cols + 1));
rtrn += check_forward_backward(c1, c2, c3, c4);
// top right to bottom left
c1 = _mm512_loadu_si512(colptr + 3);
c2 = _mm512_loadu_si512(colptr + 2 + cols);
c3 = _mm512_loadu_si512(colptr + 1 + 2 * cols);
c4 = _mm512_loadu_si512(colptr + 3 * cols);
rtrn += check_forward_backward(c1, c2, c3, c4);
}
const std::uint64_t tailcount = static_cast<std::uint64_t>(
std::min<std::ptrdiff_t>(64, cols - col));
const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount);
const char* colptr = rowptr + col;
__m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr);
__m512i c2 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + cols + 1);
__m512i c3 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2 * (cols + 1));
__m512i c4 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3 * (cols + 1));
rtrn += check_forward_backward(c1, c2, c3, c4);
c1 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3);
c2 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2 + cols);
c3 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + 1 + 2 * cols);
c4 = _mm512_maskz_loadu_epi8(tailmask, colptr + 3 * cols);
rtrn += check_forward_backward(c1, c2, c3, c4);
}
}
return rtrn;
}
int main()
{
Grid grid = read_grid("xmas.txt");
std::cout << grid.rows << ' ' << grid.cols << ' ' << grid.buf.size() <<'\n';
using clock_t = std::chrono::steady_clock;
auto t1 = clock_t::now();
int result = count_xmas_dense(grid);
auto t2 = clock_t::now();
auto nanos = std::chrono::duration_cast<std::chrono::nanoseconds>(t2 - t1);
double throughput = (grid.buf.size() / (1024. * 1024.)) / (nanos.count() * 1e-9);
std::cout << result
<< ' ' << throughput << " MiB/s\n";
}
如果我将测试数据文件扩展 10,000 倍(1400140 行 x 140 列),我将实现 3.9 GiB/s 的吞吐量。AVX2版本
_mm256_cmpeq_epi8
,然后使用
_mm256_and_si256
来合并结果,然后使用
_mm256_movemask_epi8
来获取位掩码。对于尾部,您需要展开一次 16 字节的迭代,然后可能是最后几次的标量循环。