]> git.tdb.fi Git - libs/net.git/blob - source/net/communicator.cpp
Include the matching header first in .cpp files
[libs/net.git] / source / net / communicator.cpp
1 #include "communicator.h"
2 #include <cstring>
3 #include "streamsocket.h"
4
5 using namespace std;
6
7 namespace {
8
9 using namespace Msp::Net;
10
11 struct Handshake
12 {
13         uint64_t hash;
14 };
15
16
17 class HandshakeProtocol: public Protocol
18 {
19 public:
20         HandshakeProtocol();
21 };
22
23 HandshakeProtocol::HandshakeProtocol():
24         Protocol(0x7F00)
25 {
26         add<Handshake>(&Handshake::hash);
27 }
28
29
30 class HandshakeReceiver: public PacketReceiver<Handshake>
31 {
32 private:
33         uint64_t hash = 0;
34
35 public:
36         uint64_t get_hash() const { return hash; }
37         void receive(const Handshake &) override;
38 };
39
40 void HandshakeReceiver::receive(const Handshake &shake)
41 {
42         hash = shake.hash;
43 }
44
45 }
46
47
48 namespace Msp {
49 namespace Net {
50
51 Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
52         socket(s),
53         protocol(p),
54         receiver(r),
55         in_buf(new char[buf_size]),
56         in_begin(in_buf),
57         in_end(in_buf),
58         out_buf(new char[buf_size])
59 {
60         socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
61 }
62
63 Communicator::~Communicator()
64 {
65         delete[] in_buf;
66         delete[] out_buf;
67 }
68
69 void Communicator::initiate_handshake()
70 {
71         if(handshake_status!=0)
72                 throw sequence_error("handshaking already done");
73
74         send_handshake();
75         handshake_status = 1;
76 }
77
78 void Communicator::send_data(size_t size)
79 {
80         if(!good)
81                 throw sequence_error("connection aborted");
82         if(handshake_status!=2)
83                 throw sequence_error("handshake incomplete");
84
85         try
86         {
87                 socket.write(out_buf, size);
88         }
89         catch(const std::exception &e)
90         {
91                 good = false;
92                 if(signal_error.empty())
93                         throw;
94                 signal_error.emit(e);
95         }
96 }
97
98 void Communicator::data_available()
99 {
100         if(!good)
101                 return;
102
103         try
104         {
105                 in_end += socket.read(in_end, in_buf+buf_size-in_end);
106
107                 bool more = true;
108                 while(more)
109                 {
110                         if(handshake_status==2)
111                                 more = receive_packet(protocol, receiver);
112                         else
113                         {
114                                 HandshakeProtocol hsproto;
115                                 HandshakeReceiver hsrecv;
116                                 if((more = receive_packet(hsproto, hsrecv)))
117                                 {
118                                         if(handshake_status==0)
119                                                 send_handshake();
120
121                                         if(hsrecv.get_hash()==protocol.get_hash())
122                                         {
123                                                 handshake_status = 2;
124                                                 signal_handshake_done.emit();
125                                         }
126                                         else
127                                                 throw incompatible_protocol("hash mismatch");
128                                 }
129                         }
130                 }
131         }
132         catch(const exception &e)
133         {
134                 good = false;
135                 if(signal_error.empty())
136                         throw;
137                 signal_error.emit(e);
138         }
139 }
140
141 bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv)
142 {
143         int psz = proto.get_packet_size(in_begin, in_end-in_begin);
144         if(psz && psz<=in_end-in_begin)
145         {
146                 char *pkt = in_begin;
147                 in_begin += psz;
148                 proto.dispatch(recv, pkt, psz);
149                 return true;
150         }
151         else
152         {
153                 if(in_end==in_buf+buf_size)
154                 {
155                         size_t used = in_end-in_begin;
156                         memmove(in_buf, in_begin, used);
157                         in_begin = in_buf;
158                         in_end = in_begin+used;
159                 }
160                 return false;
161         }
162 }
163
164 void Communicator::send_handshake()
165 {
166         Handshake shake;
167         shake.hash = protocol.get_hash();
168
169         HandshakeProtocol hsproto;
170         size_t size = hsproto.serialize(shake, out_buf, buf_size);
171         socket.write(out_buf, size);
172 }
173
174 } // namespace Net
175 } // namespace Msp