]> git.tdb.fi Git - libs/net.git/blobdiff - source/net/communicator.cpp
Add a dynamic receiver class for more flexible packet handling
[libs/net.git] / source / net / communicator.cpp
index c9b277d3748de137381b0df792d5230b0e8e0c62..73b10c2881b71c7c5c51302b3d97c521b0c4a2ad 100644 (file)
@@ -1,13 +1,24 @@
-#include <cstring>
 #include "communicator.h"
+#include <cstring>
+#include "streamsocket.h"
+
+using namespace std;
 
 namespace {
 
 using namespace Msp::Net;
 
-struct Handshake
+// Sent when a protocol is added locally but isn't known to the other end yet
+struct PrepareProtocol
+{
+       uint64_t hash;
+       uint16_t base;
+};
+
+// Sent to confirm that a protocol is known to both peers
+struct AcceptProtocol
 {
-       unsigned hash;
+       uint64_t hash;
 };
 
 
@@ -17,67 +28,127 @@ public:
        HandshakeProtocol();
 };
 
-HandshakeProtocol::HandshakeProtocol():
-       Protocol(0x7F00)
+HandshakeProtocol::HandshakeProtocol()
 {
-       add<Handshake>()(&Handshake::hash);
+       add<PrepareProtocol>(&PrepareProtocol::hash, &PrepareProtocol::base);
+       add<AcceptProtocol>(&AcceptProtocol::hash);
 }
 
+}
 
-class HandshakeReceiver: public PacketReceiver<Handshake>
-{
-private:
-       unsigned hash;
-
-public:
-       HandshakeReceiver();
-       unsigned get_hash() const { return hash; }
-       virtual void receive(const Handshake &);
-};
 
-HandshakeReceiver::HandshakeReceiver():
-       hash(0)
-{ }
+namespace Msp {
+namespace Net {
 
-void HandshakeReceiver::receive(const Handshake &shake)
+struct Communicator::Handshake: public PacketReceiver<PrepareProtocol>,
+       public PacketReceiver<AcceptProtocol>
 {
-       hash = shake.hash;
-}
+       Communicator &communicator;
+       HandshakeProtocol protocol;
 
-}
+       Handshake(Communicator &c): communicator(c) { }
 
+       void receive(const PrepareProtocol &) override;
+       void receive(const AcceptProtocol &) override;
+};
 
-namespace Msp {
-namespace Net {
 
-Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+Communicator::Communicator(StreamSocket &s):
        socket(s),
-       protocol(p),
-       receiver(r),
-       handshake_status(0),
-       buf_size(1024),
+       handshake(new Handshake(*this)),
        in_buf(new char[buf_size]),
        in_begin(in_buf),
        in_end(in_buf),
-       out_buf(new char[buf_size]),
-       good(true)
+       out_buf(new char[buf_size])
 {
        socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
+
+       protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake));
+       if(socket.is_connected())
+               prepare_protocol(protocols.back());
+       else
+               socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished));
+}
+
+Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+       Communicator(s)
+{
+       add_protocol(p, r);
 }
 
 Communicator::~Communicator()
 {
        delete[] in_buf;
        delete[] out_buf;
+       delete handshake;
+}
+
+void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv)
+{
+       if(!good)
+               throw sequence_error("connection aborted");
+
+       unsigned max_id = proto.get_max_packet_id();
+       if(!max_id)
+               throw invalid_argument("Communicator::add_protocol");
+
+       uint64_t hash = proto.get_hash();
+       auto i = find_member(protocols, hash, &ActiveProtocol::hash);
+       if(i==protocols.end())
+       {
+               const ActiveProtocol &last = protocols.back();
+               if(!last.protocol)
+                       throw sequence_error("previous protocol is incomplete");
+               unsigned base = last.base;
+               base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF;
+
+               if(base+max_id>std::numeric_limits<std::uint16_t>::max())
+                       throw invalid_state("Communicator::add_protocol");
+
+               protocols.emplace_back(base, proto, recv);
+
+               if(socket.is_connected() && protocols.front().ready)
+                       prepare_protocol(protocols.back());
+       }
+       else if(!i->protocol)
+       {
+               i->protocol = &proto;
+               i->last = i->base+max_id;
+               i->receiver = &recv;
+               accept_protocol(*i);
+       }
+}
+
+bool Communicator::is_protocol_ready(const Protocol &proto) const
+{
+       auto i = find_member(protocols, &proto, &ActiveProtocol::protocol);
+       return (i!=protocols.end() && i->ready);
 }
 
-void Communicator::initiate_handshake()
+void Communicator::send_data(size_t size)
 {
-       if(handshake_status!=0)
-               throw sequence_error("handshaking already done");
+       if(!good)
+               throw sequence_error("connection aborted");
+
+       try
+       {
+               socket.write(out_buf, size);
+       }
+       catch(const std::exception &e)
+       {
+               good = false;
+               if(signal_error.empty())
+                       throw;
+               signal_error.emit(e);
+       }
+}
 
-       send_handshake();
-       handshake_status = 1;
+void Communicator::connect_finished(const exception *exc)
+{
+       if(exc)
+               good = false;
+       else
+               prepare_protocol(protocols.front());
 }
 
 void Communicator::data_available()
@@ -85,73 +156,112 @@ void Communicator::data_available()
        if(!good)
                return;
 
-       in_end += socket.read(in_end, in_buf+buf_size-in_end);
        try
        {
-               bool more = true;
-               while(more)
-               {
-                       if(handshake_status==2)
-                       {
-                               more = receive_packet(protocol, receiver);
-                       }
-                       else
-                       {
-                               HandshakeProtocol hsproto;
-                               HandshakeReceiver hsrecv;
-                               if((more = receive_packet(hsproto, hsrecv)))
-                               {
-                                       if(hsrecv.get_hash()==protocol.get_hash())
-                                       {
-                                               if(handshake_status==0)
-                                                       send_handshake();
-                                               handshake_status = 2;
-                                               signal_handshake_done.emit();
-                                       }
-                                       else
-                                               good = false;
-                               }
-                       }
-               }
+               in_end += socket.read(in_end, in_buf+buf_size-in_end);
+               while(receive_packet()) ;
        }
-       catch(...)
+       catch(const exception &e)
        {
                good = false;
-               throw;
+               if(signal_error.empty())
+                       throw;
+               signal_error.emit(e);
        }
 }
 
-bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv)
+bool Communicator::receive_packet()
 {
-       int psz = proto.get_packet_size(in_begin, in_end-in_begin);
-       if(psz && psz<=in_end-in_begin)
+       Protocol::PacketHeader header;
+       size_t available = in_end-in_begin;
+       if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available)
        {
+               auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last);
+               if(i==protocols.end() || header.type<i->base || header.type>i->last)
+                       throw key_error(header.type);
+
                char *pkt = in_begin;
-               in_begin += psz;
-               proto.disassemble(recv, pkt, psz);
+               in_begin += header.length;
+               i->protocol->dispatch(*i->receiver, pkt, header.length, i->base);
                return true;
        }
        else
        {
                if(in_end==in_buf+buf_size)
                {
-                       unsigned used = in_end-in_begin;
-                       memmove(in_buf, in_begin, used);
+                       memmove(in_buf, in_begin, available);
                        in_begin = in_buf;
-                       in_end = in_begin+used;
+                       in_end = in_begin+available;
                }
                return false;
        }
 }
 
-void Communicator::send_handshake()
+void Communicator::prepare_protocol(const ActiveProtocol &proto)
+{
+       PrepareProtocol prepare;
+       prepare.hash = proto.hash;
+       prepare.base = proto.base;
+       /* Use send_data() directly because this function is called to prepare the
+       handshake protocol too and send() would fail readiness check. */
+       send_data(handshake->protocol.serialize(prepare, out_buf, buf_size));
+}
+
+void Communicator::accept_protocol(ActiveProtocol &proto)
 {
-       Handshake shake;
-       shake.hash = protocol.get_hash();
+       proto.accepted = true;
 
-       HandshakeProtocol hsproto;
-       unsigned size = hsproto.assemble(shake, out_buf, buf_size);
-       socket.write(out_buf, size);
+       AcceptProtocol accept;
+       accept.hash = proto.hash;
+       send_data(handshake->protocol.serialize(accept, out_buf, buf_size));
+}
+
+
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r):
+       hash(p.get_hash()),
+       base(b),
+       last(base+p.get_max_packet_id()),
+       protocol(&p),
+       receiver(&r)
+{ }
+
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h):
+       hash(h),
+       base(b),
+       last(base)
+{ }
+
+
+void Communicator::Handshake::receive(const PrepareProtocol &prepare)
+{
+       auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base);
+       if(i!=communicator.protocols.end() && i->base==prepare.base)
+               communicator.accept_protocol(*i);
+       else
+               communicator.protocols.emplace(i, prepare.base, prepare.hash);
+}
+
+void Communicator::Handshake::receive(const AcceptProtocol &accept)
+{
+       auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash);
+       if(i==communicator.protocols.end())
+               throw key_error(accept.hash);
+
+
+       if(i->ready)
+               return;
+
+       i->ready = true;
+       if(!i->accepted)
+               communicator.accept_protocol(*i);
+       if(i->protocol==&protocol)
+       {
+               for(const ActiveProtocol &p: communicator.protocols)
+                       if(!p.ready)
+                               communicator.prepare_protocol(p);
+       }
+       else
+               communicator.signal_protocol_ready.emit(*i->protocol);
 }
 
 } // namespace Net