From f17a55dc7fc44d1516db445550f55ed31e7534fa Mon Sep 17 00:00:00 2001 From: Mikko Rasa Date: Sun, 15 Jan 2023 23:32:10 +0200 Subject: [PATCH] Redesign Communicator to support multiple protocols This makes it easier for a library to provide a base protocol and allow applications to add their own protocols on top. It's also used for the Communicator's internal handshake protocol. --- source/net/communicator.cpp | 227 +++++++++++++++++++++++++----------- source/net/communicator.h | 42 +++++-- 2 files changed, 193 insertions(+), 76 deletions(-) diff --git a/source/net/communicator.cpp b/source/net/communicator.cpp index 51202b5..73b10c2 100644 --- a/source/net/communicator.cpp +++ b/source/net/communicator.cpp @@ -8,7 +8,15 @@ namespace { using namespace Msp::Net; -struct Handshake +// Sent when a protocol is added locally but isn't known to the other end yet +struct PrepareProtocol +{ + uint64_t hash; + uint16_t base; +}; + +// Sent to confirm that a protocol is known to both peers +struct AcceptProtocol { uint64_t hash; }; @@ -20,67 +28,107 @@ public: HandshakeProtocol(); }; -HandshakeProtocol::HandshakeProtocol(): - Protocol(0x7F00) +HandshakeProtocol::HandshakeProtocol() { - add(&Handshake::hash); + add(&PrepareProtocol::hash, &PrepareProtocol::base); + add(&AcceptProtocol::hash); } +} -class HandshakeReceiver: public PacketReceiver -{ -private: - uint64_t hash = 0; -public: - uint64_t get_hash() const { return hash; } - void receive(const Handshake &) override; -}; +namespace Msp { +namespace Net { -void HandshakeReceiver::receive(const Handshake &shake) +struct Communicator::Handshake: public PacketReceiver, + public PacketReceiver { - hash = shake.hash; -} + Communicator &communicator; + HandshakeProtocol protocol; -} + Handshake(Communicator &c): communicator(c) { } + void receive(const PrepareProtocol &) override; + void receive(const AcceptProtocol &) override; +}; -namespace Msp { -namespace Net { -Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r): +Communicator::Communicator(StreamSocket &s): socket(s), - protocol(p), - receiver(r), + handshake(new Handshake(*this)), in_buf(new char[buf_size]), in_begin(in_buf), in_end(in_buf), out_buf(new char[buf_size]) { socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available)); + + protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake)); + if(socket.is_connected()) + prepare_protocol(protocols.back()); + else + socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished)); +} + +Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r): + Communicator(s) +{ + add_protocol(p, r); } Communicator::~Communicator() { delete[] in_buf; delete[] out_buf; + delete handshake; } -void Communicator::initiate_handshake() +void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv) { - if(handshake_status!=0) - throw sequence_error("handshaking already done"); + if(!good) + throw sequence_error("connection aborted"); + + unsigned max_id = proto.get_max_packet_id(); + if(!max_id) + throw invalid_argument("Communicator::add_protocol"); - send_handshake(); - handshake_status = 1; + uint64_t hash = proto.get_hash(); + auto i = find_member(protocols, hash, &ActiveProtocol::hash); + if(i==protocols.end()) + { + const ActiveProtocol &last = protocols.back(); + if(!last.protocol) + throw sequence_error("previous protocol is incomplete"); + unsigned base = last.base; + base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF; + + if(base+max_id>std::numeric_limits::max()) + throw invalid_state("Communicator::add_protocol"); + + protocols.emplace_back(base, proto, recv); + + if(socket.is_connected() && protocols.front().ready) + prepare_protocol(protocols.back()); + } + else if(!i->protocol) + { + i->protocol = &proto; + i->last = i->base+max_id; + i->receiver = &recv; + accept_protocol(*i); + } +} + +bool Communicator::is_protocol_ready(const Protocol &proto) const +{ + auto i = find_member(protocols, &proto, &ActiveProtocol::protocol); + return (i!=protocols.end() && i->ready); } void Communicator::send_data(size_t size) { if(!good) throw sequence_error("connection aborted"); - if(handshake_status!=2) - throw sequence_error("handshake incomplete"); try { @@ -95,6 +143,14 @@ void Communicator::send_data(size_t size) } } +void Communicator::connect_finished(const exception *exc) +{ + if(exc) + good = false; + else + prepare_protocol(protocols.front()); +} + void Communicator::data_available() { if(!good) @@ -103,31 +159,7 @@ void Communicator::data_available() try { in_end += socket.read(in_end, in_buf+buf_size-in_end); - - bool more = true; - while(more) - { - if(handshake_status==2) - more = receive_packet(protocol, receiver); - else - { - HandshakeProtocol hsproto; - HandshakeReceiver hsrecv; - if((more = receive_packet(hsproto, hsrecv))) - { - if(handshake_status==0) - send_handshake(); - - if(hsrecv.get_hash()==protocol.get_hash()) - { - handshake_status = 2; - signal_handshake_done.emit(); - } - else - throw incompatible_protocol("hash mismatch"); - } - } - } + while(receive_packet()) ; } catch(const exception &e) { @@ -138,37 +170,98 @@ void Communicator::data_available() } } -bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv) +bool Communicator::receive_packet() { - int psz = proto.get_packet_size(in_begin, in_end-in_begin); - if(psz && psz<=in_end-in_begin) + Protocol::PacketHeader header; + size_t available = in_end-in_begin; + if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available) { + auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last); + if(i==protocols.end() || header.typebase || header.type>i->last) + throw key_error(header.type); + char *pkt = in_begin; - in_begin += psz; - proto.dispatch(recv, pkt, psz); + in_begin += header.length; + i->protocol->dispatch(*i->receiver, pkt, header.length, i->base); return true; } else { if(in_end==in_buf+buf_size) { - size_t used = in_end-in_begin; - memmove(in_buf, in_begin, used); + memmove(in_buf, in_begin, available); in_begin = in_buf; - in_end = in_begin+used; + in_end = in_begin+available; } return false; } } -void Communicator::send_handshake() +void Communicator::prepare_protocol(const ActiveProtocol &proto) +{ + PrepareProtocol prepare; + prepare.hash = proto.hash; + prepare.base = proto.base; + /* Use send_data() directly because this function is called to prepare the + handshake protocol too and send() would fail readiness check. */ + send_data(handshake->protocol.serialize(prepare, out_buf, buf_size)); +} + +void Communicator::accept_protocol(ActiveProtocol &proto) +{ + proto.accepted = true; + + AcceptProtocol accept; + accept.hash = proto.hash; + send_data(handshake->protocol.serialize(accept, out_buf, buf_size)); +} + + +Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r): + hash(p.get_hash()), + base(b), + last(base+p.get_max_packet_id()), + protocol(&p), + receiver(&r) +{ } + +Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h): + hash(h), + base(b), + last(base) +{ } + + +void Communicator::Handshake::receive(const PrepareProtocol &prepare) +{ + auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base); + if(i!=communicator.protocols.end() && i->base==prepare.base) + communicator.accept_protocol(*i); + else + communicator.protocols.emplace(i, prepare.base, prepare.hash); +} + +void Communicator::Handshake::receive(const AcceptProtocol &accept) { - Handshake shake; - shake.hash = protocol.get_hash(); + auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash); + if(i==communicator.protocols.end()) + throw key_error(accept.hash); + + + if(i->ready) + return; - HandshakeProtocol hsproto; - size_t size = hsproto.serialize(shake, out_buf, buf_size); - socket.write(out_buf, size); + i->ready = true; + if(!i->accepted) + communicator.accept_protocol(*i); + if(i->protocol==&protocol) + { + for(const ActiveProtocol &p: communicator.protocols) + if(!p.ready) + communicator.prepare_protocol(p); + } + else + communicator.signal_protocol_ready.emit(*i->protocol); } } // namespace Net diff --git a/source/net/communicator.h b/source/net/communicator.h index e9f646a..eb9893d 100644 --- a/source/net/communicator.h +++ b/source/net/communicator.h @@ -28,14 +28,29 @@ public: class MSPNET_API Communicator: public NonCopyable { public: - sigc::signal signal_handshake_done; + sigc::signal signal_protocol_ready; sigc::signal signal_error; private: + struct ActiveProtocol + { + std::uint64_t hash = 0; + std::uint16_t base = 0; + std::uint16_t last = 0; + bool accepted = false; + bool ready = false; + const Protocol *protocol = nullptr; + ReceiverBase *receiver = nullptr; + + ActiveProtocol(std::uint16_t, const Protocol &, ReceiverBase &); + ActiveProtocol(std::uint16_t, std::uint64_t); + }; + + struct Handshake; + StreamSocket &socket; - const Protocol &protocol; - ReceiverBase &receiver; - int handshake_status = 0; + std::vector protocols; + Handshake *handshake = nullptr; std::size_t buf_size = 65536; char *in_buf = nullptr; char *in_begin = nullptr; @@ -44,11 +59,12 @@ private: bool good = true; public: + Communicator(StreamSocket &); Communicator(StreamSocket &, const Protocol &, ReceiverBase &); ~Communicator(); - void initiate_handshake(); - bool is_handshake_done() const { return handshake_status==2; } + void add_protocol(const Protocol &, ReceiverBase &); + bool is_protocol_ready(const Protocol &) const; template void send(const P &); @@ -56,15 +72,23 @@ public: private: void send_data(std::size_t); + void connect_finished(const std::exception *); void data_available(); - bool receive_packet(const Protocol &, ReceiverBase &); - void send_handshake(); + bool receive_packet(); + + void prepare_protocol(const ActiveProtocol &); + void accept_protocol(ActiveProtocol &); }; template void Communicator::send(const P &pkt) { - send_data(protocol.serialize(pkt, out_buf, buf_size)); + auto i = find_if(protocols, [](const ActiveProtocol &p){ return p.protocol && p.protocol->has_packet

(); }); + if(i==protocols.end()) + throw key_error(typeid(P).name()); + else if(!i->ready) + throw sequence_error("protocol not ready"); + send_data(i->protocol->serialize(pkt, out_buf, buf_size, i->base)); } } // namespace Net -- 2.45.2