Move the definition of PacketTypeDef<T>::class_id to a separate header
[libs/net.git] / source / net / communicator.cpp
1 #include <cstring>
2 #include "communicator.h"
3 #include "protocol_impl.h"
4 #include "streamsocket.h"
5
6 using namespace std;
7
8 namespace {
9
10 using namespace Msp::Net;
11
12 struct Handshake
13 {
14         Msp::UInt64 hash;
15 };
16
17
18 class HandshakeProtocol: public Protocol
19 {
20 public:
21         HandshakeProtocol();
22 };
23
24 HandshakeProtocol::HandshakeProtocol():
25         Protocol(0x7F00)
26 {
27         add<Handshake>()(&Handshake::hash);
28 }
29
30
31 class HandshakeReceiver: public PacketReceiver<Handshake>
32 {
33 private:
34         Msp::UInt64 hash;
35
36 public:
37         HandshakeReceiver();
38         Msp::UInt64 get_hash() const { return hash; }
39         virtual void receive(const Handshake &);
40 };
41
42 HandshakeReceiver::HandshakeReceiver():
43         hash(0)
44 { }
45
46 void HandshakeReceiver::receive(const Handshake &shake)
47 {
48         hash = shake.hash;
49 }
50
51 }
52
53
54 namespace Msp {
55 namespace Net {
56
57 Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
58         socket(s),
59         protocol(p),
60         receiver(r),
61         handshake_status(0),
62         buf_size(65536),
63         in_buf(new char[buf_size]),
64         in_begin(in_buf),
65         in_end(in_buf),
66         out_buf(new char[buf_size]),
67         good(true)
68 {
69         socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
70 }
71
72 Communicator::~Communicator()
73 {
74         delete[] in_buf;
75         delete[] out_buf;
76 }
77
78 void Communicator::initiate_handshake()
79 {
80         if(handshake_status!=0)
81                 throw sequence_error("handshaking already done");
82
83         send_handshake();
84         handshake_status = 1;
85 }
86
87 void Communicator::send_data(unsigned size)
88 {
89         if(!good)
90                 throw sequence_error("connection aborted");
91         if(handshake_status!=2)
92                 throw sequence_error("handshake incomplete");
93
94         try
95         {
96                 socket.write(out_buf, size);
97         }
98         catch(const std::exception &e)
99         {
100                 good = false;
101                 if(signal_error.empty())
102                         throw;
103                 signal_error.emit(e);
104         }
105 }
106
107 void Communicator::data_available()
108 {
109         if(!good)
110                 return;
111
112         try
113         {
114                 in_end += socket.read(in_end, in_buf+buf_size-in_end);
115
116                 bool more = true;
117                 while(more)
118                 {
119                         if(handshake_status==2)
120                                 more = receive_packet(protocol, receiver);
121                         else
122                         {
123                                 HandshakeProtocol hsproto;
124                                 HandshakeReceiver hsrecv;
125                                 if((more = receive_packet(hsproto, hsrecv)))
126                                 {
127                                         if(handshake_status==0)
128                                                 send_handshake();
129
130                                         if(hsrecv.get_hash()==protocol.get_hash())
131                                         {
132                                                 handshake_status = 2;
133                                                 signal_handshake_done.emit();
134                                         }
135                                         else
136                                                 throw incompatible_protocol("hash mismatch");
137                                 }
138                         }
139                 }
140         }
141         catch(const exception &e)
142         {
143                 good = false;
144                 if(signal_error.empty())
145                         throw;
146                 signal_error.emit(e);
147         }
148 }
149
150 bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv)
151 {
152         int psz = proto.get_packet_size(in_begin, in_end-in_begin);
153         if(psz && psz<=in_end-in_begin)
154         {
155                 char *pkt = in_begin;
156                 in_begin += psz;
157                 proto.dispatch(recv, pkt, psz);
158                 return true;
159         }
160         else
161         {
162                 if(in_end==in_buf+buf_size)
163                 {
164                         unsigned used = in_end-in_begin;
165                         memmove(in_buf, in_begin, used);
166                         in_begin = in_buf;
167                         in_end = in_begin+used;
168                 }
169                 return false;
170         }
171 }
172
173 void Communicator::send_handshake()
174 {
175         Handshake shake;
176         shake.hash = protocol.get_hash();
177
178         HandshakeProtocol hsproto;
179         unsigned size = hsproto.serialize(shake, out_buf, buf_size);
180         socket.write(out_buf, size);
181 }
182
183 } // namespace Net
184 } // namespace Msp