]> git.tdb.fi Git - libs/net.git/blob - source/net/protocol.cpp
Fix a length calculation bug in Protocol::dispatch
[libs/net.git] / source / net / protocol.cpp
1 #include <cstring>
2 #include <string>
3 #include <msp/core/maputils.h>
4 #include <msp/strings/format.h>
5 #include <msp/strings/lexicalcast.h>
6 #include "protocol.h"
7
8 using namespace std;
9
10 namespace Msp {
11 namespace Net {
12
13 Protocol::Protocol(unsigned npi):
14         header_def(0),
15         next_packet_id(npi)
16 {
17         PacketDefBuilder<PacketHeader, Serializer<PacketHeader>>(*this, header_def, Serializer<PacketHeader>())
18                 (&PacketHeader::type)(&PacketHeader::length);
19 }
20
21 Protocol::~Protocol()
22 {
23         for(auto &kvp: packet_class_defs)
24                 delete kvp.second;
25 }
26
27 unsigned Protocol::get_next_packet_class_id()
28 {
29         static unsigned next_id = 1;
30         return next_id++;
31 }
32
33 void Protocol::add_packet(PacketDefBase *pdef)
34 {
35         PacketDefBase *&ptr = packet_class_defs[pdef->get_class_id()];
36         if(ptr)
37         {
38                 packet_id_defs.erase(ptr->get_id());
39                 delete ptr;
40         }
41         ptr = pdef;
42         if(unsigned id = pdef->get_id())
43                 packet_id_defs[id] = pdef;
44 }
45
46 const Protocol::PacketDefBase &Protocol::get_packet_by_class_id(unsigned id) const
47 {
48         return *get_item(packet_class_defs, id);
49 }
50
51 const Protocol::PacketDefBase &Protocol::get_packet_by_id(unsigned id) const
52 {
53         return *get_item(packet_id_defs, id);
54 }
55
56 size_t Protocol::dispatch(ReceiverBase &rcv, const char *buf, size_t size) const
57 {
58         PacketHeader header;
59         const char *ptr = header_def.deserialize(header, buf, buf+size);
60         if(header.length>size)
61                 throw bad_packet("truncated");
62         const PacketDefBase &pdef = get_packet_by_id(header.type);
63         ptr = pdef.dispatch(rcv, ptr, ptr+header.length);
64         return ptr-buf;
65 }
66
67 size_t Protocol::get_packet_size(const char *buf, size_t size) const
68 {
69         if(size<4)
70                 return 0;
71         PacketHeader header;
72         header_def.deserialize(header, buf, buf+size);
73         return header.length;
74 }
75
76 uint64_t Protocol::get_hash() const
77 {
78         uint64_t result = hash<64>(packet_id_defs.size());
79         for(auto &kvp: packet_id_defs)
80         {
81                 hash_update<64>(result, kvp.first);
82                 hash_update<64>(result, kvp.second->get_hash());
83         }
84         return result;
85 }
86
87
88 /* TODO These assumes the machine is little-endian; are there any relevant
89 big-endian platforms these days? */
90 template<typename T>
91 char *Protocol::BasicSerializer<T>::serialize(const T &value, char *buf, char *end) const
92 {
93         if(end-buf<static_cast<int>(sizeof(T)))
94                 throw buffer_error("overflow");
95
96         const char *ptr = reinterpret_cast<const char *>(&value)+sizeof(T);
97         for(size_t i=0; i<sizeof(T); ++i)
98                 *buf++ = *--ptr;
99
100         return buf;
101 }
102
103 template<typename T>
104 const char *Protocol::BasicSerializer<T>::deserialize(T &value, const char *buf, const char *end) const
105 {
106         if(end-buf<static_cast<int>(sizeof(T)))
107                 throw buffer_error("underflow");
108
109         char *ptr = reinterpret_cast<char *>(&value)+sizeof(T);
110         for(size_t i=0; i<sizeof(T); ++i)
111                 *--ptr = *buf++;
112
113         return buf;
114 }
115
116 template char *Protocol::BasicSerializer<bool>::serialize(const bool &, char *, char *) const;
117 template char *Protocol::BasicSerializer<int8_t>::serialize(const int8_t &, char *, char *) const;
118 template char *Protocol::BasicSerializer<int16_t>::serialize(const int16_t &, char *, char *) const;
119 template char *Protocol::BasicSerializer<int32_t>::serialize(const int32_t &, char *, char *) const;
120 template char *Protocol::BasicSerializer<int64_t>::serialize(const int64_t &, char *, char *) const;
121 template char *Protocol::BasicSerializer<uint8_t>::serialize(const uint8_t &, char *, char *) const;
122 template char *Protocol::BasicSerializer<uint16_t>::serialize(const uint16_t &, char *, char *) const;
123 template char *Protocol::BasicSerializer<uint32_t>::serialize(const uint32_t &, char *, char *) const;
124 template char *Protocol::BasicSerializer<uint64_t>::serialize(const uint64_t &, char *, char *) const;
125 template char *Protocol::BasicSerializer<float>::serialize(const float &, char *, char *) const;
126 template char *Protocol::BasicSerializer<double>::serialize(const double &, char *, char *) const;
127 template const char *Protocol::BasicSerializer<bool>::deserialize(bool &, const char *, const char *) const;
128 template const char *Protocol::BasicSerializer<int8_t>::deserialize(int8_t &, const char *, const char *) const;
129 template const char *Protocol::BasicSerializer<int16_t>::deserialize(int16_t &, const char *, const char *) const;
130 template const char *Protocol::BasicSerializer<int32_t>::deserialize(int32_t &, const char *, const char *) const;
131 template const char *Protocol::BasicSerializer<int64_t>::deserialize(int64_t &, const char *, const char *) const;
132 template const char *Protocol::BasicSerializer<uint8_t>::deserialize(uint8_t &, const char *, const char *) const;
133 template const char *Protocol::BasicSerializer<uint16_t>::deserialize(uint16_t &, const char *, const char *) const;
134 template const char *Protocol::BasicSerializer<uint32_t>::deserialize(uint32_t &, const char *, const char *) const;
135 template const char *Protocol::BasicSerializer<uint64_t>::deserialize(uint64_t &, const char *, const char *) const;
136 template const char *Protocol::BasicSerializer<float>::deserialize(float &, const char *, const char *) const;
137 template const char *Protocol::BasicSerializer<double>::deserialize(double &, const char *, const char *) const;
138
139
140 Protocol::StringSerializer::StringSerializer(const Protocol &p):
141         length_serializer(p)
142 { }
143
144 char *Protocol::StringSerializer::serialize(const string &str, char *buf, char *end) const
145 {
146         buf = length_serializer.serialize(str.size(), buf, end);
147         if(end-buf<static_cast<int>(str.size()))
148                 throw buffer_error("overflow");
149         copy(str.begin(), str.end(), buf);
150         return buf+str.size();
151 }
152
153 const char *Protocol::StringSerializer::deserialize(string &str, const char *buf, const char *end) const
154 {
155         uint16_t length;
156         buf = length_serializer.deserialize(length, buf, end);
157         if(end-buf<static_cast<int>(length))
158                 throw buffer_error("underflow");
159         str.assign(buf, buf+length);
160         return buf+length;
161 }
162
163
164 Protocol::PacketDefBase::PacketDefBase(unsigned i):
165         id(i)
166 { }
167
168
169 Protocol::PacketHeader::PacketHeader(uint16_t t, uint16_t l):
170         type(t),
171         length(l)
172 { }
173
174 } // namespace Net
175 } // namespace Msp