]> git.tdb.fi Git - libs/net.git/commitdiff
Redesign Communicator to support multiple protocols
authorMikko Rasa <tdb@tdb.fi>
Sun, 15 Jan 2023 21:32:10 +0000 (23:32 +0200)
committerMikko Rasa <tdb@tdb.fi>
Sun, 15 Jan 2023 21:32:10 +0000 (23:32 +0200)
This makes it easier for a library to provide a base protocol and allow
applications to add their own protocols on top.  It's also used for the
Communicator's internal handshake protocol.

source/net/communicator.cpp
source/net/communicator.h

index 51202b5f659c07b733e5e93dbb3750142b4b884e..73b10c2881b71c7c5c51302b3d97c521b0c4a2ad 100644 (file)
@@ -8,7 +8,15 @@ namespace {
 
 using namespace Msp::Net;
 
-struct Handshake
+// Sent when a protocol is added locally but isn't known to the other end yet
+struct PrepareProtocol
+{
+       uint64_t hash;
+       uint16_t base;
+};
+
+// Sent to confirm that a protocol is known to both peers
+struct AcceptProtocol
 {
        uint64_t hash;
 };
@@ -20,67 +28,107 @@ public:
        HandshakeProtocol();
 };
 
-HandshakeProtocol::HandshakeProtocol():
-       Protocol(0x7F00)
+HandshakeProtocol::HandshakeProtocol()
 {
-       add<Handshake>(&Handshake::hash);
+       add<PrepareProtocol>(&PrepareProtocol::hash, &PrepareProtocol::base);
+       add<AcceptProtocol>(&AcceptProtocol::hash);
 }
 
+}
 
-class HandshakeReceiver: public PacketReceiver<Handshake>
-{
-private:
-       uint64_t hash = 0;
 
-public:
-       uint64_t get_hash() const { return hash; }
-       void receive(const Handshake &) override;
-};
+namespace Msp {
+namespace Net {
 
-void HandshakeReceiver::receive(const Handshake &shake)
+struct Communicator::Handshake: public PacketReceiver<PrepareProtocol>,
+       public PacketReceiver<AcceptProtocol>
 {
-       hash = shake.hash;
-}
+       Communicator &communicator;
+       HandshakeProtocol protocol;
 
-}
+       Handshake(Communicator &c): communicator(c) { }
 
+       void receive(const PrepareProtocol &) override;
+       void receive(const AcceptProtocol &) override;
+};
 
-namespace Msp {
-namespace Net {
 
-Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+Communicator::Communicator(StreamSocket &s):
        socket(s),
-       protocol(p),
-       receiver(r),
+       handshake(new Handshake(*this)),
        in_buf(new char[buf_size]),
        in_begin(in_buf),
        in_end(in_buf),
        out_buf(new char[buf_size])
 {
        socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
+
+       protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake));
+       if(socket.is_connected())
+               prepare_protocol(protocols.back());
+       else
+               socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished));
+}
+
+Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+       Communicator(s)
+{
+       add_protocol(p, r);
 }
 
 Communicator::~Communicator()
 {
        delete[] in_buf;
        delete[] out_buf;
+       delete handshake;
 }
 
-void Communicator::initiate_handshake()
+void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv)
 {
-       if(handshake_status!=0)
-               throw sequence_error("handshaking already done");
+       if(!good)
+               throw sequence_error("connection aborted");
+
+       unsigned max_id = proto.get_max_packet_id();
+       if(!max_id)
+               throw invalid_argument("Communicator::add_protocol");
 
-       send_handshake();
-       handshake_status = 1;
+       uint64_t hash = proto.get_hash();
+       auto i = find_member(protocols, hash, &ActiveProtocol::hash);
+       if(i==protocols.end())
+       {
+               const ActiveProtocol &last = protocols.back();
+               if(!last.protocol)
+                       throw sequence_error("previous protocol is incomplete");
+               unsigned base = last.base;
+               base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF;
+
+               if(base+max_id>std::numeric_limits<std::uint16_t>::max())
+                       throw invalid_state("Communicator::add_protocol");
+
+               protocols.emplace_back(base, proto, recv);
+
+               if(socket.is_connected() && protocols.front().ready)
+                       prepare_protocol(protocols.back());
+       }
+       else if(!i->protocol)
+       {
+               i->protocol = &proto;
+               i->last = i->base+max_id;
+               i->receiver = &recv;
+               accept_protocol(*i);
+       }
+}
+
+bool Communicator::is_protocol_ready(const Protocol &proto) const
+{
+       auto i = find_member(protocols, &proto, &ActiveProtocol::protocol);
+       return (i!=protocols.end() && i->ready);
 }
 
 void Communicator::send_data(size_t size)
 {
        if(!good)
                throw sequence_error("connection aborted");
-       if(handshake_status!=2)
-               throw sequence_error("handshake incomplete");
 
        try
        {
@@ -95,6 +143,14 @@ void Communicator::send_data(size_t size)
        }
 }
 
+void Communicator::connect_finished(const exception *exc)
+{
+       if(exc)
+               good = false;
+       else
+               prepare_protocol(protocols.front());
+}
+
 void Communicator::data_available()
 {
        if(!good)
@@ -103,31 +159,7 @@ void Communicator::data_available()
        try
        {
                in_end += socket.read(in_end, in_buf+buf_size-in_end);
-
-               bool more = true;
-               while(more)
-               {
-                       if(handshake_status==2)
-                               more = receive_packet(protocol, receiver);
-                       else
-                       {
-                               HandshakeProtocol hsproto;
-                               HandshakeReceiver hsrecv;
-                               if((more = receive_packet(hsproto, hsrecv)))
-                               {
-                                       if(handshake_status==0)
-                                               send_handshake();
-
-                                       if(hsrecv.get_hash()==protocol.get_hash())
-                                       {
-                                               handshake_status = 2;
-                                               signal_handshake_done.emit();
-                                       }
-                                       else
-                                               throw incompatible_protocol("hash mismatch");
-                               }
-                       }
-               }
+               while(receive_packet()) ;
        }
        catch(const exception &e)
        {
@@ -138,37 +170,98 @@ void Communicator::data_available()
        }
 }
 
-bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv)
+bool Communicator::receive_packet()
 {
-       int psz = proto.get_packet_size(in_begin, in_end-in_begin);
-       if(psz && psz<=in_end-in_begin)
+       Protocol::PacketHeader header;
+       size_t available = in_end-in_begin;
+       if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available)
        {
+               auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last);
+               if(i==protocols.end() || header.type<i->base || header.type>i->last)
+                       throw key_error(header.type);
+
                char *pkt = in_begin;
-               in_begin += psz;
-               proto.dispatch(recv, pkt, psz);
+               in_begin += header.length;
+               i->protocol->dispatch(*i->receiver, pkt, header.length, i->base);
                return true;
        }
        else
        {
                if(in_end==in_buf+buf_size)
                {
-                       size_t used = in_end-in_begin;
-                       memmove(in_buf, in_begin, used);
+                       memmove(in_buf, in_begin, available);
                        in_begin = in_buf;
-                       in_end = in_begin+used;
+                       in_end = in_begin+available;
                }
                return false;
        }
 }
 
-void Communicator::send_handshake()
+void Communicator::prepare_protocol(const ActiveProtocol &proto)
+{
+       PrepareProtocol prepare;
+       prepare.hash = proto.hash;
+       prepare.base = proto.base;
+       /* Use send_data() directly because this function is called to prepare the
+       handshake protocol too and send() would fail readiness check. */
+       send_data(handshake->protocol.serialize(prepare, out_buf, buf_size));
+}
+
+void Communicator::accept_protocol(ActiveProtocol &proto)
+{
+       proto.accepted = true;
+
+       AcceptProtocol accept;
+       accept.hash = proto.hash;
+       send_data(handshake->protocol.serialize(accept, out_buf, buf_size));
+}
+
+
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r):
+       hash(p.get_hash()),
+       base(b),
+       last(base+p.get_max_packet_id()),
+       protocol(&p),
+       receiver(&r)
+{ }
+
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h):
+       hash(h),
+       base(b),
+       last(base)
+{ }
+
+
+void Communicator::Handshake::receive(const PrepareProtocol &prepare)
+{
+       auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base);
+       if(i!=communicator.protocols.end() && i->base==prepare.base)
+               communicator.accept_protocol(*i);
+       else
+               communicator.protocols.emplace(i, prepare.base, prepare.hash);
+}
+
+void Communicator::Handshake::receive(const AcceptProtocol &accept)
 {
-       Handshake shake;
-       shake.hash = protocol.get_hash();
+       auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash);
+       if(i==communicator.protocols.end())
+               throw key_error(accept.hash);
+
+
+       if(i->ready)
+               return;
 
-       HandshakeProtocol hsproto;
-       size_t size = hsproto.serialize(shake, out_buf, buf_size);
-       socket.write(out_buf, size);
+       i->ready = true;
+       if(!i->accepted)
+               communicator.accept_protocol(*i);
+       if(i->protocol==&protocol)
+       {
+               for(const ActiveProtocol &p: communicator.protocols)
+                       if(!p.ready)
+                               communicator.prepare_protocol(p);
+       }
+       else
+               communicator.signal_protocol_ready.emit(*i->protocol);
 }
 
 } // namespace Net
index e9f646a59a62ba1f0796a36cf57f2d27454d849d..eb9893d861a2ca8eaa3679aeedcddcd6fde24ea4 100644 (file)
@@ -28,14 +28,29 @@ public:
 class MSPNET_API Communicator: public NonCopyable
 {
 public:
-       sigc::signal<void> signal_handshake_done;
+       sigc::signal<void, const Protocol &> signal_protocol_ready;
        sigc::signal<void, const std::exception &> signal_error;
 
 private:
+       struct ActiveProtocol
+       {
+               std::uint64_t hash = 0;
+               std::uint16_t base = 0;
+               std::uint16_t last = 0;
+               bool accepted = false;
+               bool ready = false;
+               const Protocol *protocol = nullptr;
+               ReceiverBase *receiver = nullptr;
+
+               ActiveProtocol(std::uint16_t, const Protocol &, ReceiverBase &);
+               ActiveProtocol(std::uint16_t, std::uint64_t);
+       };
+
+       struct Handshake;
+
        StreamSocket &socket;
-       const Protocol &protocol;
-       ReceiverBase &receiver;
-       int handshake_status = 0;
+       std::vector<ActiveProtocol> protocols;
+       Handshake *handshake = nullptr;
        std::size_t buf_size = 65536;
        char *in_buf = nullptr;
        char *in_begin = nullptr;
@@ -44,11 +59,12 @@ private:
        bool good = true;
 
 public:
+       Communicator(StreamSocket &);
        Communicator(StreamSocket &, const Protocol &, ReceiverBase &);
        ~Communicator();
 
-       void initiate_handshake();
-       bool is_handshake_done() const { return handshake_status==2; }
+       void add_protocol(const Protocol &, ReceiverBase &);
+       bool is_protocol_ready(const Protocol &) const;
 
        template<typename P>
        void send(const P &);
@@ -56,15 +72,23 @@ public:
 private:
        void send_data(std::size_t);
 
+       void connect_finished(const std::exception *);
        void data_available();
-       bool receive_packet(const Protocol &, ReceiverBase &);
-       void send_handshake();
+       bool receive_packet();
+
+       void prepare_protocol(const ActiveProtocol &);
+       void accept_protocol(ActiveProtocol &);
 };
 
 template<typename P>
 void Communicator::send(const P &pkt)
 {
-       send_data(protocol.serialize(pkt, out_buf, buf_size));
+       auto i = find_if(protocols, [](const ActiveProtocol &p){ return p.protocol && p.protocol->has_packet<P>(); });
+       if(i==protocols.end())
+               throw key_error(typeid(P).name());
+       else if(!i->ready)
+               throw sequence_error("protocol not ready");
+       send_data(i->protocol->serialize(pkt, out_buf, buf_size, i->base));
 }
 
 } // namespace Net