]> git.tdb.fi Git - libs/net.git/blobdiff - source/net/protocol.cpp
Add a dynamic receiver class for more flexible packet handling
[libs/net.git] / source / net / protocol.cpp
index d56ac77a09f70eff7137da98beef2eb7112d8978..5e82a80d7daaf17fb130752fd870380e2b78c329 100644 (file)
+#include "protocol.h"
 #include <cstring>
 #include <string>
 #include <msp/core/maputils.h>
+#include <msp/strings/format.h>
 #include <msp/strings/lexicalcast.h>
-#include "protocol.h"
 
 using namespace std;
 
-namespace {
+namespace Msp {
+namespace Net {
 
-using Msp::Net::buffer_error;
+Protocol::Protocol():
+       header_def(0)
+{
+       PacketDefBuilder<PacketHeader, Serializer<PacketHeader>>(*this, header_def, Serializer<PacketHeader>())
+               .fields(&PacketHeader::type, &PacketHeader::length);
+}
 
-template<typename T>
-class Assembler
+unsigned Protocol::get_next_packet_class_id()
 {
-public:
-       static char *assemble(const T &v, char *, char *);
-       static const char *disassemble(T &, const char *, const char *);
-};
+       static unsigned next_id = 1;
+       return next_id++;
+}
 
-template<typename T>
-class Assembler<vector<T> >
+void Protocol::add_packet(unique_ptr<PacketDefBase> pdef)
 {
-public:
-       static char *assemble(const vector<T> &v, char *, char *);
-       static const char *disassemble(vector<T> &, const char *, const char *);
-};
+       unique_ptr<PacketDefBase> &ptr = packet_class_defs[pdef->get_class_id()];
+       if(ptr)
+               packet_id_defs.erase(ptr->get_id());
+       ptr = move(pdef);
+       if(unsigned id = ptr->get_id())
+               packet_id_defs[id] = ptr.get();
+}
 
-template<typename T>
-char *Assembler<T>::assemble(const T &v, char *data, char *end)
+const Protocol::PacketDefBase &Protocol::get_packet_by_class_id(unsigned id) const
 {
-       // XXX Assumes little-endian
-       const char *ptr = reinterpret_cast<const char *>(&v)+sizeof(T);
-       for(unsigned i=0; i<sizeof(T); ++i)
-       {
-               if(data==end)
-                       throw buffer_error("overflow");
-               *data++ = *--ptr;
-       }
-       return data;
+       return *get_item(packet_class_defs, id);
 }
 
-template<>
-char *Assembler<string>::assemble(const string &v, char *data, char *end)
+const Protocol::PacketDefBase &Protocol::get_packet_by_id(unsigned id) const
 {
-       data = Assembler<unsigned short>::assemble(v.size(), data, end);
-       if(end-data<static_cast<int>(v.size()))
-               throw buffer_error("overflow");
-       memcpy(data, v.data(), v.size());
-       return data+v.size();
+       return *get_item(packet_id_defs, id);
 }
 
-template<typename T>
-char *Assembler<vector<T> >::assemble(const vector<T> &v, char *data, char *end)
+unsigned Protocol::get_max_packet_id() const
 {
-       data = Assembler<unsigned short>::assemble(v.size(), data, end);
-       for(typename vector<T>::const_iterator i=v.begin(); i!=v.end(); ++i)
-               data = Assembler<T>::assemble(*i, data, end);
-       return data;
+       if(packet_id_defs.empty())
+               return 0;
+       return prev(packet_id_defs.end())->first;
 }
 
-template<typename T>
-const char *Assembler<T>::disassemble(T &v, const char *data, const char *end)
+size_t Protocol::dispatch(ReceiverBase &rcv, const char *buf, size_t size, unsigned base_id) const
 {
-       char *ptr = reinterpret_cast<char *>(&v)+sizeof(T);
-       for(unsigned i=0; i<sizeof(T); ++i)
+       PacketHeader header;
+       const char *ptr = header_def.deserialize(header, buf, buf+size);
+       if(header.length>size)
+               throw bad_packet("truncated");
+       const PacketDefBase &pdef = get_packet_by_id(header.type-base_id);
+       if(DynamicReceiver *drcv = dynamic_cast<DynamicReceiver *>(&rcv))
        {
-               if(data==end)
-                       throw buffer_error("underflow");
-               *--ptr = *data++;
+               Variant pkt;
+               ptr = pdef.deserialize(pkt, ptr, ptr+header.length);
+               drcv->receive(pdef.get_id(), pkt);
        }
-       return data;
+       else
+               ptr = pdef.dispatch(rcv, ptr, ptr+header.length);
+       return ptr-buf;
 }
 
-template<>
-const char *Assembler<string>::disassemble(string &v, const char *data, const char *end)
+bool Protocol::get_packet_header(PacketHeader &header, const char *buf, size_t size) const
 {
-       unsigned short size;
-       data = Assembler<unsigned short>::disassemble(size, data, end);
-       if(end-data<size)
-               throw buffer_error("underflow");
-       v.assign(data, data+size);
-       return data+size;
+       if(size<4)
+               return false;
+       header_def.deserialize(header, buf, buf+size);
+       return true;
 }
 
-template<typename T>
-const char *Assembler<vector<T> >::disassemble(vector<T> &v, const char *data, const char *end)
+size_t Protocol::get_packet_size(const char *buf, size_t size) const
+{
+       PacketHeader header;
+       return (get_packet_header(header, buf, size) ? header.length : 0);
+}
+
+uint64_t Protocol::get_hash() const
 {
-       /* We assume that the vector is in pristine state - this holds because the
-       only code path leading here is from PacketDef<P>::disassemble, which creates
-       a new packet. */
-       unsigned short size;
-       data = Assembler<unsigned short>::disassemble(size, data, end);
-       for(unsigned i=0; i<size; ++i)
+       uint64_t result = hash<64>(packet_id_defs.size());
+       for(auto &kvp: packet_id_defs)
        {
-               T u;
-               data = Assembler<T>::disassemble(u, data, end);
-               v.push_back(u);
+               hash_update<64>(result, kvp.first);
+               hash_update<64>(result, kvp.second->get_hash());
        }
-       return data;
+       return result;
 }
 
-}
 
-namespace Msp {
-namespace Net {
+/* TODO These assumes the machine is little-endian; are there any relevant
+big-endian platforms these days? */
+template<typename T>
+char *Protocol::BasicSerializer<T>::serialize(const T &value, char *buf, char *end) const
+{
+       if(end-buf<static_cast<int>(sizeof(T)))
+               throw buffer_error("overflow");
 
-Protocol::Protocol(unsigned npi):
-       next_packet_id(npi)
-{ }
+       const char *ptr = reinterpret_cast<const char *>(&value)+sizeof(T);
+       for(size_t i=0; i<sizeof(T); ++i)
+               *buf++ = *--ptr;
 
-Protocol::~Protocol()
-{
-       for(map<unsigned, PacketDefBase *>::iterator i=packet_class_defs.begin(); i!=packet_class_defs.end(); ++i)
-               delete i->second;
+       return buf;
 }
 
-void Protocol::add_packet(PacketDefBase &pdef)
+template<typename T>
+const char *Protocol::BasicSerializer<T>::deserialize(T &value, const char *buf, const char *end) const
 {
-       PacketDefBase *&ptr = packet_class_defs[pdef.get_class_id()];
-       if(ptr)
-               delete ptr;
-       ptr = &pdef;
-       packet_id_defs[pdef.get_id()] = &pdef;
-}
+       if(end-buf<static_cast<int>(sizeof(T)))
+               throw buffer_error("underflow");
 
-const Protocol::PacketDefBase &Protocol::get_packet_by_class(unsigned id) const
-{
-       return *get_item(packet_class_defs, id);
-}
+       char *ptr = reinterpret_cast<char *>(&value)+sizeof(T);
+       for(size_t i=0; i<sizeof(T); ++i)
+               *--ptr = *buf++;
 
-const Protocol::PacketDefBase &Protocol::get_packet_by_id(unsigned id) const
-{
-       return *get_item(packet_id_defs, id);
+       return buf;
 }
 
-unsigned Protocol::disassemble(ReceiverBase &rcv, const char *data, unsigned size) const
-{
-       const unsigned char *udata = reinterpret_cast<const unsigned char *>(data);
-       unsigned id = (udata[0]<<8)+udata[1];
-       unsigned psz = (udata[2]<<8)+udata[3];
-       if(psz>size)
-               throw bad_packet("truncated");
-       const PacketDefBase &pdef = get_packet_by_id(id);
-       const char *ptr = pdef.disassemble(rcv, data+4, data+psz);
-       return ptr-data;
-}
+template char *Protocol::BasicSerializer<bool>::serialize(const bool &, char *, char *) const;
+template char *Protocol::BasicSerializer<int8_t>::serialize(const int8_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<int16_t>::serialize(const int16_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<int32_t>::serialize(const int32_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<int64_t>::serialize(const int64_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint8_t>::serialize(const uint8_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint16_t>::serialize(const uint16_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint32_t>::serialize(const uint32_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint64_t>::serialize(const uint64_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<float>::serialize(const float &, char *, char *) const;
+template char *Protocol::BasicSerializer<double>::serialize(const double &, char *, char *) const;
+template const char *Protocol::BasicSerializer<bool>::deserialize(bool &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int8_t>::deserialize(int8_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int16_t>::deserialize(int16_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int32_t>::deserialize(int32_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int64_t>::deserialize(int64_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint8_t>::deserialize(uint8_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint16_t>::deserialize(uint16_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint32_t>::deserialize(uint32_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint64_t>::deserialize(uint64_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<float>::deserialize(float &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<double>::deserialize(double &, const char *, const char *) const;
+
+
+Protocol::StringSerializer::StringSerializer(const Protocol &p):
+       length_serializer(p)
+{ }
 
-unsigned Protocol::get_packet_size(const char *data, unsigned size) const
+char *Protocol::StringSerializer::serialize(const string &str, char *buf, char *end) const
 {
-       if(size<4)
-               return 0;
-       const unsigned char *udata = reinterpret_cast<const unsigned char *>(data);
-       return (udata[2]<<8)+udata[3];
+       buf = length_serializer.serialize(str.size(), buf, end);
+       if(end-buf<static_cast<int>(str.size()))
+               throw buffer_error("overflow");
+       copy(str.begin(), str.end(), buf);
+       return buf+str.size();
 }
 
-unsigned Protocol::get_hash() const
+const char *Protocol::StringSerializer::deserialize(string &str, const char *buf, const char *end) const
 {
-       // TODO
-       return 123;
+       uint16_t length;
+       buf = length_serializer.deserialize(length, buf, end);
+       if(end-buf<static_cast<int>(length))
+               throw buffer_error("underflow");
+       str.assign(buf, buf+length);
+       return buf+length;
 }
 
-void Protocol::assemble_header(char *buf, unsigned id, unsigned size)
-{
-       buf[0] = (id>>8)&0xFF;
-       buf[1] = id&0xFF;
-       buf[2] = (size>>8)&0xFF;
-       buf[3] = size&0xFF;
-}
 
-template<typename T>
-char *Protocol::assemble_field(const T &v, char *d, char *e)
-{ return Assembler<T>::assemble(v, d, e); }
-
-template char *Protocol::assemble_field<>(const char &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const signed char &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const unsigned char &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const short &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const unsigned short &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const int &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const unsigned &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const long &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const unsigned long &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const float &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const double &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const string &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<char> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<signed char> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<unsigned char> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<short> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<unsigned short> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<int> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<unsigned> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<long> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<unsigned long> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<float> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<double> &v, char *d, char *e);
-template char *Protocol::assemble_field<>(const vector<string> &v, char *d, char *e);
+Protocol::PacketDefBase::PacketDefBase(unsigned i):
+       id(i)
+{ }
 
-template<typename T>
-const char *Protocol::disassemble_field(T &v, const char *d, const char *e)
-{ return Assembler<T>::disassemble(v, d, e); }
-
-template const char *Protocol::disassemble_field<>(char &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(signed char &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(unsigned char &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(short &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(unsigned short &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(int &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(unsigned &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(long &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(unsigned long &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(float &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(double &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(string &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<char> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<signed char> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<unsigned char> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<short> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<unsigned short> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<int> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<unsigned> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<long> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<unsigned long> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<float> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<double> &v, const char *d, const char *e);
-template const char *Protocol::disassemble_field<>(vector<string> &v, const char *d, const char *e);
-
-unsigned Protocol::PacketDefBase::next_class_id = 1;
+
+Protocol::PacketHeader::PacketHeader(uint16_t t, uint16_t l):
+       type(t),
+       length(l)
+{ }
 
 } // namespace Net
 } // namespace Msp