X-Git-Url: http://git.tdb.fi/?p=libs%2Fnet.git;a=blobdiff_plain;f=source%2Fnet%2Fcommunicator.cpp;h=22a79631d9a214741fdbf71721d8e6322986ea4e;hp=b30a8e965c388c8457b26ae454ce6c399fe5f6fc;hb=HEAD;hpb=3c2a877580e234df5fcbe06bf2850cd29f875e28 diff --git a/source/net/communicator.cpp b/source/net/communicator.cpp index b30a8e9..73b10c2 100644 --- a/source/net/communicator.cpp +++ b/source/net/communicator.cpp @@ -1,5 +1,6 @@ -#include #include "communicator.h" +#include +#include "streamsocket.h" using namespace std; @@ -7,9 +8,17 @@ namespace { using namespace Msp::Net; -struct Handshake +// Sent when a protocol is added locally but isn't known to the other end yet +struct PrepareProtocol +{ + uint64_t hash; + uint16_t base; +}; + +// Sent to confirm that a protocol is known to both peers +struct AcceptProtocol { - unsigned hash; + uint64_t hash; }; @@ -19,67 +28,127 @@ public: HandshakeProtocol(); }; -HandshakeProtocol::HandshakeProtocol(): - Protocol(0x7F00) +HandshakeProtocol::HandshakeProtocol() { - add()(&Handshake::hash); + add(&PrepareProtocol::hash, &PrepareProtocol::base); + add(&AcceptProtocol::hash); } +} -class HandshakeReceiver: public PacketReceiver -{ -private: - unsigned hash; - -public: - HandshakeReceiver(); - unsigned get_hash() const { return hash; } - virtual void receive(const Handshake &); -}; -HandshakeReceiver::HandshakeReceiver(): - hash(0) -{ } +namespace Msp { +namespace Net { -void HandshakeReceiver::receive(const Handshake &shake) +struct Communicator::Handshake: public PacketReceiver, + public PacketReceiver { - hash = shake.hash; -} + Communicator &communicator; + HandshakeProtocol protocol; -} + Handshake(Communicator &c): communicator(c) { } + void receive(const PrepareProtocol &) override; + void receive(const AcceptProtocol &) override; +}; -namespace Msp { -namespace Net { -Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r): +Communicator::Communicator(StreamSocket &s): socket(s), - protocol(p), - receiver(r), - handshake_status(0), - buf_size(1024), + handshake(new Handshake(*this)), in_buf(new char[buf_size]), in_begin(in_buf), in_end(in_buf), - out_buf(new char[buf_size]), - good(true) + out_buf(new char[buf_size]) { socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available)); + + protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake)); + if(socket.is_connected()) + prepare_protocol(protocols.back()); + else + socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished)); +} + +Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r): + Communicator(s) +{ + add_protocol(p, r); } Communicator::~Communicator() { delete[] in_buf; delete[] out_buf; + delete handshake; } -void Communicator::initiate_handshake() +void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv) { - if(handshake_status!=0) - throw sequence_error("handshaking already done"); + if(!good) + throw sequence_error("connection aborted"); + + unsigned max_id = proto.get_max_packet_id(); + if(!max_id) + throw invalid_argument("Communicator::add_protocol"); - send_handshake(); - handshake_status = 1; + uint64_t hash = proto.get_hash(); + auto i = find_member(protocols, hash, &ActiveProtocol::hash); + if(i==protocols.end()) + { + const ActiveProtocol &last = protocols.back(); + if(!last.protocol) + throw sequence_error("previous protocol is incomplete"); + unsigned base = last.base; + base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF; + + if(base+max_id>std::numeric_limits::max()) + throw invalid_state("Communicator::add_protocol"); + + protocols.emplace_back(base, proto, recv); + + if(socket.is_connected() && protocols.front().ready) + prepare_protocol(protocols.back()); + } + else if(!i->protocol) + { + i->protocol = &proto; + i->last = i->base+max_id; + i->receiver = &recv; + accept_protocol(*i); + } +} + +bool Communicator::is_protocol_ready(const Protocol &proto) const +{ + auto i = find_member(protocols, &proto, &ActiveProtocol::protocol); + return (i!=protocols.end() && i->ready); +} + +void Communicator::send_data(size_t size) +{ + if(!good) + throw sequence_error("connection aborted"); + + try + { + socket.write(out_buf, size); + } + catch(const std::exception &e) + { + good = false; + if(signal_error.empty()) + throw; + signal_error.emit(e); + } +} + +void Communicator::connect_finished(const exception *exc) +{ + if(exc) + good = false; + else + prepare_protocol(protocols.front()); } void Communicator::data_available() @@ -90,32 +159,7 @@ void Communicator::data_available() try { in_end += socket.read(in_end, in_buf+buf_size-in_end); - - bool more = true; - while(more) - { - if(handshake_status==2) - { - more = receive_packet(protocol, receiver); - } - else - { - HandshakeProtocol hsproto; - HandshakeReceiver hsrecv; - if((more = receive_packet(hsproto, hsrecv))) - { - if(hsrecv.get_hash()==protocol.get_hash()) - { - if(handshake_status==0) - send_handshake(); - handshake_status = 2; - signal_handshake_done.emit(); - } - else - good = false; - } - } - } + while(receive_packet()) ; } catch(const exception &e) { @@ -126,37 +170,98 @@ void Communicator::data_available() } } -bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv) +bool Communicator::receive_packet() { - int psz = proto.get_packet_size(in_begin, in_end-in_begin); - if(psz && psz<=in_end-in_begin) + Protocol::PacketHeader header; + size_t available = in_end-in_begin; + if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available) { + auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last); + if(i==protocols.end() || header.typebase || header.type>i->last) + throw key_error(header.type); + char *pkt = in_begin; - in_begin += psz; - proto.dispatch(recv, pkt, psz); + in_begin += header.length; + i->protocol->dispatch(*i->receiver, pkt, header.length, i->base); return true; } else { if(in_end==in_buf+buf_size) { - unsigned used = in_end-in_begin; - memmove(in_buf, in_begin, used); + memmove(in_buf, in_begin, available); in_begin = in_buf; - in_end = in_begin+used; + in_end = in_begin+available; } return false; } } -void Communicator::send_handshake() +void Communicator::prepare_protocol(const ActiveProtocol &proto) +{ + PrepareProtocol prepare; + prepare.hash = proto.hash; + prepare.base = proto.base; + /* Use send_data() directly because this function is called to prepare the + handshake protocol too and send() would fail readiness check. */ + send_data(handshake->protocol.serialize(prepare, out_buf, buf_size)); +} + +void Communicator::accept_protocol(ActiveProtocol &proto) { - Handshake shake; - shake.hash = protocol.get_hash(); + proto.accepted = true; - HandshakeProtocol hsproto; - unsigned size = hsproto.serialize(shake, out_buf, buf_size); - socket.write(out_buf, size); + AcceptProtocol accept; + accept.hash = proto.hash; + send_data(handshake->protocol.serialize(accept, out_buf, buf_size)); +} + + +Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r): + hash(p.get_hash()), + base(b), + last(base+p.get_max_packet_id()), + protocol(&p), + receiver(&r) +{ } + +Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h): + hash(h), + base(b), + last(base) +{ } + + +void Communicator::Handshake::receive(const PrepareProtocol &prepare) +{ + auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base); + if(i!=communicator.protocols.end() && i->base==prepare.base) + communicator.accept_protocol(*i); + else + communicator.protocols.emplace(i, prepare.base, prepare.hash); +} + +void Communicator::Handshake::receive(const AcceptProtocol &accept) +{ + auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash); + if(i==communicator.protocols.end()) + throw key_error(accept.hash); + + + if(i->ready) + return; + + i->ready = true; + if(!i->accepted) + communicator.accept_protocol(*i); + if(i->protocol==&protocol) + { + for(const ActiveProtocol &p: communicator.protocols) + if(!p.ready) + communicator.prepare_protocol(p); + } + else + communicator.signal_protocol_ready.emit(*i->protocol); } } // namespace Net