]> git.tdb.fi Git - libs/net.git/blob - source/net/communicator.cpp
Add a dynamic receiver class for more flexible packet handling
[libs/net.git] / source / net / communicator.cpp
1 #include "communicator.h"
2 #include <cstring>
3 #include "streamsocket.h"
4
5 using namespace std;
6
7 namespace {
8
9 using namespace Msp::Net;
10
11 // Sent when a protocol is added locally but isn't known to the other end yet
12 struct PrepareProtocol
13 {
14         uint64_t hash;
15         uint16_t base;
16 };
17
18 // Sent to confirm that a protocol is known to both peers
19 struct AcceptProtocol
20 {
21         uint64_t hash;
22 };
23
24
25 class HandshakeProtocol: public Protocol
26 {
27 public:
28         HandshakeProtocol();
29 };
30
31 HandshakeProtocol::HandshakeProtocol()
32 {
33         add<PrepareProtocol>(&PrepareProtocol::hash, &PrepareProtocol::base);
34         add<AcceptProtocol>(&AcceptProtocol::hash);
35 }
36
37 }
38
39
40 namespace Msp {
41 namespace Net {
42
43 struct Communicator::Handshake: public PacketReceiver<PrepareProtocol>,
44         public PacketReceiver<AcceptProtocol>
45 {
46         Communicator &communicator;
47         HandshakeProtocol protocol;
48
49         Handshake(Communicator &c): communicator(c) { }
50
51         void receive(const PrepareProtocol &) override;
52         void receive(const AcceptProtocol &) override;
53 };
54
55
56 Communicator::Communicator(StreamSocket &s):
57         socket(s),
58         handshake(new Handshake(*this)),
59         in_buf(new char[buf_size]),
60         in_begin(in_buf),
61         in_end(in_buf),
62         out_buf(new char[buf_size])
63 {
64         socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
65
66         protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake));
67         if(socket.is_connected())
68                 prepare_protocol(protocols.back());
69         else
70                 socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished));
71 }
72
73 Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
74         Communicator(s)
75 {
76         add_protocol(p, r);
77 }
78
79 Communicator::~Communicator()
80 {
81         delete[] in_buf;
82         delete[] out_buf;
83         delete handshake;
84 }
85
86 void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv)
87 {
88         if(!good)
89                 throw sequence_error("connection aborted");
90
91         unsigned max_id = proto.get_max_packet_id();
92         if(!max_id)
93                 throw invalid_argument("Communicator::add_protocol");
94
95         uint64_t hash = proto.get_hash();
96         auto i = find_member(protocols, hash, &ActiveProtocol::hash);
97         if(i==protocols.end())
98         {
99                 const ActiveProtocol &last = protocols.back();
100                 if(!last.protocol)
101                         throw sequence_error("previous protocol is incomplete");
102                 unsigned base = last.base;
103                 base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF;
104
105                 if(base+max_id>std::numeric_limits<std::uint16_t>::max())
106                         throw invalid_state("Communicator::add_protocol");
107
108                 protocols.emplace_back(base, proto, recv);
109
110                 if(socket.is_connected() && protocols.front().ready)
111                         prepare_protocol(protocols.back());
112         }
113         else if(!i->protocol)
114         {
115                 i->protocol = &proto;
116                 i->last = i->base+max_id;
117                 i->receiver = &recv;
118                 accept_protocol(*i);
119         }
120 }
121
122 bool Communicator::is_protocol_ready(const Protocol &proto) const
123 {
124         auto i = find_member(protocols, &proto, &ActiveProtocol::protocol);
125         return (i!=protocols.end() && i->ready);
126 }
127
128 void Communicator::send_data(size_t size)
129 {
130         if(!good)
131                 throw sequence_error("connection aborted");
132
133         try
134         {
135                 socket.write(out_buf, size);
136         }
137         catch(const std::exception &e)
138         {
139                 good = false;
140                 if(signal_error.empty())
141                         throw;
142                 signal_error.emit(e);
143         }
144 }
145
146 void Communicator::connect_finished(const exception *exc)
147 {
148         if(exc)
149                 good = false;
150         else
151                 prepare_protocol(protocols.front());
152 }
153
154 void Communicator::data_available()
155 {
156         if(!good)
157                 return;
158
159         try
160         {
161                 in_end += socket.read(in_end, in_buf+buf_size-in_end);
162                 while(receive_packet()) ;
163         }
164         catch(const exception &e)
165         {
166                 good = false;
167                 if(signal_error.empty())
168                         throw;
169                 signal_error.emit(e);
170         }
171 }
172
173 bool Communicator::receive_packet()
174 {
175         Protocol::PacketHeader header;
176         size_t available = in_end-in_begin;
177         if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available)
178         {
179                 auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last);
180                 if(i==protocols.end() || header.type<i->base || header.type>i->last)
181                         throw key_error(header.type);
182
183                 char *pkt = in_begin;
184                 in_begin += header.length;
185                 i->protocol->dispatch(*i->receiver, pkt, header.length, i->base);
186                 return true;
187         }
188         else
189         {
190                 if(in_end==in_buf+buf_size)
191                 {
192                         memmove(in_buf, in_begin, available);
193                         in_begin = in_buf;
194                         in_end = in_begin+available;
195                 }
196                 return false;
197         }
198 }
199
200 void Communicator::prepare_protocol(const ActiveProtocol &proto)
201 {
202         PrepareProtocol prepare;
203         prepare.hash = proto.hash;
204         prepare.base = proto.base;
205         /* Use send_data() directly because this function is called to prepare the
206         handshake protocol too and send() would fail readiness check. */
207         send_data(handshake->protocol.serialize(prepare, out_buf, buf_size));
208 }
209
210 void Communicator::accept_protocol(ActiveProtocol &proto)
211 {
212         proto.accepted = true;
213
214         AcceptProtocol accept;
215         accept.hash = proto.hash;
216         send_data(handshake->protocol.serialize(accept, out_buf, buf_size));
217 }
218
219
220 Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r):
221         hash(p.get_hash()),
222         base(b),
223         last(base+p.get_max_packet_id()),
224         protocol(&p),
225         receiver(&r)
226 { }
227
228 Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h):
229         hash(h),
230         base(b),
231         last(base)
232 { }
233
234
235 void Communicator::Handshake::receive(const PrepareProtocol &prepare)
236 {
237         auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base);
238         if(i!=communicator.protocols.end() && i->base==prepare.base)
239                 communicator.accept_protocol(*i);
240         else
241                 communicator.protocols.emplace(i, prepare.base, prepare.hash);
242 }
243
244 void Communicator::Handshake::receive(const AcceptProtocol &accept)
245 {
246         auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash);
247         if(i==communicator.protocols.end())
248                 throw key_error(accept.hash);
249
250
251         if(i->ready)
252                 return;
253
254         i->ready = true;
255         if(!i->accepted)
256                 communicator.accept_protocol(*i);
257         if(i->protocol==&protocol)
258         {
259                 for(const ActiveProtocol &p: communicator.protocols)
260                         if(!p.ready)
261                                 communicator.prepare_protocol(p);
262         }
263         else
264                 communicator.signal_protocol_ready.emit(*i->protocol);
265 }
266
267 } // namespace Net
268 } // namespace Msp