1 #include "communicator.h"
3 #include "streamsocket.h"
9 using namespace Msp::Net;
11 // Sent when a protocol is added locally but isn't known to the other end yet
12 struct PrepareProtocol
18 // Sent to confirm that a protocol is known to both peers
25 class HandshakeProtocol: public Protocol
31 HandshakeProtocol::HandshakeProtocol()
33 add<PrepareProtocol>(&PrepareProtocol::hash, &PrepareProtocol::base);
34 add<AcceptProtocol>(&AcceptProtocol::hash);
43 struct Communicator::Handshake: public PacketReceiver<PrepareProtocol>,
44 public PacketReceiver<AcceptProtocol>
46 Communicator &communicator;
47 HandshakeProtocol protocol;
49 Handshake(Communicator &c): communicator(c) { }
51 void receive(const PrepareProtocol &) override;
52 void receive(const AcceptProtocol &) override;
56 Communicator::Communicator(StreamSocket &s):
58 handshake(new Handshake(*this)),
59 in_buf(new char[buf_size]),
62 out_buf(new char[buf_size])
64 socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
66 protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake));
67 if(socket.is_connected())
68 prepare_protocol(protocols.back());
70 socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished));
73 Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
79 Communicator::~Communicator()
86 void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv)
89 throw sequence_error("connection aborted");
91 unsigned max_id = proto.get_max_packet_id();
93 throw invalid_argument("Communicator::add_protocol");
95 uint64_t hash = proto.get_hash();
96 auto i = find_member(protocols, hash, &ActiveProtocol::hash);
97 if(i==protocols.end())
99 const ActiveProtocol &last = protocols.back();
101 throw sequence_error("previous protocol is incomplete");
102 unsigned base = last.base;
103 base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF;
105 if(base+max_id>std::numeric_limits<std::uint16_t>::max())
106 throw invalid_state("Communicator::add_protocol");
108 protocols.emplace_back(base, proto, recv);
110 if(socket.is_connected() && protocols.front().ready)
111 prepare_protocol(protocols.back());
113 else if(!i->protocol)
115 i->protocol = &proto;
116 i->last = i->base+max_id;
122 bool Communicator::is_protocol_ready(const Protocol &proto) const
124 auto i = find_member(protocols, &proto, &ActiveProtocol::protocol);
125 return (i!=protocols.end() && i->ready);
128 void Communicator::send_data(size_t size)
131 throw sequence_error("connection aborted");
135 socket.write(out_buf, size);
137 catch(const std::exception &e)
140 if(signal_error.empty())
142 signal_error.emit(e);
146 void Communicator::connect_finished(const exception *exc)
151 prepare_protocol(protocols.front());
154 void Communicator::data_available()
161 in_end += socket.read(in_end, in_buf+buf_size-in_end);
162 while(receive_packet()) ;
164 catch(const exception &e)
167 if(signal_error.empty())
169 signal_error.emit(e);
173 bool Communicator::receive_packet()
175 Protocol::PacketHeader header;
176 size_t available = in_end-in_begin;
177 if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available)
179 auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last);
180 if(i==protocols.end() || header.type<i->base || header.type>i->last)
181 throw key_error(header.type);
183 char *pkt = in_begin;
184 in_begin += header.length;
185 i->protocol->dispatch(*i->receiver, pkt, header.length, i->base);
190 if(in_end==in_buf+buf_size)
192 memmove(in_buf, in_begin, available);
194 in_end = in_begin+available;
200 void Communicator::prepare_protocol(const ActiveProtocol &proto)
202 PrepareProtocol prepare;
203 prepare.hash = proto.hash;
204 prepare.base = proto.base;
205 /* Use send_data() directly because this function is called to prepare the
206 handshake protocol too and send() would fail readiness check. */
207 send_data(handshake->protocol.serialize(prepare, out_buf, buf_size));
210 void Communicator::accept_protocol(ActiveProtocol &proto)
212 proto.accepted = true;
214 AcceptProtocol accept;
215 accept.hash = proto.hash;
216 send_data(handshake->protocol.serialize(accept, out_buf, buf_size));
220 Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r):
223 last(base+p.get_max_packet_id()),
228 Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h):
235 void Communicator::Handshake::receive(const PrepareProtocol &prepare)
237 auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base);
238 if(i!=communicator.protocols.end() && i->base==prepare.base)
239 communicator.accept_protocol(*i);
241 communicator.protocols.emplace(i, prepare.base, prepare.hash);
244 void Communicator::Handshake::receive(const AcceptProtocol &accept)
246 auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash);
247 if(i==communicator.protocols.end())
248 throw key_error(accept.hash);
256 communicator.accept_protocol(*i);
257 if(i->protocol==&protocol)
259 for(const ActiveProtocol &p: communicator.protocols)
261 communicator.prepare_protocol(p);
264 communicator.signal_protocol_ready.emit(*i->protocol);