计算字节数组中特定 3 位模式的出现次数

问题描述 投票:0回答:1

我在面试时被问到这个问题,这就是我想出的解决方案。有人告诉我这不是最有效的解决方案,但我想不出任何其他解决方案。

这就是问题所在:

给定一个字节数组和一个固定的 3 位模式 0b101,我们需要计算该 3 位模式在数组中出现的次数。该模式可以跨越字节边界,这意味着它可以使用连续字节中的位来形成。

int count_3bit_pattern_occurrences(const uint8_t* array, size_t size, uint8_t pattern = 0b101) {
    // Check for input validation
    if (size == 0 || pattern > 0b111) {
        return 0; 
    }

    int count = 0;
    uint32_t value = 0;  // To hold the accumulated bits
    int bit_count = 0;

    for (size_t i = 0; i < size; ++i) {
        value = (value << 8) | array[i];
        bit_count += 8;

        // Check within the current byte
        for (int bit_pos = bit_count - 8; bit_pos <= bit_count - 3; ++bit_pos) {
            if (((value >> (bit_count - 3 - bit_pos)) & 0b111) == pattern) {
                ++count;
            }
        }

        // Prepare for the next iteration, keeping only the last 2 bits
        value &= 0b11;
        bit_count = 2;
    }

    return count;
}

我的目标是通过以 O(1) 空间复杂度实现这一点来最大限度地减少空间使用。但是,我不确定是否有比迭代字节数组更有效的方法。是否可以进一步优化,或者我错过了什么?

c++ bit-manipulation
1个回答
0
投票

正如评论者 @500 所建议的,查找表比逐位查看值要快得多。这是两者的示例。

count_3bit_pattern_occurrences()
是一个类似于 OP 的顺序函数。
count_3bit_pattern_occurrences_blocked()
是一次处理 6 位的块查找版本。这需要进行 256 个字节数组查找,并保留前 2 位用于下一次查找。该代码使用 constexpr 查找数组的编译时评估。块查找速度大约快 5 倍。查找数组大小为 256 字节,并且适合缓存。

#include <cstdint>
#include <vector>
#include <iostream>
#include <array>
#include <random>
#include <chrono>

class Timer {
    std::chrono::system_clock::time_point snapTime;
    double cumTime = 0;
public:
    Timer() { reset(); }
    void reset() { cumTime = 0; start(); }
    void start() { snapTime = std::chrono::system_clock::now(); }
    double stop() { cumTime += std::chrono::duration<double>(std::chrono::system_clock::now() - snapTime).count(); return cumTime; }
};

constexpr uint8_t pattern_3bit{ 0b101 }; // pattern is constexpr to enable compile time lookup table creation

constexpr int count_3bit_pattern_occurrences(const uint8_t* array, size_t size)
{
    uint16_t v{ array[0] };
    int next_array_index = 1;
    size--;
    int total_matched{};

    for (int bit_count = 8; bit_count >= 3; bit_count--)
    {
        if (bit_count == 8 && size > 0)
        {
            v |= array[next_array_index++] << 8;
            bit_count = 16;
            size--;
        }
        if ((v & 7) == pattern_3bit)
            total_matched++;
        v >>= 1;
    }
    return total_matched;
}

int count_3bit_pattern_occurrences_blocked(const uint8_t* array, size_t size)
{
    // lambda to initialize lookup array
    auto init = []() {
        std::array<uint8_t, 256> lookup;
        for (int i = 0; i < 256; i++)
        {
            uint8_t x = i;
            lookup[i] = count_3bit_pattern_occurrences(&x, 1);
        }
        return lookup;
        };

    constexpr std::array<uint8_t, 256> lookup = init();
    uint16_t v{ array[0] };
    int bit_count = 8;
    int next_array_index = 1;
    size--;
    int total_matched{};
    for (;;)
    {
        if (size == 0)  // if we are at the end of the byte array, decode w/o lookup
        {
            for (; bit_count >= 3; bit_count--)
            {
                if ((v & 7) == pattern_3bit)
                    total_matched++;
                v >>= 1;
            }
            return total_matched;
        }
        if (bit_count <= 8) // keep uint filled
        {
            v |= array[next_array_index++] << bit_count;
            bit_count += 8;
            size--;
        }
        total_matched += lookup[255 & v];   // while 8 bits used for lookup, msb 2 bits are preserved for next block decode
        v >>= 6;
        bit_count -= 6;
    }
    return total_matched;
}


int main()
{
    constexpr int loops{ 1'000'000 };
    Timer time[2];      // accumulate time for pattern counters
    int cnt[2]{};       // accumulate total patterns detected
    std::mt19937 gen(1); // mersenne_twister_engine
    std::uniform_int_distribution<> distrib_n(1, 2000), distrib_256(0, 255);
    for (int i = 0; i < loops; i++)
    {
        std::vector<uint8_t> v(distrib_n(gen));     // generate a vector between 1 and 2000 bytes
        for (size_t ii = 0; ii < v.size(); ii++)    // fill hte vector with random bytes
            v[ii] = distrib_256(gen);

        // track time and add patterns detected for each decoder
        time[0].start(); cnt[0] += count_3bit_pattern_occurrences(v.data(), v.size()); time[0].stop();
        time[1].start(); cnt[1] += count_3bit_pattern_occurrences_blocked(v.data(), v.size()); time[1].stop();
    }
    std::cout << "Sequential Execution Time:" << time[0].stop() << " pattern count:" << cnt[0] << '\n';
    std::cout << "Block Execution Time:" << time[1].stop() << " pattern count:" << cnt[1] << '\n';
    // On my machine, x64, opt. MSVC
    //  Sequential Execution Time : 9.4333 pattern count : 999171615
    //  Block Execution Time : 1.90501 pattern count : 999171615
}
© www.soinside.com 2019 - 2024. All rights reserved.