using namespace Msp::Net;
-struct Handshake
+// Sent when a protocol is added locally but isn't known to the other end yet
+struct PrepareProtocol
+{
+ uint64_t hash;
+ uint16_t base;
+};
+
+// Sent to confirm that a protocol is known to both peers
+struct AcceptProtocol
{
uint64_t hash;
};
HandshakeProtocol();
};
-HandshakeProtocol::HandshakeProtocol():
- Protocol(0x7F00)
+HandshakeProtocol::HandshakeProtocol()
{
- add<Handshake>(&Handshake::hash);
+ add<PrepareProtocol>(&PrepareProtocol::hash, &PrepareProtocol::base);
+ add<AcceptProtocol>(&AcceptProtocol::hash);
}
+}
-class HandshakeReceiver: public PacketReceiver<Handshake>
-{
-private:
- uint64_t hash = 0;
-public:
- uint64_t get_hash() const { return hash; }
- void receive(const Handshake &) override;
-};
+namespace Msp {
+namespace Net {
-void HandshakeReceiver::receive(const Handshake &shake)
+struct Communicator::Handshake: public PacketReceiver<PrepareProtocol>,
+ public PacketReceiver<AcceptProtocol>
{
- hash = shake.hash;
-}
+ Communicator &communicator;
+ HandshakeProtocol protocol;
-}
+ Handshake(Communicator &c): communicator(c) { }
+ void receive(const PrepareProtocol &) override;
+ void receive(const AcceptProtocol &) override;
+};
-namespace Msp {
-namespace Net {
-Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+Communicator::Communicator(StreamSocket &s):
socket(s),
- protocol(p),
- receiver(r),
+ handshake(new Handshake(*this)),
in_buf(new char[buf_size]),
in_begin(in_buf),
in_end(in_buf),
out_buf(new char[buf_size])
{
socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
+
+ protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake));
+ if(socket.is_connected())
+ prepare_protocol(protocols.back());
+ else
+ socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished));
+}
+
+Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+ Communicator(s)
+{
+ add_protocol(p, r);
}
Communicator::~Communicator()
{
delete[] in_buf;
delete[] out_buf;
+ delete handshake;
}
-void Communicator::initiate_handshake()
+void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv)
{
- if(handshake_status!=0)
- throw sequence_error("handshaking already done");
+ if(!good)
+ throw sequence_error("connection aborted");
+
+ unsigned max_id = proto.get_max_packet_id();
+ if(!max_id)
+ throw invalid_argument("Communicator::add_protocol");
- send_handshake();
- handshake_status = 1;
+ uint64_t hash = proto.get_hash();
+ auto i = find_member(protocols, hash, &ActiveProtocol::hash);
+ if(i==protocols.end())
+ {
+ const ActiveProtocol &last = protocols.back();
+ if(!last.protocol)
+ throw sequence_error("previous protocol is incomplete");
+ unsigned base = last.base;
+ base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF;
+
+ if(base+max_id>std::numeric_limits<std::uint16_t>::max())
+ throw invalid_state("Communicator::add_protocol");
+
+ protocols.emplace_back(base, proto, recv);
+
+ if(socket.is_connected() && protocols.front().ready)
+ prepare_protocol(protocols.back());
+ }
+ else if(!i->protocol)
+ {
+ i->protocol = &proto;
+ i->last = i->base+max_id;
+ i->receiver = &recv;
+ accept_protocol(*i);
+ }
+}
+
+bool Communicator::is_protocol_ready(const Protocol &proto) const
+{
+ auto i = find_member(protocols, &proto, &ActiveProtocol::protocol);
+ return (i!=protocols.end() && i->ready);
}
void Communicator::send_data(size_t size)
{
if(!good)
throw sequence_error("connection aborted");
- if(handshake_status!=2)
- throw sequence_error("handshake incomplete");
try
{
}
}
+void Communicator::connect_finished(const exception *exc)
+{
+ if(exc)
+ good = false;
+ else
+ prepare_protocol(protocols.front());
+}
+
void Communicator::data_available()
{
if(!good)
try
{
in_end += socket.read(in_end, in_buf+buf_size-in_end);
-
- bool more = true;
- while(more)
- {
- if(handshake_status==2)
- more = receive_packet(protocol, receiver);
- else
- {
- HandshakeProtocol hsproto;
- HandshakeReceiver hsrecv;
- if((more = receive_packet(hsproto, hsrecv)))
- {
- if(handshake_status==0)
- send_handshake();
-
- if(hsrecv.get_hash()==protocol.get_hash())
- {
- handshake_status = 2;
- signal_handshake_done.emit();
- }
- else
- throw incompatible_protocol("hash mismatch");
- }
- }
- }
+ while(receive_packet()) ;
}
catch(const exception &e)
{
}
}
-bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv)
+bool Communicator::receive_packet()
{
- int psz = proto.get_packet_size(in_begin, in_end-in_begin);
- if(psz && psz<=in_end-in_begin)
+ Protocol::PacketHeader header;
+ size_t available = in_end-in_begin;
+ if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available)
{
+ auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last);
+ if(i==protocols.end() || header.type<i->base || header.type>i->last)
+ throw key_error(header.type);
+
char *pkt = in_begin;
- in_begin += psz;
- proto.dispatch(recv, pkt, psz);
+ in_begin += header.length;
+ i->protocol->dispatch(*i->receiver, pkt, header.length, i->base);
return true;
}
else
{
if(in_end==in_buf+buf_size)
{
- size_t used = in_end-in_begin;
- memmove(in_buf, in_begin, used);
+ memmove(in_buf, in_begin, available);
in_begin = in_buf;
- in_end = in_begin+used;
+ in_end = in_begin+available;
}
return false;
}
}
-void Communicator::send_handshake()
+void Communicator::prepare_protocol(const ActiveProtocol &proto)
+{
+ PrepareProtocol prepare;
+ prepare.hash = proto.hash;
+ prepare.base = proto.base;
+ /* Use send_data() directly because this function is called to prepare the
+ handshake protocol too and send() would fail readiness check. */
+ send_data(handshake->protocol.serialize(prepare, out_buf, buf_size));
+}
+
+void Communicator::accept_protocol(ActiveProtocol &proto)
+{
+ proto.accepted = true;
+
+ AcceptProtocol accept;
+ accept.hash = proto.hash;
+ send_data(handshake->protocol.serialize(accept, out_buf, buf_size));
+}
+
+
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r):
+ hash(p.get_hash()),
+ base(b),
+ last(base+p.get_max_packet_id()),
+ protocol(&p),
+ receiver(&r)
+{ }
+
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h):
+ hash(h),
+ base(b),
+ last(base)
+{ }
+
+
+void Communicator::Handshake::receive(const PrepareProtocol &prepare)
+{
+ auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base);
+ if(i!=communicator.protocols.end() && i->base==prepare.base)
+ communicator.accept_protocol(*i);
+ else
+ communicator.protocols.emplace(i, prepare.base, prepare.hash);
+}
+
+void Communicator::Handshake::receive(const AcceptProtocol &accept)
{
- Handshake shake;
- shake.hash = protocol.get_hash();
+ auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash);
+ if(i==communicator.protocols.end())
+ throw key_error(accept.hash);
+
+
+ if(i->ready)
+ return;
- HandshakeProtocol hsproto;
- size_t size = hsproto.serialize(shake, out_buf, buf_size);
- socket.write(out_buf, size);
+ i->ready = true;
+ if(!i->accepted)
+ communicator.accept_protocol(*i);
+ if(i->protocol==&protocol)
+ {
+ for(const ActiveProtocol &p: communicator.protocols)
+ if(!p.ready)
+ communicator.prepare_protocol(p);
+ }
+ else
+ communicator.signal_protocol_ready.emit(*i->protocol);
}
} // namespace Net
class MSPNET_API Communicator: public NonCopyable
{
public:
- sigc::signal<void> signal_handshake_done;
+ sigc::signal<void, const Protocol &> signal_protocol_ready;
sigc::signal<void, const std::exception &> signal_error;
private:
+ struct ActiveProtocol
+ {
+ std::uint64_t hash = 0;
+ std::uint16_t base = 0;
+ std::uint16_t last = 0;
+ bool accepted = false;
+ bool ready = false;
+ const Protocol *protocol = nullptr;
+ ReceiverBase *receiver = nullptr;
+
+ ActiveProtocol(std::uint16_t, const Protocol &, ReceiverBase &);
+ ActiveProtocol(std::uint16_t, std::uint64_t);
+ };
+
+ struct Handshake;
+
StreamSocket &socket;
- const Protocol &protocol;
- ReceiverBase &receiver;
- int handshake_status = 0;
+ std::vector<ActiveProtocol> protocols;
+ Handshake *handshake = nullptr;
std::size_t buf_size = 65536;
char *in_buf = nullptr;
char *in_begin = nullptr;
bool good = true;
public:
+ Communicator(StreamSocket &);
Communicator(StreamSocket &, const Protocol &, ReceiverBase &);
~Communicator();
- void initiate_handshake();
- bool is_handshake_done() const { return handshake_status==2; }
+ void add_protocol(const Protocol &, ReceiverBase &);
+ bool is_protocol_ready(const Protocol &) const;
template<typename P>
void send(const P &);
private:
void send_data(std::size_t);
+ void connect_finished(const std::exception *);
void data_available();
- bool receive_packet(const Protocol &, ReceiverBase &);
- void send_handshake();
+ bool receive_packet();
+
+ void prepare_protocol(const ActiveProtocol &);
+ void accept_protocol(ActiveProtocol &);
};
template<typename P>
void Communicator::send(const P &pkt)
{
- send_data(protocol.serialize(pkt, out_buf, buf_size));
+ auto i = find_if(protocols, [](const ActiveProtocol &p){ return p.protocol && p.protocol->has_packet<P>(); });
+ if(i==protocols.end())
+ throw key_error(typeid(P).name());
+ else if(!i->ready)
+ throw sequence_error("protocol not ready");
+ send_data(i->protocol->serialize(pkt, out_buf, buf_size, i->base));
}
} // namespace Net