如何有效地使用SIMD来统计大型单词搜索网格(包括垂直和对角线)中的4个字符匹配?

问题描述 投票:0回答:1

对于 2024 年代码出现的第 4 天,存在一个问题,您需要查找字符网格中包含多少个“XMAS”字符串,例如

MMMSXXMASM MSAMXMSMSA AMXSXMAAMM MSAMASMSMX XMASAMXAMM XXAMMXXAMA SMSMSASXSS SAXAMASAAA MAMMMXMMMM MXMXAXMASX
单词可以出现在 8 个方向中的任意一个方向:水平(上/下)/垂直(左/右),或 4 个对角线中的任意一个。

从问题陈述中不清楚网格大小是否是固定的。

他们给你的实际测试用例是 140x140。


我解决这个问题的方法是将 X 开始的 8 个可能方向的所有 32 个字符加载到 SIMD 寄存器中,并将其与掩码进行比较。

这有效并且相对较快(当我增加输入文件大小来测试其限制时,它达到了 460 MB/s),但是当在包含很多开发人员以及一些超过 15 年的开发人员的不和谐服务器上谈论我们的解决方案时具有专业经验的其中一位告诉我,对于这个特定问题,使用 SIMD 指令并不能让代码运行得更快,底层算法本身更重要。

所以我想知道,在什么情况下SIMD值得使用?我如何识别使用 SIMD 指令会产生显着差异的情况?我认为比较 1 条指令中的所有 32 个字符会产生很大的差异,而不是嵌套 2 个 for 循环,但我可能错了?

typedef struct { std::vector<char> chars; uint32_t line_length; } Data; Data readFile() { std::ifstream file("./input/day4_input.txt", std::ios::binary | std::ios::ate); std::streamsize size = file.tellg(); file.seekg(0, std::ios::beg); std::vector<char> buffer(size); file.read(buffer.data(), size); file.close(); uint32_t line_length = std::ranges::find(buffer, '\n') - buffer.begin(); buffer.erase(std::remove(buffer.begin(), buffer.end(), '\n'), buffer.end()); Data input_data { buffer, line_length }; return input_data; } int main(int argc, char** argv) { Data idata = readFile(); auto coordsToIdx = [&idata] (uint32_t i, uint32_t j) -> uint32_t { return i*idata.line_length + j; }; auto idxToCoords = [&idata] (uint32_t idx) -> std::pair<uint32_t, uint32_t> { return std::make_pair(static_cast<uint32_t>(idx / idata.line_length), static_cast<uint32_t>(idx % idata.line_length)); }; unsigned int res = 0; const __m256i _mask = _mm256_setr_epi8( 'X', 'M', 'A', 'S', 'X', 'M', 'A', 'S', 'X', 'M', 'A', 'S', 'X', 'M', 'A', 'S', 'X', 'M', 'A', 'S', 'X', 'M', 'A', 'S', 'X', 'M', 'A', 'S', 'X', 'M', 'A', 'S' ); std::array<int, 8> out_buff; for(size_t i = 0; i < idata.chars.size(); i++) { if(idata.chars[i] != 'X') continue; char buff[32] = {0}; std::pair<uint32_t, uint32_t> coords = idxToCoords(i); bool u = coords.first >= 3; bool d = coords.first <= (idata.chars.size() / idata.line_length) - 4; // 0 indexed bool l = coords.second >= 3; bool r = coords.second <= idata.line_length - 4; if(u) for(size_t c = 0; c < 4; c++) buff[c] = idata.chars[coordsToIdx(coords.first - c, coords.second)]; if(u && r) for(size_t c = 0; c < 4; c++) buff[c + 4] = idata.chars[coordsToIdx(coords.first - c, coords.second + c)]; if(r) for(size_t c = 0; c < 4; c++) buff[c + 8] = idata.chars[coordsToIdx(coords.first, coords.second + c)]; if(d && r) for(size_t c = 0; c < 4; c++) buff[c + 12] = idata.chars[coordsToIdx(coords.first + c, coords.second + c)]; if(d) for(size_t c = 0; c < 4; c++) buff[c + 16] = idata.chars[coordsToIdx(coords.first + c, coords.second)]; if(d && l) for(size_t c = 0; c < 4; c++) buff[c + 20] = idata.chars[coordsToIdx(coords.first + c, coords.second - c)]; if(l) for(size_t c = 0; c < 4; c++) buff[c + 24] = idata.chars[coordsToIdx(coords.first, coords.second - c)]; if(u && l) for(size_t c = 0; c < 4; c++) buff[c + 28] = idata.chars[coordsToIdx(coords.first - c, coords.second - c)]; __m256i _block = _mm256_loadu_epi8(buff); __m256i _cmpeq = _mm256_cmpeq_epi32(_block, _mask); _mm256_storeu_epi32(out_buff.data(), _cmpeq); // Contains signed 32 bits integers for each direction // If a direction match, all bits are set to 1 // Which is -1. So we substract the total to add it instead res -= std::accumulate(out_buff.begin(), out_buff.end(), 0); } std::cout << "Result is : << res << std::endl; }
    
c++ optimization simd micro-optimization
1个回答
0
投票
一般主题

我认为与你交谈的那个人陷入了一种错误的二分法。使用高级算法和矢量化并不相互排斥(尽管可能更困难)。举个例子:该任务是针干草堆问题的变体,其中

strstr

 是标准实现。 Glibc 为此实现了一种高级算法(修改后的 Horspool),但也有一个 SSE2 加速版本。相比之下,他们删除了 AVX512 加速版本,似乎是因为它实现了朴素的算法。

那个人的说法是正确的,首先你应该关注算法

然后从矢量化或并行化开始。后两者只能给你带来线性的性能提升。选择更好的算法可以获得更多;但修改矢量化或并行化代码要困难得多。

但是,这里我们有一个非常具体的任务:我们在大海捞针中搜索的针正好是 4 个字节,并且它在输入数据中出现非常频繁。复杂的算法和大 O 表示法对于小数字来说毫无意义。这是简单的暴力算法的领域。

矢量化算法

这是我的看法。为了理解它,我们从头开始构建它。让我们关注水平匹配,仅向前方向,仅一行。我们想要的是

的矢量化形式

int count = 0; for(int col = 0; col <= cols - 4; ++col) count += (haystack[col] == 'X') & (haystack[col+1] == 'M') & (haystack[col+1] == 'A') & (haystack[col+3] == 'S');
使用 AVX512,我们可以一次比较 32 个条目,如下所示:

const __m512i xvec = _mm512_set1_epi8('X'); const __m512i mvec = _mm512_set1_epi8('M'); const __m512i avec = _mm512_set1_epi8('A'); const __m512i svec = _mm512_set1_epi8('S'); for(int col = 0; col <= cols - 64 - 4; ++col) { const char* colptr = haystack + col; const __m512i c1 = _mm512_loadu_si512(colptr); const __m512i c2 = _mm512_loadu_si512(colptr + 1); const __m512i c3 = _mm512_loadu_si512(colptr + 2); const __m512i c4 = _mm512_loadu_si512(colptr + 3); __mmask64 fe = _mm512_cmpeq_epi8_mask(c1, xvec); fe = _mm512_mask_cmpeq_epi8_mask(fe, c2, mvec); fe = _mm512_mask_cmpeq_epi8_mask(fe, c3, avec); fe = _mm512_mask_cmpeq_epi8_mask(fe, c4, svec); count += _popcnt64(fe); }
我们对每个值进行重叠加载,比较它们并对掩码中设置的位进行计数。您会在许多应用(例如卷积)中发现这种重叠负载模式。

在 AVX512 中处理尾端非常简单:只需加载遮罩即可。由于零字节无法匹配我们的针,因此算法的其余部分可以保持不变。

int col; for(col = 0; col <= cols - 64 - 4; ++col) { ... // (main loop) } const std::uint64_t tailcount = static_cast<std::uint64_t>( std::min(64, cols - col)); if(tailcount) { const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount); const char* colptr = haystack + col; const __m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr); const __m512i c2 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + 1); const __m512i c3 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2); const __m512i c4 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3); ... // same as main loop from here }
现在我们必须将其扩展到 2D 网格和其他 7 个方向。
检查反向方向可以简单地通过第二组比较来完成,只需将 

xvec

 反转为 
svec
 顺序即可。

垂直和对角线检查很简单,因为我们不是前进一列,而是前进行加负 1 列。进行范围检查很容易出错,否则都是一样的。

这是完整的算法。我已经验证了结果,但我不能排除测试数据未涵盖的边缘情况。

#include <chrono> #include <cstdint> #include <cstring> #include <fstream> #include <iostream> #include <string> #include <immintrin.h> struct Grid { std::string buf; int rows, cols; }; static Grid read_grid(const char* filepath) { Grid grid {}; std::ifstream infile { filepath }; std::string tmp; while(std::getline(infile, tmp)) { grid.rows += 1; grid.cols = tmp.size(); grid.buf += tmp; } return grid; } static int count_xmas_dense(const Grid& grid) { const char* gridptr = grid.buf.c_str(); const std::ptrdiff_t rows = grid.rows; const std::ptrdiff_t cols = grid.cols; const __m512i xvec = _mm512_set1_epi8('X'); const __m512i mvec = _mm512_set1_epi8('M'); const __m512i avec = _mm512_set1_epi8('A'); const __m512i svec = _mm512_set1_epi8('S'); auto check_forward_backward = [&](__m512i c1, __m512i c2, __m512i c3, __m512i c4) { // check forward __mmask64 fe = _mm512_cmpeq_epi8_mask(c1, xvec); fe = _mm512_mask_cmpeq_epi8_mask(fe, c2, mvec); fe = _mm512_mask_cmpeq_epi8_mask(fe, c3, avec); fe = _mm512_mask_cmpeq_epi8_mask(fe, c4, svec); int rtrn = _popcnt64(fe); // check backward __mmask64 be = _mm512_cmpeq_epi8_mask(c1, svec); be = _mm512_mask_cmpeq_epi8_mask(be, c2, avec); be = _mm512_mask_cmpeq_epi8_mask(be, c3, mvec); be = _mm512_mask_cmpeq_epi8_mask(be, c4, xvec); rtrn += _popcnt64(be); return rtrn; }; int rtrn = 0; if(cols > 3) { // check left to right and right to left for(std::ptrdiff_t row = 0; row < rows; ++row) { const char* rowptr = gridptr + row * cols; std::ptrdiff_t col; for(col = 0; col <= cols - 64 - 3; col += 64) { const char* colptr = rowptr + col; const __m512i c1 = _mm512_loadu_si512(colptr); const __m512i c2 = _mm512_loadu_si512(colptr + 1); const __m512i c3 = _mm512_loadu_si512(colptr + 2); const __m512i c4 = _mm512_loadu_si512(colptr + 3); rtrn += check_forward_backward(c1, c2, c3, c4); } const std::uint64_t tailcount = static_cast<std::uint64_t>( std::min<std::ptrdiff_t>(64, cols - col)); if(!tailcount) continue; const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount); const char* colptr = rowptr + col; const __m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr); const __m512i c2 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + 1); const __m512i c3 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2); const __m512i c4 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3); rtrn += check_forward_backward(c1, c2, c3, c4); } } if(rows > 3) { // check top to bottom and bottom up for(std::ptrdiff_t row = 0; row <= rows - 4; ++row) { const char* rowptr = gridptr + row * cols; std::ptrdiff_t col; for(col = 0; col <= cols - 64; col += 64) { const char* colptr = rowptr + col; const __m512i c1 = _mm512_loadu_si512(colptr); const __m512i c2 = _mm512_loadu_si512(colptr + cols); const __m512i c3 = _mm512_loadu_si512(colptr + 2 * cols); const __m512i c4 = _mm512_loadu_si512(colptr + 3 * cols); rtrn += check_forward_backward(c1, c2, c3, c4); } const std::uint64_t tailcount = static_cast<std::uint64_t>( std::min<std::ptrdiff_t>(64, cols - col)); if(! tailcount) continue; const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount); const char* colptr = rowptr + col; const __m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr); const __m512i c2 = _mm512_maskz_loadu_epi8(tailmask, colptr + cols); const __m512i c3 = _mm512_maskz_loadu_epi8(tailmask, colptr + 2 * cols); const __m512i c4 = _mm512_maskz_loadu_epi8(tailmask, colptr + 3 * cols); rtrn += check_forward_backward(c1, c2, c3, c4); } } if(rows > 3 && cols > 3) { for(std::ptrdiff_t row = 0; row <= rows - 4; ++row) { const char* rowptr = gridptr + row * cols; std::ptrdiff_t col; for(col = 0; col <= cols - 64 - 3; col += 64) { const char* colptr = rowptr + col; // top left to bottom right __m512i c1 = _mm512_loadu_si512(colptr); __m512i c2 = _mm512_loadu_si512(colptr + cols + 1); __m512i c3 = _mm512_loadu_si512(colptr + 2 * (cols + 1)); __m512i c4 = _mm512_loadu_si512(colptr + 3 * (cols + 1)); rtrn += check_forward_backward(c1, c2, c3, c4); // top right to bottom left c1 = _mm512_loadu_si512(colptr + 3); c2 = _mm512_loadu_si512(colptr + 2 + cols); c3 = _mm512_loadu_si512(colptr + 1 + 2 * cols); c4 = _mm512_loadu_si512(colptr + 3 * cols); rtrn += check_forward_backward(c1, c2, c3, c4); } const std::uint64_t tailcount = static_cast<std::uint64_t>( std::min<std::ptrdiff_t>(64, cols - col)); const std::uint64_t tailmask = (~std::uint64_t{0}) >> (64 - tailcount); const char* colptr = rowptr + col; __m512i c1 = _mm512_maskz_loadu_epi8(tailmask, colptr); __m512i c2 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + cols + 1); __m512i c3 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2 * (cols + 1)); __m512i c4 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3 * (cols + 1)); rtrn += check_forward_backward(c1, c2, c3, c4); c1 = _mm512_maskz_loadu_epi8(tailmask >> 3, colptr + 3); c2 = _mm512_maskz_loadu_epi8(tailmask >> 2, colptr + 2 + cols); c3 = _mm512_maskz_loadu_epi8(tailmask >> 1, colptr + 1 + 2 * cols); c4 = _mm512_maskz_loadu_epi8(tailmask, colptr + 3 * cols); rtrn += check_forward_backward(c1, c2, c3, c4); } } return rtrn; } int main() { Grid grid = read_grid("xmas.txt"); std::cout << grid.rows << ' ' << grid.cols << ' ' << grid.buf.size() <<'\n'; using clock_t = std::chrono::steady_clock; auto t1 = clock_t::now(); int result = count_xmas_dense(grid); auto t2 = clock_t::now(); auto nanos = std::chrono::duration_cast<std::chrono::nanoseconds>(t2 - t1); double throughput = (grid.buf.size() / (1024. * 1024.)) / (nanos.count() * 1e-9); std::cout << result << ' ' << throughput << " MiB/s\n"; }
如果我将测试数据文件扩展 10,000 倍(1400140 行 x 140 列),我将实现 3.9 GiB/s 的吞吐量。

AVX2版本

将其转换为 AVX2 应该不会太难。为了进行比较,您只需使用

_mm256_cmpeq_epi8

,然后使用 
_mm256_and_si256
 来合并结果,然后使用 
_mm256_movemask_epi8
 来获取位掩码。

对于尾部,您需要展开一次 16 字节的迭代,然后可能是最后几次的标量循环。

© www.soinside.com 2019 - 2024. All rights reserved.