我在面试时被问到这个问题,这就是我想出的解决方案。有人告诉我这不是最有效的解决方案,但我想不出任何其他解决方案。
这就是问题所在:
给定一个字节数组和一个固定的 3 位模式 0b101,我们需要计算该 3 位模式在数组中出现的次数。该模式可以跨越字节边界,这意味着它可以使用连续字节中的位来形成。
int count_3bit_pattern_occurrences(const uint8_t* array, size_t size, uint8_t pattern = 0b101) {
// Check for input validation
if (size == 0 || pattern > 0b111) {
return 0;
}
int count = 0;
uint32_t value = 0; // To hold the accumulated bits
int bit_count = 0;
for (size_t i = 0; i < size; ++i) {
value = (value << 8) | array[i];
bit_count += 8;
// Check within the current byte
for (int bit_pos = bit_count - 8; bit_pos <= bit_count - 3; ++bit_pos) {
if (((value >> (bit_count - 3 - bit_pos)) & 0b111) == pattern) {
++count;
}
}
// Prepare for the next iteration, keeping only the last 2 bits
value &= 0b11;
bit_count = 2;
}
return count;
}
我的目标是通过以 O(1) 空间复杂度实现这一点来最大限度地减少空间使用。但是,我不确定是否有比迭代字节数组更有效的方法。是否可以进一步优化,或者我错过了什么?
正如评论者 @500 所建议的,查找表比逐位查看值要快得多。这是两者的示例。
count_3bit_pattern_occurrences()
是一个类似于 OP 的顺序函数。 count_3bit_pattern_occurrences_blocked()
是一次处理 6 位的块查找版本。这需要进行 256 个字节数组查找,并保留前 2 位用于下一次查找。该代码使用 constexpr 查找数组的编译时评估。块查找速度大约快 5 倍。查找数组大小为 256 字节,并且适合缓存。
#include <cstdint>
#include <vector>
#include <iostream>
#include <array>
#include <random>
#include <chrono>
class Timer {
std::chrono::system_clock::time_point snapTime;
double cumTime = 0;
public:
Timer() { reset(); }
void reset() { cumTime = 0; start(); }
void start() { snapTime = std::chrono::system_clock::now(); }
double stop() { cumTime += std::chrono::duration<double>(std::chrono::system_clock::now() - snapTime).count(); return cumTime; }
};
constexpr uint8_t pattern_3bit{ 0b101 }; // pattern is constexpr to enable compile time lookup table creation
constexpr int count_3bit_pattern_occurrences(const uint8_t* array, size_t size)
{
uint16_t v{ array[0] };
int next_array_index = 1;
size--;
int total_matched{};
for (int bit_count = 8; bit_count >= 3; bit_count--)
{
if (bit_count == 8 && size > 0)
{
v |= array[next_array_index++] << 8;
bit_count = 16;
size--;
}
if ((v & 7) == pattern_3bit)
total_matched++;
v >>= 1;
}
return total_matched;
}
int count_3bit_pattern_occurrences_blocked(const uint8_t* array, size_t size)
{
// lambda to initialize lookup array
auto init = []() {
std::array<uint8_t, 256> lookup;
for (int i = 0; i < 256; i++)
{
uint8_t x = i;
lookup[i] = count_3bit_pattern_occurrences(&x, 1);
}
return lookup;
};
constexpr std::array<uint8_t, 256> lookup = init();
uint16_t v{ array[0] };
int bit_count = 8;
int next_array_index = 1;
size--;
int total_matched{};
for (;;)
{
if (size == 0) // if we are at the end of the byte array, decode w/o lookup
{
for (; bit_count >= 3; bit_count--)
{
if ((v & 7) == pattern_3bit)
total_matched++;
v >>= 1;
}
return total_matched;
}
if (bit_count <= 8) // keep uint filled
{
v |= array[next_array_index++] << bit_count;
bit_count += 8;
size--;
}
total_matched += lookup[255 & v]; // while 8 bits used for lookup, msb 2 bits are preserved for next block decode
v >>= 6;
bit_count -= 6;
}
return total_matched;
}
int main()
{
constexpr int loops{ 1'000'000 };
Timer time[2]; // accumulate time for pattern counters
int cnt[2]{}; // accumulate total patterns detected
std::mt19937 gen(1); // mersenne_twister_engine
std::uniform_int_distribution<> distrib_n(1, 2000), distrib_256(0, 255);
for (int i = 0; i < loops; i++)
{
std::vector<uint8_t> v(distrib_n(gen)); // generate a vector between 1 and 2000 bytes
for (size_t ii = 0; ii < v.size(); ii++) // fill hte vector with random bytes
v[ii] = distrib_256(gen);
// track time and add patterns detected for each decoder
time[0].start(); cnt[0] += count_3bit_pattern_occurrences(v.data(), v.size()); time[0].stop();
time[1].start(); cnt[1] += count_3bit_pattern_occurrences_blocked(v.data(), v.size()); time[1].stop();
}
std::cout << "Sequential Execution Time:" << time[0].stop() << " pattern count:" << cnt[0] << '\n';
std::cout << "Block Execution Time:" << time[1].stop() << " pattern count:" << cnt[1] << '\n';
// On my machine, x64, opt. MSVC
// Sequential Execution Time : 9.4333 pattern count : 999171615
// Block Execution Time : 1.90501 pattern count : 999171615
}