]> git.tdb.fi Git - libs/net.git/commitdiff
Add a dynamic receiver class for more flexible packet handling master
authorMikko Rasa <tdb@tdb.fi>
Thu, 1 Jun 2023 07:31:43 +0000 (10:31 +0300)
committerMikko Rasa <tdb@tdb.fi>
Thu, 1 Jun 2023 07:31:43 +0000 (10:31 +0300)
It's useful for middleware libraries which may not know all packet types
at compile time, or if different packets need to be routed to different
receivers.

65 files changed:
Build
examples/httpget.cpp
examples/netcat.cpp
source/http/client.cpp
source/http/client.h
source/http/formdata.cpp
source/http/formdata.h
source/http/header.cpp
source/http/header.h
source/http/message.cpp
source/http/message.h
source/http/request.cpp
source/http/request.h
source/http/response.cpp
source/http/response.h
source/http/server.cpp
source/http/server.h
source/http/status.h
source/http/submessage.h
source/http/utils.cpp
source/http/utils.h
source/http/version.cpp
source/http/version.h
source/net/clientsocket.cpp
source/net/clientsocket.h
source/net/communicator.cpp
source/net/communicator.h
source/net/constants.cpp [deleted file]
source/net/constants.h
source/net/datagramsocket.cpp
source/net/datagramsocket.h
source/net/inet.cpp
source/net/inet.h
source/net/inet6.cpp
source/net/inet6.h
source/net/mspnet_api.h [new file with mode: 0644]
source/net/protocol.cpp
source/net/protocol.h
source/net/protocol_impl.h [new file with mode: 0644]
source/net/receiver.cpp [new file with mode: 0644]
source/net/receiver.h
source/net/resolve.cpp
source/net/resolve.h
source/net/serversocket.cpp
source/net/serversocket.h
source/net/sockaddr.cpp
source/net/sockaddr.h
source/net/sockaddr_private.h
source/net/socket.cpp
source/net/socket.h
source/net/socket_private.h
source/net/streamserversocket.cpp
source/net/streamserversocket.h
source/net/streamsocket.cpp
source/net/streamsocket.h
source/net/unix.cpp
source/net/unix.h
source/net/unix/socket.cpp
source/net/unix/unix.cpp
source/net/windows/socket.cpp
source/net/windows/unix.cpp
tests/.gitignore [new file with mode: 0644]
tests/Build [new file with mode: 0644]
tests/http_header.cpp [new file with mode: 0644]
tests/protocol.cpp [new file with mode: 0644]

diff --git a/Build b/Build
index 404ddbf42f6d8eb5a4426d17105c344f3b3a5ab3..44bf9c0bcba88d59f726b1f29b63c23ebe7d88f8 100644 (file)
--- a/Build
+++ b/Build
@@ -1,5 +1,7 @@
 package "mspnet"
 {
+       version "1.0";
+
        require "sigc++-2.0";
        require "mspcore";
        if_arch "windows"
@@ -10,6 +12,11 @@ package "mspnet"
                };
        };
 
+       build_info
+       {
+               standard CXX "c++14";
+       };
+
        library "mspnet"
        {
                source "source/net";
index a4314cef6e3ec12ee83114d36f3e6bd0a823ae71..18ef79933791849bd79dddaf307bf76285fa0e6a 100644 (file)
@@ -28,13 +28,8 @@ HttpGet::HttpGet(int argc, char **argv):
 {
        GetOpt getopt;
        getopt.add_option('v', "verbose", verbose, GetOpt::NO_ARG);
+       getopt.add_argument("url", url, GetOpt::REQUIRED_ARG);
        getopt(argc, argv);
-
-       const vector<string> &args = getopt.get_args();
-       if(args.empty())
-               throw usage_error("No URL");
-
-       url = args.front();
 }
 
 int HttpGet::main()
index 9bbad87b8d0fa944a636cf52e28ba37f67d7e8ee..cd549f4a045e9486f92d9c01da5fe04edda9c16a 100644 (file)
@@ -34,16 +34,15 @@ NetCat::NetCat(int argc, char **argv):
        server_sock(0),
        sock(0)
 {
+       string host_name;
+
        GetOpt getopt;
        getopt.add_option('6', "ipv6",   ipv6,   GetOpt::NO_ARG);
        getopt.add_option('l', "listen", listen, GetOpt::NO_ARG);
+       getopt.add_argument("host", host_name, GetOpt::REQUIRED_ARG);
        getopt(argc, argv);
 
-       const vector<string> &args = getopt.get_args();
-       if(args.empty())
-               throw usage_error("host argument missing");
-
-       RefPtr<Net::SockAddr> addr = Net::resolve(args.front(), (ipv6 ? Net::INET6 : Net::INET));
+       RefPtr<Net::SockAddr> addr = Net::resolve(host_name, (ipv6 ? Net::INET6 : Net::INET));
        if(!listen)
        {
                sock = new Net::StreamSocket(addr->get_family());
index 6f221810c922a78f85289934b36ce4d6b71b2100..1c4c953e5d5d5fb31b3243e13a2fa6d728d12a33 100644 (file)
@@ -1,7 +1,7 @@
-#include <msp/core/refptr.h>
-#include <msp/net/resolve.h>
-#include <msp/time/units.h>
 #include "client.h"
+#include <msp/core/except.h>
+#include <msp/net/resolve.h>
+#include <msp/time/timedelta.h>
 #include "request.h"
 #include "response.h"
 
@@ -10,22 +10,8 @@ using namespace std;
 namespace Msp {
 namespace Http {
 
-Client::Client():
-       sock(0),
-       event_disp(0),
-       resolver(0),
-       resolve_listener(0),
-       resolve_tag(0),
-       user_agent("libmsphttp/0.1"),
-       request(0),
-       response(0)
-{ }
-
 Client::~Client()
 {
-       delete sock;
-       delete request;
-       delete response;
 }
 
 void Client::use_event_dispatcher(IO::EventDispatcher *ed)
@@ -40,30 +26,25 @@ void Client::use_event_dispatcher(IO::EventDispatcher *ed)
 void Client::use_resolver(Net::Resolver *r)
 {
        if(resolver)
-       {
-               delete resolve_listener;
-               resolve_listener = 0;
-       }
+               resolve_listener.reset();
 
        resolver = r;
        if(resolver)
-               resolve_listener = new ResolveListener(*this);
+               resolve_listener = make_unique<ResolveListener>(*this);
 }
 
 void Client::start_request(const Request &r)
 {
        if(request)
-               throw client_busy();
+               throw invalid_state("already processing a request");
 
-       delete sock;
-       sock = 0;
+       sock.reset();
 
-       request = new Request(r);
+       request = make_unique<Request>(r);
        if(!user_agent.empty())
                request->set_header("User-Agent", user_agent);
 
-       delete response;
-       response = 0;
+       response.reset();
        in_buf.clear();
 
        string host = r.get_header("Host");
@@ -73,7 +54,7 @@ void Client::start_request(const Request &r)
                resolve_tag = resolver->resolve(host);
        else
        {
-               RefPtr<Net::SockAddr> addr = Net::resolve(host);
+               unique_ptr<Net::SockAddr> addr(Net::resolve(host));
                address_resolved(resolve_tag, *addr);
        }
 }
@@ -82,7 +63,7 @@ const Response *Client::get_url(const std::string &url)
 {
        start_request(Request::from_url(url));
        wait_response();
-       return response;
+       return response.get();
 }
 
 void Client::tick()
@@ -97,10 +78,8 @@ void Client::tick()
        {
                signal_response_complete.emit(*response);
 
-               delete sock;
-               sock = 0;
-               delete request;
-               request = 0;
+               sock.reset();
+               request.reset();
        }
 }
 
@@ -112,10 +91,8 @@ void Client::wait_response()
 
 void Client::abort()
 {
-       delete sock;
-       sock = 0;
-       delete request;
-       request = 0;
+       sock.reset();
+       request.reset();
 }
 
 void Client::address_resolved(unsigned tag, const Net::SockAddr &addr)
@@ -124,7 +101,7 @@ void Client::address_resolved(unsigned tag, const Net::SockAddr &addr)
                return;
        resolve_tag = 0;
 
-       sock = new Net::StreamSocket(addr.get_family());
+       sock = make_unique<Net::StreamSocket>(addr.get_family());
        sock->set_block(false);
 
        sock->signal_data_available.connect(sigc::mem_fun(this, &Client::data_available));
@@ -141,23 +118,37 @@ void Client::resolve_failed(unsigned tag, const exception &err)
                return;
        resolve_tag = 0;
 
-       signal_socket_error.emit(err);
+       request.reset();
 
-       delete request;
-       request = 0;
+       if(signal_socket_error.empty())
+               throw err;
+       signal_socket_error.emit(err);
 }
 
 void Client::connect_finished(const exception *err)
 {
        if(err)
        {
-               signal_socket_error.emit(*err);
+               request.reset();
 
-               delete request;
-               request = 0;
+               if(signal_socket_error.empty())
+                       throw *err;
+               signal_socket_error.emit(*err);
        }
        else
-               sock->write(request->str());
+       {
+               try
+               {
+                       sock->write(request->str());
+               }
+               catch(const exception &e)
+               {
+                       if(signal_socket_error.empty())
+                               throw;
+                       signal_socket_error.emit(e);
+                       return;
+               }
+       }
 }
 
 void Client::data_available()
@@ -170,6 +161,8 @@ void Client::data_available()
        }
        catch(const exception &e)
        {
+               if(signal_socket_error.empty())
+                       throw;
                signal_socket_error.emit(e);
                return;
        }
@@ -182,7 +175,7 @@ void Client::data_available()
        {
                if(in_buf.find("\r\n\r\n")!=string::npos || in_buf.find("\n\n")!=string::npos)
                {
-                       response = new Response(Response::parse(in_buf));
+                       response = make_unique<Response>(Response::parse(in_buf));
                        response->set_user_data(request->get_user_data());
                        in_buf = string();
                }
@@ -197,8 +190,7 @@ void Client::data_available()
        {
                signal_response_complete.emit(*response);
 
-               delete request;
-               request = 0;
+               request.reset();
        }
 }
 
index ebb73216a4eacfbbdfce4cc511e86f1e771395ef..b1f097d8852a4517c501a7d996ee077a97e753fc 100644 (file)
@@ -1,26 +1,21 @@
 #ifndef MSP_HTTP_CLIENT_H_
 #define MSP_HTTP_CLIENT_H_
 
+#include <memory>
 #include <string>
 #include <sigc++/signal.h>
 #include <msp/io/eventdispatcher.h>
+#include <msp/net/mspnet_api.h>
 #include <msp/net/resolve.h>
 #include <msp/net/streamsocket.h>
 
 namespace Msp {
 namespace Http {
 
-class client_busy: public std::logic_error
-{
-public:
-       client_busy(): std::logic_error(std::string()) { }
-       virtual ~client_busy() throw() { }
-};
-
 class Request;
 class Response;
 
-class Client
+class MSPNET_API Client
 {
 public:
        sigc::signal<void, const Response &> signal_response_complete;
@@ -37,20 +32,17 @@ private:
                void resolve_failed(unsigned, const std::exception &);
        };
 
-       Net::StreamSocket *sock;
-       IO::EventDispatcher *event_disp;
-       Net::Resolver *resolver;
-       ResolveListener *resolve_listener;
-       unsigned resolve_tag;
-       std::string user_agent;
-       Request *request;
-       Response *response;
+       std::unique_ptr<Net::StreamSocket> sock;
+       IO::EventDispatcher *event_disp = nullptr;
+       Net::Resolver *resolver = nullptr;
+       std::unique_ptr<ResolveListener> resolve_listener;
+       unsigned resolve_tag = 0;
+       std::string user_agent = "libmspnet/1.0";
+       std::unique_ptr<Request> request;
+       std::unique_ptr<Response> response;
        std::string in_buf;
 
-       Client(const Client &);
-       Client &operator=(const Client &);
 public:
-       Client();
        ~Client();
 
        void use_event_dispatcher(IO::EventDispatcher *);
@@ -61,8 +53,8 @@ public:
        void tick();
        void wait_response();
        void abort();
-       const Request *get_request() const { return request; }
-       const Response *get_response() const { return response; }
+       const Request *get_request() const { return request.get(); }
+       const Response *get_response() const { return response.get(); }
 private:
        void address_resolved(unsigned, const Net::SockAddr &);
        void resolve_failed(unsigned, const std::exception &);
index 182c5bbbdc7bfc1890cbf6e74380f6ac785902c1..4f2b7b3fa4de926e983f173d8dd7fdc386ce8fab 100644 (file)
@@ -1,9 +1,9 @@
+#include "formdata.h"
 #include <msp/core/maputils.h>
 #include "header.h"
 #include "request.h"
 #include "submessage.h"
 #include "utils.h"
-#include "formdata.h"
 
 using namespace std;
 
@@ -53,15 +53,15 @@ void FormData::parse_multipart(const Request &req, const string &boundary)
 
                if(is_boundary)
                {
-                       /* The CRLF preceding the boundary delimiter is treated as part
-                       of the delimiter as per RFC 2046 */
-                       string::size_type part_end = line_start-1;
-                       if(content[part_end-1]=='\r')
-                               --part_end;
-
                        if(part_start>0)
                        {
-                               SubMessage part = SubMessage::parse(content.substr(part_start, line_start-part_start));
+                               /* The CRLF preceding the boundary delimiter is treated as part
+                               of the delimiter as per RFC 2046 */
+                               string::size_type part_end = line_start-1;
+                               if(content[part_end-1]=='\r')
+                                       --part_end;
+
+                               SubMessage part = SubMessage::parse(content.substr(part_start, part_end-part_start));
                                Header content_disposition(part, "Content-Disposition");
                                const Header::Value &cd_value = content_disposition.values.at(0);
                                if(cd_value.value=="form-data")
@@ -72,10 +72,10 @@ void FormData::parse_multipart(const Request &req, const string &boundary)
                        }
 
                        part_start = lf+1;
-               }
 
-               if(!content.compare(line_start+2+boundary.size(), 2, "--"))
-                       break;
+                       if(!content.compare(line_start+2+boundary.size(), 2, "--"))
+                               break;
+               }
 
                line_start = lf+1;
        }
@@ -83,7 +83,7 @@ void FormData::parse_multipart(const Request &req, const string &boundary)
 
 const string &FormData::get_value(const string &key) const
 {
-       map<string, string>::const_iterator i = fields.find(key);
+       auto i = fields.find(key);
        if(i==fields.end())
        {
                static string dummy;
index 32328910b92c55eab01ce8cc3f1e2108ca56d28c..cf78acd30b387b8e12913826478d4a5700fb1f3a 100644 (file)
@@ -3,13 +3,14 @@
 
 #include <map>
 #include <string>
+#include <msp/net/mspnet_api.h>
 
 namespace Msp {
 namespace Http {
 
 class Request;
 
-class FormData
+class MSPNET_API FormData
 {
 private:
        std::map<std::string, std::string> fields;
index 02c82ca08001c10f898102afd1595a980ee3c8b8..a112e91b6246155139743bc063873aae5e2c41bd 100644 (file)
@@ -1,6 +1,6 @@
+#include "header.h"
 #include <stdexcept>
 #include <msp/strings/utils.h>
-#include "header.h"
 #include "message.h"
 
 using namespace std;
@@ -8,37 +8,74 @@ using namespace std;
 namespace Msp {
 namespace Http {
 
-Header::Header(const Message &msg, const string &n):
+Header::Header(const Message &msg, const string &n, Style s):
        name(n),
+       style(s),
        raw_value(msg.get_header(name))
 {
        parse();
 }
 
-Header::Header(const string &n, const string &rv):
+Header::Header(const string &n, const string &rv, Style s):
        name(n),
+       style(s),
        raw_value(rv)
 {
        parse();
 }
 
+Header::Style Header::get_default_style(const string &name)
+{
+       if(!strcasecmp(name, "content-disposition"))
+               return VALUE_WITH_ATTRIBUTES;
+       else if(!strcasecmp(name, "content-type"))
+               return VALUE_WITH_ATTRIBUTES;
+       else if(!strcasecmp(name, "cookie"))
+               return KEY_VALUE_LIST;
+       else if(!strcasecmp(name, "set-cookie"))
+               return VALUE_WITH_ATTRIBUTES;
+       else
+               return SINGLE_VALUE;
+}
+
 void Header::parse()
 {
-       string::const_iterator i = raw_value.begin();
-       while(i!=raw_value.end())
+       if(style==DEFAULT)
+               style = get_default_style(name);
+
+       if(style==SINGLE_VALUE)
        {
                Value value;
+               value.value = strip(raw_value);
+               values.push_back(value);
+               return;
+       }
 
-               string::const_iterator start = i;
-               for(; (i!=raw_value.end() && *i!=';' && *i!=','); ++i) ;
-               value.value = strip(string(start, i));
-               if(value.value.empty())
-                       throw invalid_argument("Header::parse");
+       char value_sep = (style==VALUE_WITH_ATTRIBUTES ? 0 : ',');
 
-               while(i!=raw_value.end() && *i!=',')
+       auto i = raw_value.cbegin();
+       while(i!=raw_value.cend())
+       {
+               Value value;
+
+               auto start = i;
+               if(style==KEY_VALUE_LIST)
+                       value.value = name;
+               else
+               {
+                       for(; (i!=raw_value.end() && *i!=';' && *i!=value_sep); ++i) ;
+                       value.value = strip(string(start, i));
+                       if(value.value.empty())
+                               throw invalid_argument("Header::parse");
+               }
+
+               while(i!=raw_value.end() && (*i!=',' || style==KEY_VALUE_LIST) && style!=LIST)
                {
-                       start = ++i;
-                       for(; (i!=raw_value.end() && *i!=';' && *i!=',' && *i!='='); ++i) ;
+                       if(*i==';' || *i==',')
+                               ++i;
+
+                       start = i;
+                       for(; (i!=raw_value.end() && *i!=';' && *i!=value_sep && *i!='='); ++i) ;
                        string pname = strip(string(start, i));
                        if(pname.empty())
                                throw invalid_argument("Header::parse");
@@ -47,7 +84,7 @@ void Header::parse()
                        if(i!=raw_value.end() && *i=='=')
                        {
                                for(++i; (i!=raw_value.end() && isspace(*i)); ++i) ;
-                               if(i==raw_value.end() || *i==';' || *i==',')
+                               if(i==raw_value.end() || *i==';' || *i==value_sep)
                                        throw invalid_argument("Header::parse");
 
                                if(*i=='"')
@@ -59,14 +96,14 @@ void Header::parse()
 
                                        pvalue = string(start, i);
 
-                                       for(++i; (i!=raw_value.end() && *i!=';' && *i!=','); ++i)
+                                       for(++i; (i!=raw_value.end() && *i!=';' && *i!=value_sep); ++i)
                                                if(!isspace(*i))
                                                        throw invalid_argument("Header::parse");
                                }
                                else
                                {
                                        start = i;
-                                       for(; (i!=raw_value.end() && *i!=';' && *i!=','); ++i) ;
+                                       for(; (i!=raw_value.end() && *i!=';' && *i!=value_sep); ++i) ;
                                        pvalue = strip(string(start, i));
                                }
                        }
@@ -75,6 +112,9 @@ void Header::parse()
                }
 
                values.push_back(value);
+
+               if(i!=raw_value.end() && (*i==';' || *i==','))
+                       ++i;
        }
 }
 
index 202b844110b1044d360eec00baaa30d6db8f94ad..1beb95f01f5327ff6ed40a0dc2ffde5cda35279d 100644 (file)
@@ -4,14 +4,25 @@
 #include <map>
 #include <string>
 #include <vector>
+#include <msp/net/mspnet_api.h>
 
 namespace Msp {
 namespace Http {
 
 class Message;
 
-struct Header
+struct MSPNET_API Header
 {
+       enum Style
+       {
+               DEFAULT,
+               SINGLE_VALUE,
+               LIST,
+               KEY_VALUE_LIST,
+               VALUE_WITH_ATTRIBUTES,
+               LIST_WITH_ATTRIBUTES
+       };
+
        struct Value
        {
                std::string value;
@@ -19,12 +30,15 @@ struct Header
        };
 
        std::string name;
+       Style style;
        std::string raw_value;
        std::vector<Value> values;
 
-       Header() { }
-       Header(const Message &, const std::string &);
-       Header(const std::string &, const std::string &);
+       Header() = default;
+       Header(const Message &, const std::string &, Style = DEFAULT);
+       Header(const std::string &, const std::string &, Style = DEFAULT);
+
+       static Style get_default_style(const std::string &);
 
        void parse();
 };
index b86f95495eb267639bbc1bdf9ee58662dd198bf1..a6ab4653de9dec58c7c186ea5ecd5c6138545102 100644 (file)
@@ -1,20 +1,14 @@
+#include "message.h"
 #include <cstdlib>
 #include <msp/core/maputils.h>
 #include <msp/strings/format.h>
 #include <msp/strings/utils.h>
-#include "message.h"
 
 using namespace std;
 
 namespace Msp {
 namespace Http {
 
-Message::Message():
-       http_version(0x11),
-       chunk_length(0),
-       complete(false)
-{ }
-
 void Message::set_header(const string &hdr, const string &val)
 {
        headers[normalize_header_name(hdr)] = val;
@@ -48,7 +42,7 @@ unsigned Message::parse_content(const string &d)
        if(complete)
                return 0;
 
-       HeaderMap::const_iterator i = headers.find("Content-Length");
+       auto i = headers.find("Content-Length");
        if(i!=headers.end())
        {
                string::size_type needed = lexical_cast<string::size_type>(i->second)-content.size();
@@ -120,9 +114,9 @@ string Message::str_common() const
 {
        string result;
 
-       for(HeaderMap::const_iterator i=headers.begin(); i!=headers.end(); ++i)
-               if(i->first[0]!='-')
-                       result += format("%s: %s\r\n", i->first, i->second);
+       for(auto &kvp: headers)
+               if(kvp.first[0]!='-')
+                       result += format("%s: %s\r\n", kvp.first, kvp.second);
        result += "\r\n";
        result += content;
 
@@ -133,17 +127,17 @@ string Message::normalize_header_name(const string &hdr) const
 {
        string result = hdr;
        bool upper = true;
-       for(string::iterator i=result.begin(); i!=result.end(); ++i)
+       for(char &c: result)
        {
-               if(*i=='-')
+               if(c=='-')
                        upper = true;
                else if(upper)
                {
-                       *i = toupper(*i);
+                       c = toupper(static_cast<unsigned char>(c));
                        upper = false;
                }
                else
-                       *i = tolower(*i);
+                       c = tolower(static_cast<unsigned char>(c));
        }
        return result;
 }
index e5253606b05e034c0ab9a208a6cd6cd32e39810d..be2224d88af4480704bd1d610e3455d14ad6218b 100644 (file)
@@ -4,26 +4,27 @@
 #include <map>
 #include <string>
 #include <msp/core/variant.h>
+#include <msp/net/mspnet_api.h>
 #include "version.h"
 
 namespace Msp {
 namespace Http {
 
-class Message
+class MSPNET_API Message
 {
 protected:
        typedef std::map<std::string, std::string> HeaderMap;
 
-       Version http_version;
+       Version http_version = 0x11;
        HeaderMap headers;
        std::string content;
-       std::string::size_type chunk_length;
-       bool complete;
+       std::string::size_type chunk_length = 0;
+       bool complete = false;
        Variant user_data;
 
-       Message();
+       Message() = default;
 public:
-       virtual ~Message() { }
+       virtual ~Message() = default;
 
        void set_header(const std::string &, const std::string &);
        bool has_header(const std::string &) const;
index 54b0beeaa9cb4f7e908f995a2d0d328bc1bf91a0..685de332371749d35b7665e7619e399c39726f29 100644 (file)
@@ -1,7 +1,7 @@
+#include "request.h"
 #include <msp/strings/format.h>
 #include <msp/strings/regex.h>
 #include <msp/strings/utils.h>
-#include "request.h"
 #include "utils.h"
 
 using namespace std;
@@ -28,6 +28,8 @@ string Request::str() const
 Request Request::parse(const string &str)
 {
        string::size_type lf = str.find('\n');
+       if(lf==0)
+               throw invalid_argument("Request::parse");
        vector<string> parts = split(str.substr(0, lf-(str[lf-1]=='\r')), ' ', 2);
        if(parts.size()<3)
                throw invalid_argument("Request::parse");
@@ -51,11 +53,7 @@ Request Request::from_url(const string &str)
        string path = urlencode(url.path);
        if(path.empty())
                path = "/";
-       if(!url.query.empty())
-       {
-               path += '?';
-               path += url.query;
-       }
+       append(path, "?", url.query);
 
        Request result("GET", path);
        result.set_header("Host", url.host);
index ec5030e7e543b6e04c004df30f4ddf8937add09d..11af9936c662d0e4af688b15cc3141bf0fa6e095 100644 (file)
@@ -2,12 +2,13 @@
 #define MSP_HTTP_REQUEST_H_
 
 #include <string>
+#include <msp/net/mspnet_api.h>
 #include "message.h"
 
 namespace Msp {
 namespace Http {
 
-class Request: public Message
+class MSPNET_API Request: public Message
 {
 private:
        std::string method;
@@ -17,7 +18,7 @@ public:
        Request(const std::string &, const std::string &);
        const std::string &get_method() const { return method; }
        const std::string &get_path() const { return path; }
-       virtual std::string str() const;
+       std::string str() const override;
 
        static Request parse(const std::string &);
        static Request from_url(const std::string &);
index 739a20f3da2fa2ba6503e13b216ba842b9403ec3..6005d35ae6e1e26bf62837262485302e5d3df117 100644 (file)
@@ -1,6 +1,6 @@
+#include "response.h"
 #include <msp/strings/format.h>
 #include <msp/strings/utils.h>
-#include "response.h"
 
 using namespace std;
 
@@ -24,7 +24,9 @@ Response Response::parse(const string &str)
        Response result;
 
        string::size_type lf = str.find('\n');
-       vector<string> parts = split(str.substr(0, lf), ' ', 2);
+       if(lf==0)
+               throw invalid_argument("Response::parse");
+       vector<string> parts = split(str.substr(0, lf-(str[lf-1]=='\r')), ' ', 2);
        if(parts.size()<2)
                throw invalid_argument("Response::parse");
 
index bb471f0de4fe80879ea7da223152b8ee4e671486..8644b7228a11b1b3ce5652c7998ec474a28109ba 100644 (file)
@@ -1,22 +1,23 @@
 #ifndef MSP_HTTP_RESPONSE_H_
 #define MSP_HTTP_RESPONSE_H_
 
+#include <msp/net/mspnet_api.h>
 #include "message.h"
 #include "status.h"
 
 namespace Msp {
 namespace Http {
 
-class Response: public Message
+class MSPNET_API Response: public Message
 {
 private:
        Status status;
 
-       Response() { }
+       Response() = default;
 public:
        Response(Status);
        Status get_status() const { return status; }
-       virtual std::string str() const;
+       std::string str() const override;
 
        static Response parse(const std::string &);
 };
index f0474fd18e993c57de16b1a3f51c07dc260f069f..04bae0b4d8c24f7f9e72177c923c19fc2234083c 100644 (file)
@@ -1,6 +1,8 @@
+#include "server.h"
 #include <exception>
+#include <typeinfo>
 #include <msp/core/maputils.h>
-#include <msp/core/refptr.h>
+#include <msp/debug/demangle.h>
 #include <msp/net/inet.h>
 #include <msp/net/resolve.h>
 #include <msp/net/streamsocket.h>
 #include <msp/strings/utils.h>
 #include "request.h"
 #include "response.h"
-#include "server.h"
 
 using namespace std;
 
 namespace Msp {
 namespace Http {
 
+Server::Server():
+       sock(Net::INET6)
+{ }
+
 Server::Server(unsigned port):
-       sock(Net::INET),
-       event_disp(0)
+       sock(Net::INET6)
 {
-       sock.signal_data_available.connect(sigc::mem_fun(this, &Server::data_available));
-       RefPtr<Net::SockAddr> addr = Net::resolve("*", format("%d", port));
-       sock.listen(*addr, 8);
+       listen(port);
 }
 
 // Avoid emitting sigc::signal destructor in files including server.h
@@ -29,6 +31,13 @@ Server::~Server()
 {
 }
 
+void Server::listen(unsigned port)
+{
+       unique_ptr<Net::SockAddr> addr(Net::resolve("*", format("%d", port), Net::INET6));
+       sock.listen(*addr, 8);
+       sock.signal_data_available.connect(sigc::mem_fun(this, &Server::data_available));
+}
+
 unsigned Server::get_port() const
 {
        const Net::SockAddr &addr = sock.get_local_address();
@@ -42,15 +51,15 @@ void Server::use_event_dispatcher(IO::EventDispatcher *ed)
        if(event_disp)
        {
                event_disp->remove(sock);
-               for(list<Client>::iterator i=clients.begin(); i!=clients.end(); ++i)
-                       event_disp->remove(*i->sock);
+               for(Client &c: clients)
+                       event_disp->remove(*c.sock);
        }
        event_disp = ed;
        if(event_disp)
        {
                event_disp->add(sock);
-               for(list<Client>::iterator i=clients.begin(); i!=clients.end(); ++i)
-                       event_disp->add(*i->sock);
+               for(Client &c: clients)
+                       event_disp->add(*c.sock);
        }
 }
 
@@ -71,66 +80,98 @@ void Server::cancel_keepalive(Response &resp)
        get_client_by_response(resp).keepalive = false;
 }
 
+void Server::close_connections(const Time::TimeDelta &timeout)
+{
+       IO::Poller poller;
+       for(Client &c: clients)
+       {
+               c.sock->shutdown(IO::M_WRITE);
+               poller.set_object(*c.sock, IO::P_INPUT);
+       }
+
+       while(!clients.empty() && poller.poll(timeout))
+       {
+               for(const IO::Poller::PolledObject &p: poller.get_result())
+                       for(auto j=clients.begin(); j!=clients.end(); ++j)
+                               if(j->sock.get()==p.object)
+                               {
+                                       poller.set_object(*j->sock, IO::P_NONE);
+                                       clients.erase(j);
+                                       break;
+                               }
+       }
+}
+
 void Server::data_available()
 {
-       Net::StreamSocket *csock = sock.accept();
-       clients.push_back(Client(csock));
-       csock->signal_data_available.connect(sigc::bind(sigc::mem_fun(this, &Server::client_data_available), sigc::ref(clients.back())));
-       csock->signal_end_of_file.connect(sigc::bind(sigc::mem_fun(this, &Server::client_end_of_file), sigc::ref(clients.back())));
+       unique_ptr<Net::StreamSocket> csock(sock.accept());
+       clients.emplace_back(move(csock));
+       Client &cl = clients.back();
+       cl.sock->signal_data_available.connect(sigc::bind(sigc::mem_fun(this, &Server::client_data_available), sigc::ref(clients.back())));
+       cl.sock->signal_end_of_file.connect(sigc::bind(sigc::mem_fun(this, &Server::client_end_of_file), sigc::ref(clients.back())));
        if(event_disp)
-               event_disp->add(*csock);
+               event_disp->add(*cl.sock);
 }
 
 void Server::client_data_available(Client &cl)
 {
-       for(list<Client>::iterator i=clients.begin(); i!=clients.end(); ++i)
+       for(auto i=clients.begin(); i!=clients.end(); ++i)
                if(i->stale && &*i!=&cl)
                {
                        clients.erase(i);
                        break;
                }
 
-       char rbuf[4096];
-       unsigned len = cl.sock->read(rbuf, sizeof(rbuf));
-       if(cl.stale)
+       try
+       {
+               char rbuf[4096];
+               unsigned len = cl.sock->read(rbuf, sizeof(rbuf));
+               if(cl.stale)
+                       return;
+               cl.in_buf.append(rbuf, len);
+       }
+       catch(const exception &)
+       {
+               cl.stale = true;
                return;
-       cl.in_buf.append(rbuf, len);
+       }
 
-       RefPtr<Response> response;
+       unique_ptr<Response> response;
        if(!cl.request)
        {
                if(cl.in_buf.find("\r\n\r\n")!=string::npos || cl.in_buf.find("\n\n")!=string::npos)
                {
                        try
                        {
-                               cl.request = new Request(Request::parse(cl.in_buf));
+                               cl.request = make_unique<Request>(Request::parse(cl.in_buf));
 
                                string addr_str = cl.sock->get_peer_address().str();
-                               string::size_type colon = addr_str.find(':');
+                               string::size_type colon = addr_str.find(':', (addr_str[0]=='[' ? addr_str.find(']')+1 : 0));
                                cl.request->set_header("-Client-Host", addr_str.substr(0, colon));
 
                                if(cl.request->get_method()!="GET" && cl.request->get_method()!="POST")
                                {
-                                       response = new Response(NOT_IMPLEMENTED);
+                                       response = make_unique<Response>(NOT_IMPLEMENTED);
                                        response->add_content("Method not implemented\n");
                                }
                                else if(cl.request->get_path()[0]!='/')
                                {
-                                       response = new Response(BAD_REQUEST);
+                                       response = make_unique<Response>(BAD_REQUEST);
                                        response->add_content("Path must be absolute\n");
                                }
                        }
                        catch(const exception &e)
                        {
-                               response = new Response(BAD_REQUEST);
-                               response->add_content(e.what());
+                               response = make_unique<Response>(BAD_REQUEST);
+                               response->add_content(format("An error occurred while parsing request headers:\ntype: %s\nwhat: %s",
+                                       Debug::demangle(typeid(e).name()), e.what()));
                        }
                        cl.in_buf = string();
                }
        }
        else
        {
-               len = cl.request->parse_content(cl.in_buf);
+               unsigned len = cl.request->parse_content(cl.in_buf);
                cl.in_buf.erase(0, len);
        }
 
@@ -140,31 +181,30 @@ void Server::client_data_available(Client &cl)
                if(cl.request->has_header("Connection"))
                        cl.keepalive = !strcasecmp(cl.request->get_header("Connection"), "keep-alive");
 
-               response = new Response(NONE);
+               response = make_unique<Response>(NONE);
                try
                {
-                       cl.response = response.get();
-                       responses[cl.response] = &cl;
-                       signal_request.emit(*cl.request, *response);
-                       if(cl.async)
-                               response.release();
-                       else
+                       cl.response = move(response);
+                       responses[cl.response.get()] = &cl;
+                       signal_request.emit(*cl.request, *cl.response);
+                       if(!cl.async)
                        {
-                               responses.erase(cl.response);
-                               cl.response = 0;
+                               responses.erase(cl.response.get());
+                               response = move(cl.response);
                                if(response->get_status()==NONE)
                                {
-                                       response = new Response(NOT_FOUND);
+                                       response = make_unique<Response>(NOT_FOUND);
                                        response->add_content("The requested resource was not found\n");
                                }
                        }
                }
                catch(const exception &e)
                {
-                       responses.erase(cl.response);
-                       cl.response = 0;
-                       response = new Response(INTERNAL_ERROR);
-                       response->add_content(e.what());
+                       responses.erase(cl.response.get());
+                       cl.response.reset();
+                       response = make_unique<Response>(INTERNAL_ERROR);
+                       response->add_content(format("An error occurred while processing the request:\ntype: %s\nwhat: %s",
+                               Debug::demangle(typeid(e).name()), e.what()));
                }
        }
 
@@ -176,14 +216,22 @@ void Server::send_response(Client &cl, Response &resp)
 {
        if(cl.keepalive)
                resp.set_header("Connection", "keep-alive");
-       cl.sock->write(resp.str());
+
+       try
+       {
+               cl.sock->write(resp.str());
+       }
+       catch(const exception &)
+       {
+               cl.stale = true;
+               return;
+       }
+
        cl.async = false;
        if(cl.keepalive)
        {
-               delete cl.request;
-               cl.request = 0;
-               delete cl.response;
-               cl.response = 0;
+               cl.request.reset();
+               cl.response.reset();
        }
        else
        {
@@ -203,20 +251,9 @@ Server::Client &Server::get_client_by_response(Response &resp)
 }
 
 
-Server::Client::Client(RefPtr<Net::StreamSocket> s):
-       sock(s),
-       request(0),
-       response(0),
-       keepalive(false),
-       async(false),
-       stale(false)
+Server::Client::Client(unique_ptr<Net::StreamSocket> s):
+       sock(move(s))
 { }
 
-Server::Client::~Client()
-{
-       delete request;
-       delete response;
-}
-
 } // namespace Http
 } // namespace Msp
index 3679013dda51c22c9f76438298b870e069bcd6e1..8665fa3c88f5c0777f25e11ce578f23d2a8d67ff 100644 (file)
@@ -1,9 +1,10 @@
 #ifndef MSP_HTTP_SERVER_H_
 #define MSP_HTTP_SERVER_H_
 
-#include <msp/core/refptr.h>
 #include <msp/io/eventdispatcher.h>
+#include <msp/net/mspnet_api.h>
 #include <msp/net/streamserversocket.h>
+#include <msp/time/timedelta.h>
 
 namespace Msp {
 namespace Http {
@@ -11,7 +12,7 @@ namespace Http {
 class Request;
 class Response;
 
-class Server
+class MSPNET_API Server
 {
 public:
        sigc::signal<void, const Request &, Response &> signal_request;
@@ -19,32 +20,34 @@ public:
 private:
        struct Client
        {
-               RefPtr<Net::StreamSocket> sock;
+               std::unique_ptr<Net::StreamSocket> sock;
                std::string in_buf;
-               Request *request;
-               Response *response;
-               bool keepalive;
-               bool async;
-               bool stale;
-
-               Client(RefPtr<Net::StreamSocket>);
-               ~Client();
+               std::unique_ptr<Request> request;
+               std::unique_ptr<Response> response;
+               bool keepalive = false;
+               bool async = false;
+               bool stale = false;
+
+               Client(std::unique_ptr<Net::StreamSocket>);
        };
 
        Net::StreamServerSocket sock;
        std::list<Client> clients;
        std::map<Response *, Client *> responses;
-       IO::EventDispatcher *event_disp;
+       IO::EventDispatcher *event_disp = nullptr;
 
 public:
+       Server();
        Server(unsigned);
        ~Server();
 
+       void listen(unsigned);
        unsigned get_port() const;
        void use_event_dispatcher(IO::EventDispatcher *);
        void delay_response(Response &);
        void submit_response(Response &);
        void cancel_keepalive(Response &);
+       void close_connections(const Time::TimeDelta &);
 private:
        void data_available();
        void client_data_available(Client &);
index 5e7d54b19fc994b7331dfbad7095e4e2adef463d..5f4656e26fc706b0cd972332178b8becc630808b 100644 (file)
@@ -2,6 +2,7 @@
 #define MSP_HTTPSERVER_STATUS_H_
 
 #include <ostream>
+#include <msp/net/mspnet_api.h>
 
 namespace Msp {
 namespace Http {
@@ -20,7 +21,7 @@ enum Status
        NOT_IMPLEMENTED = 501
 };
 
-extern std::ostream &operator<<(std::ostream &, Status);
+MSPNET_API std::ostream &operator<<(std::ostream &, Status);
 
 } // namespace Http
 } // namespace Msp
index f60276a9267c7c6997215e7a01db49e01cf5a137..7ef9e376ddb5334f6e39daa69572c2852b29c2ab 100644 (file)
@@ -9,10 +9,10 @@ namespace Http {
 class SubMessage: public Message
 {
 private:
-       SubMessage() { }
+       SubMessage() = default;
 
 public:
-       virtual std::string str() const;
+       std::string str() const override;
 
        static SubMessage parse(const std::string &);
 };
index 41fabeff80cb83c83e1ffd430b328073a27e1eb0..7e1037e693d42094679d59ec905375eb6cd81ffe 100644 (file)
@@ -1,8 +1,8 @@
+#include "utils.h"
 #include <algorithm>
 #include <msp/strings/format.h>
 #include <msp/strings/regex.h>
 #include <msp/strings/utils.h>
-#include "utils.h"
 
 using namespace std;
 
@@ -31,12 +31,13 @@ namespace Http {
 string urlencode(const string &str, EncodeLevel level)
 {
        string result;
-       for(string::const_iterator i=str.begin(); i!=str.end(); ++i)
+       result.reserve(str.size());
+       for(char c: str)
        {
-               if(is_reserved(*i, level))
-                       result += format("%%%02X", *i);
+               if(is_reserved(c, level))
+                       result += format("%%%02X", c);
                else
-                       result += *i;
+                       result += c;
        }
        return result;
 }
@@ -44,14 +45,15 @@ string urlencode(const string &str, EncodeLevel level)
 string urlencode_plus(const string &str, EncodeLevel level)
 {
        string result;
-       for(string::const_iterator i=str.begin(); i!=str.end(); ++i)
+       result.reserve(str.size());
+       for(char c: str)
        {
-               if(*i==' ')
+               if(c==' ')
                        result += '+';
-               else if(is_reserved(*i, level))
-                       result += format("%%%02X", *i);
+               else if(is_reserved(c, level))
+                       result += format("%%%02X", c);
                else
-                       result += *i;
+                       result += c;
        }
        return result;
 }
@@ -79,7 +81,7 @@ string urldecode(const string &str)
 
 Url parse_url(const string &str)
 {
-       static Regex r_url("^(([a-z]+)://)?([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*(:[0-9]+)?)?(/[^?#]*)?(\\?([^#]+))?(#(.*))?$");
+       static Regex r_url("^(([a-z]+)://)?([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*(:[0-9]+)?)?(/[^?#]*)?(\\?([^#]*))?(#(.*))?$");
        if(RegMatch m = r_url.match(str))
        {
                Url url;
@@ -104,29 +106,20 @@ string build_url(const Url &url)
                str += url.scheme+"://";
        str += url.host;
        str += urlencode(url.path);
-       if(!url.query.empty())
-       {
-               str += '?';
-               str += url.query;
-       }
-       if(!url.fragment.empty())
-       {
-               str += '#';
-               str += url.fragment;
-       }
+       append(str, "?", url.query);
+       append(str, "#", url.fragment);
        return str;
 }
 
 Query parse_query(const std::string &str)
 {
-       vector<string> parts = split(str, '&');
        Query query;
-       for(vector<string>::const_iterator i=parts.begin(); i!=parts.end(); ++i)
+       for(const string &p: split(str, '&'))
        {
-               string::size_type equals = i->find('=');
-               string &value = query[urldecode(i->substr(0, equals))];
+               string::size_type equals = p.find('=');
+               string &value = query[urldecode(p.substr(0, equals))];
                if(equals!=string::npos)
-                       value = urldecode(i->substr(equals+1));
+                       value = urldecode(p.substr(equals+1));
        }
        return query;
 }
@@ -134,13 +127,11 @@ Query parse_query(const std::string &str)
 string build_query(const Query &query)
 {
        string str;
-       for(Query::const_iterator i=query.begin(); i!=query.end(); ++i)
+       for(const auto &kvp: query)
        {
-               if(i!=query.begin())
-                       str += '&';
-               str += urlencode_plus(i->first);
+               append(str, "&", urlencode_plus(kvp.first));
                str += '=';
-               str += urlencode_plus(i->second);
+               str += urlencode_plus(kvp.second);
        }
        return str;
 }
index a25a748a6aa7464409c20fb1c67702b9b3e367a1..2920d2d7ea5e86885d61f952af8d768c4d3534f2 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <map>
 #include <string>
+#include <msp/net/mspnet_api.h>
 
 namespace Msp {
 namespace Http {
@@ -25,13 +26,13 @@ struct Url
 
 typedef std::map<std::string, std::string> Query;
 
-std::string urlencode(const std::string &, EncodeLevel =SAFE);
-std::string urlencode_plus(const std::string &, EncodeLevel =SAFE);
-std::string urldecode(const std::string &);
-Url parse_url(const std::string &);
-std::string build_url(const Url &);
-Query parse_query(const std::string &);
-std::string build_query(const Query &);
+MSPNET_API std::string urlencode(const std::string &, EncodeLevel =SAFE);
+MSPNET_API std::string urlencode_plus(const std::string &, EncodeLevel =SAFE);
+MSPNET_API std::string urldecode(const std::string &);
+MSPNET_API Url parse_url(const std::string &);
+MSPNET_API std::string build_url(const Url &);
+MSPNET_API Query parse_query(const std::string &);
+MSPNET_API std::string build_query(const Query &);
 
 } // namespace Http
 } // namespace Msp
index c9fd1c91bd83c4d352c36fdbce9e7c8770c18fb9..d6f846c68c80fc30b118f96a6e4fa025aa2ea5db 100644 (file)
@@ -1,7 +1,7 @@
+#include "version.h"
 #include <msp/strings/format.h>
 #include <msp/strings/lexicalcast.h>
 #include <msp/strings/regex.h>
-#include "version.h"
 
 using namespace std;
 
index 19a50b91097ea545de16deda5df671bfe989cd63..665e619ee4184b51277155f886d9bdea85b16fe2 100644 (file)
@@ -2,14 +2,15 @@
 #define MSP_HTTP_MISC_H_
 
 #include <string>
+#include <msp/net/mspnet_api.h>
 
 namespace Msp {
 namespace Http {
 
 typedef unsigned Version;
 
-Version parse_version(const std::string &);
-std::string version_str(Version);
+MSPNET_API Version parse_version(const std::string &);
+MSPNET_API std::string version_str(Version);
 
 } // namespace Http
 } // namespace Msp
index 04d07000a7bb1286c105a1697fdc71f0bc880901..51a8dfa01b7913485397438ddcc3c24b5e9ae35b 100644 (file)
@@ -1,21 +1,19 @@
 #include "platform_api.h"
-#include <msp/core/systemerror.h>
 #include "clientsocket.h"
+#include <msp/core/systemerror.h>
 #include "socket_private.h"
 
+using namespace std;
+
 namespace Msp {
 namespace Net {
 
 ClientSocket::ClientSocket(Family af, int type, int proto):
-       Socket(af, type, proto),
-       connecting(false),
-       connected(false),
-       peer_addr(0)
+       Socket(af, type, proto)
 { }
 
 ClientSocket::ClientSocket(const Private &p, const SockAddr &paddr):
        Socket(p),
-       connecting(false),
        connected(true),
        peer_addr(paddr.copy())
 { }
@@ -23,8 +21,6 @@ ClientSocket::ClientSocket(const Private &p, const SockAddr &paddr):
 ClientSocket::~ClientSocket()
 {
        signal_flush_required.emit();
-
-       delete peer_addr;
 }
 
 void ClientSocket::shutdown(IO::Mode m)
@@ -55,12 +51,12 @@ void ClientSocket::shutdown(IO::Mode m)
 
 const SockAddr &ClientSocket::get_peer_address() const
 {
-       if(peer_addr==0)
+       if(!peer_addr)
                throw bad_socket_state("not connected");
        return *peer_addr;
 }
 
-unsigned ClientSocket::do_write(const char *buf, unsigned size)
+size_t ClientSocket::do_write(const char *buf, size_t size)
 {
        check_access(IO::M_WRITE);
        if(!connected)
@@ -72,24 +68,28 @@ unsigned ClientSocket::do_write(const char *buf, unsigned size)
        return check_sys_error(::send(priv->handle, buf, size, 0), "send");
 }
 
-unsigned ClientSocket::do_read(char *buf, unsigned size)
+size_t ClientSocket::do_read(char *buf, size_t size)
 {
        check_access(IO::M_READ);
        if(!connected)
                throw bad_socket_state("not connected");
 
+       // XXX This breaks level-triggered semantics on Windows
        if(size==0)
                return 0;
 
-       unsigned ret = check_sys_error(::recv(priv->handle, buf, size, 0), "recv");
-       if(ret==0 && !eof_flag)
+       make_signed<size_t>::type ret = ::recv(priv->handle, buf, size, 0);
+       if(ret==0)
        {
-               eof_flag = true;
-               signal_end_of_file.emit();
-               set_socket_events(S_NONE);
+               if(!eof_flag)
+               {
+                       set_socket_events(S_NONE);
+                       set_eof();
+               }
+               return 0;
        }
 
-       return ret;
+       return check_sys_error(ret, "recv");
 }
 
 } // namespace Net
index 0b570cd7afa634357274d620a5b8980d22f025f9..80d927a330c63847c8a9cdfe88643681dad75680 100644 (file)
@@ -1,6 +1,7 @@
 #ifndef MSP_NET_CLIENTSOCKET_H_
 #define MSP_NET_CLIENTSOCKET_H_
 
+#include "mspnet_api.h"
 #include "socket.h"
 
 namespace Msp {
@@ -9,16 +10,16 @@ namespace Net {
 /**
 ClientSockets are used for sending and receiving data over the network.
 */
-class ClientSocket: public Socket
+class MSPNET_API ClientSocket: public Socket
 {
 public:
        /** Emitted when the socket finishes connecting. */
        sigc::signal<void, const std::exception *> signal_connect_finished;
 
 protected:
-       bool connecting;
-       bool connected;
-       SockAddr *peer_addr;
+       bool connecting = false;
+       bool connected = false;
+       std::unique_ptr<SockAddr> peer_addr;
 
        ClientSocket(const Private &, const SockAddr &);
        ClientSocket(Family, int, int);
@@ -42,8 +43,8 @@ public:
 
        const SockAddr &get_peer_address() const;
 protected:
-       virtual unsigned do_write(const char *, unsigned);
-       virtual unsigned do_read(char *, unsigned);
+       std::size_t do_write(const char *, std::size_t) override;
+       std::size_t do_read(char *, std::size_t) override;
 };
 
 } // namespace Net
index 3d6353a2158fb2d62fbe6df164085a1b2ea7a417..73b10c2881b71c7c5c51302b3d97c521b0c4a2ad 100644 (file)
@@ -1,5 +1,5 @@
-#include <cstring>
 #include "communicator.h"
+#include <cstring>
 #include "streamsocket.h"
 
 using namespace std;
@@ -8,9 +8,17 @@ namespace {
 
 using namespace Msp::Net;
 
-struct Handshake
+// Sent when a protocol is added locally but isn't known to the other end yet
+struct PrepareProtocol
+{
+       uint64_t hash;
+       uint16_t base;
+};
+
+// Sent to confirm that a protocol is known to both peers
+struct AcceptProtocol
 {
-       Msp::UInt64 hash;
+       uint64_t hash;
 };
 
 
@@ -20,75 +28,107 @@ public:
        HandshakeProtocol();
 };
 
-HandshakeProtocol::HandshakeProtocol():
-       Protocol(0x7F00)
+HandshakeProtocol::HandshakeProtocol()
 {
-       add<Handshake>()(&Handshake::hash);
+       add<PrepareProtocol>(&PrepareProtocol::hash, &PrepareProtocol::base);
+       add<AcceptProtocol>(&AcceptProtocol::hash);
 }
 
+}
 
-class HandshakeReceiver: public PacketReceiver<Handshake>
-{
-private:
-       Msp::UInt64 hash;
-
-public:
-       HandshakeReceiver();
-       Msp::UInt64 get_hash() const { return hash; }
-       virtual void receive(const Handshake &);
-};
 
-HandshakeReceiver::HandshakeReceiver():
-       hash(0)
-{ }
+namespace Msp {
+namespace Net {
 
-void HandshakeReceiver::receive(const Handshake &shake)
+struct Communicator::Handshake: public PacketReceiver<PrepareProtocol>,
+       public PacketReceiver<AcceptProtocol>
 {
-       hash = shake.hash;
-}
+       Communicator &communicator;
+       HandshakeProtocol protocol;
 
-}
+       Handshake(Communicator &c): communicator(c) { }
 
+       void receive(const PrepareProtocol &) override;
+       void receive(const AcceptProtocol &) override;
+};
 
-namespace Msp {
-namespace Net {
 
-Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+Communicator::Communicator(StreamSocket &s):
        socket(s),
-       protocol(p),
-       receiver(r),
-       handshake_status(0),
-       buf_size(65536),
+       handshake(new Handshake(*this)),
        in_buf(new char[buf_size]),
        in_begin(in_buf),
        in_end(in_buf),
-       out_buf(new char[buf_size]),
-       good(true)
+       out_buf(new char[buf_size])
 {
        socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available));
+
+       protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake));
+       if(socket.is_connected())
+               prepare_protocol(protocols.back());
+       else
+               socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished));
+}
+
+Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r):
+       Communicator(s)
+{
+       add_protocol(p, r);
 }
 
 Communicator::~Communicator()
 {
        delete[] in_buf;
        delete[] out_buf;
+       delete handshake;
 }
 
-void Communicator::initiate_handshake()
+void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv)
 {
-       if(handshake_status!=0)
-               throw sequence_error("handshaking already done");
+       if(!good)
+               throw sequence_error("connection aborted");
+
+       unsigned max_id = proto.get_max_packet_id();
+       if(!max_id)
+               throw invalid_argument("Communicator::add_protocol");
+
+       uint64_t hash = proto.get_hash();
+       auto i = find_member(protocols, hash, &ActiveProtocol::hash);
+       if(i==protocols.end())
+       {
+               const ActiveProtocol &last = protocols.back();
+               if(!last.protocol)
+                       throw sequence_error("previous protocol is incomplete");
+               unsigned base = last.base;
+               base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF;
+
+               if(base+max_id>std::numeric_limits<std::uint16_t>::max())
+                       throw invalid_state("Communicator::add_protocol");
+
+               protocols.emplace_back(base, proto, recv);
+
+               if(socket.is_connected() && protocols.front().ready)
+                       prepare_protocol(protocols.back());
+       }
+       else if(!i->protocol)
+       {
+               i->protocol = &proto;
+               i->last = i->base+max_id;
+               i->receiver = &recv;
+               accept_protocol(*i);
+       }
+}
 
-       send_handshake();
-       handshake_status = 1;
+bool Communicator::is_protocol_ready(const Protocol &proto) const
+{
+       auto i = find_member(protocols, &proto, &ActiveProtocol::protocol);
+       return (i!=protocols.end() && i->ready);
 }
 
-void Communicator::send_data(unsigned size)
+void Communicator::send_data(size_t size)
 {
        if(!good)
                throw sequence_error("connection aborted");
-       if(handshake_status!=2)
-               throw sequence_error("handshake incomplete");
 
        try
        {
@@ -103,6 +143,14 @@ void Communicator::send_data(unsigned size)
        }
 }
 
+void Communicator::connect_finished(const exception *exc)
+{
+       if(exc)
+               good = false;
+       else
+               prepare_protocol(protocols.front());
+}
+
 void Communicator::data_available()
 {
        if(!good)
@@ -111,31 +159,7 @@ void Communicator::data_available()
        try
        {
                in_end += socket.read(in_end, in_buf+buf_size-in_end);
-
-               bool more = true;
-               while(more)
-               {
-                       if(handshake_status==2)
-                               more = receive_packet(protocol, receiver);
-                       else
-                       {
-                               HandshakeProtocol hsproto;
-                               HandshakeReceiver hsrecv;
-                               if((more = receive_packet(hsproto, hsrecv)))
-                               {
-                                       if(handshake_status==0)
-                                               send_handshake();
-
-                                       if(hsrecv.get_hash()==protocol.get_hash())
-                                       {
-                                               handshake_status = 2;
-                                               signal_handshake_done.emit();
-                                       }
-                                       else
-                                               throw incompatible_protocol("hash mismatch");
-                               }
-                       }
-               }
+               while(receive_packet()) ;
        }
        catch(const exception &e)
        {
@@ -146,37 +170,98 @@ void Communicator::data_available()
        }
 }
 
-bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv)
+bool Communicator::receive_packet()
 {
-       int psz = proto.get_packet_size(in_begin, in_end-in_begin);
-       if(psz && psz<=in_end-in_begin)
+       Protocol::PacketHeader header;
+       size_t available = in_end-in_begin;
+       if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available)
        {
+               auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last);
+               if(i==protocols.end() || header.type<i->base || header.type>i->last)
+                       throw key_error(header.type);
+
                char *pkt = in_begin;
-               in_begin += psz;
-               proto.dispatch(recv, pkt, psz);
+               in_begin += header.length;
+               i->protocol->dispatch(*i->receiver, pkt, header.length, i->base);
                return true;
        }
        else
        {
                if(in_end==in_buf+buf_size)
                {
-                       unsigned used = in_end-in_begin;
-                       memmove(in_buf, in_begin, used);
+                       memmove(in_buf, in_begin, available);
                        in_begin = in_buf;
-                       in_end = in_begin+used;
+                       in_end = in_begin+available;
                }
                return false;
        }
 }
 
-void Communicator::send_handshake()
+void Communicator::prepare_protocol(const ActiveProtocol &proto)
 {
-       Handshake shake;
-       shake.hash = protocol.get_hash();
+       PrepareProtocol prepare;
+       prepare.hash = proto.hash;
+       prepare.base = proto.base;
+       /* Use send_data() directly because this function is called to prepare the
+       handshake protocol too and send() would fail readiness check. */
+       send_data(handshake->protocol.serialize(prepare, out_buf, buf_size));
+}
+
+void Communicator::accept_protocol(ActiveProtocol &proto)
+{
+       proto.accepted = true;
+
+       AcceptProtocol accept;
+       accept.hash = proto.hash;
+       send_data(handshake->protocol.serialize(accept, out_buf, buf_size));
+}
+
 
-       HandshakeProtocol hsproto;
-       unsigned size = hsproto.serialize(shake, out_buf, buf_size);
-       socket.write(out_buf, size);
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r):
+       hash(p.get_hash()),
+       base(b),
+       last(base+p.get_max_packet_id()),
+       protocol(&p),
+       receiver(&r)
+{ }
+
+Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h):
+       hash(h),
+       base(b),
+       last(base)
+{ }
+
+
+void Communicator::Handshake::receive(const PrepareProtocol &prepare)
+{
+       auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base);
+       if(i!=communicator.protocols.end() && i->base==prepare.base)
+               communicator.accept_protocol(*i);
+       else
+               communicator.protocols.emplace(i, prepare.base, prepare.hash);
+}
+
+void Communicator::Handshake::receive(const AcceptProtocol &accept)
+{
+       auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash);
+       if(i==communicator.protocols.end())
+               throw key_error(accept.hash);
+
+
+       if(i->ready)
+               return;
+
+       i->ready = true;
+       if(!i->accepted)
+               communicator.accept_protocol(*i);
+       if(i->protocol==&protocol)
+       {
+               for(const ActiveProtocol &p: communicator.protocols)
+                       if(!p.ready)
+                               communicator.prepare_protocol(p);
+       }
+       else
+               communicator.signal_protocol_ready.emit(*i->protocol);
 }
 
 } // namespace Net
index 63114fd9ba18cd09caa2124cca6b2119b6e1a7e8..eb9893d861a2ca8eaa3679aeedcddcd6fde24ea4 100644 (file)
@@ -1,7 +1,10 @@
 #ifndef MSP_NET_COMMUNICATOR_H_
 #define MSP_NET_COMMUNICATOR_H_
 
+#include <msp/core/except.h>
+#include <msp/core/noncopyable.h>
 #include <sigc++/signal.h>
+#include "mspnet_api.h"
 #include "protocol.h"
 
 namespace Msp {
@@ -9,61 +12,83 @@ namespace Net {
 
 class StreamSocket;
 
-class sequence_error: public std::logic_error
+class MSPNET_API sequence_error: public invalid_state
 {
 public:
-       sequence_error(const std::string &w): std::logic_error(w) { }
-       virtual ~sequence_error() throw() { }
+       sequence_error(const std::string &w): invalid_state(w) { }
 };
 
-class incompatible_protocol: public std::runtime_error
+class MSPNET_API incompatible_protocol: public std::runtime_error
 {
 public:
        incompatible_protocol(const std::string &w): std::runtime_error(w) { }
-       virtual ~incompatible_protocol() throw() { }
 };
 
 
-class Communicator
+class MSPNET_API Communicator: public NonCopyable
 {
 public:
-       sigc::signal<void> signal_handshake_done;
+       sigc::signal<void, const Protocol &> signal_protocol_ready;
        sigc::signal<void, const std::exception &> signal_error;
 
 private:
+       struct ActiveProtocol
+       {
+               std::uint64_t hash = 0;
+               std::uint16_t base = 0;
+               std::uint16_t last = 0;
+               bool accepted = false;
+               bool ready = false;
+               const Protocol *protocol = nullptr;
+               ReceiverBase *receiver = nullptr;
+
+               ActiveProtocol(std::uint16_t, const Protocol &, ReceiverBase &);
+               ActiveProtocol(std::uint16_t, std::uint64_t);
+       };
+
+       struct Handshake;
+
        StreamSocket &socket;
-       const Protocol &protocol;
-       ReceiverBase &receiver;
-       int handshake_status;
-       unsigned buf_size;
-       char *in_buf;
-       char *in_begin;
-       char *in_end;
-       char *out_buf;
-       bool good;
+       std::vector<ActiveProtocol> protocols;
+       Handshake *handshake = nullptr;
+       std::size_t buf_size = 65536;
+       char *in_buf = nullptr;
+       char *in_begin = nullptr;
+       char *in_end = nullptr;
+       char *out_buf = nullptr;
+       bool good = true;
 
 public:
+       Communicator(StreamSocket &);
        Communicator(StreamSocket &, const Protocol &, ReceiverBase &);
        ~Communicator();
 
-       void initiate_handshake();
-       bool is_handshake_done() const { return handshake_status==2; }
+       void add_protocol(const Protocol &, ReceiverBase &);
+       bool is_protocol_ready(const Protocol &) const;
 
        template<typename P>
        void send(const P &);
 
 private:
-       void send_data(unsigned);
+       void send_data(std::size_t);
 
+       void connect_finished(const std::exception *);
        void data_available();
-       bool receive_packet(const Protocol &, ReceiverBase &);
-       void send_handshake();
+       bool receive_packet();
+
+       void prepare_protocol(const ActiveProtocol &);
+       void accept_protocol(ActiveProtocol &);
 };
 
 template<typename P>
 void Communicator::send(const P &pkt)
 {
-       send_data(protocol.serialize(pkt, out_buf, buf_size));
+       auto i = find_if(protocols, [](const ActiveProtocol &p){ return p.protocol && p.protocol->has_packet<P>(); });
+       if(i==protocols.end())
+               throw key_error(typeid(P).name());
+       else if(!i->ready)
+               throw sequence_error("protocol not ready");
+       send_data(i->protocol->serialize(pkt, out_buf, buf_size, i->base));
 }
 
 } // namespace Net
diff --git a/source/net/constants.cpp b/source/net/constants.cpp
deleted file mode 100644 (file)
index 5088e51..0000000
+++ /dev/null
@@ -1,35 +0,0 @@
-#include <stdexcept>
-#include "platform_api.h"
-#include "constants.h"
-
-using namespace std;
-
-namespace Msp {
-namespace Net {
-
-int family_to_sys(Family f)
-{
-       switch(f)
-       {
-       case UNSPEC: return AF_UNSPEC;
-       case INET:   return AF_INET;
-       case INET6:  return AF_INET6;
-       case UNIX:   return AF_UNIX;
-       default: throw invalid_argument("family_to_sys");
-       }
-}
-
-Family family_from_sys(int f)
-{
-       switch(f)
-       {
-       case AF_UNSPEC: return UNSPEC;
-       case AF_INET:   return INET;
-       case AF_INET6:  return INET6;
-       case AF_UNIX:   return UNIX;
-       default: throw invalid_argument("family_from_sys");
-       }
-}
-
-} // namespace Net
-} // namespace Msp
index 0d94661584d1027e2ffbe39b98c52a71da7f8974..2e6d544127d7dbb5db248844f0881424179abed2 100644 (file)
@@ -1,21 +1,8 @@
 #ifndef MSP_NET_CONSTANTS_H_
 #define MSP_NET_CONSTANTS_H_
 
-namespace Msp {
-namespace Net {
+#warning "This header is deprected and should not be used"
 
-enum Family
-{
-       UNSPEC,
-       INET,
-       INET6,
-       UNIX
-};
-
-int family_to_sys(Family);
-Family family_from_sys(int);
-
-} // namespace Net
-} // namespace Msp
+#include "sockaddr.h"
 
 #endif
index fe01ef5c6a6ffc7e2419750679bfaed347238fd3..54935545bbf984629602498c6d162d0e7170f75b 100644 (file)
@@ -1,8 +1,8 @@
 #include "platform_api.h"
+#include "datagramsocket.h"
 #include <msp/core/systemerror.h>
 #include <msp/io/handle_private.h>
 #include <msp/strings/format.h>
-#include "datagramsocket.h"
 #include "sockaddr_private.h"
 #include "socket_private.h"
 
@@ -20,20 +20,18 @@ bool DatagramSocket::connect(const SockAddr &addr)
        SockAddr::SysAddr sa = addr.to_sys();
        check_sys_connect_error(::connect(priv->handle, reinterpret_cast<const sockaddr *>(&sa.addr), sa.size));
 
-       delete peer_addr;
-       peer_addr = addr.copy();
+       peer_addr.reset(addr.copy());
 
-       delete local_addr;
        SockAddr::SysAddr lsa;
        getsockname(priv->handle, reinterpret_cast<sockaddr *>(&lsa.addr), &lsa.size);
-       local_addr = SockAddr::new_from_sys(lsa);
+       local_addr.reset(SockAddr::new_from_sys(lsa));
 
        connected = true;
 
        return true;
 }
 
-unsigned DatagramSocket::sendto(const char *buf, unsigned size, const SockAddr &addr)
+size_t DatagramSocket::sendto(const char *buf, size_t size, const SockAddr &addr)
 {
        if(size==0)
                return 0;
@@ -42,13 +40,13 @@ unsigned DatagramSocket::sendto(const char *buf, unsigned size, const SockAddr &
        return check_sys_error(::sendto(priv->handle, buf, size, 0, reinterpret_cast<const sockaddr *>(&sa.addr), sa.size), "sendto");
 }
 
-unsigned DatagramSocket::recvfrom(char *buf, unsigned size, SockAddr *&from_addr)
+size_t DatagramSocket::recvfrom(char *buf, size_t size, SockAddr *&from_addr)
 {
        if(size==0)
                return 0;
 
        SockAddr::SysAddr sa;
-       unsigned ret = check_sys_error(::recvfrom(priv->handle, buf, size, 0, reinterpret_cast<sockaddr *>(&sa.addr), &sa.size), "recvfrom");
+       size_t ret = check_sys_error(::recvfrom(priv->handle, buf, size, 0, reinterpret_cast<sockaddr *>(&sa.addr), &sa.size), "recvfrom");
        from_addr = SockAddr::new_from_sys(sa);
 
        return ret;
index 23ca296cbb27b93fb0735617209ff79f75d0066d..a978b9f5742ff2d0a56751aada207d9916bea9c2 100644 (file)
@@ -2,20 +2,21 @@
 #define MSP_NET_DATAGRAMSOCKET_H_
 
 #include "clientsocket.h"
+#include "mspnet_api.h"
 
 namespace Msp {
 namespace Net {
 
-class DatagramSocket: public ClientSocket
+class MSPNET_API DatagramSocket: public ClientSocket
 {
 public:
        DatagramSocket(Family, int = 0);
 
-       virtual bool connect(const SockAddr &);
-       virtual bool poll_connect(const Time::TimeDelta &) { return false; }
+       bool connect(const SockAddr &) override;
+       bool poll_connect(const Time::TimeDelta &) override { return false; }
 
-       unsigned sendto(const char *, unsigned, const SockAddr &);
-       unsigned recvfrom(char *, unsigned, SockAddr *&);
+       std::size_t sendto(const char *, std::size_t, const SockAddr &);
+       std::size_t recvfrom(char *, std::size_t, SockAddr *&);
 };
 
 } // namespace Net
index 8e75cd89dfb1eb930b3e55a0217a93e42ab03e15..17546e88dfe62dc6ec2ba273d808c47653a5e1b8 100644 (file)
@@ -1,6 +1,6 @@
 #include "platform_api.h"
-#include <msp/strings/format.h>
 #include "inet.h"
+#include <msp/strings/format.h>
 #include "sockaddr_private.h"
 
 using namespace std;
@@ -8,12 +8,6 @@ using namespace std;
 namespace Msp {
 namespace Net {
 
-InetAddr::InetAddr():
-       port(0)
-{
-       fill(addr, addr+4, 0);
-}
-
 InetAddr::InetAddr(const SysAddr &sa)
 {
        const sockaddr_in &sai = reinterpret_cast<const sockaddr_in &>(sa.addr);
@@ -22,6 +16,22 @@ InetAddr::InetAddr(const SysAddr &sa)
        port = ntohs(sai.sin_port);
 }
 
+InetAddr InetAddr::wildcard(unsigned port)
+{
+       InetAddr addr;
+       addr.port = port;
+       return addr;
+}
+
+InetAddr InetAddr::localhost(unsigned port)
+{
+       InetAddr addr;
+       addr.addr[0] = 127;
+       addr.addr[3] = 1;
+       addr.port = port;
+       return addr;
+}
+
 SockAddr::SysAddr InetAddr::to_sys() const
 {
        SysAddr sa;
index 639e64cbb1eca5ec8c9fc2e0d5247c91600adf38..d970559dab23fb628e61837275c1e0ba8ba807c5 100644 (file)
@@ -1,6 +1,7 @@
 #ifndef MSP_NET_INET_H_
 #define MSP_NET_INET_H_
 
+#include "mspnet_api.h"
 #include "sockaddr.h"
 
 namespace Msp {
@@ -9,23 +10,26 @@ namespace Net {
 /**
 Address class for IPv4 sockets.
 */
-class InetAddr: public SockAddr
+class MSPNET_API InetAddr: public SockAddr
 {
 private:
-       unsigned char addr[4];
-       unsigned port;
+       unsigned char addr[4] = { };
+       unsigned port = 0;
 
 public:
-       InetAddr();
+       InetAddr() = default;
        InetAddr(const SysAddr &);
 
-       virtual InetAddr *copy() const { return new InetAddr(*this); }
+       static InetAddr wildcard(unsigned);
+       static InetAddr localhost(unsigned);
 
-       virtual SysAddr to_sys() const;
+       InetAddr *copy() const override { return new InetAddr(*this); }
 
-       virtual Family get_family() const { return INET; }
+       SysAddr to_sys() const override;
+
+       Family get_family() const override { return INET; }
        unsigned get_port() const { return port; }
-       virtual std::string str() const;
+       std::string str() const override;
 };
 
 } // namespace Net
index 77d757082379f70e713d7587bbf0d61df675ff75..1b8fc39ce2628bd793a2148723e2e33868a1ad64 100644 (file)
@@ -1,6 +1,7 @@
+#include "inet6.h"
 #include "platform_api.h"
 #include <msp/strings/format.h>
-#include "inet6.h"
+#include <msp/strings/utils.h>
 #include "sockaddr_private.h"
 
 using namespace std;
@@ -8,12 +9,6 @@ using namespace std;
 namespace Msp {
 namespace Net {
 
-Inet6Addr::Inet6Addr():
-       port(0)
-{
-       fill(addr, addr+16, 0);
-}
-
 Inet6Addr::Inet6Addr(const SysAddr &sa)
 {
        const sockaddr_in6 &sai6 = reinterpret_cast<const sockaddr_in6 &>(sa.addr);
@@ -21,6 +16,21 @@ Inet6Addr::Inet6Addr(const SysAddr &sa)
        port = htons(sai6.sin6_port);
 }
 
+Inet6Addr Inet6Addr::wildcard(unsigned port)
+{
+       Inet6Addr addr;
+       addr.port = port;
+       return addr;
+}
+
+Inet6Addr Inet6Addr::localhost(unsigned port)
+{
+       Inet6Addr addr;
+       addr.addr[15] = 1;
+       addr.port = port;
+       return addr;
+}
+
 SockAddr::SysAddr Inet6Addr::to_sys() const
 {
        SysAddr sa;
@@ -40,9 +50,9 @@ string Inet6Addr::str() const
        string result = "[";
        for(unsigned i=0; i<16; i+=2)
        {
-               unsigned short part = (addr[i]<<8) | addr[i+1];
                if(i>0)
                        result += ':';
+               unsigned short part = (addr[i]<<8) | addr[i+1];
                result += format("%x", part);
        }
        result += ']';
index cc33e764936e4dbaaa4df8311774f6132407bbff..f6b0cf4187142b32938f828efcbed6c4d3a52ce9 100644 (file)
@@ -1,28 +1,32 @@
 #ifndef MSP_NET_INET6_H_
-#define NSP_NET_INET6_H_
+#define MSP_NET_INET6_H_
 
+#include "mspnet_api.h"
 #include "sockaddr.h"
 
 namespace Msp {
 namespace Net {
 
-class Inet6Addr: public SockAddr
+class MSPNET_API Inet6Addr: public SockAddr
 {
 private:
-       unsigned char addr[16];
-       unsigned port;
+       unsigned char addr[16] = { };
+       unsigned port = 0;
 
 public:
-       Inet6Addr();
+       Inet6Addr() = default;
        Inet6Addr(const SysAddr &);
 
-       virtual Inet6Addr *copy() const { return new Inet6Addr(*this); }
+       static Inet6Addr wildcard(unsigned);
+       static Inet6Addr localhost(unsigned);
 
-       virtual SysAddr to_sys() const;
+       Inet6Addr *copy() const override { return new Inet6Addr(*this); }
 
-       virtual Family get_family() const { return INET6; }
+       SysAddr to_sys() const override;
+
+       Family get_family() const override { return INET6; }
        unsigned get_port() const { return port; }
-       virtual std::string str() const;
+       std::string str() const override;
 };
 
 } // namespace Net
diff --git a/source/net/mspnet_api.h b/source/net/mspnet_api.h
new file mode 100644 (file)
index 0000000..eccdd41
--- /dev/null
@@ -0,0 +1,18 @@
+#ifndef MSP_NET_API_H_
+#define MSP_NET_API_H_
+
+#if defined(_WIN32)
+#if defined(MSPNET_BUILD)
+#define MSPNET_API __declspec(dllexport)
+#elif defined(MSPNET_IMPORT)
+#define MSPNET_API __declspec(dllimport)
+#else
+#define MSPNET_API
+#endif
+#elif defined(__GNUC__)
+#define MSPNET_API __attribute__((visibility("default")))
+#else
+#define MSPNET_API
+#endif
+
+#endif
index b24cef1f5e3f03ef067711271480c2eb1ab378cb..5e82a80d7daaf17fb130752fd870380e2b78c329 100644 (file)
@@ -1,41 +1,36 @@
+#include "protocol.h"
 #include <cstring>
 #include <string>
-#include <msp/core/hash.h>
 #include <msp/core/maputils.h>
 #include <msp/strings/format.h>
 #include <msp/strings/lexicalcast.h>
-#include "protocol.h"
 
 using namespace std;
 
 namespace Msp {
 namespace Net {
 
-Protocol::Protocol(unsigned npi):
-       header_def(0),
-       next_packet_id(npi)
+Protocol::Protocol():
+       header_def(0)
 {
-       PacketDefBuilder<PacketHeader, NullSerializer<PacketHeader> >(*this, header_def, NullSerializer<PacketHeader>())
-               (&PacketHeader::type)(&PacketHeader::length);
+       PacketDefBuilder<PacketHeader, Serializer<PacketHeader>>(*this, header_def, Serializer<PacketHeader>())
+               .fields(&PacketHeader::type, &PacketHeader::length);
 }
 
-Protocol::~Protocol()
+unsigned Protocol::get_next_packet_class_id()
 {
-       for(map<unsigned, PacketDefBase *>::iterator i=packet_class_defs.begin(); i!=packet_class_defs.end(); ++i)
-               delete i->second;
+       static unsigned next_id = 1;
+       return next_id++;
 }
 
-void Protocol::add_packet(PacketDefBase *pdef)
+void Protocol::add_packet(unique_ptr<PacketDefBase> pdef)
 {
-       PacketDefBase *&ptr = packet_class_defs[pdef->get_class_id()];
+       unique_ptr<PacketDefBase> &ptr = packet_class_defs[pdef->get_class_id()];
        if(ptr)
-       {
                packet_id_defs.erase(ptr->get_id());
-               delete ptr;
-       }
-       ptr = pdef;
-       if(unsigned id = pdef->get_id())
-               packet_id_defs[id] = pdef;
+       ptr = move(pdef);
+       if(unsigned id = ptr->get_id())
+               packet_id_defs[id] = ptr.get();
 }
 
 const Protocol::PacketDefBase &Protocol::get_packet_by_class_id(unsigned id) const
@@ -48,32 +43,54 @@ const Protocol::PacketDefBase &Protocol::get_packet_by_id(unsigned id) const
        return *get_item(packet_id_defs, id);
 }
 
-unsigned Protocol::dispatch(ReceiverBase &rcv, const char *buf, unsigned size) const
+unsigned Protocol::get_max_packet_id() const
+{
+       if(packet_id_defs.empty())
+               return 0;
+       return prev(packet_id_defs.end())->first;
+}
+
+size_t Protocol::dispatch(ReceiverBase &rcv, const char *buf, size_t size, unsigned base_id) const
 {
        PacketHeader header;
-       buf = header_def.deserialize(header, buf, buf+size);
+       const char *ptr = header_def.deserialize(header, buf, buf+size);
        if(header.length>size)
                throw bad_packet("truncated");
-       const PacketDefBase &pdef = get_packet_by_id(header.type);
-       const char *ptr = pdef.dispatch(rcv, buf, buf+header.length);
+       const PacketDefBase &pdef = get_packet_by_id(header.type-base_id);
+       if(DynamicReceiver *drcv = dynamic_cast<DynamicReceiver *>(&rcv))
+       {
+               Variant pkt;
+               ptr = pdef.deserialize(pkt, ptr, ptr+header.length);
+               drcv->receive(pdef.get_id(), pkt);
+       }
+       else
+               ptr = pdef.dispatch(rcv, ptr, ptr+header.length);
        return ptr-buf;
 }
 
-unsigned Protocol::get_packet_size(const char *buf, unsigned size) const
+bool Protocol::get_packet_header(PacketHeader &header, const char *buf, size_t size) const
 {
        if(size<4)
-               return 0;
-       PacketHeader header;
+               return false;
        header_def.deserialize(header, buf, buf+size);
-       return header.length;
+       return true;
+}
+
+size_t Protocol::get_packet_size(const char *buf, size_t size) const
+{
+       PacketHeader header;
+       return (get_packet_header(header, buf, size) ? header.length : 0);
 }
 
-UInt64 Protocol::get_hash() const
+uint64_t Protocol::get_hash() const
 {
-       string description;
-       for(PacketMap::const_iterator i=packet_id_defs.begin(); i!=packet_id_defs.end(); ++i)
-               description += format("%d:%s\n", i->first, i->second->describe());
-       return hash64(description);
+       uint64_t result = hash<64>(packet_id_defs.size());
+       for(auto &kvp: packet_id_defs)
+       {
+               hash_update<64>(result, kvp.first);
+               hash_update<64>(result, kvp.second->get_hash());
+       }
+       return result;
 }
 
 
@@ -86,7 +103,7 @@ char *Protocol::BasicSerializer<T>::serialize(const T &value, char *buf, char *e
                throw buffer_error("overflow");
 
        const char *ptr = reinterpret_cast<const char *>(&value)+sizeof(T);
-       for(unsigned i=0; i<sizeof(T); ++i)
+       for(size_t i=0; i<sizeof(T); ++i)
                *buf++ = *--ptr;
 
        return buf;
@@ -99,30 +116,32 @@ const char *Protocol::BasicSerializer<T>::deserialize(T &value, const char *buf,
                throw buffer_error("underflow");
 
        char *ptr = reinterpret_cast<char *>(&value)+sizeof(T);
-       for(unsigned i=0; i<sizeof(T); ++i)
+       for(size_t i=0; i<sizeof(T); ++i)
                *--ptr = *buf++;
 
        return buf;
 }
 
-template char *Protocol::BasicSerializer<Int8>::serialize(const Int8 &, char *, char *) const;
-template char *Protocol::BasicSerializer<Int16>::serialize(const Int16 &, char *, char *) const;
-template char *Protocol::BasicSerializer<Int32>::serialize(const Int32 &, char *, char *) const;
-template char *Protocol::BasicSerializer<Int64>::serialize(const Int64 &, char *, char *) const;
-template char *Protocol::BasicSerializer<UInt8>::serialize(const UInt8 &, char *, char *) const;
-template char *Protocol::BasicSerializer<UInt16>::serialize(const UInt16 &, char *, char *) const;
-template char *Protocol::BasicSerializer<UInt32>::serialize(const UInt32 &, char *, char *) const;
-template char *Protocol::BasicSerializer<UInt64>::serialize(const UInt64 &, char *, char *) const;
+template char *Protocol::BasicSerializer<bool>::serialize(const bool &, char *, char *) const;
+template char *Protocol::BasicSerializer<int8_t>::serialize(const int8_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<int16_t>::serialize(const int16_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<int32_t>::serialize(const int32_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<int64_t>::serialize(const int64_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint8_t>::serialize(const uint8_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint16_t>::serialize(const uint16_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint32_t>::serialize(const uint32_t &, char *, char *) const;
+template char *Protocol::BasicSerializer<uint64_t>::serialize(const uint64_t &, char *, char *) const;
 template char *Protocol::BasicSerializer<float>::serialize(const float &, char *, char *) const;
 template char *Protocol::BasicSerializer<double>::serialize(const double &, char *, char *) const;
-template const char *Protocol::BasicSerializer<Int8>::deserialize(Int8 &, const char *, const char *) const;
-template const char *Protocol::BasicSerializer<Int16>::deserialize(Int16 &, const char *, const char *) const;
-template const char *Protocol::BasicSerializer<Int32>::deserialize(Int32 &, const char *, const char *) const;
-template const char *Protocol::BasicSerializer<Int64>::deserialize(Int64 &, const char *, const char *) const;
-template const char *Protocol::BasicSerializer<UInt8>::deserialize(UInt8 &, const char *, const char *) const;
-template const char *Protocol::BasicSerializer<UInt16>::deserialize(UInt16 &, const char *, const char *) const;
-template const char *Protocol::BasicSerializer<UInt32>::deserialize(UInt32 &, const char *, const char *) const;
-template const char *Protocol::BasicSerializer<UInt64>::deserialize(UInt64 &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<bool>::deserialize(bool &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int8_t>::deserialize(int8_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int16_t>::deserialize(int16_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int32_t>::deserialize(int32_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<int64_t>::deserialize(int64_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint8_t>::deserialize(uint8_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint16_t>::deserialize(uint16_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint32_t>::deserialize(uint32_t &, const char *, const char *) const;
+template const char *Protocol::BasicSerializer<uint64_t>::deserialize(uint64_t &, const char *, const char *) const;
 template const char *Protocol::BasicSerializer<float>::deserialize(float &, const char *, const char *) const;
 template const char *Protocol::BasicSerializer<double>::deserialize(double &, const char *, const char *) const;
 
@@ -142,7 +161,7 @@ char *Protocol::StringSerializer::serialize(const string &str, char *buf, char *
 
 const char *Protocol::StringSerializer::deserialize(string &str, const char *buf, const char *end) const
 {
-       UInt16 length;
+       uint16_t length;
        buf = length_serializer.deserialize(length, buf, end);
        if(end-buf<static_cast<int>(length))
                throw buffer_error("underflow");
@@ -151,19 +170,12 @@ const char *Protocol::StringSerializer::deserialize(string &str, const char *buf
 }
 
 
-unsigned Protocol::PacketDefBase::next_class_id = 1;
-
 Protocol::PacketDefBase::PacketDefBase(unsigned i):
        id(i)
 { }
 
 
-Protocol::PacketHeader::PacketHeader():
-       type(0),
-       length(0)
-{ }
-
-Protocol::PacketHeader::PacketHeader(UInt16 t, UInt16 l):
+Protocol::PacketHeader::PacketHeader(uint16_t t, uint16_t l):
        type(t),
        length(l)
 { }
index 7ff63cb1e7dd12cc2a8fb3e9ba1e092062bcc754..5051338f3911558c8e2ee19fe2d0665c0e10eec3 100644 (file)
@@ -1,87 +1,74 @@
 #ifndef MSP_NET_PROTOCOL_H_
 #define MSP_NET_PROTOCOL_H_
 
+#include <cstdint>
 #include <map>
+#include <memory>
 #include <stdexcept>
 #include <vector>
-#include <msp/core/inttypes.h>
+#include <msp/core/hash.h>
+#include "mspnet_api.h"
 #include "receiver.h"
 
 namespace Msp {
 namespace Net {
 
-class bad_packet: public std::runtime_error
+class MSPNET_API bad_packet: public std::runtime_error
 {
 public:
        bad_packet(const std::string &w): std::runtime_error(w) { }
-       virtual ~bad_packet() throw() { }
 };
 
 
-class buffer_error: public std::runtime_error
+class MSPNET_API buffer_error: public std::runtime_error
 {
 public:
        buffer_error(const std::string &w): std::runtime_error(w) { }
-       virtual ~buffer_error() throw() { }
 };
 
 
-class Protocol
+class MSPNET_API Protocol
 {
+public:
+       struct PacketHeader
+       {
+               std::uint16_t type = 0;
+               std::uint16_t length = 0;
+
+               PacketHeader() = default;
+               PacketHeader(std::uint16_t, std::uint16_t);
+       };
+
 private:
-       template<typename T, UInt8 K>
+       template<typename T, std::uint8_t K>
        struct BasicTraits;
 
        template<typename T>
        struct Traits;
 
        template<typename C>
-       struct CompoundTypeDef
-       {
-               virtual ~CompoundTypeDef() { }
-
-               virtual std::string describe() const = 0;
-               virtual char *serialize(const C &, char *, char *) const = 0;
-               virtual const char *deserialize(C &, const char *, const char *) const = 0;
-       };
-
-       template<typename C, typename S>
-       struct CompoundDef: public CompoundTypeDef<C>
-       {
-               S serializer;
-
-               CompoundDef(const S &);
-
-               virtual std::string describe() const;
-               virtual char *serialize(const C &, char *, char *) const;
-               virtual const char *deserialize(C &, const char *, const char *) const;
-       };
+       class Serializer;
 
        template<typename T>
        class BasicSerializer
        {
        public:
-               typedef T ValueType;
-
                BasicSerializer(const Protocol &) { }
 
-               std::string describe() const { return get_type_signature<T>(); }
+               std::uint64_t get_hash() const { return Traits<T>::signature; }
                char *serialize(const T &, char *, char *) const;
                const char *deserialize(T &, const char *, const char *) const;
        };
 
        class StringSerializer
        {
-       public:
-               typedef std::string ValueType;
-
        private:
-               BasicSerializer<UInt16> length_serializer;
+               BasicSerializer<std::uint16_t> length_serializer;
 
        public:
                StringSerializer(const Protocol &);
 
-               std::string describe() const { return get_type_signature<std::string>(); }
+               std::uint64_t get_hash() const;
                char *serialize(const std::string &, char *, char *) const;
                const char *deserialize(std::string &, const char *, const char *) const;
        };
@@ -89,17 +76,14 @@ private:
        template<typename A>
        class ArraySerializer
        {
-       public:
-               typedef A ValueType;
-
        private:
-               BasicSerializer<UInt16> length_serializer;
+               BasicSerializer<std::uint16_t> length_serializer;
                typename Traits<typename A::value_type>::Serializer element_serializer;
 
        public:
                ArraySerializer(const Protocol &);
 
-               std::string describe() const;
+               std::uint64_t get_hash() const;
                char *serialize(const A &, char *, char *) const;
                const char *deserialize(A &, const char *, const char *) const;
        };
@@ -107,57 +91,48 @@ private:
        template<typename C>
        class CompoundSerializer
        {
-       public:
-               typedef C ValueType;
-
        private:
-               const CompoundTypeDef<C> &def;
+               const Serializer<C> &serializer;
 
        public:
                CompoundSerializer(const Protocol &);
 
-               std::string describe() const { return def.describe(); }
+               std::uint64_t get_hash() const;
                char *serialize(const C &, char *, char *) const;
                const char *deserialize(C &, const char *, const char *) const;
        };
 
-       template<typename P, typename Head, typename S>
-       class Serializer: public Head
+       template<typename C, typename Head, typename T>
+       class FieldSerializer: public Head
        {
        public:
                template<typename N>
-               struct Next
-               {
-                       typedef Serializer<P, Serializer<P, Head, S>, typename Traits<N>::Serializer> Type;
-               };
+               using Next = FieldSerializer<C, FieldSerializer<C, Head, T>, N>;
 
        private:
-               typedef typename S::ValueType P::*Pointer;
-
-               Pointer ptr;
-               S ser;
+               T C::*ptr;
+               typename Traits<T>::Serializer ser;
 
        public:
-               Serializer(const Head &, Pointer, const Protocol &);
+               FieldSerializer(const Head &, T C::*, const Protocol &);
 
-               std::string describe() const;
-               char *serialize(const P &, char *, char *) const;
-               const char *deserialize(P &, const char *, const char *) const;
+               std::uint64_t get_hash() const;
+               char *serialize(const C &, char *, char *) const;
+               const char *deserialize(C &, const char *, const char *) const;
        };
 
-       template<typename P>
-       class NullSerializer
+       template<typename C>
+       class Serializer
        {
        public:
                template<typename N>
-               struct Next
-               {
-                       typedef Serializer<P, NullSerializer, typename Traits<N>::Serializer> Type;
-               };
-
-               std::string describe() const { return std::string(); }
-               char *serialize(const P &, char *b, char *) const { return b; }
-               const char *deserialize(P &, const char *b, const char *) const { return b; }
+               using Next = FieldSerializer<C, Serializer<C>, N>;
+
+               virtual ~Serializer() = default;
+
+               virtual std::uint64_t get_hash() const { return 0; }
+               virtual char *serialize(const C &, char *b, char *) const { return b; }
+               virtual const char *deserialize(C &, const char *b, const char *) const { return b; }
        };
 
        class PacketDefBase
@@ -165,14 +140,14 @@ private:
        protected:
                unsigned id;
 
-               static unsigned next_class_id;
-
                PacketDefBase(unsigned);
        public:
-               virtual ~PacketDefBase() { }
+               virtual ~PacketDefBase() = default;
+
                virtual unsigned get_class_id() const = 0;
                unsigned get_id() const { return id; }
-               virtual std::string describe() const = 0;
+               virtual std::uint64_t get_hash() const = 0;
+               virtual const char *deserialize(Variant &, const char *, const char *) const = 0;
                virtual const char *dispatch(ReceiverBase &, const char *, const char *) const = 0;
        };
 
@@ -180,26 +155,23 @@ private:
        class PacketTypeDef: public PacketDefBase
        {
        private:
-               CompoundTypeDef<P> *compound;
-
-               static unsigned class_id;
+               std::unique_ptr<Serializer<P>> serializer;
 
        public:
                PacketTypeDef(unsigned);
-               ~PacketTypeDef();
 
-               static unsigned get_static_class_id() { return class_id; }
-               virtual unsigned get_class_id() const { return class_id; }
+               unsigned get_class_id() const override { return get_packet_class_id<P>(); }
 
                template<typename S>
                void set_serializer(const S &);
 
-               const CompoundTypeDef<P> &get_compound() const { return *compound; }
+               const Serializer<P> &get_serializer() const { return *serializer; }
 
-               virtual std::string describe() const { return compound->describe(); }
+               std::uint64_t get_hash() const override { return serializer->get_hash(); }
                char *serialize(const P &, char *, char *) const;
                const char *deserialize(P &, const char *, const char *) const;
-               virtual const char *dispatch(ReceiverBase &, const char *, const char *) const;
+               const char *deserialize(Variant &, const char *, const char *) const override;
+               const char *dispatch(ReceiverBase &, const char *, const char *) const override;
        };
 
        template<typename P, typename S>
@@ -212,41 +184,36 @@ private:
 
        public:
                PacketDefBuilder(const Protocol &, PacketTypeDef<P> &, const S &);
-               
-               template<typename T>
-               PacketDefBuilder<P, typename S::template Next<T>::Type> operator()(T P::*);
-       };
 
-       struct PacketHeader
-       {
-               UInt16 type;
-               UInt16 length;
+               template<typename T>
+               PacketDefBuilder<P, typename S::template Next<T>> fields(T P::*);
 
-               PacketHeader();
-               PacketHeader(UInt16, UInt16);
+               template<typename T1, typename T2, typename... Rest>
+               auto fields(T1 P::*first, T2 P::*second, Rest P::*...rest) { return fields(first).fields(second, rest...); }
        };
 
-       typedef std::map<unsigned, PacketDefBase *> PacketMap;
-
        PacketTypeDef<PacketHeader> header_def;
-       unsigned next_packet_id;
-       PacketMap packet_class_defs;
-       PacketMap packet_id_defs;
+       unsigned next_packet_id = 1;
+       std::map<unsigned, std::unique_ptr<PacketDefBase>> packet_class_defs;
+       std::map<unsigned, PacketDefBase *> packet_id_defs;
 
 protected:
-       Protocol(unsigned = 1);
-public:
-       ~Protocol();
+       Protocol();
 
 private:
-       void add_packet(PacketDefBase *);
+       static unsigned get_next_packet_class_id();
 
-protected:
        template<typename P>
-       PacketDefBuilder<P, NullSerializer<P> > add(unsigned);
+       static unsigned get_packet_class_id();
 
+       void add_packet(std::unique_ptr<PacketDefBase>);
+
+protected:
        template<typename P>
-       PacketDefBuilder<P, NullSerializer<P> > add();
+       PacketDefBuilder<P, Serializer<P>> add();
+
+       template<typename P, typename T, typename... Rest>
+       auto add(T P::*field, Rest P::*...rest) { return add<P>().fields(field, rest...); }
 
        const PacketDefBase &get_packet_by_class_id(unsigned) const;
        const PacketDefBase &get_packet_by_id(unsigned) const;
@@ -256,123 +223,103 @@ protected:
 
 public:
        template<typename P>
-       unsigned serialize(const P &, char *, unsigned) const;
+       bool has_packet() const { return packet_class_defs.count(get_packet_class_id<P>()); }
 
-       unsigned get_packet_size(const char *, unsigned) const;
-       unsigned dispatch(ReceiverBase &, const char *, unsigned) const;
+       template<typename P>
+       unsigned get_packet_id() const { return get_item(packet_class_defs, get_packet_class_id<P>())->get_id(); }
 
-       UInt64 get_hash() const;
+       unsigned get_max_packet_id() const;
 
-private:
-       template<typename T>
-       static std::string get_type_signature();
+       template<typename P>
+       std::size_t serialize(const P &, char *, std::size_t, unsigned = 0) const;
+
+       bool get_packet_header(PacketHeader &, const char *, std::size_t) const;
+       std::size_t get_packet_size(const char *, std::size_t) const;
+       std::size_t dispatch(ReceiverBase &, const char *, std::size_t, unsigned = 0) const;
+
+       std::uint64_t get_hash() const;
 };
 
 
 template<typename P>
-Protocol::PacketDefBuilder<P, Protocol::NullSerializer<P> > Protocol::add(unsigned id)
+unsigned Protocol::get_packet_class_id()
 {
-       PacketTypeDef<P> *pdef = new PacketTypeDef<P>(id);
-       add_packet(pdef);
-       return PacketDefBuilder<P, NullSerializer<P> >(*this, *pdef, NullSerializer<P>());
+       static unsigned id = get_next_packet_class_id();
+       return id;
 }
 
 template<typename P>
-Protocol::PacketDefBuilder<P, Protocol::NullSerializer<P> > Protocol::add()
+Protocol::PacketDefBuilder<P, Protocol::Serializer<P>> Protocol::add()
 {
-       return add<P>(next_packet_id++);
+       std::unique_ptr<PacketTypeDef<P>> pdef = std::make_unique<PacketTypeDef<P>>(next_packet_id++);
+       PacketDefBuilder<P, Serializer<P>> next(*this, *pdef, Serializer<P>());
+       add_packet(move(pdef));
+       return next;
 }
 
 template<typename P>
 const Protocol::PacketTypeDef<P> &Protocol::get_packet_by_class() const
 {
-       const PacketDefBase &pdef = get_packet_by_class_id(PacketTypeDef<P>::get_static_class_id());
+       const PacketDefBase &pdef = get_packet_by_class_id(get_packet_class_id<P>());
        return static_cast<const PacketTypeDef<P> &>(pdef);
 }
 
 template<typename P>
-unsigned Protocol::serialize(const P &pkt, char *buf, unsigned size) const
+std::size_t Protocol::serialize(const P &pkt, char *buf, std::size_t size, unsigned base_id) const
 {
        const PacketTypeDef<P> &pdef = get_packet_by_class<P>();
+       if(!pdef.get_id())
+               throw std::invalid_argument("no packet id");
        char *ptr = pdef.serialize(pkt, buf+4, buf+size);
        size = ptr-buf;
-       header_def.serialize(PacketHeader(pdef.get_id(), size), buf, buf+4);
+       header_def.serialize(PacketHeader(base_id+pdef.get_id(), size), buf, buf+4);
        return size;
 }
 
-template<typename T>
-std::string Protocol::get_type_signature()
-{
-       const UInt16 sig = Traits<T>::signature;
-       std::string result;
-       result += sig&0xFF;
-       if(sig>=0x100)
-               result += '0'+(sig>>8);
-       return result;
-}
-
 
-template<typename T, UInt8 K>
+template<typename T, std::uint8_t K>
 struct Protocol::BasicTraits
 {
-       static const UInt16 signature = K | (sizeof(T)<<8);
+       static const std::uint16_t signature = K | (sizeof(T)<<8);
        typedef BasicSerializer<T> Serializer;
 };
 
 template<typename T>
 struct Protocol::Traits
 {
-       static const UInt16 signature = 'C';
+       static const std::uint16_t signature = 'C';
        typedef CompoundSerializer<T> Serializer;
 };
 
-template<> struct Protocol::Traits<Int8>: BasicTraits<Int8, 'I'> { };
-template<> struct Protocol::Traits<UInt8>: BasicTraits<UInt8, 'U'> { };
-template<> struct Protocol::Traits<Int16>: BasicTraits<Int16, 'I'> { };
-template<> struct Protocol::Traits<UInt16>: BasicTraits<UInt16, 'U'> { };
-template<> struct Protocol::Traits<Int32>: BasicTraits<Int32, 'I'> { };
-template<> struct Protocol::Traits<UInt32>: BasicTraits<UInt32, 'U'> { };
-template<> struct Protocol::Traits<Int64>: BasicTraits<Int64, 'I'> { };
-template<> struct Protocol::Traits<UInt64>: BasicTraits<UInt64, 'U'> { };
+template<> struct Protocol::Traits<bool>: BasicTraits<bool, 'B'> { };
+template<> struct Protocol::Traits<std::int8_t>: BasicTraits<std::int8_t, 'I'> { };
+template<> struct Protocol::Traits<std::uint8_t>: BasicTraits<std::uint8_t, 'U'> { };
+template<> struct Protocol::Traits<std::int16_t>: BasicTraits<std::int16_t, 'I'> { };
+template<> struct Protocol::Traits<std::uint16_t>: BasicTraits<std::uint16_t, 'U'> { };
+template<> struct Protocol::Traits<std::int32_t>: BasicTraits<std::int32_t, 'I'> { };
+template<> struct Protocol::Traits<std::uint32_t>: BasicTraits<std::uint32_t, 'U'> { };
+template<> struct Protocol::Traits<std::int64_t>: BasicTraits<std::int64_t, 'I'> { };
+template<> struct Protocol::Traits<std::uint64_t>: BasicTraits<std::uint64_t, 'U'> { };
 template<> struct Protocol::Traits<float>: BasicTraits<float, 'F'> { };
 template<> struct Protocol::Traits<double>: BasicTraits<double, 'F'> { };
 
 template<> struct Protocol::Traits<std::string>
 {
-       static const UInt16 signature = 'S';
+       static const std::uint16_t signature = 'S';
        typedef StringSerializer Serializer;
 };
 
 template<typename T>
-struct Protocol::Traits<std::vector<T> >
+struct Protocol::Traits<std::vector<T>>
 {
-       static const UInt16 signature = 'A';
-       typedef ArraySerializer<std::vector<T> > Serializer;
+       static const std::uint16_t signature = 'A';
+       typedef ArraySerializer<std::vector<T>> Serializer;
 };
 
 
-
-template<typename C, typename S>
-Protocol::CompoundDef<C, S>::CompoundDef(const S &s):
-       serializer(s)
-{ }
-
-template<typename C, typename S>
-std::string Protocol::CompoundDef<C, S>::describe() const
+inline std::uint64_t Protocol::StringSerializer::get_hash() const
 {
-       return "{"+serializer.describe()+"}";
-}
-
-template<typename C, typename S>
-char *Protocol::CompoundDef<C, S>::serialize(const C &com, char *buf, char *end) const
-{
-       return serializer.serialize(com, buf, end);
-}
-
-template<typename C, typename S>
-const char *Protocol::CompoundDef<C, S>::deserialize(C &com, const char *buf, const char *end) const
-{
-       return serializer.deserialize(com, buf, end);
+       return Traits<std::string>::signature;
 }
 
 
@@ -383,24 +330,24 @@ Protocol::ArraySerializer<A>::ArraySerializer(const Protocol &proto):
 { }
 
 template<typename A>
-std::string Protocol::ArraySerializer<A>::describe() const
+std::uint64_t Protocol::ArraySerializer<A>::get_hash() const
 {
-       return "["+element_serializer.describe()+"]";
+       return hash_round<64>(element_serializer.get_hash(), 'A');
 }
 
 template<typename A>
 char *Protocol::ArraySerializer<A>::serialize(const A &array, char *buf, char *end) const
 {
        buf = length_serializer.serialize(array.size(), buf, end);
-       for(typename A::const_iterator i=array.begin(); i!=array.end(); ++i)
-               buf = element_serializer.serialize(*i, buf, end);
+       for(const auto &e: array)
+               buf = element_serializer.serialize(e, buf, end);
        return buf;
 }
 
 template<typename A>
 const char *Protocol::ArraySerializer<A>::deserialize(A &array, const char *buf, const char *end) const
 {
-       UInt16 length;
+       std::uint16_t length;
        buf = length_serializer.deserialize(length, buf, end);
        array.resize(length);
        for(unsigned i=0; i<length; ++i)
@@ -411,86 +358,88 @@ const char *Protocol::ArraySerializer<A>::deserialize(A &array, const char *buf,
 
 template<typename C>
 Protocol::CompoundSerializer<C>::CompoundSerializer(const Protocol &proto):
-       def(proto.get_packet_by_class<C>().get_compound())
+       serializer(proto.get_packet_by_class<C>().get_serializer())
 { }
 
+template<typename C>
+std::uint64_t Protocol::CompoundSerializer<C>::get_hash() const
+{
+       return hash_round<64>(serializer.get_hash(), 'C');
+}
+
 template<typename C>
 char *Protocol::CompoundSerializer<C>::serialize(const C &com, char *buf, char *end) const
 {
-       return def.serialize(com, buf, end);
+       return serializer.serialize(com, buf, end);
 }
 
 template<typename C>
 const char *Protocol::CompoundSerializer<C>::deserialize(C &com, const char *buf, const char *end) const
 {
-       return def.deserialize(com, buf, end);
+       return serializer.deserialize(com, buf, end);
 }
 
 
-template<typename P, typename Head, typename S>
-Protocol::Serializer<P, Head, S>::Serializer(const Head &h, Pointer p, const Protocol &proto):
+template<typename C, typename Head, typename T>
+Protocol::FieldSerializer<C, Head, T>::FieldSerializer(const Head &h, T C::*p, const Protocol &proto):
        Head(h),
        ptr(p),
        ser(proto)
 { }
 
-template<typename P, typename Head, typename S>
-std::string Protocol::Serializer<P, Head, S>::describe() const
+template<typename C, typename Head, typename T>
+std::uint64_t Protocol::FieldSerializer<C, Head, T>::get_hash() const
 {
-       return Head::describe()+ser.describe();
+       return hash_update<64>(Head::get_hash(), ser.get_hash());
 }
 
-template<typename P, typename Head, typename S>
-char *Protocol::Serializer<P, Head, S>::serialize(const P &pkt, char *buf, char *end) const
+template<typename C, typename Head, typename T>
+char *Protocol::FieldSerializer<C, Head, T>::serialize(const C &com, char *buf, char *end) const
 {
-       buf = Head::serialize(pkt, buf, end);
-       return ser.serialize(pkt.*ptr, buf, end);
+       buf = Head::serialize(com, buf, end);
+       return ser.serialize(com.*ptr, buf, end);
 }
 
-template<typename P, typename Head, typename S>
-const char *Protocol::Serializer<P, Head, S>::deserialize(P &pkt, const char *buf, const char *end) const
+template<typename C, typename Head, typename T>
+const char *Protocol::FieldSerializer<C, Head, T>::deserialize(C &com, const char *buf, const char *end) const
 {
-       buf = Head::deserialize(pkt, buf, end);
-       return ser.deserialize(pkt.*ptr, buf, end);
+       buf = Head::deserialize(com, buf, end);
+       return ser.deserialize(com.*ptr, buf, end);
 }
 
 
-template<typename P>
-unsigned Protocol::PacketTypeDef<P>::class_id = 0;
-
 template<typename P>
 Protocol::PacketTypeDef<P>::PacketTypeDef(unsigned i):
        PacketDefBase(i),
-       compound(new CompoundDef<P, NullSerializer<P> >(NullSerializer<P>()))
-{
-       if(!class_id)
-               class_id = next_class_id++;
-}
-
-template<typename P>
-Protocol::PacketTypeDef<P>::~PacketTypeDef()
-{
-       delete compound;
-}
+       serializer(std::make_unique<Serializer<P>>())
+{ }
 
 template<typename P>
 template<typename S>
 void Protocol::PacketTypeDef<P>::set_serializer(const S &ser)
 {
-       delete compound;
-       compound = new CompoundDef<P, S>(ser);
+       serializer = std::make_unique<S>(ser);
 }
 
 template<typename P>
 char *Protocol::PacketTypeDef<P>::serialize(const P &pkt, char *buf, char *end) const
 {
-       return compound->serialize(pkt, buf, end);
+       return serializer->serialize(pkt, buf, end);
 }
 
 template<typename P>
 const char *Protocol::PacketTypeDef<P>::deserialize(P &pkt, const char *buf, const char *end) const
 {
-       return compound->deserialize(pkt, buf, end);
+       return serializer->deserialize(pkt, buf, end);
+}
+
+template<typename P>
+const char *Protocol::PacketTypeDef<P>::deserialize(Variant &var_pkt, const char *buf, const char *end) const
+{
+       P pkt;
+       const char *ptr = serializer->deserialize(pkt, buf, end);
+       var_pkt = std::move(pkt);
+       return ptr;
 }
 
 template<typename P>
@@ -515,11 +464,11 @@ Protocol::PacketDefBuilder<P, S>::PacketDefBuilder(const Protocol &p, PacketType
 
 template<typename P, typename S>
 template<typename T>
-Protocol::PacketDefBuilder<P, typename S::template Next<T>::Type> Protocol::PacketDefBuilder<P, S>::operator()(T P::*ptr)
+Protocol::PacketDefBuilder<P, typename S::template Next<T>> Protocol::PacketDefBuilder<P, S>::fields(T P::*ptr)
 {
-       typename S::template Next<T>::Type next_ser(serializer, ptr, protocol);
+       typename S::template Next<T> next_ser(serializer, ptr, protocol);
        pktdef.set_serializer(next_ser);
-       return PacketDefBuilder<P, typename S::template Next<T>::Type>(protocol, pktdef, next_ser);
+       return PacketDefBuilder<P, typename S::template Next<T>>(protocol, pktdef, next_ser);
 }
 
 } // namespace Net
diff --git a/source/net/protocol_impl.h b/source/net/protocol_impl.h
new file mode 100644 (file)
index 0000000..4eb2861
--- /dev/null
@@ -0,0 +1,6 @@
+#ifndef MSP_NET_PROTOCOL_IMPL_H_
+#define MSP_NET_PROTOCOL_IMPL_H_
+
+#warning "This header is deprected and should not be used"
+
+#endif
diff --git a/source/net/receiver.cpp b/source/net/receiver.cpp
new file mode 100644 (file)
index 0000000..46c12f3
--- /dev/null
@@ -0,0 +1,16 @@
+#include "receiver.h"
+
+namespace Msp {
+namespace Net {
+
+void DynamicDispatcher::receive(unsigned packet_id, const Variant &packet)
+{
+       auto i = lower_bound_member(targets, packet_id, &Target::packet_id);
+       if(i==targets.end() || i->packet_id!=packet_id)
+               throw key_error(packet_id);
+
+       i->func(*i->receiver, packet);
+}
+
+} // namespace Net
+} // namespace Msp
index 19e69e0d4c5926ceaa9c23ec518731fdcc82fcf3..a0a9f2913d0ed0e48fb6a47aa5793f04aa76ac99 100644 (file)
@@ -1,26 +1,81 @@
 #ifndef MSP_NET_RECEIVER_H_
 #define MSP_NET_RECEIVER_H_
 
+#include <vector>
+#include <msp/core/algorithm.h>
+#include <msp/core/maputils.h>
+#include <msp/core/variant.h>
+#include "mspnet_api.h"
+
 namespace Msp {
 namespace Net {
 
-class ReceiverBase
+class MSPNET_API ReceiverBase
 {
 protected:
-       ReceiverBase() { }
+       ReceiverBase() = default;
 public:
-       virtual ~ReceiverBase() { }
+       virtual ~ReceiverBase() = default;
 };
 
+
 template<typename P>
 class PacketReceiver: public virtual ReceiverBase
 {
 protected:
-       PacketReceiver() { }
+       PacketReceiver() = default;
 public:
        virtual void receive(const P &) = 0;
 };
 
+
+class MSPNET_API DynamicReceiver: public ReceiverBase
+{
+protected:
+       DynamicReceiver() = default;
+public:
+       virtual void receive(unsigned, const Variant &) = 0;
+};
+
+
+class MSPNET_API DynamicDispatcher: public DynamicReceiver
+{
+private:
+       using DispatchFunc = void(ReceiverBase &, const Variant &);
+
+       struct Target
+       {
+               unsigned packet_id;
+               ReceiverBase *receiver;
+               DispatchFunc *func;
+
+               Target(unsigned i, ReceiverBase &r, DispatchFunc *f): packet_id(i), receiver(&r), func(f) { }
+       };
+
+       std::vector<Target> targets;
+
+public:
+       template<typename P>
+       void add_receiver(unsigned, PacketReceiver<P> &);
+
+       void receive(unsigned, const Variant &) override;
+};
+
+
+template<typename P>
+void DynamicDispatcher::add_receiver(unsigned packet_id, PacketReceiver<P> &r)
+{
+       auto i = lower_bound_member(targets, packet_id, &Target::packet_id);
+       if(i!=targets.end() && i->packet_id==packet_id)
+               throw key_error(packet_id);
+
+       auto dispatch = [](ReceiverBase &receiver, const Variant &packet){
+               dynamic_cast<PacketReceiver<P> &>(receiver).receive(packet.value<P>());
+       };
+
+       targets.emplace(i, packet_id, r, +dispatch);
+}
+
 } // namespace Net
 } // namespace Msp
 
index 10e8c656797fedfdd685b2edb48ad5f20a217996..9ec5395df99c5b2504965dbb998029d0933f7a30 100644 (file)
@@ -1,9 +1,9 @@
-#include "platform_api.h"
+#include "resolve.h"
 #include <msp/core/systemerror.h>
 #include <msp/strings/format.h>
+#include "platform_api.h"
 #include "sockaddr_private.h"
 #include "socket.h"
-#include "resolve.h"
 
 using namespace std;
 
@@ -40,16 +40,16 @@ namespace Net {
 
 SockAddr *resolve(const string &host, const string &serv, Family family)
 {
-       const char *chost = (host.empty() ? 0 : host.c_str());
-       const char *cserv = (serv.empty() ? 0 : serv.c_str());
+       const char *chost = (host.empty() ? nullptr : host.c_str());
+       const char *cserv = (serv.empty() ? nullptr : serv.c_str());
        int flags = 0;
        if(host=="*")
        {
                flags = AI_PASSIVE;
-               chost = 0;
+               chost = nullptr;
        }
 
-       addrinfo hints = { flags, family_to_sys(family), 0, 0, 0, 0, 0, 0 };
+       addrinfo hints = { flags, family_to_sys(family), 0, 0, 0, nullptr, nullptr, nullptr };
        addrinfo *res;
 
        int err = getaddrinfo(chost, cserv, &hints, &res);
@@ -81,9 +81,7 @@ SockAddr *resolve(const string &str, Family family)
 }
 
 
-Resolver::Resolver():
-       event_disp(0),
-       next_tag(1)
+Resolver::Resolver()
 {
        thread.get_notify_pipe().signal_data_available.connect(sigc::mem_fun(this, &Resolver::task_done));
 }
@@ -104,7 +102,7 @@ unsigned Resolver::resolve(const string &host, const string &serv, Family family
        task.host = host;
        task.serv = serv;
        task.family = family;
-       thread.add_task(task);
+       thread.add_task(move(task));
        return task.tag;
 }
 
@@ -132,24 +130,23 @@ void Resolver::task_done()
                if(task->addr)
                        signal_address_resolved.emit(task->tag, *task->addr);
                else if(task->error)
+               {
+                       if(signal_resolve_failed.empty())
+                       {
+                               unique_ptr<runtime_error> err = move(task->error);
+                               thread.pop_complete_task();
+                               throw *err;
+                       }
                        signal_resolve_failed.emit(task->tag, *task->error);
+               }
                thread.pop_complete_task();
        }
 }
 
 
-Resolver::Task::Task():
-       tag(0),
-       family(UNSPEC),
-       addr(0),
-       error(0)
-{ }
-
-
 Resolver::WorkerThread::WorkerThread():
        Thread("Resolver"),
-       sem(1),
-       done(false)
+       sem(1)
 {
        launch();
 }
@@ -161,11 +158,11 @@ Resolver::WorkerThread::~WorkerThread()
        join();
 }
 
-void Resolver::WorkerThread::add_task(const Task &t)
+void Resolver::WorkerThread::add_task(Task &&t)
 {
        MutexLock lock(queue_mutex);
        bool was_starved = (queue.empty() || queue.back().is_complete());
-       queue.push_back(t);
+       queue.push_back(move(t));
        if(was_starved)
                sem.signal();
 }
@@ -176,18 +173,14 @@ Resolver::Task *Resolver::WorkerThread::get_complete_task()
        if(!queue.empty() && queue.front().is_complete())
                return &queue.front();
        else
-               return 0;
+               return nullptr;
 }
 
 void Resolver::WorkerThread::pop_complete_task()
 {
        MutexLock lock(queue_mutex);
        if(!queue.empty() && queue.front().is_complete())
-       {
-               delete queue.front().addr;
-               delete queue.front().error;
                queue.pop_front();
-       }
 }
 
 void Resolver::WorkerThread::main()
@@ -199,10 +192,10 @@ void Resolver::WorkerThread::main()
                        sem.wait();
                wait = false;
 
-               Task *task = 0;
+               Task *task = nullptr;
                {
                        MutexLock lock(queue_mutex);
-                       for(list<Task>::iterator i=queue.begin(); (!task && i!=queue.end()); ++i)
+                       for(auto i=queue.begin(); (!task && i!=queue.end()); ++i)
                                if(!i->is_complete())
                                        task = &*i;
                }
@@ -211,16 +204,16 @@ void Resolver::WorkerThread::main()
                {
                        try
                        {
-                               SockAddr *addr = Net::resolve(task->host, task->serv, task->family);
+                               unique_ptr<SockAddr> addr(Net::resolve(task->host, task->serv, task->family));
                                {
                                        MutexLock lock(queue_mutex);
-                                       task->addr = addr;
+                                       task->addr = move(addr);
                                }
                        }
                        catch(const runtime_error &e)
                        {
                                MutexLock lock(queue_mutex);
-                               task->error = new runtime_error(e);
+                               task->error = make_unique<runtime_error>(e);
                        }
                        notify_pipe.put(1);
                }
index 21737ff4794bb0def530bd364cae00a2291ec1c2..3071b771f569aa57f196c41cc4f8f9520c715130 100644 (file)
@@ -1,48 +1,47 @@
 #ifndef MSP_NET_RESOLVE_H_
 #define MSP_NET_RESOLVE_H_
 
+#include <deque>
+#include <memory>
 #include <string>
 #include <msp/core/mutex.h>
 #include <msp/core/semaphore.h>
 #include <msp/core/thread.h>
 #include <msp/io/eventdispatcher.h>
 #include <msp/io/pipe.h>
-#include "constants.h"
+#include "mspnet_api.h"
+#include "sockaddr.h"
 
 namespace Msp {
 namespace Net {
 
-class SockAddr;
-
 /** Resolves host and service names into a socket address.  If host is empty,
 the loopback address will be used.  If host is "*", the wildcard address will
 be used.  If service is empty, a socket address with a null service will be
 returned.  With the IP families, these are not very useful. */
-SockAddr *resolve(const std::string &, const std::string &, Family = UNSPEC);
+MSPNET_API SockAddr *resolve(const std::string &, const std::string &, Family = UNSPEC);
 
 /** And overload of resolve() that takes host and service as a single string,
 separated by a colon.  If the host part contains colons, such as is the case
 with a numeric IPv6 address, it must be enclosed in brackets. */
-SockAddr *resolve(const std::string &, Family = UNSPEC);
+MSPNET_API SockAddr *resolve(const std::string &, Family = UNSPEC);
 
 
 /**
 An asynchronous name resolver.  Blocking calls are performed in a thread and
 completion is notified with one of the two signals.
 */
-class Resolver
+class MSPNET_API Resolver
 {
 private:
        struct Task
        {
-               unsigned tag;
+               unsigned tag = 0;
                std::string host;
                std::string serv;
-               Family family;
-               SockAddr *addr;
-               std::runtime_error *error;
-
-               Task();
+               Family family = UNSPEC;
+               std::unique_ptr<SockAddr> addr;
+               std::unique_ptr<std::runtime_error> error;
 
                bool is_complete() const { return addr || error; }
        };
@@ -50,24 +49,24 @@ private:
        class WorkerThread: public Thread
        {
        private:
-               std::list<Task> queue;
+               std::deque<Task> queue;
                Mutex queue_mutex;
                Semaphore sem;
                IO::Pipe notify_pipe;
-               bool done;
+               bool done = false;
 
        public:
                WorkerThread();
                ~WorkerThread();
 
-               void add_task(const Task &);
+               void add_task(Task &&);
                Task *get_complete_task();
                void pop_complete_task();
 
                IO::Pipe &get_notify_pipe() { return notify_pipe; }
 
        private:
-               virtual void main();
+               void main() override;
        };
 
 public:
@@ -75,9 +74,9 @@ public:
        sigc::signal<void, unsigned, const std::exception &> signal_resolve_failed;
 
 private:
-       IO::EventDispatcher *event_disp;
+       IO::EventDispatcher *event_disp = nullptr;
        WorkerThread thread;
-       unsigned next_tag;
+       unsigned next_tag = 1;
 
 public:
        Resolver();
index 60284837b1b91b5a95f6e98f9bcc9c6f3c4b40c0..66876c534fcbf73e2317527b95eb03554a034174 100644 (file)
@@ -1,4 +1,5 @@
 #include "serversocket.h"
+#include <msp/core/except.h>
 
 using namespace std;
 
@@ -9,14 +10,14 @@ ServerSocket::ServerSocket(Family af, int type, int proto):
        Socket(af, type, proto)
 { }
 
-unsigned ServerSocket::do_write(const char *, unsigned)
+size_t ServerSocket::do_write(const char *, size_t)
 {
-       throw logic_error("can't write to ServerSocket");
+       throw unsupported("ServerSocket::write");
 }
 
-unsigned ServerSocket::do_read(char *, unsigned)
+size_t ServerSocket::do_read(char *, size_t)
 {
-       throw logic_error("can't read from ServerSocket");
+       throw unsupported("ServerSocket::read");
 }
 
 } // namespace Net
index 10375f03d15f0efc96c66f4d3d52dcde412795cd..8c082b3772f9dc44451d8bf84d281780bd817cfb 100644 (file)
@@ -1,6 +1,7 @@
 #ifndef MSP_NET_SERVERSOCKET_H_
 #define MSP_NET_SERVERSOCKET_H_
 
+#include "mspnet_api.h"
 #include "socket.h"
 
 namespace Msp {
@@ -12,7 +13,7 @@ class ClientSocket;
 ServerSockets are used to receive incoming connections.  They cannot be used
 for sending and receiving data.
 */
-class ServerSocket: public Socket
+class MSPNET_API ServerSocket: public Socket
 {
 protected:
        ServerSocket(Family, int, int);
@@ -22,8 +23,8 @@ public:
 
        virtual ClientSocket *accept() = 0;
 protected:
-       virtual unsigned do_write(const char *, unsigned);
-       virtual unsigned do_read(char *, unsigned);
+       std::size_t do_write(const char *, std::size_t) override;
+       std::size_t do_read(char *, std::size_t) override;
 };
 
 } // namespace Net
index db436fe17227218e193e71f7ad28b18cfac4cab0..a3d3eed306626df06991cbc61f9c14e428a476e5 100644 (file)
@@ -1,3 +1,4 @@
+#include "sockaddr.h"
 #include <stdexcept>
 #include "platform_api.h"
 #include "inet.h"
@@ -25,11 +26,35 @@ SockAddr *SockAddr::new_from_sys(const SysAddr &sa)
        }
 }
 
-SockAddr::SysAddr::SysAddr():
-       size(sizeof(sockaddr_storage))
+SockAddr::SysAddr::SysAddr()
 {
        addr.ss_family = AF_UNSPEC;
 }
 
+
+int family_to_sys(Family f)
+{
+       switch(f)
+       {
+       case UNSPEC: return AF_UNSPEC;
+       case INET:   return AF_INET;
+       case INET6:  return AF_INET6;
+       case UNIX:   return AF_UNIX;
+       default: throw invalid_argument("family_to_sys");
+       }
+}
+
+Family family_from_sys(int f)
+{
+       switch(f)
+       {
+       case AF_UNSPEC: return UNSPEC;
+       case AF_INET:   return INET;
+       case AF_INET6:  return INET6;
+       case AF_UNIX:   return UNIX;
+       default: throw invalid_argument("family_from_sys");
+       }
+}
+
 } // namespace Net
 } // namespace Msp
index aad5e29065e5695ff6b8794cf5dd95f5696e2dc3..931d4f4d6e5d50ea87f97620f609cbedc44def1d 100644 (file)
@@ -2,20 +2,33 @@
 #define MSP_NET_SOCKADDR_H_
 
 #include <string>
-#include "constants.h"
+#include "mspnet_api.h"
 
 namespace Msp {
 namespace Net {
 
-class SockAddr
+enum Family
+{
+       UNSPEC,
+       INET,
+       INET6,
+       UNIX
+};
+
+
+class MSPNET_API SockAddr
 {
 public:
        struct SysAddr;
 
 protected:
-       SockAddr() { }
+       SockAddr() = default;
+       SockAddr(const SockAddr &) = default;
+       SockAddr(SockAddr &&) = default;
+       SockAddr &operator=(const SockAddr &) = default;
+       SockAddr &operator=(SockAddr &&) = default;
 public:
-       virtual ~SockAddr() { }
+       virtual ~SockAddr() = default;
 
        virtual SockAddr *copy() const = 0;
 
index 909ef9255853e6a65f9669a8d688c1b75dba45de..2a3d14b27aba149099663404bc48e05e6a182010 100644 (file)
@@ -14,11 +14,15 @@ namespace Net {
 struct SockAddr::SysAddr
 {
        struct sockaddr_storage addr;
-       socklen_t size;
+       socklen_t size = sizeof(sockaddr_storage);
 
        SysAddr();
 };
 
+
+int family_to_sys(Family);
+Family family_from_sys(int);
+
 } // namespace Net
 } // namespace Msp
 
index e8e5eaf6275eddb27634a0d473377bed18570d8b..2a1a71c0a4b3aa404b616f09a7be6e1799660c27 100644 (file)
@@ -1,16 +1,17 @@
 #include "platform_api.h"
+#include "socket.h"
 #include <msp/core/systemerror.h>
 #include <msp/io/handle_private.h>
 #include "sockaddr_private.h"
-#include "socket.h"
 #include "socket_private.h"
 
+using namespace std;
+
 namespace Msp {
 namespace Net {
 
 Socket::Socket(const Private &p):
-       priv(new Private),
-       local_addr(0)
+       priv(make_unique<Private>())
 {
        mode = IO::M_RDWR;
 
@@ -18,18 +19,23 @@ Socket::Socket(const Private &p):
 
        SockAddr::SysAddr sa;
        getsockname(priv->handle, reinterpret_cast<sockaddr *>(&sa.addr), &sa.size);
-       local_addr = SockAddr::new_from_sys(sa);
+       local_addr.reset(SockAddr::new_from_sys(sa));
 
        platform_init();
 }
 
 Socket::Socket(Family af, int type, int proto):
-       priv(new Private),
-       local_addr(0)
+       priv(make_unique<Private>())
 {
        mode = IO::M_RDWR;
 
+#ifdef __linux__
+       type |= SOCK_CLOEXEC;
+#endif
        priv->handle = socket(family_to_sys(af), type, proto);
+#ifndef __linux__
+       set_inherit(false);
+#endif
 
        platform_init();
 }
@@ -37,20 +43,26 @@ Socket::Socket(Family af, int type, int proto):
 Socket::~Socket()
 {
        platform_cleanup();
-
-       delete local_addr;
-       delete priv;
 }
 
 void Socket::set_block(bool b)
 {
-       mode = (mode&~IO::M_NONBLOCK);
-       if(b)
-               mode = (mode|IO::M_NONBLOCK);
-
+       IO::adjust_mode(mode, IO::M_NONBLOCK, !b);
        priv->set_block(b);
 }
 
+void Socket::set_inherit(bool i)
+{
+       IO::adjust_mode(mode, IO::M_INHERIT, i);
+       priv->set_inherit(i);
+}
+
+const IO::Handle &Socket::get_handle(IO::Mode)
+{
+       // TODO could this be implemented somehow?
+       throw unsupported("Socket::get_handle");
+}
+
 const IO::Handle &Socket::get_event_handle()
 {
        return priv->event;
@@ -64,13 +76,12 @@ void Socket::bind(const SockAddr &addr)
        if(err==-1)
                throw system_error("bind");
 
-       delete local_addr;
-       local_addr = addr.copy();
+       local_addr.reset(addr.copy());
 }
 
 const SockAddr &Socket::get_local_address() const
 {
-       if(local_addr==0)
+       if(!local_addr)
                throw bad_socket_state("not bound");
        return *local_addr;
 }
index 1f60f9bb2276942ce31e7eaf69612e4e67641f02..b89c175bda040d2e22b17ba3335cdefd225616bb 100644 (file)
@@ -1,23 +1,24 @@
 #ifndef MSP_NET_SOCKET_H_
 #define MSP_NET_SOCKET_H_
 
+#include <memory>
+#include <msp/core/except.h>
 #include <msp/io/eventobject.h>
 #include <msp/io/handle.h>
-#include "constants.h"
+#include "mspnet_api.h"
 #include "sockaddr.h"
 
 namespace Msp {
 namespace Net {
 
-class bad_socket_state: public std::logic_error
+class MSPNET_API bad_socket_state: public invalid_state
 {
 public:
-       bad_socket_state(const std::string &w): std::logic_error(w) { }
-       virtual ~bad_socket_state() throw() { }
+       bad_socket_state(const std::string &w): invalid_state(w) { }
 };
 
 
-class Socket: public IO::EventObject
+class MSPNET_API Socket: public IO::EventObject
 {
 protected:
        enum SocketEvent
@@ -30,8 +31,8 @@ protected:
 
        struct Private;
 
-       Private *priv;
-       SockAddr *local_addr;
+       std::unique_ptr<Private> priv;
+       std::unique_ptr<SockAddr> local_addr;
 
        Socket(const Private &);
        Socket(Family, int, int);
@@ -41,13 +42,16 @@ private:
 public:
        ~Socket();
 
-       virtual void set_block(bool);
-       virtual const IO::Handle &get_event_handle();
+       void set_block(bool) override;
+       void set_inherit(bool) override;
+       const IO::Handle &get_handle(IO::Mode) override;
+       const IO::Handle &get_event_handle() override;
 
        /** Associates the socket with a local address.  There must be no existing
        users of the address. */
        void bind(const SockAddr &);
 
+       bool is_bound() const { return static_cast<bool>(local_addr); }
        const SockAddr &get_local_address() const;
 
        void set_timeout(const Time::TimeDelta &);
index 742877b4516514b718a623c0185a025a11ddb66f..bb9eff6f1c280fbce695dbe4afaf37454a7c7845 100644 (file)
@@ -22,11 +22,12 @@ struct Socket::Private
        IO::Handle event;
 
        void set_block(bool);
+       void set_inherit(bool);
        int set_option(int, int, const void *, socklen_t);
        int get_option(int, int, void *, socklen_t *);
 };
 
-unsigned check_sys_error(int, const char *);
+std::size_t check_sys_error(std::make_signed<std::size_t>::type, const char *);
 bool check_sys_connect_error(int);
 
 } // namespace Net
index 817145bcff15c024469e64f10cdcfe8516e4e551..8e655fc47ec08ddc6ed90640598959a81b31d49b 100644 (file)
@@ -1,11 +1,11 @@
 #include "platform_api.h"
+#include "streamserversocket.h"
 #include <msp/core/refptr.h>
 #include <msp/core/systemerror.h>
 #include <msp/io/handle_private.h>
 #include <msp/strings/format.h>
 #include "sockaddr_private.h"
 #include "socket_private.h"
-#include "streamserversocket.h"
 #include "streamsocket.h"
 
 using namespace std;
@@ -14,12 +14,14 @@ namespace Msp {
 namespace Net {
 
 StreamServerSocket::StreamServerSocket(Family af, int proto):
-       ServerSocket(af, SOCK_STREAM, proto),
-       listening(false)
+       ServerSocket(af, SOCK_STREAM, proto)
 { }
 
 void StreamServerSocket::listen(const SockAddr &addr, unsigned backlog)
 {
+       if(listening)
+               throw bad_socket_state("already listening");
+
        bind(addr);
 
        int err = ::listen(priv->handle, backlog);
index aa4868c121601899a340b01d9e85de70fa19a7a9..bbe1c91ca56038efcbf31c2d1bfacd78c66213fc 100644 (file)
@@ -1,22 +1,24 @@
 #ifndef MSP_NET_STREAMSERVERSOCKET_H_
 #define MSP_NET_STREAMSERVERSOCKET_H_
 
+#include "mspnet_api.h"
 #include "serversocket.h"
 #include "streamsocket.h"
 
 namespace Msp {
 namespace Net {
 
-class StreamServerSocket: public ServerSocket
+class MSPNET_API StreamServerSocket: public ServerSocket
 {
 private:
-       bool listening;
+       bool listening = false;
 
 public:
        StreamServerSocket(Family, int = 0);
 
-       virtual void listen(const SockAddr &, unsigned = 4);
-       virtual StreamSocket *accept();
+       void listen(const SockAddr &, unsigned = 4) override;
+       bool is_listening() const { return listening; }
+       StreamSocket *accept() override;
 };
 
 } // namespace Net
index 11887d398c4b09a894c088fb4b3656cd49bf8c45..9e2d79afe7e1b27dd87d81c5a6b2f7db2f6103b6 100644 (file)
@@ -1,11 +1,11 @@
 #include "platform_api.h"
+#include "streamsocket.h"
 #include <msp/core/systemerror.h>
 #include <msp/io/handle_private.h>
 #include <msp/io/poll.h>
 #include <msp/strings/format.h>
 #include "sockaddr_private.h"
 #include "socket_private.h"
-#include "streamsocket.h"
 
 namespace Msp {
 namespace Net {
@@ -34,13 +34,11 @@ bool StreamSocket::connect(const SockAddr &addr)
                set_socket_events(S_CONNECT);
        }
 
-       delete peer_addr;
-       peer_addr = addr.copy();
+       peer_addr.reset(addr.copy());
 
-       delete local_addr;
        SockAddr::SysAddr lsa;
        getsockname(priv->handle, reinterpret_cast<sockaddr *>(&lsa.addr), &lsa.size);
-       local_addr = SockAddr::new_from_sys(lsa);
+       local_addr.reset(SockAddr::new_from_sys(lsa));
 
        if(finished)
        {
@@ -99,10 +97,7 @@ void StreamSocket::on_event(IO::PollEvent ev)
                        signal_connect_finished.emit(0);
 
                if(err!=0)
-               {
-                       delete peer_addr;
-                       peer_addr = 0;
-               }
+                       peer_addr.reset();
 
                set_socket_events((err==0) ? S_INPUT : S_NONE);
        }
index 8b39e91afd358214168c977e299e01bc714d3e76..84b347dbc07da8ed04217da249248a7ea329c28f 100644 (file)
@@ -2,11 +2,12 @@
 #define MSP_NET_STREAMSOCKET_H_
 
 #include "clientsocket.h"
+#include "mspnet_api.h"
 
 namespace Msp {
 namespace Net {
 
-class StreamSocket: public ClientSocket
+class MSPNET_API StreamSocket: public ClientSocket
 {
        friend class StreamServerSocket;
 
@@ -23,12 +24,12 @@ public:
        If the socket is non-blocking, this function may return before the
        connection is fully established.  The caller must then use either the
        poll_connect function or an EventDispatcher to finish the process. */
-       virtual bool connect(const SockAddr &);
+       bool connect(const SockAddr &) override;
 
-       virtual bool poll_connect(const Time::TimeDelta &);
+       bool poll_connect(const Time::TimeDelta &) override;
 
 private:
-       void on_event(IO::PollEvent);
+       void on_event(IO::PollEvent) override;
 };
 
 } // namespace Net
index fe00020988bf92c1b915b7e6a7ef399ad7e9c4af..1da822b2da03be2a7dc8ad976ec2e82816d13a5a 100644 (file)
@@ -5,11 +5,6 @@ using namespace std;
 namespace Msp {
 namespace Net {
 
-UnixAddr::UnixAddr():
-       abstract(false)
-{
-}
-
 string UnixAddr::str() const
 {
        string result = "unix:";
index 821e915f0154d814163621ee4e47371ff70daf96..86f1dddf1ea9647bc8a55b43b156658d67677510 100644 (file)
@@ -1,28 +1,29 @@
 #ifndef MSP_NET_UNIX_H_
 #define MSP_NET_UNIX_H_
 
+#include "mspnet_api.h"
 #include "sockaddr.h"
 
 namespace Msp {
 namespace Net {
 
-class UnixAddr: public SockAddr
+class MSPNET_API UnixAddr: public SockAddr
 {
 private:
        std::string path;
-       bool abstract;
+       bool abstract = false;
 
 public:
-       UnixAddr();
+       UnixAddr() = default;
        UnixAddr(const SysAddr &);
        UnixAddr(const std::string &, bool = false);
 
-       virtual UnixAddr *copy() const { return new UnixAddr(*this); }
+       UnixAddr *copy() const override { return new UnixAddr(*this); }
 
-       virtual SysAddr to_sys() const;
+       SysAddr to_sys() const override;
 
-       virtual Family get_family() const { return UNIX; }
-       virtual std::string str() const;
+       Family get_family() const override { return UNIX; }
+       std::string str() const override;
 };
 
 } // namespace Net
index 02ef7c3c06c4999e7d62ebe838427918271f6d99..2c1e2ad8d355ac9e0fe9a31b5292f0709535c2f8 100644 (file)
@@ -1,20 +1,23 @@
+#include "platform_api.h"
+#include "socket.h"
 #include <cerrno>
 #include <unistd.h>
 #include <fcntl.h>
-#include "platform_api.h"
 #include <msp/core/systemerror.h>
 #include <msp/io/handle_private.h>
 #include <msp/time/rawtime_private.h>
 #include "sockaddr_private.h"
-#include "socket.h"
 #include "socket_private.h"
 
+using namespace std;
+
 namespace Msp {
 namespace Net {
 
 void Socket::platform_init()
 {
        *priv->event = priv->handle;
+       set_inherit(false);
 }
 
 void Socket::platform_cleanup()
@@ -37,7 +40,13 @@ void Socket::set_platform_events(unsigned)
 void Socket::Private::set_block(bool b)
 {
        int flags = fcntl(handle, F_GETFL);
-       fcntl(handle, F_SETFL, (flags&O_NONBLOCK)|(b?0:O_NONBLOCK));
+       fcntl(handle, F_SETFL, (flags&~O_NONBLOCK)|(b?0:O_NONBLOCK));
+}
+
+void Socket::Private::set_inherit(bool i)
+{
+       int flags = fcntl(handle, F_GETFD);
+       fcntl(handle, F_SETFD, (flags&~O_CLOEXEC)|(i?0:O_CLOEXEC));
 }
 
 int Socket::Private::set_option(int level, int optname, const void *optval, socklen_t optlen)
@@ -51,11 +60,11 @@ int Socket::Private::get_option(int level, int optname, void *optval, socklen_t
 }
 
 
-unsigned check_sys_error(int ret, const char *func)
+size_t check_sys_error(make_signed<size_t>::type ret, const char *func)
 {
        if(ret<0)
        {
-               if(errno==EAGAIN)
+               if(errno==EAGAIN || errno==EWOULDBLOCK)
                        return 0;
                else
                        throw system_error(func);
index 6bb34361584e773b38fa8309cb84c2e3eca365bd..648d53a303a946f5dc64e1be09481142f573f36c 100644 (file)
@@ -1,16 +1,15 @@
+#include "platform_api.h"
+#include "unix.h"
 #include <stdexcept>
 #include <sys/un.h>
-#include "platform_api.h"
 #include "sockaddr_private.h"
-#include "unix.h"
 
 using namespace std;
 
 namespace Msp {
 namespace Net {
 
-UnixAddr::UnixAddr(const SysAddr &sa):
-       abstract(false)
+UnixAddr::UnixAddr(const SysAddr &sa)
 {
        const sockaddr_un &sau = reinterpret_cast<const sockaddr_un &>(sa.addr);
        if(static_cast<size_t>(sa.size)>sizeof(sa_family_t))
@@ -25,7 +24,7 @@ UnixAddr::UnixAddr(const string &p, bool a):
        abstract(a)
 {
        if(sizeof(sa_family_t)+path.size()+1>sizeof(sockaddr_storage))
-               throw invalid_argument("UnixAddr");
+               throw invalid_argument("UnixAddr::UnixAddr");
 }
 
 SockAddr::SysAddr UnixAddr::to_sys() const
index f4fab44c478d491f0677b333dfc26a853113be78..eadc22af3a2de5355c349d16e201f8d1ec01b032 100644 (file)
@@ -1,11 +1,13 @@
-#include <iostream>
 #include "platform_api.h"
+#include "socket.h"
+#include <iostream>
 #include <msp/core/systemerror.h>
 #include <msp/io/handle_private.h>
 #include "sockaddr_private.h"
-#include "socket.h"
 #include "socket_private.h"
 
+using namespace std;
+
 namespace {
 
 class WinSockHelper
@@ -25,7 +27,7 @@ public:
        }
 };
 
-WinSockHelper wsh;
+unique_ptr<WinSockHelper> wsh;
 
 }
 
@@ -35,6 +37,8 @@ namespace Net {
 
 void Socket::platform_init()
 {
+       if(!wsh)
+               wsh = make_unique<WinSockHelper>();
        *priv->event = CreateEvent(0, false, false, 0);
 }
 
@@ -70,6 +74,10 @@ void Socket::Private::set_block(bool b)
        ioctlsocket(handle, FIONBIO, &flag);
 }
 
+void Socket::Private::set_inherit(bool)
+{
+}
+
 int Socket::Private::set_option(int level, int optname, const void *optval, socklen_t optlen)
 {
        return setsockopt(handle, level, optname, reinterpret_cast<const char *>(optval), optlen);
@@ -81,7 +89,7 @@ int Socket::Private::get_option(int level, int optname, void *optval, socklen_t
 }
 
 
-unsigned check_sys_error(int ret, const char *func)
+size_t check_sys_error(make_signed<size_t>::type ret, const char *func)
 {
        if(ret<0)
        {
index 2aa885708da84efb9c33afde5b4e5fa0639fdcb3..b38a258b24421f208f42e7366d1552bd401e204c 100644 (file)
@@ -1,17 +1,16 @@
-#include <stdexcept>
 #include "platform_api.h"
-#include "sockaddr_private.h"
 #include "unix.h"
+#include <msp/core/except.h>
+#include "sockaddr_private.h"
 
 using namespace std;
 
 namespace Msp {
 namespace Net {
 
-UnixAddr::UnixAddr(const SysAddr &):
-       abstract(false)
+UnixAddr::UnixAddr(const SysAddr &)
 {
-       throw logic_error("AF_UNIX not supported");
+       throw unsupported("AF_UNIX");
 }
 
 UnixAddr::UnixAddr(const string &p, bool a):
@@ -22,7 +21,7 @@ UnixAddr::UnixAddr(const string &p, bool a):
 
 SockAddr::SysAddr UnixAddr::to_sys() const
 {
-       throw logic_error("AF_UNIX not supported");
+       throw unsupported("AF_UNIX");
 }
 
 } // namespace Net
diff --git a/tests/.gitignore b/tests/.gitignore
new file mode 100644 (file)
index 0000000..ee4c926
--- /dev/null
@@ -0,0 +1 @@
+/test
diff --git a/tests/Build b/tests/Build
new file mode 100644 (file)
index 0000000..971a59d
--- /dev/null
@@ -0,0 +1,11 @@
+package "mspnet-tests"
+{
+       require "mspcore";
+       require "mspnet";
+       require "msptest";
+
+       program "test"
+       {
+               source ".";
+       };
+};
diff --git a/tests/http_header.cpp b/tests/http_header.cpp
new file mode 100644 (file)
index 0000000..e123714
--- /dev/null
@@ -0,0 +1,99 @@
+#include <msp/http/header.h>
+#include <msp/test/test.h>
+
+using namespace std;
+using namespace Msp;
+
+class HttpHeaderTests: public Test::RegisteredTest<HttpHeaderTests>
+{
+public:
+       HttpHeaderTests();
+
+       static const char *get_name() { return "Http::Header"; }
+
+private:
+       void single_value();
+       void list_of_values();
+       void key_value_list();
+       void value_with_attributes();
+       void list_with_attributes();
+       void quoted_values();
+};
+
+HttpHeaderTests::HttpHeaderTests()
+{
+       add(&HttpHeaderTests::single_value, "Single value");
+       add(&HttpHeaderTests::list_of_values, "List");
+       add(&HttpHeaderTests::key_value_list, "Key-value list");
+       add(&HttpHeaderTests::value_with_attributes, "Value with attributes");
+       add(&HttpHeaderTests::list_with_attributes, "List with attributes");
+       add(&HttpHeaderTests::quoted_values, "Quoted values");
+}
+
+void HttpHeaderTests::single_value()
+{
+       Http::Header header("test", "1;2 , 3 ", Http::Header::SINGLE_VALUE);
+       EXPECT_EQUAL(header.values.size(), 1);
+       EXPECT_EQUAL(header.values.front().value, "1;2 , 3");
+       EXPECT(header.values.front().parameters.empty());
+}
+
+void HttpHeaderTests::list_of_values()
+{
+       Http::Header header("test", "1;2 , 3 ", Http::Header::LIST);
+       EXPECT_EQUAL(header.values.size(), 3);
+       EXPECT_EQUAL(header.values[0].value, "1");
+       EXPECT_EQUAL(header.values[1].value, "2");
+       EXPECT_EQUAL(header.values[2].value, "3");
+}
+
+void HttpHeaderTests::key_value_list()
+{
+       Http::Header header("test", "a=1; b = 2 ,c=3 ", Http::Header::KEY_VALUE_LIST);
+       EXPECT_EQUAL(header.values.size(), 1);
+       EXPECT_EQUAL(header.values.front().parameters.count("a"), 1);
+       EXPECT_EQUAL(header.values.front().parameters["a"], "1");
+       EXPECT_EQUAL(header.values.front().parameters.count("b"), 1);
+       EXPECT_EQUAL(header.values.front().parameters["b"], "2");
+       EXPECT_EQUAL(header.values.front().parameters.count("c"), 1);
+       EXPECT_EQUAL(header.values.front().parameters["c"], "3");
+}
+
+void HttpHeaderTests::value_with_attributes()
+{
+       Http::Header header("test", "X;a=1, 2 ; b=3 ", Http::Header::VALUE_WITH_ATTRIBUTES);
+       EXPECT_EQUAL(header.values.size(), 1);
+       EXPECT_EQUAL(header.values.front().value, "X");
+       EXPECT_EQUAL(header.values.front().parameters.count("a"), 1);
+       EXPECT_EQUAL(header.values.front().parameters["a"], "1, 2");
+       EXPECT_EQUAL(header.values.front().parameters.count("b"), 1);
+       EXPECT_EQUAL(header.values.front().parameters["b"], "3");
+}
+
+void HttpHeaderTests::list_with_attributes()
+{
+       Http::Header header("test", "X;a= 1;b= 2 ,Y ; c=3 ", Http::Header::LIST_WITH_ATTRIBUTES);
+       EXPECT_EQUAL(header.values.size(), 2);
+       EXPECT_EQUAL(header.values[0].value, "X");
+       EXPECT_EQUAL(header.values[0].parameters.count("a"), 1);
+       EXPECT_EQUAL(header.values[0].parameters["a"], "1");
+       EXPECT_EQUAL(header.values[0].parameters.count("b"), 1);
+       EXPECT_EQUAL(header.values[0].parameters["b"], "2");
+       EXPECT_EQUAL(header.values[1].value, "Y");
+       EXPECT_EQUAL(header.values[1].parameters.count("c"), 1);
+       EXPECT_EQUAL(header.values[1].parameters["c"], "3");
+}
+
+void HttpHeaderTests::quoted_values()
+{
+       Http::Header header("test", "X;a= 1;b=\" ;2, \",Y ; c=\"3 \" ", Http::Header::LIST_WITH_ATTRIBUTES);
+       EXPECT_EQUAL(header.values.size(), 2);
+       EXPECT_EQUAL(header.values[0].value, "X");
+       EXPECT_EQUAL(header.values[0].parameters.count("a"), 1);
+       EXPECT_EQUAL(header.values[0].parameters["a"], "1");
+       EXPECT_EQUAL(header.values[0].parameters.count("b"), 1);
+       EXPECT_EQUAL(header.values[0].parameters["b"], " ;2, ");
+       EXPECT_EQUAL(header.values[1].value, "Y");
+       EXPECT_EQUAL(header.values[1].parameters.count("c"), 1);
+       EXPECT_EQUAL(header.values[1].parameters["c"], "3 ");
+}
diff --git a/tests/protocol.cpp b/tests/protocol.cpp
new file mode 100644 (file)
index 0000000..bca1161
--- /dev/null
@@ -0,0 +1,179 @@
+#include <msp/net/protocol.h>
+#include <msp/test/test.h>
+
+using namespace std;
+using namespace Msp;
+
+class ProtocolTests: public Test::RegisteredTest<ProtocolTests>
+{
+public:
+       ProtocolTests();
+
+       static const char *get_name() { return "Protocol"; }
+
+private:
+       void hash_match();
+       void buffer_overflow();
+       void truncated_packet();
+       void stub_header();
+
+       template<typename P>
+       void transmit(const P &, P &, size_t);
+
+       void transmit_int();
+       void transmit_string();
+       void transmit_array();
+       void transmit_composite();
+};
+
+
+class Protocol: public Msp::Net::Protocol
+{
+public:
+       Protocol();
+};
+
+struct Packet1
+{
+       uint32_t value;
+};
+
+struct Packet2
+{
+       std::string value;
+};
+
+struct Packet3
+{
+       std::vector<uint32_t> values;
+};
+
+struct Packet4
+{
+       Packet2 sub1;
+       std::vector<Packet3> sub2;
+};
+
+Protocol::Protocol()
+{
+       add<Packet1>(&Packet1::value);
+       add<Packet2>(&Packet2::value);
+       add<Packet3>(&Packet3::values);
+       add<Packet4>(&Packet4::sub1, &Packet4::sub2);
+}
+
+template<typename T>
+class Receiver: public Net::PacketReceiver<T>
+{
+private:
+       T &storage;
+
+public:
+       Receiver(T &s): storage(s) { }
+
+       void receive(const T &p) override { storage = p; }
+};
+
+
+ProtocolTests::ProtocolTests()
+{
+       add(&ProtocolTests::hash_match, "Hash match");
+       add(&ProtocolTests::buffer_overflow, "Serialization buffer overflow").expect_throw<Net::buffer_error>();
+       add(&ProtocolTests::truncated_packet, "Truncated packet").expect_throw<Net::bad_packet>();
+       add(&ProtocolTests::stub_header, "Stub header");
+       add(&ProtocolTests::transmit_int, "Integer transmission");
+       add(&ProtocolTests::transmit_string, "String transmission");
+       add(&ProtocolTests::transmit_array, "Array transmission");
+       add(&ProtocolTests::transmit_composite, "Composite transmission");
+}
+
+void ProtocolTests::hash_match()
+{
+       Protocol proto1;
+       Protocol proto2;
+       EXPECT_EQUAL(proto1.get_hash(), proto2.get_hash());
+}
+
+void ProtocolTests::buffer_overflow()
+{
+       Protocol proto;
+       Packet1 pkt = { 42 };
+       char buffer[7];
+       proto.serialize(pkt, buffer, sizeof(buffer));
+}
+
+void ProtocolTests::truncated_packet()
+{
+       Protocol proto;
+       Packet1 pkt = { 42 };
+       char buffer[16];
+       size_t len = proto.serialize(pkt, buffer, sizeof(buffer));
+       Receiver<Packet1> recv(pkt);
+       proto.dispatch(recv, buffer, len-1);
+}
+
+void ProtocolTests::stub_header()
+{
+       Protocol proto;
+       char buffer[3] = { 4, 0, 1 };
+       size_t len = proto.get_packet_size(buffer, sizeof(buffer));
+       EXPECT_EQUAL(len, 0);
+}
+
+template<typename P>
+void ProtocolTests::transmit(const P &pkt, P &rpkt, size_t expected_length)
+{
+       Protocol proto;
+       char buffer[128];
+       size_t len = proto.serialize(pkt, buffer, sizeof(buffer));
+       EXPECT_EQUAL(len, expected_length);
+
+       size_t rlen = proto.get_packet_size(buffer, sizeof(buffer));
+       EXPECT_EQUAL(rlen, len);
+
+       Receiver<P> recv(rpkt);
+       size_t dlen = proto.dispatch(recv, buffer, sizeof(buffer));
+       EXPECT_EQUAL(dlen, len);
+}
+
+void ProtocolTests::transmit_int()
+{
+       Packet1 pkt = { 42 };
+       Packet1 rpkt;
+       transmit(pkt, rpkt, 8);
+       EXPECT_EQUAL(rpkt.value, 42);
+}
+
+void ProtocolTests::transmit_string()
+{
+       Packet2 pkt = { "Hello" };
+       Packet2 rpkt;
+       transmit(pkt, rpkt, 11);
+       EXPECT_EQUAL(rpkt.value, "Hello");
+}
+
+void ProtocolTests::transmit_array()
+{
+       Packet3 pkt = {{ 2, 3, 5, 7, 11 }};
+       Packet3 rpkt;
+       transmit(pkt, rpkt, 26);
+       EXPECT_EQUAL(rpkt.values.size(), 5);
+       for(size_t i=0; i<pkt.values.size(); ++i)
+               EXPECT_EQUAL(rpkt.values[i], pkt.values[i]);
+}
+
+void ProtocolTests::transmit_composite()
+{
+       Packet4 pkt = { "Don't panic", { }};
+       pkt.sub2.emplace_back(Packet3{{ 2, 3, 5, 7, 11 }});
+       pkt.sub2.emplace_back(Packet3{{ 20, 10, 5, 16, 8, 4, 2, 1 }});
+       Packet4 rpkt;
+       transmit(pkt, rpkt, 75);
+       EXPECT_EQUAL(rpkt.sub1.value, "Don't panic");
+       EXPECT_EQUAL(rpkt.sub2.size(), 2);
+       EXPECT_EQUAL(rpkt.sub2[0].values.size(), 5);
+       EXPECT_EQUAL(rpkt.sub2[1].values.size(), 8);
+       for(size_t i=0; i<pkt.sub2.size(); ++i)
+               for(size_t j=0; j<pkt.sub2[i].values.size(); ++j)       
+                       EXPECT_EQUAL(rpkt.sub2[i].values[j], pkt.sub2[i].values[j]);
+}