使用 Boost.Beast 连续循环处理多个 WebSocket 写入请求

问题描述 投票:0回答:1
template<typename Derived>
class Websocket: public std::enable_shared_from_this<Websocket<Derived>>
{
public:
    using InternalWSType = boost::beast::websocket::stream<
        boost::asio::ssl::stream<boost::asio::ip::tcp::socket>>;

    using SSLHandshakeConfigurator = std::function<void(Websocket *, InternalWSType &)>;

    explicit Websocket(boost::asio::io_context &io_context)
        : m_io_context { io_context }
        , m_resolver { m_io_context }
        , m_ws { m_io_context, m_ssl_context }
    {
        m_response.reserve(100'000'000);
    }

    Websocket(
        boost::asio::io_context &io_context,
        SSLHandshakeConfigurator ssl_handshake_configurator
    )
        : m_io_context { io_context }
        , m_resolver { m_io_context }
        , m_ws { m_io_context, m_ssl_context }
        , m_ssl_handshake_configurator(std::move(ssl_handshake_configurator))
    {
        m_ws.read_message_max(0);
        m_ws.auto_fragment(false);

        m_response.reserve(100'000'000);
    }

    virtual ~Websocket() = default;

    Websocket(const Websocket &) = delete;

    Websocket(Websocket &&) noexcept = delete;

    auto operator=(const Websocket &) -> Websocket & = delete;

    auto operator=(Websocket &&) noexcept -> Websocket & = delete;

    void async_start(std::string_view host, std::string_view port, std::string_view target)
    {
        m_host = std::string { host };
        m_target = std::string { target };

        // Look up the domain name
        m_resolver.async_resolve(
            m_host,
            std::string { port },
            [this](
                boost::system::error_code error_code,
                boost::asio::ip::tcp::resolver::results_type res
            ) {
                if (error_code)
                {
                    // handle
                    SPDLOG_ERROR("Failed to async_start");
                    SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                    SPDLOG_ERROR("    error_code.what: {}", error_code.what());
                }
                else
                {
                    async_connect(std::move(res));
                }
            }
        );
    }

    void async_stop()
    {
        m_stop_requested = true;
        auto holder = this->shared_from_this();

        if (m_ws.next_layer().next_layer().is_open())
        {
            m_ws.async_close(
                boost::beast::websocket::close_code::normal,
                [holder = std::move(holder)](const boost::system::error_code &) {}
            );
        }
    }

    void send_message(std::string_view message)
    {
        m_ws.async_write(
            boost::beast::net::buffer(std::string { message }),
            [self = this->shared_from_this()](boost::system::error_code error_code, size_t) {
                if (error_code)
                {
                    SPDLOG_ERROR("Failed to send message");
                    SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                    SPDLOG_ERROR("    error_code.what: {}", error_code.what());
                }
            }
        );
    }

    void poll()
    {
        m_io_context.poll();
    }

    void run()
    {
        m_io_context.run();
    }


protected:
    [[nodiscard]] auto get_io_context() const -> boost::asio::io_context &
    {
        return m_io_context;
    }

    virtual void send_subscribe_message() {};

private:
    void async_connect(boost::asio::ip::tcp::resolver::results_type result)
    {
        if (!SSL_set_tlsext_host_name(m_ws.next_layer().native_handle(), m_host.c_str()))
        {
            SPDLOG_ERROR("Boost::Beast error: async connect");
            return;
        }

        boost::asio::async_connect(
            m_ws.next_layer().next_layer(),
            result.begin(),
            result.end(),
            [this](boost::system::error_code error_code, boost::asio::ip::tcp::resolver::iterator) {
                if (error_code)
                {
                    SPDLOG_ERROR("Failed to async connect");
                    SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                    SPDLOG_ERROR("    error_code.what: {}", error_code.what());

                    if (!m_stop_requested)
                    {
                        // handle the beast error
                    }
                }
                else
                {
                    on_connected();
                }
            }
        );
    }

    void on_connected()
    {
        m_ws.next_layer().async_handshake(
            boost::asio::ssl::stream_base::client,
            [this](boost::system::error_code error_code) {
                if (error_code)
                {
                    if (!m_stop_requested)
                    {
                    }
                }
                else
                {
                    on_async_ssl_handshake();
                }
            }
        );
    }

    void on_async_ssl_handshake()
    {
        if (m_ssl_handshake_configurator)
        {
            m_ssl_handshake_configurator(this, m_ws);
        }

        m_ws.async_handshake(m_host, m_target, [this](boost::system::error_code error_code) {
            if (!error_code)
                send_subscribe_message();
            else
            {
                SPDLOG_ERROR("Failed to async ssl handshake");
                SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                SPDLOG_ERROR("    error_code.what: {}", error_code.what());
            }


            start_read(error_code);
        });
    }

    void start_read(boost::system::error_code error_code)
    {
        if (error_code)
        {
            if (!m_stop_requested)
            {
                SPDLOG_ERROR("Failed to start_read");
                SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                SPDLOG_ERROR("    error_code.what: {}", error_code.what());
            }

            async_stop();

            return;
        }

        m_buffer.prepare(512'000'000);
        m_ws.async_read(m_buffer, [this](boost::system::error_code error_code, size_t size) {
            on_read(error_code, size);
        });
    }

    void on_read(boost::system::error_code error_code, [[maybe_unused]] size_t size)
    {
        if (error_code)
        {
            if (!m_stop_requested)
            {
                SPDLOG_ERROR("Failed to on_read");
                SPDLOG_ERROR("    error_code.message: {}", error_code.message());
                SPDLOG_ERROR("    error_code.what: {}", error_code.what());
                SPDLOG_ERROR("    reason: {}", m_ws.reason().reason.data());
                SPDLOG_ERROR("    reason_code: {}", m_ws.reason().code);
            }

            return;
        }

        m_response.clear();

        for (const auto &bytes : m_buffer.data())
        {
            m_response.append(static_cast<const char *>(bytes.data()), bytes.size());
        }
        m_buffer.consume(m_buffer.size());

        static_cast<Derived *>(this)->handle_message(m_response);

        // restart
        start_read(boost::system::error_code {});
    }

    boost::asio::io_context &m_io_context;
    boost::asio::ssl::context m_ssl_context { boost::asio::ssl::context::sslv23_client };
    boost::asio::ip::tcp::resolver m_resolver;
    InternalWSType m_ws;
    boost::beast::multi_buffer m_buffer;
    std::string m_response;
    std::string m_host;
    std::string m_target;
    bool m_stop_requested {};

    SSLHandshakeConfigurator m_ssl_handshake_configurator;
};

class EntryPoint final: public Websocket<EntryPoint>
{
public:
    EntryPoint(boost::asio::io_context &io_context): Websocket<EntryPoint>(io_context)
    {
    }


    void handle_message(std::string_view response)
    {
        // handle the response
    }

private:
    void subscribe()
    {
        async_start("url", "443", "/ws/private/orders/control");
    }
};

auto main() -> int32_t {
    std::cout << "entrypoint" << std::endl;
    boost::asio::io_context ioctx;
    auto entry_point_ws = std::make_shared<EntryPoint>(ioctx);
    entry_point_ws->async_start();
    
    while (true) {
        auto pseudo_request = "pseudo_request_popped_from_queue_pushed_by_another_thread";
        entry_point_ws->send_message("pseudo-request");
        ioctx.poll();
    }
}

我的理解是,对于 WebSocket 连接,同时不能有多个活动的

async_write
调用。但是,在我的主循环中,我不断从队列中弹出请求并调用
send_message
,然后调用 poll。

我的问题是:

当一次只能有一个活动的

async_write
时,我应该如何构建我的程序来处理多个写入请求?我想确保发送队列中的所有消息,而不会由于重叠的
async_write
调用而丢失任何消息。

任何有关如何对这些写入操作进行排队或重组程序以处理这种情况的建议将不胜感激。

我可能的解决方案:

在单独的线程中运行 io_context::run ,并从主线程将请求推送到无锁队列,然后将其排队到 async_write 并从另一个队列返回响应,该队列将是 spsc 无锁队列,所有响应都将在其中.

代码:https://gist.github.com/Naseefabu/966d4469980977450f2746db73a43065

c++ networking boost-asio boost-beast
1个回答
0
投票

你的问题要大得多。您的整个写入都会调用 UB,因为根据定义,您将过时的缓冲区传递给

async_write
:

asio::buffer(std::string{message})

这是一个临时字符串对象,根据定义,它会在操作完成之前被破坏。

您通常所做的是将消息移动到 Websocket 本地的队列中:

void send_message(std::string_view message) {
    m_outgoing_messages.emplace_back(message);
    if (m_outgoing_messages.size() == 1) {
        do_write_loop();
    }
}

然后

私人: std::dequestd::string m_outgoing_messages;

void do_write_loop() {
    if (m_outgoing_messages.empty()) {
        return;
    }

    m_ws.async_write(asio::buffer(m_outgoing_messages.front()),
                     [this, self = this->shared_from_this()](error_code ec, size_t) {
                         if (ec) {
                         SPDLOG_ERROR("Failed to send message(s)");
                             SPDLOG_ERROR("    error_code.message: {}", ec.message());
                             SPDLOG_ERROR("    error_code.what: {}", ec.what());
                         } else {
                             m_outgoing_messages.pop_front();
                             do_write_loop();
                         }
                     });
}

还有很多不相关的事情需要改进。也许您想查看频道

至少考虑使用执行器消除与

io_context&
的耦合。

以下是一些可能帮助您实现目标的简化/概括:

住在Coliru

#include <boost/asio.hpp>
#include <boost/beast.hpp>
#include <boost/beast/ssl.hpp>
#include <deque>
#include <fmt/format.h>
#include <iostream>
namespace spdlog {
    class logger {
      public:
        template <typename... Args>
        static constexpr void error(std::string_view level, std::string_view format, Args const&... args) {
            std::cerr << level << "\t" << fmt::format(fmt::runtime(format), args...) << std::endl;
        }
    };
    constexpr logger instance{};
} // namespace spdlog

#define SPDLOG_ERROR(...) spdlog::instance.error("ERR", __VA_ARGS__)
#define SPDLOG_WARN(...) spdlog::instance.error("WRN", __VA_ARGS__)
#define SPDLOG_INFO(...) spdlog::instance.error("INF", __VA_ARGS__)
#define TRACE SPDLOG_INFO("{}:{}", __PRETTY_FUNCTION__, __LINE__)

namespace beast     = boost::beast;
namespace asio      = boost::asio;
namespace ssl       = boost::asio::ssl;
namespace websocket = boost::beast::websocket;
using tcp           = boost::asio::ip::tcp;
using beast::error_code;

template <typename Derived> class Websocket : public std::enable_shared_from_this<Websocket<Derived>> {
  public:
    using Stream          = websocket::stream<ssl::stream<tcp::socket>>;
    using SSLConfigurator = std::function<void(Websocket*, Stream&)>;

    explicit Websocket(asio::any_io_executor ex, SSLConfigurator ssl_handshake_configurator = {})
        : m_resolver{ex}
        , m_ws{ex, m_ssl_context}
        , m_ssl_handshake_configurator(std::move(ssl_handshake_configurator)) {
        TRACE;
        m_ws.read_message_max(0);
        m_ws.auto_fragment(false);

        m_response.reserve(100'000'000);
    }

    virtual ~Websocket() {
        TRACE;
    }

    void start(std::string_view host, std::string_view port, std::string_view target) {
        TRACE;
        asio::post( //
            m_ws.get_executor(),
            [this, self = shared_from_this(), host = std::string{host}, port = std::string{port},
             target = std::string{target}]() mutable { //
                do_start(host, port, target);
            });
    }

    void stop() {
        TRACE;
        asio::post(m_ws.get_executor(), [this, self = shared_from_this()] { do_stop(); });
    }

    void send_message(std::string_view message) {
        return;
        TRACE;
        asio::post(m_ws.get_executor(),
                   [this, self = shared_from_this(), m = std::string{message}]() mutable {
                       do_send_message(std::move(m));
                   });
    }

  protected:
    virtual void send_subscribe_message() {
        TRACE; // send the subscribe message
    };

  private:
    using std::enable_shared_from_this<Websocket<Derived>>::shared_from_this;

    void do_start(std::string_view host, std::string_view port, std::string_view target) { // on the strand
        TRACE;
        m_host   = std::string{host};
        m_target = std::string{target};
        SPDLOG_INFO("Connecting to {}:{}", m_host, port);

        // Look up the domain name
        m_resolver.async_resolve( //
            m_host, port, [this, self = shared_from_this()](error_code ec, tcp::resolver::results_type res) {
                if (!check_fail(ec, "async_start")) {
                    async_connect(std::move(res));
                }
            });
    }

    void do_stop() {
        TRACE;
        if (m_stop_requested.exchange(true))
            return;

        if (m_ws.next_layer().next_layer().is_open()) {
            m_ws.async_close(websocket::close_code::normal, [self = shared_from_this()](error_code) {});
        }
    }

    std::deque<std::string> m_outgoing_messages;

    void do_send_message(std::string_view message) { // on the strand
        TRACE;
        m_outgoing_messages.emplace_back(message);
        if (m_outgoing_messages.size() == 1) {
            do_write_loop();
        }
    }

    void do_write_loop() { // on the strand
        TRACE;
        if (m_outgoing_messages.empty()) {
            return;
        }

        m_ws.async_write( //
            asio::buffer(m_outgoing_messages.front()),
            [this, self = shared_from_this()](error_code ec, size_t) {
                if (!check_fail(ec, "do_write_loop")) {
                    m_outgoing_messages.pop_front();
                    do_write_loop();
                }
            });
    }

    void async_connect(tcp::resolver::results_type result) {
        TRACE;
        if (!SSL_set_tlsext_host_name(m_ws.next_layer().native_handle(), m_host.c_str())) {
            SPDLOG_ERROR("Boost::Beast error: async connect");
            return;
        }

        asio::async_connect( //
            m_ws.next_layer().next_layer(), result,
            [this, self = shared_from_this()](error_code ec, tcp::endpoint) {
                if (check_fail(ec, "async_connect")) {
                    if (!m_stop_requested) {
                        // handle the beast error
                    }
                } else {
                    on_connected();
                }
            });
    }

    void on_connected() {
        TRACE;
        m_ws.next_layer().async_handshake( //
            ssl::stream_base::client, [this, self = shared_from_this()](error_code ec) {
                if (!check_fail(ec, "ssl handshake")) {
                    on_ssl_handshake();
                }
            });
    }

    void on_ssl_handshake() {
        TRACE;
        if (m_ssl_handshake_configurator) {
            m_ssl_handshake_configurator(this, m_ws);
        }

        m_ws.async_handshake(m_host, m_target, [this, self = shared_from_this()](error_code ec) {
            if (!check_fail(ec, "ws handshake")) {
                TRACE;
                send_subscribe_message();
                start_read(ec);
            }
        });
    }

    void start_read(error_code ec) {
        TRACE;
        if (check_fail(ec, "start_read")) {
            return stop();
        }

        m_buffer.prepare(512'000'000);
        m_ws.async_read(                                                    //
            m_buffer,                                                       //
            [this, self = shared_from_this()](error_code ec, size_t size) { //
                on_read(ec, size);
            });
    }

    void on_read(error_code ec, [[maybe_unused]] size_t size) {
        TRACE;
        if (check_fail(ec, "on_read")) {
            SPDLOG_ERROR("    reason: {}", m_ws.reason().reason.data());
            SPDLOG_ERROR("    reason_code: {}", m_ws.reason().code);
            return stop();
        }

        m_response.clear();

        for (auto const& bytes : m_buffer.data()) {
            m_response.append(static_cast<char const*>(bytes.data()), bytes.size());
        }
        m_buffer.consume(m_buffer.size());

        static_cast<Derived*>(this)->handle_message(m_response);

        // restart
        start_read(error_code{});
    }

    bool check_fail(error_code ec, std::string_view task) const {
        if (m_stop_requested)
            return true;

        if (ec) {
            SPDLOG_ERROR("Failed to {}", task);
            SPDLOG_ERROR("    error_code.message: {}", ec.message());
            SPDLOG_ERROR("    error_code.what: {}", ec.what());
        }
        return ec.failed();
    }

    ssl::context        m_ssl_context{ssl::context::sslv23_client};
    tcp::resolver       m_resolver;
    Stream              m_ws;
    beast::multi_buffer m_buffer;
    std::string         m_response;
    std::string         m_host, m_target;
    std::atomic_bool    m_stop_requested{false};

    SSLConfigurator m_ssl_handshake_configurator;
};

class EntryPoint final : public Websocket<EntryPoint> {
  public:
    EntryPoint(asio::any_io_executor ex) : Websocket<EntryPoint>(ex) {}

    void handle_message(std::string_view /*response*/) {
        TRACE;
        // handle the response
    }

    void subscribe() {
        TRACE;
        start("localhost", "1443", "/ws/private/orders/control");
    }
};

static_assert(not std::is_copy_constructible_v<EntryPoint>);
static_assert(not std::is_move_assignable_v<EntryPoint>);
static_assert(not std::is_copy_assignable_v<EntryPoint>);

using namespace std::chrono_literals;
using std::this_thread::sleep_for;

int main() {
    std::cout << "entrypoint" << std::endl;
    asio::thread_pool ioc(1);

    auto ws = std::make_shared<EntryPoint>(make_strand(ioc));
    ws->subscribe();

    for (sleep_for(1s);; sleep_for(100ms)) {
        ws->send_message("pseudo_request_popped_from_queue_pushed_by_another_thread");
    }

    ioc.join();
}

我无法全面测试,因为我无法足够快地设置wss服务器。

© www.soinside.com 2019 - 2024. All rights reserved.