]> git.tdb.fi Git - libs/net.git/commitdiff
Add protocol framework
authorMikko Rasa <tdb@tdb.fi>
Mon, 2 Mar 2009 09:11:08 +0000 (09:11 +0000)
committerMikko Rasa <tdb@tdb.fi>
Mon, 2 Mar 2009 09:11:08 +0000 (09:11 +0000)
Set event mask in StreamSocket when connect finishes instantly

source/communicator.cpp [new file with mode: 0644]
source/communicator.h [new file with mode: 0644]
source/protocol.cpp [new file with mode: 0644]
source/protocol.h [new file with mode: 0644]
source/receiver.h [new file with mode: 0644]
source/streamsocket.cpp

diff --git a/source/communicator.cpp b/source/communicator.cpp
new file mode 100644 (file)
index 0000000..b0b3b45
--- /dev/null
@@ -0,0 +1,161 @@
+/* $Id$
+
+This file is part of libmspnet
+Copyright © 2009  Mikkosoft Productions, Mikko Rasa
+Distributed under the LGPL
+*/
+
+#include <cstring>
+#include "communicator.h"
+
+namespace {
+
+using namespace Msp::Net;
+
+struct Handshake
+{
+       unsigned hash;
+};
+
+
+class HandshakeProtocol: public Protocol
+{
+public:
+       HandshakeProtocol();
+};
+
+HandshakeProtocol::HandshakeProtocol():
+       Protocol(0x7F00)
+{
+       add<Handshake>()(&Handshake::hash);
+}
+
+
+class HandshakeReceiver: public PacketReceiver<Handshake>
+{
+private:
+       unsigned hash;
+
+public:
+       HandshakeReceiver();
+       unsigned get_hash() const { return hash; }
+       virtual void receive(const Handshake &);
+};
+
+HandshakeReceiver::HandshakeReceiver():
+       hash(0)
+{ }
+
+void HandshakeReceiver::receive(const Handshake &shake)
+{
+       hash=shake.hash;
+}
+
+}
+
+
+namespace Msp {
+namespace Net {
+
+Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+       socket(s),
+       protocol(p),
+       receiver(r),
+       handshake_status(0),
+       buf_size(1024),
+       in_buf(new char[buf_size]),
+       in_begin(in_buf),
+       in_end(in_buf),
+       out_buf(new char[buf_size])
+{
+       socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
+}
+
+Communicator::~Communicator()
+{
+       delete[] in_buf;
+       delete[] out_buf;
+}
+
+void Communicator::initiate_handshake()
+{
+       if(handshake_status!=0)
+               throw InvalidState("Handshaking is already underway or done");
+
+       send_handshake();
+       handshake_status=1;
+}
+
+void Communicator::data_available()
+{
+       in_end+=socket.read(in_end, in_buf+buf_size-in_end);
+       try
+       {
+               bool more=true;
+               while(more)
+               {
+                       if(handshake_status==2)
+                       {
+                               more=receive_packet(protocol, receiver);
+                       }
+                       else
+                       {
+                               HandshakeProtocol hsproto;
+                               HandshakeReceiver hsrecv;
+                               if((more=receive_packet(hsproto, hsrecv)))
+                               {
+                                       if(hsrecv.get_hash()==protocol.get_hash())
+                                       {
+                                               if(handshake_status==0)
+                                                       send_handshake();
+                                               handshake_status=2;
+                                               signal_handshake_done.emit();
+                                       }
+                                       else
+                                               socket.close();
+                               }
+                       }
+               }
+       }
+       catch(...)
+       {
+               socket.close();
+               throw;
+       }
+}
+
+bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv)
+{
+       int psz=proto.get_packet_size(in_begin, in_end-in_begin);
+       if(psz && psz<=in_end-in_begin)
+       {
+               char *pkt=in_begin;
+               in_begin+=psz;
+               proto.disassemble(recv, pkt, psz);
+               return true;
+       }
+       else
+       {
+               if(in_end==in_buf+buf_size)
+               {
+                       unsigned used=in_end-in_begin;
+                       memmove(in_buf, in_begin, used);
+                       in_begin=in_buf;
+                       in_end=in_begin+used;
+               }
+               return false;
+       }
+}
+
+void Communicator::send_handshake()
+{
+       Handshake shake;
+       shake.hash=protocol.get_hash();
+
+       HandshakeProtocol hsproto;
+       unsigned size=hsproto.assemble(shake, out_buf, buf_size);
+       socket.write(out_buf, size);
+}
+
+} // namespace Net
+} // namespace Msp
diff --git a/source/communicator.h b/source/communicator.h
new file mode 100644 (file)
index 0000000..144f872
--- /dev/null
@@ -0,0 +1,57 @@
+/* $Id$
+
+This file is part of libmspnet
+Copyright © 2009  Mikkosoft Productions, Mikko Rasa
+Distributed under the LGPL
+*/
+
+#ifndef MSP_NET_COMMUNICATOR_H_
+#define MSP_NET_COMMUNICATOR_H_
+
+#include "protocol.h"
+#include "streamsocket.h"
+
+namespace Msp {
+namespace Net {
+
+class Communicator
+{
+public:
+       sigc::signal<void> signal_handshake_done;
+
+private:
+       StreamSocket &socket;
+       const Protocol &protocol;
+       ReceiverBase &receiver;
+       int handshake_status;
+       unsigned buf_size;
+       char *in_buf;
+       char *in_begin;
+       char *in_end;
+       char *out_buf;
+
+public:
+       Communicator(StreamSocket &, const Protocol &, ReceiverBase &);
+       ~Communicator();
+
+       void initiate_handshake();
+
+       template<typename P>
+       void send(const P &pkt)
+       {
+               if(handshake_status!=2)
+                       throw InvalidState("Handshaking is not done");
+               unsigned size=protocol.assemble(pkt, out_buf, buf_size);
+               socket.write(out_buf, size);
+       }
+
+private:
+       void data_available();
+       bool receive_packet(const Protocol &, ReceiverBase &);
+       void send_handshake();
+};
+
+} // namespace Net
+} // namespace Msp
+
+#endif
diff --git a/source/protocol.cpp b/source/protocol.cpp
new file mode 100644 (file)
index 0000000..1a21c9e
--- /dev/null
@@ -0,0 +1,242 @@
+/* $Id$
+
+This file is part of libmspnet
+Copyright © 2009  Mikkosoft Productions, Mikko Rasa
+Distributed under the LGPL
+*/
+
+#include <cstring>
+#include <string>
+#include <msp/strings/lexicalcast.h>
+#include "protocol.h"
+
+using namespace std;
+
+namespace {
+
+template<typename T>
+class Assembler
+{
+public:
+       static char *assemble(const T &v, char *, char *);
+       static const char *disassemble(T &, const char *, const char *);
+};
+
+template<typename T>
+class Assembler<vector<T> >
+{
+public:
+       static char *assemble(const vector<T> &v, char *, char *);
+       static const char *disassemble(vector<T> &, const char *, const char *);
+};
+
+template<typename T>
+char *Assembler<T>::assemble(const T &v, char *data, char *end)
+{
+       // XXX Assumes little-endian
+       const char *ptr=reinterpret_cast<const char *>(&v)+sizeof(T);
+       for(unsigned i=0; i<sizeof(T); ++i)
+       {
+               if(data==end)
+                       throw Msp::Exception("Out of buffer space");
+               *data++=*--ptr;
+       }
+       return data;
+}
+
+template<>
+char *Assembler<string>::assemble(const string &v, char *data, char *end)
+{
+       data=Assembler<unsigned short>::assemble(v.size(), data, end);
+       if(end-data<static_cast<int>(v.size()))
+               throw Msp::Exception("Out of buffer space");
+       memcpy(data, v.data(), v.size());
+       return data+v.size();
+}
+
+template<typename T>
+char *Assembler<vector<T> >::assemble(const vector<T> &v, char *data, char *end)
+{
+       data=Assembler<unsigned short>::assemble(v.size(), data, end);
+       for(typename vector<T>::const_iterator i=v.begin(); i!=v.end(); ++i)
+               data=Assembler<T>::assemble(*i, data, end);
+       return data;
+}
+
+template<typename T>
+const char *Assembler<T>::disassemble(T &v, const char *data, const char *end)
+{
+       char *ptr=reinterpret_cast<char *>(&v)+sizeof(T);
+       for(unsigned i=0; i<sizeof(T); ++i)
+       {
+               if(data==end)
+                       throw Msp::Exception("Premature end of data");
+               *--ptr=*data++;
+       }
+       return data;
+}
+
+template<>
+const char *Assembler<string>::disassemble(string &v, const char *data, const char *end)
+{
+       unsigned short size;
+       data=Assembler<unsigned short>::disassemble(size, data, end);
+       if(end-data<size)
+               throw Msp::Exception("Premature end of data");
+       v.assign(data, data+size);
+       return data+size;
+}
+
+template<typename T>
+const char *Assembler<vector<T> >::disassemble(vector<T> &v, const char *data, const char *end)
+{
+       /* We assume that the vector is in pristine state - this holds because the
+       only code path leading here is from PacketDef<P>::disassemble, which creates
+       a new packet. */
+       unsigned short size;
+       data=Assembler<unsigned short>::disassemble(size, data, end);
+       for(unsigned i=0; i<size; ++i)
+       {
+               T u;
+               data=Assembler<T>::disassemble(u, data, end);
+               v.push_back(u);
+       }
+       return data;
+}
+
+}
+
+namespace Msp {
+namespace Net {
+
+Protocol::Protocol(unsigned npi):
+       next_packet_id(npi)
+{ }
+
+Protocol::~Protocol()
+{
+       for(map<unsigned, PacketDefBase *>::iterator i=packet_class_defs.begin(); i!=packet_class_defs.end(); ++i)
+               delete i->second;
+}
+
+void Protocol::add_packet(PacketDefBase &pdef)
+{
+       PacketDefBase *&ptr=packet_class_defs[pdef.get_class_id()];
+       if(ptr)
+               delete ptr;
+       ptr=&pdef;
+       packet_id_defs[pdef.get_id()]=&pdef;
+}
+
+const Protocol::PacketDefBase &Protocol::get_packet_by_class(unsigned id) const
+{
+       PacketMap::const_iterator i=packet_class_defs.find(id);
+       if(i==packet_class_defs.end())
+               throw KeyError("Unknown packet class", lexical_cast(id));
+       return *i->second;
+}
+
+const Protocol::PacketDefBase &Protocol::get_packet_by_id(unsigned id) const
+{
+       PacketMap::const_iterator i=packet_id_defs.find(id);
+       if(i==packet_id_defs.end())
+               throw KeyError("Unknown packet ID", lexical_cast(id));
+       return *i->second;
+}
+
+unsigned Protocol::disassemble(ReceiverBase &rcv, const char *data, unsigned size) const
+{
+       const unsigned char *udata=reinterpret_cast<const unsigned char *>(data);
+       unsigned id=(udata[0]<<8)+udata[1];
+       unsigned psz=(udata[2]<<8)+udata[3];
+       if(psz>size)
+               throw InvalidParameterValue("Not enough data for packet");
+       const PacketDefBase &pdef=get_packet_by_id(id);
+       const char *ptr=pdef.disassemble(rcv, data+4, data+psz);
+       return ptr-data;
+}
+
+unsigned Protocol::get_packet_size(const char *data, unsigned size) const
+{
+       if(size<4)
+               return 0;
+       const unsigned char *udata=reinterpret_cast<const unsigned char *>(data);
+       return (udata[2]<<8)+udata[3];
+}
+
+unsigned Protocol::get_hash() const
+{
+       // TODO
+       return 123;
+}
+
+void Protocol::assemble_header(char *buf, unsigned id, unsigned size)
+{
+       buf[0]=(id>>8)&0xFF;
+       buf[1]=id&0xFF;
+       buf[2]=(size>>8)&0xFF;
+       buf[3]=size&0xFF;
+}
+
+template<typename T>
+char *Protocol::assemble_field(const T &v, char *d, char *e)
+{ return Assembler<T>::assemble(v, d, e); }
+
+template char *Protocol::assemble_field<>(const char &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const signed char &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const unsigned char &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const short &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const unsigned short &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const int &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const unsigned &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const long &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const unsigned long &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const float &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const double &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const string &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<char> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<signed char> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<unsigned char> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<short> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<unsigned short> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<int> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<unsigned> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<long> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<unsigned long> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<float> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<double> &v, char *d, char *e);
+template char *Protocol::assemble_field<>(const vector<string> &v, char *d, char *e);
+
+template<typename T>
+const char *Protocol::disassemble_field(T &v, const char *d, const char *e)
+{ return Assembler<T>::disassemble(v, d, e); }
+
+template const char *Protocol::disassemble_field<>(char &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(signed char &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(unsigned char &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(short &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(unsigned short &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(int &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(unsigned &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(long &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(unsigned long &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(float &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(double &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(string &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<char> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<signed char> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<unsigned char> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<short> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<unsigned short> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<int> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<unsigned> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<long> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<unsigned long> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<float> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<double> &v, const char *d, const char *e);
+template const char *Protocol::disassemble_field<>(vector<string> &v, const char *d, const char *e);
+
+unsigned Protocol::PacketDefBase::next_class_id=1;
+
+} // namespace Net
+} // namespace Msp
diff --git a/source/protocol.h b/source/protocol.h
new file mode 100644 (file)
index 0000000..082687b
--- /dev/null
@@ -0,0 +1,161 @@
+/* $Id$
+
+This file is part of libmspnet
+Copyright © 2009  Mikkosoft Productions, Mikko Rasa
+Distributed under the LGPL
+*/
+
+#ifndef MSP_NET_PROTOCOL_H_
+#define MSP_NET_PROTOCOL_H_
+
+#include <map>
+#include <vector>
+#include <msp/core/except.h>
+#include "receiver.h"
+
+namespace Msp {
+namespace Net {
+
+class Protocol
+{
+private:
+       class PacketDefBase
+       {
+       protected:
+               unsigned id;
+
+               PacketDefBase(unsigned i): id(i) { }
+       public:
+               virtual ~PacketDefBase() { }
+               virtual unsigned get_class_id() const =0;
+               unsigned get_id() const { return id; }
+               virtual const char *disassemble(ReceiverBase &, const char *, const char *) const =0;
+
+               static unsigned next_class_id;
+       };
+
+       template<typename P>
+       class FieldBase
+       {
+       protected:
+               FieldBase() { }
+       public:
+               virtual ~FieldBase() { }
+               virtual char *assemble(const P &, char *, char *) const =0;
+               virtual const char *disassemble(P &, const char *, const char *) const =0;
+       };
+
+       template<typename P, typename T>
+       class Field: public FieldBase<P>
+       {
+       private:
+               T P::*ptr;
+
+       public:
+               Field(T P::*p): ptr(p) { }
+
+               virtual char *assemble(const P &p, char *d, char *e) const
+               { return assemble_field(p.*ptr, d, e); }
+
+               virtual const char *disassemble(P &p, const char *d, const char *e) const
+               { return disassemble_field(p.*ptr, d, e); }
+       };
+
+protected:
+       template<typename P>
+       class PacketDef: public PacketDefBase
+       {
+       private:
+               std::vector<FieldBase<P> *> fields;
+
+       public:
+               PacketDef(unsigned i): PacketDefBase(i)
+               { if(!class_id) class_id=next_class_id++; }
+
+               virtual unsigned get_class_id() const { return class_id; }
+
+               template<typename T>
+               PacketDef &operator()(T P::*p)
+               { fields.push_back(new Field<P, T>(p)); return *this; }
+
+               char *assemble(const P &p, char *d, char *e) const
+               {
+                       for(typename std::vector<FieldBase<P> *>::const_iterator i=fields.begin(); i!=fields.end(); ++i)
+                               d=(*i)->assemble(p, d, e);
+                       return d;
+               }
+
+               const char *disassemble(ReceiverBase &r, const char *d, const char *e) const
+               {
+                       PacketReceiver<P> *prcv=dynamic_cast<PacketReceiver<P> *>(&r);
+                       if(!prcv)
+                               throw Exception("Packet type not supported by receiver");
+                       P pkt;
+                       for(typename std::vector<FieldBase<P> *>::const_iterator i=fields.begin(); i!=fields.end(); ++i)
+                               d=(*i)->disassemble(pkt, d, e);
+                       prcv->receive(pkt);
+                       return d;
+               }
+
+               static unsigned class_id;
+       };
+
+       typedef std::map<unsigned, PacketDefBase *> PacketMap;
+
+       unsigned next_packet_id;
+       PacketMap packet_class_defs;
+       PacketMap packet_id_defs;
+
+       Protocol(unsigned =1);
+public:
+       ~Protocol();
+
+private:
+       void add_packet(PacketDefBase &);
+
+protected:
+       template<typename P>
+       PacketDef<P> &add()
+       {
+               PacketDef<P> *pdef=new PacketDef<P>(next_packet_id++);
+               add_packet(*pdef);
+               return *pdef;
+       }
+
+       const PacketDefBase &get_packet_by_class(unsigned) const;
+       const PacketDefBase &get_packet_by_id(unsigned) const;
+
+public:
+       template<typename P>
+       unsigned assemble(const P &pkt, char *buf, unsigned size) const
+       {
+               unsigned id=PacketDef<P>::class_id;
+               const PacketDef<P> &pdef=static_cast<const PacketDef<P> &>(get_packet_by_class(id));
+               char *ptr=pdef.assemble(pkt, buf+4, buf+size);
+               assemble_header(buf, pdef.get_id(), (size=ptr-buf));
+               return size;
+       }
+
+       unsigned disassemble(ReceiverBase &, const char *, unsigned) const;
+
+       unsigned get_packet_size(const char *, unsigned) const;
+
+       unsigned get_hash() const;
+
+private:
+       static void assemble_header(char *, unsigned, unsigned);
+
+       template<typename T>
+       static char *assemble_field(const T &, char *, char *);
+
+       template<typename T>
+       static const char *disassemble_field(T &, const char *, const char *);
+};
+
+template<typename P>
+unsigned Protocol::PacketDef<P>::class_id=0;
+
+} // namespace Net
+} // namespace Msp
+
+#endif
diff --git a/source/receiver.h b/source/receiver.h
new file mode 100644 (file)
index 0000000..33917cd
--- /dev/null
@@ -0,0 +1,34 @@
+/* $Id$
+
+This file is part of libmspnet
+Copyright © 2009  Mikkosoft Productions, Mikko Rasa
+Distributed under the LGPL
+*/
+
+#ifndef MSP_NET_RECEIVER_H_
+#define MSP_NET_RECEIVER_H_
+
+namespace Msp {
+namespace Net {
+
+class ReceiverBase
+{
+protected:
+       ReceiverBase() { }
+public:
+       virtual ~ReceiverBase() { }
+};
+
+template<typename P>
+class PacketReceiver: public virtual ReceiverBase
+{
+protected:
+       PacketReceiver() { }
+public:
+       virtual void receive(const P &) =0;
+};
+
+} // namespace Net
+} // namespace Msp
+
+#endif
index 30c7efd77a35d2b257785e7af1ddb6b80f0156cd..48d06a79ccebab9809d3c2b2cf65988c93033a5c 100644 (file)
@@ -1,7 +1,7 @@
 /* $Id$
 
 This file is part of libmspnet
-Copyright © 2008  Mikkosoft Productions, Mikko Rasa
+Copyright © 2008-2009  Mikkosoft Productions, Mikko Rasa
 Distributed under the LGPL
 */
 
@@ -144,6 +144,7 @@ int StreamSocket::connect(const SockAddr &addr)
        if(err==0)
        {
                connected=true;
+               set_events(IO::P_INPUT);
                signal_connect_finished.emit(0);
        }