From: Mikko Rasa Date: Thu, 1 Jun 2023 07:31:43 +0000 (+0300) Subject: Add a dynamic receiver class for more flexible packet handling X-Git-Url: http://git.tdb.fi/?p=libs%2Fnet.git;a=commitdiff_plain;h=HEAD;hp=2dfa05663dd67d4d7c68f96df0b1ab733b2063c2 Add a dynamic receiver class for more flexible packet handling It's useful for middleware libraries which may not know all packet types at compile time, or if different packets need to be routed to different receivers. --- diff --git a/Build b/Build index 404ddbf..44bf9c0 100644 --- a/Build +++ b/Build @@ -1,5 +1,7 @@ package "mspnet" { + version "1.0"; + require "sigc++-2.0"; require "mspcore"; if_arch "windows" @@ -10,6 +12,11 @@ package "mspnet" }; }; + build_info + { + standard CXX "c++14"; + }; + library "mspnet" { source "source/net"; diff --git a/examples/httpget.cpp b/examples/httpget.cpp index a4314ce..18ef799 100644 --- a/examples/httpget.cpp +++ b/examples/httpget.cpp @@ -28,13 +28,8 @@ HttpGet::HttpGet(int argc, char **argv): { GetOpt getopt; getopt.add_option('v', "verbose", verbose, GetOpt::NO_ARG); + getopt.add_argument("url", url, GetOpt::REQUIRED_ARG); getopt(argc, argv); - - const vector &args = getopt.get_args(); - if(args.empty()) - throw usage_error("No URL"); - - url = args.front(); } int HttpGet::main() diff --git a/examples/netcat.cpp b/examples/netcat.cpp index 9bbad87..cd549f4 100644 --- a/examples/netcat.cpp +++ b/examples/netcat.cpp @@ -34,16 +34,15 @@ NetCat::NetCat(int argc, char **argv): server_sock(0), sock(0) { + string host_name; + GetOpt getopt; getopt.add_option('6', "ipv6", ipv6, GetOpt::NO_ARG); getopt.add_option('l', "listen", listen, GetOpt::NO_ARG); + getopt.add_argument("host", host_name, GetOpt::REQUIRED_ARG); getopt(argc, argv); - const vector &args = getopt.get_args(); - if(args.empty()) - throw usage_error("host argument missing"); - - RefPtr addr = Net::resolve(args.front(), (ipv6 ? Net::INET6 : Net::INET)); + RefPtr addr = Net::resolve(host_name, (ipv6 ? Net::INET6 : Net::INET)); if(!listen) { sock = new Net::StreamSocket(addr->get_family()); diff --git a/source/http/client.cpp b/source/http/client.cpp index 6f22181..1c4c953 100644 --- a/source/http/client.cpp +++ b/source/http/client.cpp @@ -1,7 +1,7 @@ -#include -#include -#include #include "client.h" +#include +#include +#include #include "request.h" #include "response.h" @@ -10,22 +10,8 @@ using namespace std; namespace Msp { namespace Http { -Client::Client(): - sock(0), - event_disp(0), - resolver(0), - resolve_listener(0), - resolve_tag(0), - user_agent("libmsphttp/0.1"), - request(0), - response(0) -{ } - Client::~Client() { - delete sock; - delete request; - delete response; } void Client::use_event_dispatcher(IO::EventDispatcher *ed) @@ -40,30 +26,25 @@ void Client::use_event_dispatcher(IO::EventDispatcher *ed) void Client::use_resolver(Net::Resolver *r) { if(resolver) - { - delete resolve_listener; - resolve_listener = 0; - } + resolve_listener.reset(); resolver = r; if(resolver) - resolve_listener = new ResolveListener(*this); + resolve_listener = make_unique(*this); } void Client::start_request(const Request &r) { if(request) - throw client_busy(); + throw invalid_state("already processing a request"); - delete sock; - sock = 0; + sock.reset(); - request = new Request(r); + request = make_unique(r); if(!user_agent.empty()) request->set_header("User-Agent", user_agent); - delete response; - response = 0; + response.reset(); in_buf.clear(); string host = r.get_header("Host"); @@ -73,7 +54,7 @@ void Client::start_request(const Request &r) resolve_tag = resolver->resolve(host); else { - RefPtr addr = Net::resolve(host); + unique_ptr addr(Net::resolve(host)); address_resolved(resolve_tag, *addr); } } @@ -82,7 +63,7 @@ const Response *Client::get_url(const std::string &url) { start_request(Request::from_url(url)); wait_response(); - return response; + return response.get(); } void Client::tick() @@ -97,10 +78,8 @@ void Client::tick() { signal_response_complete.emit(*response); - delete sock; - sock = 0; - delete request; - request = 0; + sock.reset(); + request.reset(); } } @@ -112,10 +91,8 @@ void Client::wait_response() void Client::abort() { - delete sock; - sock = 0; - delete request; - request = 0; + sock.reset(); + request.reset(); } void Client::address_resolved(unsigned tag, const Net::SockAddr &addr) @@ -124,7 +101,7 @@ void Client::address_resolved(unsigned tag, const Net::SockAddr &addr) return; resolve_tag = 0; - sock = new Net::StreamSocket(addr.get_family()); + sock = make_unique(addr.get_family()); sock->set_block(false); sock->signal_data_available.connect(sigc::mem_fun(this, &Client::data_available)); @@ -141,23 +118,37 @@ void Client::resolve_failed(unsigned tag, const exception &err) return; resolve_tag = 0; - signal_socket_error.emit(err); + request.reset(); - delete request; - request = 0; + if(signal_socket_error.empty()) + throw err; + signal_socket_error.emit(err); } void Client::connect_finished(const exception *err) { if(err) { - signal_socket_error.emit(*err); + request.reset(); - delete request; - request = 0; + if(signal_socket_error.empty()) + throw *err; + signal_socket_error.emit(*err); } else - sock->write(request->str()); + { + try + { + sock->write(request->str()); + } + catch(const exception &e) + { + if(signal_socket_error.empty()) + throw; + signal_socket_error.emit(e); + return; + } + } } void Client::data_available() @@ -170,6 +161,8 @@ void Client::data_available() } catch(const exception &e) { + if(signal_socket_error.empty()) + throw; signal_socket_error.emit(e); return; } @@ -182,7 +175,7 @@ void Client::data_available() { if(in_buf.find("\r\n\r\n")!=string::npos || in_buf.find("\n\n")!=string::npos) { - response = new Response(Response::parse(in_buf)); + response = make_unique(Response::parse(in_buf)); response->set_user_data(request->get_user_data()); in_buf = string(); } @@ -197,8 +190,7 @@ void Client::data_available() { signal_response_complete.emit(*response); - delete request; - request = 0; + request.reset(); } } diff --git a/source/http/client.h b/source/http/client.h index ebb7321..b1f097d 100644 --- a/source/http/client.h +++ b/source/http/client.h @@ -1,26 +1,21 @@ #ifndef MSP_HTTP_CLIENT_H_ #define MSP_HTTP_CLIENT_H_ +#include #include #include #include +#include #include #include namespace Msp { namespace Http { -class client_busy: public std::logic_error -{ -public: - client_busy(): std::logic_error(std::string()) { } - virtual ~client_busy() throw() { } -}; - class Request; class Response; -class Client +class MSPNET_API Client { public: sigc::signal signal_response_complete; @@ -37,20 +32,17 @@ private: void resolve_failed(unsigned, const std::exception &); }; - Net::StreamSocket *sock; - IO::EventDispatcher *event_disp; - Net::Resolver *resolver; - ResolveListener *resolve_listener; - unsigned resolve_tag; - std::string user_agent; - Request *request; - Response *response; + std::unique_ptr sock; + IO::EventDispatcher *event_disp = nullptr; + Net::Resolver *resolver = nullptr; + std::unique_ptr resolve_listener; + unsigned resolve_tag = 0; + std::string user_agent = "libmspnet/1.0"; + std::unique_ptr request; + std::unique_ptr response; std::string in_buf; - Client(const Client &); - Client &operator=(const Client &); public: - Client(); ~Client(); void use_event_dispatcher(IO::EventDispatcher *); @@ -61,8 +53,8 @@ public: void tick(); void wait_response(); void abort(); - const Request *get_request() const { return request; } - const Response *get_response() const { return response; } + const Request *get_request() const { return request.get(); } + const Response *get_response() const { return response.get(); } private: void address_resolved(unsigned, const Net::SockAddr &); void resolve_failed(unsigned, const std::exception &); diff --git a/source/http/formdata.cpp b/source/http/formdata.cpp index 182c5bb..4f2b7b3 100644 --- a/source/http/formdata.cpp +++ b/source/http/formdata.cpp @@ -1,9 +1,9 @@ +#include "formdata.h" #include #include "header.h" #include "request.h" #include "submessage.h" #include "utils.h" -#include "formdata.h" using namespace std; @@ -53,15 +53,15 @@ void FormData::parse_multipart(const Request &req, const string &boundary) if(is_boundary) { - /* The CRLF preceding the boundary delimiter is treated as part - of the delimiter as per RFC 2046 */ - string::size_type part_end = line_start-1; - if(content[part_end-1]=='\r') - --part_end; - if(part_start>0) { - SubMessage part = SubMessage::parse(content.substr(part_start, line_start-part_start)); + /* The CRLF preceding the boundary delimiter is treated as part + of the delimiter as per RFC 2046 */ + string::size_type part_end = line_start-1; + if(content[part_end-1]=='\r') + --part_end; + + SubMessage part = SubMessage::parse(content.substr(part_start, part_end-part_start)); Header content_disposition(part, "Content-Disposition"); const Header::Value &cd_value = content_disposition.values.at(0); if(cd_value.value=="form-data") @@ -72,10 +72,10 @@ void FormData::parse_multipart(const Request &req, const string &boundary) } part_start = lf+1; - } - if(!content.compare(line_start+2+boundary.size(), 2, "--")) - break; + if(!content.compare(line_start+2+boundary.size(), 2, "--")) + break; + } line_start = lf+1; } @@ -83,7 +83,7 @@ void FormData::parse_multipart(const Request &req, const string &boundary) const string &FormData::get_value(const string &key) const { - map::const_iterator i = fields.find(key); + auto i = fields.find(key); if(i==fields.end()) { static string dummy; diff --git a/source/http/formdata.h b/source/http/formdata.h index 3232891..cf78acd 100644 --- a/source/http/formdata.h +++ b/source/http/formdata.h @@ -3,13 +3,14 @@ #include #include +#include namespace Msp { namespace Http { class Request; -class FormData +class MSPNET_API FormData { private: std::map fields; diff --git a/source/http/header.cpp b/source/http/header.cpp index 02c82ca..a112e91 100644 --- a/source/http/header.cpp +++ b/source/http/header.cpp @@ -1,6 +1,6 @@ +#include "header.h" #include #include -#include "header.h" #include "message.h" using namespace std; @@ -8,37 +8,74 @@ using namespace std; namespace Msp { namespace Http { -Header::Header(const Message &msg, const string &n): +Header::Header(const Message &msg, const string &n, Style s): name(n), + style(s), raw_value(msg.get_header(name)) { parse(); } -Header::Header(const string &n, const string &rv): +Header::Header(const string &n, const string &rv, Style s): name(n), + style(s), raw_value(rv) { parse(); } +Header::Style Header::get_default_style(const string &name) +{ + if(!strcasecmp(name, "content-disposition")) + return VALUE_WITH_ATTRIBUTES; + else if(!strcasecmp(name, "content-type")) + return VALUE_WITH_ATTRIBUTES; + else if(!strcasecmp(name, "cookie")) + return KEY_VALUE_LIST; + else if(!strcasecmp(name, "set-cookie")) + return VALUE_WITH_ATTRIBUTES; + else + return SINGLE_VALUE; +} + void Header::parse() { - string::const_iterator i = raw_value.begin(); - while(i!=raw_value.end()) + if(style==DEFAULT) + style = get_default_style(name); + + if(style==SINGLE_VALUE) { Value value; + value.value = strip(raw_value); + values.push_back(value); + return; + } - string::const_iterator start = i; - for(; (i!=raw_value.end() && *i!=';' && *i!=','); ++i) ; - value.value = strip(string(start, i)); - if(value.value.empty()) - throw invalid_argument("Header::parse"); + char value_sep = (style==VALUE_WITH_ATTRIBUTES ? 0 : ','); - while(i!=raw_value.end() && *i!=',') + auto i = raw_value.cbegin(); + while(i!=raw_value.cend()) + { + Value value; + + auto start = i; + if(style==KEY_VALUE_LIST) + value.value = name; + else + { + for(; (i!=raw_value.end() && *i!=';' && *i!=value_sep); ++i) ; + value.value = strip(string(start, i)); + if(value.value.empty()) + throw invalid_argument("Header::parse"); + } + + while(i!=raw_value.end() && (*i!=',' || style==KEY_VALUE_LIST) && style!=LIST) { - start = ++i; - for(; (i!=raw_value.end() && *i!=';' && *i!=',' && *i!='='); ++i) ; + if(*i==';' || *i==',') + ++i; + + start = i; + for(; (i!=raw_value.end() && *i!=';' && *i!=value_sep && *i!='='); ++i) ; string pname = strip(string(start, i)); if(pname.empty()) throw invalid_argument("Header::parse"); @@ -47,7 +84,7 @@ void Header::parse() if(i!=raw_value.end() && *i=='=') { for(++i; (i!=raw_value.end() && isspace(*i)); ++i) ; - if(i==raw_value.end() || *i==';' || *i==',') + if(i==raw_value.end() || *i==';' || *i==value_sep) throw invalid_argument("Header::parse"); if(*i=='"') @@ -59,14 +96,14 @@ void Header::parse() pvalue = string(start, i); - for(++i; (i!=raw_value.end() && *i!=';' && *i!=','); ++i) + for(++i; (i!=raw_value.end() && *i!=';' && *i!=value_sep); ++i) if(!isspace(*i)) throw invalid_argument("Header::parse"); } else { start = i; - for(; (i!=raw_value.end() && *i!=';' && *i!=','); ++i) ; + for(; (i!=raw_value.end() && *i!=';' && *i!=value_sep); ++i) ; pvalue = strip(string(start, i)); } } @@ -75,6 +112,9 @@ void Header::parse() } values.push_back(value); + + if(i!=raw_value.end() && (*i==';' || *i==',')) + ++i; } } diff --git a/source/http/header.h b/source/http/header.h index 202b844..1beb95f 100644 --- a/source/http/header.h +++ b/source/http/header.h @@ -4,14 +4,25 @@ #include #include #include +#include namespace Msp { namespace Http { class Message; -struct Header +struct MSPNET_API Header { + enum Style + { + DEFAULT, + SINGLE_VALUE, + LIST, + KEY_VALUE_LIST, + VALUE_WITH_ATTRIBUTES, + LIST_WITH_ATTRIBUTES + }; + struct Value { std::string value; @@ -19,12 +30,15 @@ struct Header }; std::string name; + Style style; std::string raw_value; std::vector values; - Header() { } - Header(const Message &, const std::string &); - Header(const std::string &, const std::string &); + Header() = default; + Header(const Message &, const std::string &, Style = DEFAULT); + Header(const std::string &, const std::string &, Style = DEFAULT); + + static Style get_default_style(const std::string &); void parse(); }; diff --git a/source/http/message.cpp b/source/http/message.cpp index b86f954..a6ab465 100644 --- a/source/http/message.cpp +++ b/source/http/message.cpp @@ -1,20 +1,14 @@ +#include "message.h" #include #include #include #include -#include "message.h" using namespace std; namespace Msp { namespace Http { -Message::Message(): - http_version(0x11), - chunk_length(0), - complete(false) -{ } - void Message::set_header(const string &hdr, const string &val) { headers[normalize_header_name(hdr)] = val; @@ -48,7 +42,7 @@ unsigned Message::parse_content(const string &d) if(complete) return 0; - HeaderMap::const_iterator i = headers.find("Content-Length"); + auto i = headers.find("Content-Length"); if(i!=headers.end()) { string::size_type needed = lexical_cast(i->second)-content.size(); @@ -120,9 +114,9 @@ string Message::str_common() const { string result; - for(HeaderMap::const_iterator i=headers.begin(); i!=headers.end(); ++i) - if(i->first[0]!='-') - result += format("%s: %s\r\n", i->first, i->second); + for(auto &kvp: headers) + if(kvp.first[0]!='-') + result += format("%s: %s\r\n", kvp.first, kvp.second); result += "\r\n"; result += content; @@ -133,17 +127,17 @@ string Message::normalize_header_name(const string &hdr) const { string result = hdr; bool upper = true; - for(string::iterator i=result.begin(); i!=result.end(); ++i) + for(char &c: result) { - if(*i=='-') + if(c=='-') upper = true; else if(upper) { - *i = toupper(*i); + c = toupper(static_cast(c)); upper = false; } else - *i = tolower(*i); + c = tolower(static_cast(c)); } return result; } diff --git a/source/http/message.h b/source/http/message.h index e525360..be2224d 100644 --- a/source/http/message.h +++ b/source/http/message.h @@ -4,26 +4,27 @@ #include #include #include +#include #include "version.h" namespace Msp { namespace Http { -class Message +class MSPNET_API Message { protected: typedef std::map HeaderMap; - Version http_version; + Version http_version = 0x11; HeaderMap headers; std::string content; - std::string::size_type chunk_length; - bool complete; + std::string::size_type chunk_length = 0; + bool complete = false; Variant user_data; - Message(); + Message() = default; public: - virtual ~Message() { } + virtual ~Message() = default; void set_header(const std::string &, const std::string &); bool has_header(const std::string &) const; diff --git a/source/http/request.cpp b/source/http/request.cpp index 54b0bee..685de33 100644 --- a/source/http/request.cpp +++ b/source/http/request.cpp @@ -1,7 +1,7 @@ +#include "request.h" #include #include #include -#include "request.h" #include "utils.h" using namespace std; @@ -28,6 +28,8 @@ string Request::str() const Request Request::parse(const string &str) { string::size_type lf = str.find('\n'); + if(lf==0) + throw invalid_argument("Request::parse"); vector parts = split(str.substr(0, lf-(str[lf-1]=='\r')), ' ', 2); if(parts.size()<3) throw invalid_argument("Request::parse"); @@ -51,11 +53,7 @@ Request Request::from_url(const string &str) string path = urlencode(url.path); if(path.empty()) path = "/"; - if(!url.query.empty()) - { - path += '?'; - path += url.query; - } + append(path, "?", url.query); Request result("GET", path); result.set_header("Host", url.host); diff --git a/source/http/request.h b/source/http/request.h index ec5030e..11af993 100644 --- a/source/http/request.h +++ b/source/http/request.h @@ -2,12 +2,13 @@ #define MSP_HTTP_REQUEST_H_ #include +#include #include "message.h" namespace Msp { namespace Http { -class Request: public Message +class MSPNET_API Request: public Message { private: std::string method; @@ -17,7 +18,7 @@ public: Request(const std::string &, const std::string &); const std::string &get_method() const { return method; } const std::string &get_path() const { return path; } - virtual std::string str() const; + std::string str() const override; static Request parse(const std::string &); static Request from_url(const std::string &); diff --git a/source/http/response.cpp b/source/http/response.cpp index 739a20f..6005d35 100644 --- a/source/http/response.cpp +++ b/source/http/response.cpp @@ -1,6 +1,6 @@ +#include "response.h" #include #include -#include "response.h" using namespace std; @@ -24,7 +24,9 @@ Response Response::parse(const string &str) Response result; string::size_type lf = str.find('\n'); - vector parts = split(str.substr(0, lf), ' ', 2); + if(lf==0) + throw invalid_argument("Response::parse"); + vector parts = split(str.substr(0, lf-(str[lf-1]=='\r')), ' ', 2); if(parts.size()<2) throw invalid_argument("Response::parse"); diff --git a/source/http/response.h b/source/http/response.h index bb471f0..8644b72 100644 --- a/source/http/response.h +++ b/source/http/response.h @@ -1,22 +1,23 @@ #ifndef MSP_HTTP_RESPONSE_H_ #define MSP_HTTP_RESPONSE_H_ +#include #include "message.h" #include "status.h" namespace Msp { namespace Http { -class Response: public Message +class MSPNET_API Response: public Message { private: Status status; - Response() { } + Response() = default; public: Response(Status); Status get_status() const { return status; } - virtual std::string str() const; + std::string str() const override; static Response parse(const std::string &); }; diff --git a/source/http/server.cpp b/source/http/server.cpp index f0474fd..04bae0b 100644 --- a/source/http/server.cpp +++ b/source/http/server.cpp @@ -1,6 +1,8 @@ +#include "server.h" #include +#include #include -#include +#include #include #include #include @@ -8,20 +10,20 @@ #include #include "request.h" #include "response.h" -#include "server.h" using namespace std; namespace Msp { namespace Http { +Server::Server(): + sock(Net::INET6) +{ } + Server::Server(unsigned port): - sock(Net::INET), - event_disp(0) + sock(Net::INET6) { - sock.signal_data_available.connect(sigc::mem_fun(this, &Server::data_available)); - RefPtr addr = Net::resolve("*", format("%d", port)); - sock.listen(*addr, 8); + listen(port); } // Avoid emitting sigc::signal destructor in files including server.h @@ -29,6 +31,13 @@ Server::~Server() { } +void Server::listen(unsigned port) +{ + unique_ptr addr(Net::resolve("*", format("%d", port), Net::INET6)); + sock.listen(*addr, 8); + sock.signal_data_available.connect(sigc::mem_fun(this, &Server::data_available)); +} + unsigned Server::get_port() const { const Net::SockAddr &addr = sock.get_local_address(); @@ -42,15 +51,15 @@ void Server::use_event_dispatcher(IO::EventDispatcher *ed) if(event_disp) { event_disp->remove(sock); - for(list::iterator i=clients.begin(); i!=clients.end(); ++i) - event_disp->remove(*i->sock); + for(Client &c: clients) + event_disp->remove(*c.sock); } event_disp = ed; if(event_disp) { event_disp->add(sock); - for(list::iterator i=clients.begin(); i!=clients.end(); ++i) - event_disp->add(*i->sock); + for(Client &c: clients) + event_disp->add(*c.sock); } } @@ -71,66 +80,98 @@ void Server::cancel_keepalive(Response &resp) get_client_by_response(resp).keepalive = false; } +void Server::close_connections(const Time::TimeDelta &timeout) +{ + IO::Poller poller; + for(Client &c: clients) + { + c.sock->shutdown(IO::M_WRITE); + poller.set_object(*c.sock, IO::P_INPUT); + } + + while(!clients.empty() && poller.poll(timeout)) + { + for(const IO::Poller::PolledObject &p: poller.get_result()) + for(auto j=clients.begin(); j!=clients.end(); ++j) + if(j->sock.get()==p.object) + { + poller.set_object(*j->sock, IO::P_NONE); + clients.erase(j); + break; + } + } +} + void Server::data_available() { - Net::StreamSocket *csock = sock.accept(); - clients.push_back(Client(csock)); - csock->signal_data_available.connect(sigc::bind(sigc::mem_fun(this, &Server::client_data_available), sigc::ref(clients.back()))); - csock->signal_end_of_file.connect(sigc::bind(sigc::mem_fun(this, &Server::client_end_of_file), sigc::ref(clients.back()))); + unique_ptr csock(sock.accept()); + clients.emplace_back(move(csock)); + Client &cl = clients.back(); + cl.sock->signal_data_available.connect(sigc::bind(sigc::mem_fun(this, &Server::client_data_available), sigc::ref(clients.back()))); + cl.sock->signal_end_of_file.connect(sigc::bind(sigc::mem_fun(this, &Server::client_end_of_file), sigc::ref(clients.back()))); if(event_disp) - event_disp->add(*csock); + event_disp->add(*cl.sock); } void Server::client_data_available(Client &cl) { - for(list::iterator i=clients.begin(); i!=clients.end(); ++i) + for(auto i=clients.begin(); i!=clients.end(); ++i) if(i->stale && &*i!=&cl) { clients.erase(i); break; } - char rbuf[4096]; - unsigned len = cl.sock->read(rbuf, sizeof(rbuf)); - if(cl.stale) + try + { + char rbuf[4096]; + unsigned len = cl.sock->read(rbuf, sizeof(rbuf)); + if(cl.stale) + return; + cl.in_buf.append(rbuf, len); + } + catch(const exception &) + { + cl.stale = true; return; - cl.in_buf.append(rbuf, len); + } - RefPtr response; + unique_ptr response; if(!cl.request) { if(cl.in_buf.find("\r\n\r\n")!=string::npos || cl.in_buf.find("\n\n")!=string::npos) { try { - cl.request = new Request(Request::parse(cl.in_buf)); + cl.request = make_unique(Request::parse(cl.in_buf)); string addr_str = cl.sock->get_peer_address().str(); - string::size_type colon = addr_str.find(':'); + string::size_type colon = addr_str.find(':', (addr_str[0]=='[' ? addr_str.find(']')+1 : 0)); cl.request->set_header("-Client-Host", addr_str.substr(0, colon)); if(cl.request->get_method()!="GET" && cl.request->get_method()!="POST") { - response = new Response(NOT_IMPLEMENTED); + response = make_unique(NOT_IMPLEMENTED); response->add_content("Method not implemented\n"); } else if(cl.request->get_path()[0]!='/') { - response = new Response(BAD_REQUEST); + response = make_unique(BAD_REQUEST); response->add_content("Path must be absolute\n"); } } catch(const exception &e) { - response = new Response(BAD_REQUEST); - response->add_content(e.what()); + response = make_unique(BAD_REQUEST); + response->add_content(format("An error occurred while parsing request headers:\ntype: %s\nwhat: %s", + Debug::demangle(typeid(e).name()), e.what())); } cl.in_buf = string(); } } else { - len = cl.request->parse_content(cl.in_buf); + unsigned len = cl.request->parse_content(cl.in_buf); cl.in_buf.erase(0, len); } @@ -140,31 +181,30 @@ void Server::client_data_available(Client &cl) if(cl.request->has_header("Connection")) cl.keepalive = !strcasecmp(cl.request->get_header("Connection"), "keep-alive"); - response = new Response(NONE); + response = make_unique(NONE); try { - cl.response = response.get(); - responses[cl.response] = &cl; - signal_request.emit(*cl.request, *response); - if(cl.async) - response.release(); - else + cl.response = move(response); + responses[cl.response.get()] = &cl; + signal_request.emit(*cl.request, *cl.response); + if(!cl.async) { - responses.erase(cl.response); - cl.response = 0; + responses.erase(cl.response.get()); + response = move(cl.response); if(response->get_status()==NONE) { - response = new Response(NOT_FOUND); + response = make_unique(NOT_FOUND); response->add_content("The requested resource was not found\n"); } } } catch(const exception &e) { - responses.erase(cl.response); - cl.response = 0; - response = new Response(INTERNAL_ERROR); - response->add_content(e.what()); + responses.erase(cl.response.get()); + cl.response.reset(); + response = make_unique(INTERNAL_ERROR); + response->add_content(format("An error occurred while processing the request:\ntype: %s\nwhat: %s", + Debug::demangle(typeid(e).name()), e.what())); } } @@ -176,14 +216,22 @@ void Server::send_response(Client &cl, Response &resp) { if(cl.keepalive) resp.set_header("Connection", "keep-alive"); - cl.sock->write(resp.str()); + + try + { + cl.sock->write(resp.str()); + } + catch(const exception &) + { + cl.stale = true; + return; + } + cl.async = false; if(cl.keepalive) { - delete cl.request; - cl.request = 0; - delete cl.response; - cl.response = 0; + cl.request.reset(); + cl.response.reset(); } else { @@ -203,20 +251,9 @@ Server::Client &Server::get_client_by_response(Response &resp) } -Server::Client::Client(RefPtr s): - sock(s), - request(0), - response(0), - keepalive(false), - async(false), - stale(false) +Server::Client::Client(unique_ptr s): + sock(move(s)) { } -Server::Client::~Client() -{ - delete request; - delete response; -} - } // namespace Http } // namespace Msp diff --git a/source/http/server.h b/source/http/server.h index 3679013..8665fa3 100644 --- a/source/http/server.h +++ b/source/http/server.h @@ -1,9 +1,10 @@ #ifndef MSP_HTTP_SERVER_H_ #define MSP_HTTP_SERVER_H_ -#include #include +#include #include +#include namespace Msp { namespace Http { @@ -11,7 +12,7 @@ namespace Http { class Request; class Response; -class Server +class MSPNET_API Server { public: sigc::signal signal_request; @@ -19,32 +20,34 @@ public: private: struct Client { - RefPtr sock; + std::unique_ptr sock; std::string in_buf; - Request *request; - Response *response; - bool keepalive; - bool async; - bool stale; - - Client(RefPtr); - ~Client(); + std::unique_ptr request; + std::unique_ptr response; + bool keepalive = false; + bool async = false; + bool stale = false; + + Client(std::unique_ptr); }; Net::StreamServerSocket sock; std::list clients; std::map responses; - IO::EventDispatcher *event_disp; + IO::EventDispatcher *event_disp = nullptr; public: + Server(); Server(unsigned); ~Server(); + void listen(unsigned); unsigned get_port() const; void use_event_dispatcher(IO::EventDispatcher *); void delay_response(Response &); void submit_response(Response &); void cancel_keepalive(Response &); + void close_connections(const Time::TimeDelta &); private: void data_available(); void client_data_available(Client &); diff --git a/source/http/status.h b/source/http/status.h index 5e7d54b..5f4656e 100644 --- a/source/http/status.h +++ b/source/http/status.h @@ -2,6 +2,7 @@ #define MSP_HTTPSERVER_STATUS_H_ #include +#include namespace Msp { namespace Http { @@ -20,7 +21,7 @@ enum Status NOT_IMPLEMENTED = 501 }; -extern std::ostream &operator<<(std::ostream &, Status); +MSPNET_API std::ostream &operator<<(std::ostream &, Status); } // namespace Http } // namespace Msp diff --git a/source/http/submessage.h b/source/http/submessage.h index f60276a..7ef9e37 100644 --- a/source/http/submessage.h +++ b/source/http/submessage.h @@ -9,10 +9,10 @@ namespace Http { class SubMessage: public Message { private: - SubMessage() { } + SubMessage() = default; public: - virtual std::string str() const; + std::string str() const override; static SubMessage parse(const std::string &); }; diff --git a/source/http/utils.cpp b/source/http/utils.cpp index 41fabef..7e1037e 100644 --- a/source/http/utils.cpp +++ b/source/http/utils.cpp @@ -1,8 +1,8 @@ +#include "utils.h" #include #include #include #include -#include "utils.h" using namespace std; @@ -31,12 +31,13 @@ namespace Http { string urlencode(const string &str, EncodeLevel level) { string result; - for(string::const_iterator i=str.begin(); i!=str.end(); ++i) + result.reserve(str.size()); + for(char c: str) { - if(is_reserved(*i, level)) - result += format("%%%02X", *i); + if(is_reserved(c, level)) + result += format("%%%02X", c); else - result += *i; + result += c; } return result; } @@ -44,14 +45,15 @@ string urlencode(const string &str, EncodeLevel level) string urlencode_plus(const string &str, EncodeLevel level) { string result; - for(string::const_iterator i=str.begin(); i!=str.end(); ++i) + result.reserve(str.size()); + for(char c: str) { - if(*i==' ') + if(c==' ') result += '+'; - else if(is_reserved(*i, level)) - result += format("%%%02X", *i); + else if(is_reserved(c, level)) + result += format("%%%02X", c); else - result += *i; + result += c; } return result; } @@ -79,7 +81,7 @@ string urldecode(const string &str) Url parse_url(const string &str) { - static Regex r_url("^(([a-z]+)://)?([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*(:[0-9]+)?)?(/[^?#]*)?(\\?([^#]+))?(#(.*))?$"); + static Regex r_url("^(([a-z]+)://)?([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*(:[0-9]+)?)?(/[^?#]*)?(\\?([^#]*))?(#(.*))?$"); if(RegMatch m = r_url.match(str)) { Url url; @@ -104,29 +106,20 @@ string build_url(const Url &url) str += url.scheme+"://"; str += url.host; str += urlencode(url.path); - if(!url.query.empty()) - { - str += '?'; - str += url.query; - } - if(!url.fragment.empty()) - { - str += '#'; - str += url.fragment; - } + append(str, "?", url.query); + append(str, "#", url.fragment); return str; } Query parse_query(const std::string &str) { - vector parts = split(str, '&'); Query query; - for(vector::const_iterator i=parts.begin(); i!=parts.end(); ++i) + for(const string &p: split(str, '&')) { - string::size_type equals = i->find('='); - string &value = query[urldecode(i->substr(0, equals))]; + string::size_type equals = p.find('='); + string &value = query[urldecode(p.substr(0, equals))]; if(equals!=string::npos) - value = urldecode(i->substr(equals+1)); + value = urldecode(p.substr(equals+1)); } return query; } @@ -134,13 +127,11 @@ Query parse_query(const std::string &str) string build_query(const Query &query) { string str; - for(Query::const_iterator i=query.begin(); i!=query.end(); ++i) + for(const auto &kvp: query) { - if(i!=query.begin()) - str += '&'; - str += urlencode_plus(i->first); + append(str, "&", urlencode_plus(kvp.first)); str += '='; - str += urlencode_plus(i->second); + str += urlencode_plus(kvp.second); } return str; } diff --git a/source/http/utils.h b/source/http/utils.h index a25a748..2920d2d 100644 --- a/source/http/utils.h +++ b/source/http/utils.h @@ -3,6 +3,7 @@ #include #include +#include namespace Msp { namespace Http { @@ -25,13 +26,13 @@ struct Url typedef std::map Query; -std::string urlencode(const std::string &, EncodeLevel =SAFE); -std::string urlencode_plus(const std::string &, EncodeLevel =SAFE); -std::string urldecode(const std::string &); -Url parse_url(const std::string &); -std::string build_url(const Url &); -Query parse_query(const std::string &); -std::string build_query(const Query &); +MSPNET_API std::string urlencode(const std::string &, EncodeLevel =SAFE); +MSPNET_API std::string urlencode_plus(const std::string &, EncodeLevel =SAFE); +MSPNET_API std::string urldecode(const std::string &); +MSPNET_API Url parse_url(const std::string &); +MSPNET_API std::string build_url(const Url &); +MSPNET_API Query parse_query(const std::string &); +MSPNET_API std::string build_query(const Query &); } // namespace Http } // namespace Msp diff --git a/source/http/version.cpp b/source/http/version.cpp index c9fd1c9..d6f846c 100644 --- a/source/http/version.cpp +++ b/source/http/version.cpp @@ -1,7 +1,7 @@ +#include "version.h" #include #include #include -#include "version.h" using namespace std; diff --git a/source/http/version.h b/source/http/version.h index 19a50b9..665e619 100644 --- a/source/http/version.h +++ b/source/http/version.h @@ -2,14 +2,15 @@ #define MSP_HTTP_MISC_H_ #include +#include namespace Msp { namespace Http { typedef unsigned Version; -Version parse_version(const std::string &); -std::string version_str(Version); +MSPNET_API Version parse_version(const std::string &); +MSPNET_API std::string version_str(Version); } // namespace Http } // namespace Msp diff --git a/source/net/clientsocket.cpp b/source/net/clientsocket.cpp index 04d0700..51a8dfa 100644 --- a/source/net/clientsocket.cpp +++ b/source/net/clientsocket.cpp @@ -1,21 +1,19 @@ #include "platform_api.h" -#include #include "clientsocket.h" +#include #include "socket_private.h" +using namespace std; + namespace Msp { namespace Net { ClientSocket::ClientSocket(Family af, int type, int proto): - Socket(af, type, proto), - connecting(false), - connected(false), - peer_addr(0) + Socket(af, type, proto) { } ClientSocket::ClientSocket(const Private &p, const SockAddr &paddr): Socket(p), - connecting(false), connected(true), peer_addr(paddr.copy()) { } @@ -23,8 +21,6 @@ ClientSocket::ClientSocket(const Private &p, const SockAddr &paddr): ClientSocket::~ClientSocket() { signal_flush_required.emit(); - - delete peer_addr; } void ClientSocket::shutdown(IO::Mode m) @@ -55,12 +51,12 @@ void ClientSocket::shutdown(IO::Mode m) const SockAddr &ClientSocket::get_peer_address() const { - if(peer_addr==0) + if(!peer_addr) throw bad_socket_state("not connected"); return *peer_addr; } -unsigned ClientSocket::do_write(const char *buf, unsigned size) +size_t ClientSocket::do_write(const char *buf, size_t size) { check_access(IO::M_WRITE); if(!connected) @@ -72,24 +68,28 @@ unsigned ClientSocket::do_write(const char *buf, unsigned size) return check_sys_error(::send(priv->handle, buf, size, 0), "send"); } -unsigned ClientSocket::do_read(char *buf, unsigned size) +size_t ClientSocket::do_read(char *buf, size_t size) { check_access(IO::M_READ); if(!connected) throw bad_socket_state("not connected"); + // XXX This breaks level-triggered semantics on Windows if(size==0) return 0; - unsigned ret = check_sys_error(::recv(priv->handle, buf, size, 0), "recv"); - if(ret==0 && !eof_flag) + make_signed::type ret = ::recv(priv->handle, buf, size, 0); + if(ret==0) { - eof_flag = true; - signal_end_of_file.emit(); - set_socket_events(S_NONE); + if(!eof_flag) + { + set_socket_events(S_NONE); + set_eof(); + } + return 0; } - return ret; + return check_sys_error(ret, "recv"); } } // namespace Net diff --git a/source/net/clientsocket.h b/source/net/clientsocket.h index 0b570cd..80d927a 100644 --- a/source/net/clientsocket.h +++ b/source/net/clientsocket.h @@ -1,6 +1,7 @@ #ifndef MSP_NET_CLIENTSOCKET_H_ #define MSP_NET_CLIENTSOCKET_H_ +#include "mspnet_api.h" #include "socket.h" namespace Msp { @@ -9,16 +10,16 @@ namespace Net { /** ClientSockets are used for sending and receiving data over the network. */ -class ClientSocket: public Socket +class MSPNET_API ClientSocket: public Socket { public: /** Emitted when the socket finishes connecting. */ sigc::signal signal_connect_finished; protected: - bool connecting; - bool connected; - SockAddr *peer_addr; + bool connecting = false; + bool connected = false; + std::unique_ptr peer_addr; ClientSocket(const Private &, const SockAddr &); ClientSocket(Family, int, int); @@ -42,8 +43,8 @@ public: const SockAddr &get_peer_address() const; protected: - virtual unsigned do_write(const char *, unsigned); - virtual unsigned do_read(char *, unsigned); + std::size_t do_write(const char *, std::size_t) override; + std::size_t do_read(char *, std::size_t) override; }; } // namespace Net diff --git a/source/net/communicator.cpp b/source/net/communicator.cpp index 3d6353a..73b10c2 100644 --- a/source/net/communicator.cpp +++ b/source/net/communicator.cpp @@ -1,5 +1,5 @@ -#include #include "communicator.h" +#include #include "streamsocket.h" using namespace std; @@ -8,9 +8,17 @@ namespace { using namespace Msp::Net; -struct Handshake +// Sent when a protocol is added locally but isn't known to the other end yet +struct PrepareProtocol +{ + uint64_t hash; + uint16_t base; +}; + +// Sent to confirm that a protocol is known to both peers +struct AcceptProtocol { - Msp::UInt64 hash; + uint64_t hash; }; @@ -20,75 +28,107 @@ public: HandshakeProtocol(); }; -HandshakeProtocol::HandshakeProtocol(): - Protocol(0x7F00) +HandshakeProtocol::HandshakeProtocol() { - add()(&Handshake::hash); + add(&PrepareProtocol::hash, &PrepareProtocol::base); + add(&AcceptProtocol::hash); } +} -class HandshakeReceiver: public PacketReceiver -{ -private: - Msp::UInt64 hash; - -public: - HandshakeReceiver(); - Msp::UInt64 get_hash() const { return hash; } - virtual void receive(const Handshake &); -}; -HandshakeReceiver::HandshakeReceiver(): - hash(0) -{ } +namespace Msp { +namespace Net { -void HandshakeReceiver::receive(const Handshake &shake) +struct Communicator::Handshake: public PacketReceiver, + public PacketReceiver { - hash = shake.hash; -} + Communicator &communicator; + HandshakeProtocol protocol; -} + Handshake(Communicator &c): communicator(c) { } + void receive(const PrepareProtocol &) override; + void receive(const AcceptProtocol &) override; +}; -namespace Msp { -namespace Net { -Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r): +Communicator::Communicator(StreamSocket &s): socket(s), - protocol(p), - receiver(r), - handshake_status(0), - buf_size(65536), + handshake(new Handshake(*this)), in_buf(new char[buf_size]), in_begin(in_buf), in_end(in_buf), - out_buf(new char[buf_size]), - good(true) + out_buf(new char[buf_size]) { socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available)); + + protocols.emplace_back(0, ref(handshake->protocol), ref(*handshake)); + if(socket.is_connected()) + prepare_protocol(protocols.back()); + else + socket.signal_connect_finished.connect(sigc::mem_fun(this, &Communicator::connect_finished)); +} + +Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r): + Communicator(s) +{ + add_protocol(p, r); } Communicator::~Communicator() { delete[] in_buf; delete[] out_buf; + delete handshake; } -void Communicator::initiate_handshake() +void Communicator::add_protocol(const Protocol &proto, ReceiverBase &recv) { - if(handshake_status!=0) - throw sequence_error("handshaking already done"); + if(!good) + throw sequence_error("connection aborted"); + + unsigned max_id = proto.get_max_packet_id(); + if(!max_id) + throw invalid_argument("Communicator::add_protocol"); + + uint64_t hash = proto.get_hash(); + auto i = find_member(protocols, hash, &ActiveProtocol::hash); + if(i==protocols.end()) + { + const ActiveProtocol &last = protocols.back(); + if(!last.protocol) + throw sequence_error("previous protocol is incomplete"); + unsigned base = last.base; + base += (last.protocol->get_max_packet_id()+0xFF)&~0xFF; + + if(base+max_id>std::numeric_limits::max()) + throw invalid_state("Communicator::add_protocol"); + + protocols.emplace_back(base, proto, recv); + + if(socket.is_connected() && protocols.front().ready) + prepare_protocol(protocols.back()); + } + else if(!i->protocol) + { + i->protocol = &proto; + i->last = i->base+max_id; + i->receiver = &recv; + accept_protocol(*i); + } +} - send_handshake(); - handshake_status = 1; +bool Communicator::is_protocol_ready(const Protocol &proto) const +{ + auto i = find_member(protocols, &proto, &ActiveProtocol::protocol); + return (i!=protocols.end() && i->ready); } -void Communicator::send_data(unsigned size) +void Communicator::send_data(size_t size) { if(!good) throw sequence_error("connection aborted"); - if(handshake_status!=2) - throw sequence_error("handshake incomplete"); try { @@ -103,6 +143,14 @@ void Communicator::send_data(unsigned size) } } +void Communicator::connect_finished(const exception *exc) +{ + if(exc) + good = false; + else + prepare_protocol(protocols.front()); +} + void Communicator::data_available() { if(!good) @@ -111,31 +159,7 @@ void Communicator::data_available() try { in_end += socket.read(in_end, in_buf+buf_size-in_end); - - bool more = true; - while(more) - { - if(handshake_status==2) - more = receive_packet(protocol, receiver); - else - { - HandshakeProtocol hsproto; - HandshakeReceiver hsrecv; - if((more = receive_packet(hsproto, hsrecv))) - { - if(handshake_status==0) - send_handshake(); - - if(hsrecv.get_hash()==protocol.get_hash()) - { - handshake_status = 2; - signal_handshake_done.emit(); - } - else - throw incompatible_protocol("hash mismatch"); - } - } - } + while(receive_packet()) ; } catch(const exception &e) { @@ -146,37 +170,98 @@ void Communicator::data_available() } } -bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv) +bool Communicator::receive_packet() { - int psz = proto.get_packet_size(in_begin, in_end-in_begin); - if(psz && psz<=in_end-in_begin) + Protocol::PacketHeader header; + size_t available = in_end-in_begin; + if(handshake->protocol.get_packet_header(header, in_begin, available) && header.length<=available) { + auto i = lower_bound_member(protocols, header.type, &ActiveProtocol::last); + if(i==protocols.end() || header.typebase || header.type>i->last) + throw key_error(header.type); + char *pkt = in_begin; - in_begin += psz; - proto.dispatch(recv, pkt, psz); + in_begin += header.length; + i->protocol->dispatch(*i->receiver, pkt, header.length, i->base); return true; } else { if(in_end==in_buf+buf_size) { - unsigned used = in_end-in_begin; - memmove(in_buf, in_begin, used); + memmove(in_buf, in_begin, available); in_begin = in_buf; - in_end = in_begin+used; + in_end = in_begin+available; } return false; } } -void Communicator::send_handshake() +void Communicator::prepare_protocol(const ActiveProtocol &proto) { - Handshake shake; - shake.hash = protocol.get_hash(); + PrepareProtocol prepare; + prepare.hash = proto.hash; + prepare.base = proto.base; + /* Use send_data() directly because this function is called to prepare the + handshake protocol too and send() would fail readiness check. */ + send_data(handshake->protocol.serialize(prepare, out_buf, buf_size)); +} + +void Communicator::accept_protocol(ActiveProtocol &proto) +{ + proto.accepted = true; + + AcceptProtocol accept; + accept.hash = proto.hash; + send_data(handshake->protocol.serialize(accept, out_buf, buf_size)); +} + - HandshakeProtocol hsproto; - unsigned size = hsproto.serialize(shake, out_buf, buf_size); - socket.write(out_buf, size); +Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, const Protocol &p, ReceiverBase &r): + hash(p.get_hash()), + base(b), + last(base+p.get_max_packet_id()), + protocol(&p), + receiver(&r) +{ } + +Communicator::ActiveProtocol::ActiveProtocol(uint16_t b, uint64_t h): + hash(h), + base(b), + last(base) +{ } + + +void Communicator::Handshake::receive(const PrepareProtocol &prepare) +{ + auto i = lower_bound_member(communicator.protocols, prepare.base, &ActiveProtocol::base); + if(i!=communicator.protocols.end() && i->base==prepare.base) + communicator.accept_protocol(*i); + else + communicator.protocols.emplace(i, prepare.base, prepare.hash); +} + +void Communicator::Handshake::receive(const AcceptProtocol &accept) +{ + auto i = find_member(communicator.protocols, accept.hash, &ActiveProtocol::hash); + if(i==communicator.protocols.end()) + throw key_error(accept.hash); + + + if(i->ready) + return; + + i->ready = true; + if(!i->accepted) + communicator.accept_protocol(*i); + if(i->protocol==&protocol) + { + for(const ActiveProtocol &p: communicator.protocols) + if(!p.ready) + communicator.prepare_protocol(p); + } + else + communicator.signal_protocol_ready.emit(*i->protocol); } } // namespace Net diff --git a/source/net/communicator.h b/source/net/communicator.h index 63114fd..eb9893d 100644 --- a/source/net/communicator.h +++ b/source/net/communicator.h @@ -1,7 +1,10 @@ #ifndef MSP_NET_COMMUNICATOR_H_ #define MSP_NET_COMMUNICATOR_H_ +#include +#include #include +#include "mspnet_api.h" #include "protocol.h" namespace Msp { @@ -9,61 +12,83 @@ namespace Net { class StreamSocket; -class sequence_error: public std::logic_error +class MSPNET_API sequence_error: public invalid_state { public: - sequence_error(const std::string &w): std::logic_error(w) { } - virtual ~sequence_error() throw() { } + sequence_error(const std::string &w): invalid_state(w) { } }; -class incompatible_protocol: public std::runtime_error +class MSPNET_API incompatible_protocol: public std::runtime_error { public: incompatible_protocol(const std::string &w): std::runtime_error(w) { } - virtual ~incompatible_protocol() throw() { } }; -class Communicator +class MSPNET_API Communicator: public NonCopyable { public: - sigc::signal signal_handshake_done; + sigc::signal signal_protocol_ready; sigc::signal signal_error; private: + struct ActiveProtocol + { + std::uint64_t hash = 0; + std::uint16_t base = 0; + std::uint16_t last = 0; + bool accepted = false; + bool ready = false; + const Protocol *protocol = nullptr; + ReceiverBase *receiver = nullptr; + + ActiveProtocol(std::uint16_t, const Protocol &, ReceiverBase &); + ActiveProtocol(std::uint16_t, std::uint64_t); + }; + + struct Handshake; + StreamSocket &socket; - const Protocol &protocol; - ReceiverBase &receiver; - int handshake_status; - unsigned buf_size; - char *in_buf; - char *in_begin; - char *in_end; - char *out_buf; - bool good; + std::vector protocols; + Handshake *handshake = nullptr; + std::size_t buf_size = 65536; + char *in_buf = nullptr; + char *in_begin = nullptr; + char *in_end = nullptr; + char *out_buf = nullptr; + bool good = true; public: + Communicator(StreamSocket &); Communicator(StreamSocket &, const Protocol &, ReceiverBase &); ~Communicator(); - void initiate_handshake(); - bool is_handshake_done() const { return handshake_status==2; } + void add_protocol(const Protocol &, ReceiverBase &); + bool is_protocol_ready(const Protocol &) const; template void send(const P &); private: - void send_data(unsigned); + void send_data(std::size_t); + void connect_finished(const std::exception *); void data_available(); - bool receive_packet(const Protocol &, ReceiverBase &); - void send_handshake(); + bool receive_packet(); + + void prepare_protocol(const ActiveProtocol &); + void accept_protocol(ActiveProtocol &); }; template void Communicator::send(const P &pkt) { - send_data(protocol.serialize(pkt, out_buf, buf_size)); + auto i = find_if(protocols, [](const ActiveProtocol &p){ return p.protocol && p.protocol->has_packet

(); }); + if(i==protocols.end()) + throw key_error(typeid(P).name()); + else if(!i->ready) + throw sequence_error("protocol not ready"); + send_data(i->protocol->serialize(pkt, out_buf, buf_size, i->base)); } } // namespace Net diff --git a/source/net/constants.cpp b/source/net/constants.cpp deleted file mode 100644 index 5088e51..0000000 --- a/source/net/constants.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include -#include "platform_api.h" -#include "constants.h" - -using namespace std; - -namespace Msp { -namespace Net { - -int family_to_sys(Family f) -{ - switch(f) - { - case UNSPEC: return AF_UNSPEC; - case INET: return AF_INET; - case INET6: return AF_INET6; - case UNIX: return AF_UNIX; - default: throw invalid_argument("family_to_sys"); - } -} - -Family family_from_sys(int f) -{ - switch(f) - { - case AF_UNSPEC: return UNSPEC; - case AF_INET: return INET; - case AF_INET6: return INET6; - case AF_UNIX: return UNIX; - default: throw invalid_argument("family_from_sys"); - } -} - -} // namespace Net -} // namespace Msp diff --git a/source/net/constants.h b/source/net/constants.h index 0d94661..2e6d544 100644 --- a/source/net/constants.h +++ b/source/net/constants.h @@ -1,21 +1,8 @@ #ifndef MSP_NET_CONSTANTS_H_ #define MSP_NET_CONSTANTS_H_ -namespace Msp { -namespace Net { +#warning "This header is deprected and should not be used" -enum Family -{ - UNSPEC, - INET, - INET6, - UNIX -}; - -int family_to_sys(Family); -Family family_from_sys(int); - -} // namespace Net -} // namespace Msp +#include "sockaddr.h" #endif diff --git a/source/net/datagramsocket.cpp b/source/net/datagramsocket.cpp index fe01ef5..5493554 100644 --- a/source/net/datagramsocket.cpp +++ b/source/net/datagramsocket.cpp @@ -1,8 +1,8 @@ #include "platform_api.h" +#include "datagramsocket.h" #include #include #include -#include "datagramsocket.h" #include "sockaddr_private.h" #include "socket_private.h" @@ -20,20 +20,18 @@ bool DatagramSocket::connect(const SockAddr &addr) SockAddr::SysAddr sa = addr.to_sys(); check_sys_connect_error(::connect(priv->handle, reinterpret_cast(&sa.addr), sa.size)); - delete peer_addr; - peer_addr = addr.copy(); + peer_addr.reset(addr.copy()); - delete local_addr; SockAddr::SysAddr lsa; getsockname(priv->handle, reinterpret_cast(&lsa.addr), &lsa.size); - local_addr = SockAddr::new_from_sys(lsa); + local_addr.reset(SockAddr::new_from_sys(lsa)); connected = true; return true; } -unsigned DatagramSocket::sendto(const char *buf, unsigned size, const SockAddr &addr) +size_t DatagramSocket::sendto(const char *buf, size_t size, const SockAddr &addr) { if(size==0) return 0; @@ -42,13 +40,13 @@ unsigned DatagramSocket::sendto(const char *buf, unsigned size, const SockAddr & return check_sys_error(::sendto(priv->handle, buf, size, 0, reinterpret_cast(&sa.addr), sa.size), "sendto"); } -unsigned DatagramSocket::recvfrom(char *buf, unsigned size, SockAddr *&from_addr) +size_t DatagramSocket::recvfrom(char *buf, size_t size, SockAddr *&from_addr) { if(size==0) return 0; SockAddr::SysAddr sa; - unsigned ret = check_sys_error(::recvfrom(priv->handle, buf, size, 0, reinterpret_cast(&sa.addr), &sa.size), "recvfrom"); + size_t ret = check_sys_error(::recvfrom(priv->handle, buf, size, 0, reinterpret_cast(&sa.addr), &sa.size), "recvfrom"); from_addr = SockAddr::new_from_sys(sa); return ret; diff --git a/source/net/datagramsocket.h b/source/net/datagramsocket.h index 23ca296..a978b9f 100644 --- a/source/net/datagramsocket.h +++ b/source/net/datagramsocket.h @@ -2,20 +2,21 @@ #define MSP_NET_DATAGRAMSOCKET_H_ #include "clientsocket.h" +#include "mspnet_api.h" namespace Msp { namespace Net { -class DatagramSocket: public ClientSocket +class MSPNET_API DatagramSocket: public ClientSocket { public: DatagramSocket(Family, int = 0); - virtual bool connect(const SockAddr &); - virtual bool poll_connect(const Time::TimeDelta &) { return false; } + bool connect(const SockAddr &) override; + bool poll_connect(const Time::TimeDelta &) override { return false; } - unsigned sendto(const char *, unsigned, const SockAddr &); - unsigned recvfrom(char *, unsigned, SockAddr *&); + std::size_t sendto(const char *, std::size_t, const SockAddr &); + std::size_t recvfrom(char *, std::size_t, SockAddr *&); }; } // namespace Net diff --git a/source/net/inet.cpp b/source/net/inet.cpp index 8e75cd8..17546e8 100644 --- a/source/net/inet.cpp +++ b/source/net/inet.cpp @@ -1,6 +1,6 @@ #include "platform_api.h" -#include #include "inet.h" +#include #include "sockaddr_private.h" using namespace std; @@ -8,12 +8,6 @@ using namespace std; namespace Msp { namespace Net { -InetAddr::InetAddr(): - port(0) -{ - fill(addr, addr+4, 0); -} - InetAddr::InetAddr(const SysAddr &sa) { const sockaddr_in &sai = reinterpret_cast(sa.addr); @@ -22,6 +16,22 @@ InetAddr::InetAddr(const SysAddr &sa) port = ntohs(sai.sin_port); } +InetAddr InetAddr::wildcard(unsigned port) +{ + InetAddr addr; + addr.port = port; + return addr; +} + +InetAddr InetAddr::localhost(unsigned port) +{ + InetAddr addr; + addr.addr[0] = 127; + addr.addr[3] = 1; + addr.port = port; + return addr; +} + SockAddr::SysAddr InetAddr::to_sys() const { SysAddr sa; diff --git a/source/net/inet.h b/source/net/inet.h index 639e64c..d970559 100644 --- a/source/net/inet.h +++ b/source/net/inet.h @@ -1,6 +1,7 @@ #ifndef MSP_NET_INET_H_ #define MSP_NET_INET_H_ +#include "mspnet_api.h" #include "sockaddr.h" namespace Msp { @@ -9,23 +10,26 @@ namespace Net { /** Address class for IPv4 sockets. */ -class InetAddr: public SockAddr +class MSPNET_API InetAddr: public SockAddr { private: - unsigned char addr[4]; - unsigned port; + unsigned char addr[4] = { }; + unsigned port = 0; public: - InetAddr(); + InetAddr() = default; InetAddr(const SysAddr &); - virtual InetAddr *copy() const { return new InetAddr(*this); } + static InetAddr wildcard(unsigned); + static InetAddr localhost(unsigned); - virtual SysAddr to_sys() const; + InetAddr *copy() const override { return new InetAddr(*this); } - virtual Family get_family() const { return INET; } + SysAddr to_sys() const override; + + Family get_family() const override { return INET; } unsigned get_port() const { return port; } - virtual std::string str() const; + std::string str() const override; }; } // namespace Net diff --git a/source/net/inet6.cpp b/source/net/inet6.cpp index 77d7570..1b8fc39 100644 --- a/source/net/inet6.cpp +++ b/source/net/inet6.cpp @@ -1,6 +1,7 @@ +#include "inet6.h" #include "platform_api.h" #include -#include "inet6.h" +#include #include "sockaddr_private.h" using namespace std; @@ -8,12 +9,6 @@ using namespace std; namespace Msp { namespace Net { -Inet6Addr::Inet6Addr(): - port(0) -{ - fill(addr, addr+16, 0); -} - Inet6Addr::Inet6Addr(const SysAddr &sa) { const sockaddr_in6 &sai6 = reinterpret_cast(sa.addr); @@ -21,6 +16,21 @@ Inet6Addr::Inet6Addr(const SysAddr &sa) port = htons(sai6.sin6_port); } +Inet6Addr Inet6Addr::wildcard(unsigned port) +{ + Inet6Addr addr; + addr.port = port; + return addr; +} + +Inet6Addr Inet6Addr::localhost(unsigned port) +{ + Inet6Addr addr; + addr.addr[15] = 1; + addr.port = port; + return addr; +} + SockAddr::SysAddr Inet6Addr::to_sys() const { SysAddr sa; @@ -40,9 +50,9 @@ string Inet6Addr::str() const string result = "["; for(unsigned i=0; i<16; i+=2) { - unsigned short part = (addr[i]<<8) | addr[i+1]; if(i>0) result += ':'; + unsigned short part = (addr[i]<<8) | addr[i+1]; result += format("%x", part); } result += ']'; diff --git a/source/net/inet6.h b/source/net/inet6.h index cc33e76..f6b0cf4 100644 --- a/source/net/inet6.h +++ b/source/net/inet6.h @@ -1,28 +1,32 @@ #ifndef MSP_NET_INET6_H_ -#define NSP_NET_INET6_H_ +#define MSP_NET_INET6_H_ +#include "mspnet_api.h" #include "sockaddr.h" namespace Msp { namespace Net { -class Inet6Addr: public SockAddr +class MSPNET_API Inet6Addr: public SockAddr { private: - unsigned char addr[16]; - unsigned port; + unsigned char addr[16] = { }; + unsigned port = 0; public: - Inet6Addr(); + Inet6Addr() = default; Inet6Addr(const SysAddr &); - virtual Inet6Addr *copy() const { return new Inet6Addr(*this); } + static Inet6Addr wildcard(unsigned); + static Inet6Addr localhost(unsigned); - virtual SysAddr to_sys() const; + Inet6Addr *copy() const override { return new Inet6Addr(*this); } - virtual Family get_family() const { return INET6; } + SysAddr to_sys() const override; + + Family get_family() const override { return INET6; } unsigned get_port() const { return port; } - virtual std::string str() const; + std::string str() const override; }; } // namespace Net diff --git a/source/net/mspnet_api.h b/source/net/mspnet_api.h new file mode 100644 index 0000000..eccdd41 --- /dev/null +++ b/source/net/mspnet_api.h @@ -0,0 +1,18 @@ +#ifndef MSP_NET_API_H_ +#define MSP_NET_API_H_ + +#if defined(_WIN32) +#if defined(MSPNET_BUILD) +#define MSPNET_API __declspec(dllexport) +#elif defined(MSPNET_IMPORT) +#define MSPNET_API __declspec(dllimport) +#else +#define MSPNET_API +#endif +#elif defined(__GNUC__) +#define MSPNET_API __attribute__((visibility("default"))) +#else +#define MSPNET_API +#endif + +#endif diff --git a/source/net/protocol.cpp b/source/net/protocol.cpp index b24cef1..5e82a80 100644 --- a/source/net/protocol.cpp +++ b/source/net/protocol.cpp @@ -1,41 +1,36 @@ +#include "protocol.h" #include #include -#include #include #include #include -#include "protocol.h" using namespace std; namespace Msp { namespace Net { -Protocol::Protocol(unsigned npi): - header_def(0), - next_packet_id(npi) +Protocol::Protocol(): + header_def(0) { - PacketDefBuilder >(*this, header_def, NullSerializer()) - (&PacketHeader::type)(&PacketHeader::length); + PacketDefBuilder>(*this, header_def, Serializer()) + .fields(&PacketHeader::type, &PacketHeader::length); } -Protocol::~Protocol() +unsigned Protocol::get_next_packet_class_id() { - for(map::iterator i=packet_class_defs.begin(); i!=packet_class_defs.end(); ++i) - delete i->second; + static unsigned next_id = 1; + return next_id++; } -void Protocol::add_packet(PacketDefBase *pdef) +void Protocol::add_packet(unique_ptr pdef) { - PacketDefBase *&ptr = packet_class_defs[pdef->get_class_id()]; + unique_ptr &ptr = packet_class_defs[pdef->get_class_id()]; if(ptr) - { packet_id_defs.erase(ptr->get_id()); - delete ptr; - } - ptr = pdef; - if(unsigned id = pdef->get_id()) - packet_id_defs[id] = pdef; + ptr = move(pdef); + if(unsigned id = ptr->get_id()) + packet_id_defs[id] = ptr.get(); } const Protocol::PacketDefBase &Protocol::get_packet_by_class_id(unsigned id) const @@ -48,32 +43,54 @@ const Protocol::PacketDefBase &Protocol::get_packet_by_id(unsigned id) const return *get_item(packet_id_defs, id); } -unsigned Protocol::dispatch(ReceiverBase &rcv, const char *buf, unsigned size) const +unsigned Protocol::get_max_packet_id() const +{ + if(packet_id_defs.empty()) + return 0; + return prev(packet_id_defs.end())->first; +} + +size_t Protocol::dispatch(ReceiverBase &rcv, const char *buf, size_t size, unsigned base_id) const { PacketHeader header; - buf = header_def.deserialize(header, buf, buf+size); + const char *ptr = header_def.deserialize(header, buf, buf+size); if(header.length>size) throw bad_packet("truncated"); - const PacketDefBase &pdef = get_packet_by_id(header.type); - const char *ptr = pdef.dispatch(rcv, buf, buf+header.length); + const PacketDefBase &pdef = get_packet_by_id(header.type-base_id); + if(DynamicReceiver *drcv = dynamic_cast(&rcv)) + { + Variant pkt; + ptr = pdef.deserialize(pkt, ptr, ptr+header.length); + drcv->receive(pdef.get_id(), pkt); + } + else + ptr = pdef.dispatch(rcv, ptr, ptr+header.length); return ptr-buf; } -unsigned Protocol::get_packet_size(const char *buf, unsigned size) const +bool Protocol::get_packet_header(PacketHeader &header, const char *buf, size_t size) const { if(size<4) - return 0; - PacketHeader header; + return false; header_def.deserialize(header, buf, buf+size); - return header.length; + return true; +} + +size_t Protocol::get_packet_size(const char *buf, size_t size) const +{ + PacketHeader header; + return (get_packet_header(header, buf, size) ? header.length : 0); } -UInt64 Protocol::get_hash() const +uint64_t Protocol::get_hash() const { - string description; - for(PacketMap::const_iterator i=packet_id_defs.begin(); i!=packet_id_defs.end(); ++i) - description += format("%d:%s\n", i->first, i->second->describe()); - return hash64(description); + uint64_t result = hash<64>(packet_id_defs.size()); + for(auto &kvp: packet_id_defs) + { + hash_update<64>(result, kvp.first); + hash_update<64>(result, kvp.second->get_hash()); + } + return result; } @@ -86,7 +103,7 @@ char *Protocol::BasicSerializer::serialize(const T &value, char *buf, char *e throw buffer_error("overflow"); const char *ptr = reinterpret_cast(&value)+sizeof(T); - for(unsigned i=0; i::deserialize(T &value, const char *buf, throw buffer_error("underflow"); char *ptr = reinterpret_cast(&value)+sizeof(T); - for(unsigned i=0; i::serialize(const Int8 &, char *, char *) const; -template char *Protocol::BasicSerializer::serialize(const Int16 &, char *, char *) const; -template char *Protocol::BasicSerializer::serialize(const Int32 &, char *, char *) const; -template char *Protocol::BasicSerializer::serialize(const Int64 &, char *, char *) const; -template char *Protocol::BasicSerializer::serialize(const UInt8 &, char *, char *) const; -template char *Protocol::BasicSerializer::serialize(const UInt16 &, char *, char *) const; -template char *Protocol::BasicSerializer::serialize(const UInt32 &, char *, char *) const; -template char *Protocol::BasicSerializer::serialize(const UInt64 &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const bool &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const int8_t &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const int16_t &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const int32_t &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const int64_t &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const uint8_t &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const uint16_t &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const uint32_t &, char *, char *) const; +template char *Protocol::BasicSerializer::serialize(const uint64_t &, char *, char *) const; template char *Protocol::BasicSerializer::serialize(const float &, char *, char *) const; template char *Protocol::BasicSerializer::serialize(const double &, char *, char *) const; -template const char *Protocol::BasicSerializer::deserialize(Int8 &, const char *, const char *) const; -template const char *Protocol::BasicSerializer::deserialize(Int16 &, const char *, const char *) const; -template const char *Protocol::BasicSerializer::deserialize(Int32 &, const char *, const char *) const; -template const char *Protocol::BasicSerializer::deserialize(Int64 &, const char *, const char *) const; -template const char *Protocol::BasicSerializer::deserialize(UInt8 &, const char *, const char *) const; -template const char *Protocol::BasicSerializer::deserialize(UInt16 &, const char *, const char *) const; -template const char *Protocol::BasicSerializer::deserialize(UInt32 &, const char *, const char *) const; -template const char *Protocol::BasicSerializer::deserialize(UInt64 &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(bool &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(int8_t &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(int16_t &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(int32_t &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(int64_t &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(uint8_t &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(uint16_t &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(uint32_t &, const char *, const char *) const; +template const char *Protocol::BasicSerializer::deserialize(uint64_t &, const char *, const char *) const; template const char *Protocol::BasicSerializer::deserialize(float &, const char *, const char *) const; template const char *Protocol::BasicSerializer::deserialize(double &, const char *, const char *) const; @@ -142,7 +161,7 @@ char *Protocol::StringSerializer::serialize(const string &str, char *buf, char * const char *Protocol::StringSerializer::deserialize(string &str, const char *buf, const char *end) const { - UInt16 length; + uint16_t length; buf = length_serializer.deserialize(length, buf, end); if(end-buf(length)) throw buffer_error("underflow"); @@ -151,19 +170,12 @@ const char *Protocol::StringSerializer::deserialize(string &str, const char *buf } -unsigned Protocol::PacketDefBase::next_class_id = 1; - Protocol::PacketDefBase::PacketDefBase(unsigned i): id(i) { } -Protocol::PacketHeader::PacketHeader(): - type(0), - length(0) -{ } - -Protocol::PacketHeader::PacketHeader(UInt16 t, UInt16 l): +Protocol::PacketHeader::PacketHeader(uint16_t t, uint16_t l): type(t), length(l) { } diff --git a/source/net/protocol.h b/source/net/protocol.h index 7ff63cb..5051338 100644 --- a/source/net/protocol.h +++ b/source/net/protocol.h @@ -1,87 +1,74 @@ #ifndef MSP_NET_PROTOCOL_H_ #define MSP_NET_PROTOCOL_H_ +#include #include +#include #include #include -#include +#include +#include "mspnet_api.h" #include "receiver.h" namespace Msp { namespace Net { -class bad_packet: public std::runtime_error +class MSPNET_API bad_packet: public std::runtime_error { public: bad_packet(const std::string &w): std::runtime_error(w) { } - virtual ~bad_packet() throw() { } }; -class buffer_error: public std::runtime_error +class MSPNET_API buffer_error: public std::runtime_error { public: buffer_error(const std::string &w): std::runtime_error(w) { } - virtual ~buffer_error() throw() { } }; -class Protocol +class MSPNET_API Protocol { +public: + struct PacketHeader + { + std::uint16_t type = 0; + std::uint16_t length = 0; + + PacketHeader() = default; + PacketHeader(std::uint16_t, std::uint16_t); + }; + private: - template + template struct BasicTraits; template struct Traits; template - struct CompoundTypeDef - { - virtual ~CompoundTypeDef() { } - - virtual std::string describe() const = 0; - virtual char *serialize(const C &, char *, char *) const = 0; - virtual const char *deserialize(C &, const char *, const char *) const = 0; - }; - - template - struct CompoundDef: public CompoundTypeDef - { - S serializer; - - CompoundDef(const S &); - - virtual std::string describe() const; - virtual char *serialize(const C &, char *, char *) const; - virtual const char *deserialize(C &, const char *, const char *) const; - }; + class Serializer; template class BasicSerializer { public: - typedef T ValueType; - BasicSerializer(const Protocol &) { } - std::string describe() const { return get_type_signature(); } + std::uint64_t get_hash() const { return Traits::signature; } char *serialize(const T &, char *, char *) const; const char *deserialize(T &, const char *, const char *) const; }; class StringSerializer { - public: - typedef std::string ValueType; - private: - BasicSerializer length_serializer; + BasicSerializer length_serializer; public: StringSerializer(const Protocol &); - std::string describe() const { return get_type_signature(); } + std::uint64_t get_hash() const; char *serialize(const std::string &, char *, char *) const; const char *deserialize(std::string &, const char *, const char *) const; }; @@ -89,17 +76,14 @@ private: template class ArraySerializer { - public: - typedef A ValueType; - private: - BasicSerializer length_serializer; + BasicSerializer length_serializer; typename Traits::Serializer element_serializer; public: ArraySerializer(const Protocol &); - std::string describe() const; + std::uint64_t get_hash() const; char *serialize(const A &, char *, char *) const; const char *deserialize(A &, const char *, const char *) const; }; @@ -107,57 +91,48 @@ private: template class CompoundSerializer { - public: - typedef C ValueType; - private: - const CompoundTypeDef &def; + const Serializer &serializer; public: CompoundSerializer(const Protocol &); - std::string describe() const { return def.describe(); } + std::uint64_t get_hash() const; char *serialize(const C &, char *, char *) const; const char *deserialize(C &, const char *, const char *) const; }; - template - class Serializer: public Head + template + class FieldSerializer: public Head { public: template - struct Next - { - typedef Serializer, typename Traits::Serializer> Type; - }; + using Next = FieldSerializer, N>; private: - typedef typename S::ValueType P::*Pointer; - - Pointer ptr; - S ser; + T C::*ptr; + typename Traits::Serializer ser; public: - Serializer(const Head &, Pointer, const Protocol &); + FieldSerializer(const Head &, T C::*, const Protocol &); - std::string describe() const; - char *serialize(const P &, char *, char *) const; - const char *deserialize(P &, const char *, const char *) const; + std::uint64_t get_hash() const; + char *serialize(const C &, char *, char *) const; + const char *deserialize(C &, const char *, const char *) const; }; - template - class NullSerializer + template + class Serializer { public: template - struct Next - { - typedef Serializer::Serializer> Type; - }; - - std::string describe() const { return std::string(); } - char *serialize(const P &, char *b, char *) const { return b; } - const char *deserialize(P &, const char *b, const char *) const { return b; } + using Next = FieldSerializer, N>; + + virtual ~Serializer() = default; + + virtual std::uint64_t get_hash() const { return 0; } + virtual char *serialize(const C &, char *b, char *) const { return b; } + virtual const char *deserialize(C &, const char *b, const char *) const { return b; } }; class PacketDefBase @@ -165,14 +140,14 @@ private: protected: unsigned id; - static unsigned next_class_id; - PacketDefBase(unsigned); public: - virtual ~PacketDefBase() { } + virtual ~PacketDefBase() = default; + virtual unsigned get_class_id() const = 0; unsigned get_id() const { return id; } - virtual std::string describe() const = 0; + virtual std::uint64_t get_hash() const = 0; + virtual const char *deserialize(Variant &, const char *, const char *) const = 0; virtual const char *dispatch(ReceiverBase &, const char *, const char *) const = 0; }; @@ -180,26 +155,23 @@ private: class PacketTypeDef: public PacketDefBase { private: - CompoundTypeDef

*compound; - - static unsigned class_id; + std::unique_ptr> serializer; public: PacketTypeDef(unsigned); - ~PacketTypeDef(); - static unsigned get_static_class_id() { return class_id; } - virtual unsigned get_class_id() const { return class_id; } + unsigned get_class_id() const override { return get_packet_class_id

(); } template void set_serializer(const S &); - const CompoundTypeDef

&get_compound() const { return *compound; } + const Serializer

&get_serializer() const { return *serializer; } - virtual std::string describe() const { return compound->describe(); } + std::uint64_t get_hash() const override { return serializer->get_hash(); } char *serialize(const P &, char *, char *) const; const char *deserialize(P &, const char *, const char *) const; - virtual const char *dispatch(ReceiverBase &, const char *, const char *) const; + const char *deserialize(Variant &, const char *, const char *) const override; + const char *dispatch(ReceiverBase &, const char *, const char *) const override; }; template @@ -212,41 +184,36 @@ private: public: PacketDefBuilder(const Protocol &, PacketTypeDef

&, const S &); - - template - PacketDefBuilder::Type> operator()(T P::*); - }; - struct PacketHeader - { - UInt16 type; - UInt16 length; + template + PacketDefBuilder> fields(T P::*); - PacketHeader(); - PacketHeader(UInt16, UInt16); + template + auto fields(T1 P::*first, T2 P::*second, Rest P::*...rest) { return fields(first).fields(second, rest...); } }; - typedef std::map PacketMap; - PacketTypeDef header_def; - unsigned next_packet_id; - PacketMap packet_class_defs; - PacketMap packet_id_defs; + unsigned next_packet_id = 1; + std::map> packet_class_defs; + std::map packet_id_defs; protected: - Protocol(unsigned = 1); -public: - ~Protocol(); + Protocol(); private: - void add_packet(PacketDefBase *); + static unsigned get_next_packet_class_id(); -protected: template - PacketDefBuilder > add(unsigned); + static unsigned get_packet_class_id(); + void add_packet(std::unique_ptr); + +protected: template - PacketDefBuilder > add(); + PacketDefBuilder> add(); + + template + auto add(T P::*field, Rest P::*...rest) { return add

().fields(field, rest...); } const PacketDefBase &get_packet_by_class_id(unsigned) const; const PacketDefBase &get_packet_by_id(unsigned) const; @@ -256,123 +223,103 @@ protected: public: template - unsigned serialize(const P &, char *, unsigned) const; + bool has_packet() const { return packet_class_defs.count(get_packet_class_id

()); } - unsigned get_packet_size(const char *, unsigned) const; - unsigned dispatch(ReceiverBase &, const char *, unsigned) const; + template + unsigned get_packet_id() const { return get_item(packet_class_defs, get_packet_class_id

())->get_id(); } - UInt64 get_hash() const; + unsigned get_max_packet_id() const; -private: - template - static std::string get_type_signature(); + template + std::size_t serialize(const P &, char *, std::size_t, unsigned = 0) const; + + bool get_packet_header(PacketHeader &, const char *, std::size_t) const; + std::size_t get_packet_size(const char *, std::size_t) const; + std::size_t dispatch(ReceiverBase &, const char *, std::size_t, unsigned = 0) const; + + std::uint64_t get_hash() const; }; template -Protocol::PacketDefBuilder > Protocol::add(unsigned id) +unsigned Protocol::get_packet_class_id() { - PacketTypeDef

*pdef = new PacketTypeDef

(id); - add_packet(pdef); - return PacketDefBuilder >(*this, *pdef, NullSerializer

()); + static unsigned id = get_next_packet_class_id(); + return id; } template -Protocol::PacketDefBuilder > Protocol::add() +Protocol::PacketDefBuilder> Protocol::add() { - return add

(next_packet_id++); + std::unique_ptr> pdef = std::make_unique>(next_packet_id++); + PacketDefBuilder> next(*this, *pdef, Serializer

()); + add_packet(move(pdef)); + return next; } template const Protocol::PacketTypeDef

&Protocol::get_packet_by_class() const { - const PacketDefBase &pdef = get_packet_by_class_id(PacketTypeDef

::get_static_class_id()); + const PacketDefBase &pdef = get_packet_by_class_id(get_packet_class_id

()); return static_cast &>(pdef); } template -unsigned Protocol::serialize(const P &pkt, char *buf, unsigned size) const +std::size_t Protocol::serialize(const P &pkt, char *buf, std::size_t size, unsigned base_id) const { const PacketTypeDef

&pdef = get_packet_by_class

(); + if(!pdef.get_id()) + throw std::invalid_argument("no packet id"); char *ptr = pdef.serialize(pkt, buf+4, buf+size); size = ptr-buf; - header_def.serialize(PacketHeader(pdef.get_id(), size), buf, buf+4); + header_def.serialize(PacketHeader(base_id+pdef.get_id(), size), buf, buf+4); return size; } -template -std::string Protocol::get_type_signature() -{ - const UInt16 sig = Traits::signature; - std::string result; - result += sig&0xFF; - if(sig>=0x100) - result += '0'+(sig>>8); - return result; -} - -template +template struct Protocol::BasicTraits { - static const UInt16 signature = K | (sizeof(T)<<8); + static const std::uint16_t signature = K | (sizeof(T)<<8); typedef BasicSerializer Serializer; }; template struct Protocol::Traits { - static const UInt16 signature = 'C'; + static const std::uint16_t signature = 'C'; typedef CompoundSerializer Serializer; }; -template<> struct Protocol::Traits: BasicTraits { }; -template<> struct Protocol::Traits: BasicTraits { }; -template<> struct Protocol::Traits: BasicTraits { }; -template<> struct Protocol::Traits: BasicTraits { }; -template<> struct Protocol::Traits: BasicTraits { }; -template<> struct Protocol::Traits: BasicTraits { }; -template<> struct Protocol::Traits: BasicTraits { }; -template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; +template<> struct Protocol::Traits: BasicTraits { }; template<> struct Protocol::Traits: BasicTraits { }; template<> struct Protocol::Traits: BasicTraits { }; template<> struct Protocol::Traits { - static const UInt16 signature = 'S'; + static const std::uint16_t signature = 'S'; typedef StringSerializer Serializer; }; template -struct Protocol::Traits > +struct Protocol::Traits> { - static const UInt16 signature = 'A'; - typedef ArraySerializer > Serializer; + static const std::uint16_t signature = 'A'; + typedef ArraySerializer> Serializer; }; - -template -Protocol::CompoundDef::CompoundDef(const S &s): - serializer(s) -{ } - -template -std::string Protocol::CompoundDef::describe() const +inline std::uint64_t Protocol::StringSerializer::get_hash() const { - return "{"+serializer.describe()+"}"; -} - -template -char *Protocol::CompoundDef::serialize(const C &com, char *buf, char *end) const -{ - return serializer.serialize(com, buf, end); -} - -template -const char *Protocol::CompoundDef::deserialize(C &com, const char *buf, const char *end) const -{ - return serializer.deserialize(com, buf, end); + return Traits::signature; } @@ -383,24 +330,24 @@ Protocol::ArraySerializer::ArraySerializer(const Protocol &proto): { } template -std::string Protocol::ArraySerializer::describe() const +std::uint64_t Protocol::ArraySerializer::get_hash() const { - return "["+element_serializer.describe()+"]"; + return hash_round<64>(element_serializer.get_hash(), 'A'); } template char *Protocol::ArraySerializer::serialize(const A &array, char *buf, char *end) const { buf = length_serializer.serialize(array.size(), buf, end); - for(typename A::const_iterator i=array.begin(); i!=array.end(); ++i) - buf = element_serializer.serialize(*i, buf, end); + for(const auto &e: array) + buf = element_serializer.serialize(e, buf, end); return buf; } template const char *Protocol::ArraySerializer::deserialize(A &array, const char *buf, const char *end) const { - UInt16 length; + std::uint16_t length; buf = length_serializer.deserialize(length, buf, end); array.resize(length); for(unsigned i=0; i::deserialize(A &array, const char *buf, template Protocol::CompoundSerializer::CompoundSerializer(const Protocol &proto): - def(proto.get_packet_by_class().get_compound()) + serializer(proto.get_packet_by_class().get_serializer()) { } +template +std::uint64_t Protocol::CompoundSerializer::get_hash() const +{ + return hash_round<64>(serializer.get_hash(), 'C'); +} + template char *Protocol::CompoundSerializer::serialize(const C &com, char *buf, char *end) const { - return def.serialize(com, buf, end); + return serializer.serialize(com, buf, end); } template const char *Protocol::CompoundSerializer::deserialize(C &com, const char *buf, const char *end) const { - return def.deserialize(com, buf, end); + return serializer.deserialize(com, buf, end); } -template -Protocol::Serializer::Serializer(const Head &h, Pointer p, const Protocol &proto): +template +Protocol::FieldSerializer::FieldSerializer(const Head &h, T C::*p, const Protocol &proto): Head(h), ptr(p), ser(proto) { } -template -std::string Protocol::Serializer::describe() const +template +std::uint64_t Protocol::FieldSerializer::get_hash() const { - return Head::describe()+ser.describe(); + return hash_update<64>(Head::get_hash(), ser.get_hash()); } -template -char *Protocol::Serializer::serialize(const P &pkt, char *buf, char *end) const +template +char *Protocol::FieldSerializer::serialize(const C &com, char *buf, char *end) const { - buf = Head::serialize(pkt, buf, end); - return ser.serialize(pkt.*ptr, buf, end); + buf = Head::serialize(com, buf, end); + return ser.serialize(com.*ptr, buf, end); } -template -const char *Protocol::Serializer::deserialize(P &pkt, const char *buf, const char *end) const +template +const char *Protocol::FieldSerializer::deserialize(C &com, const char *buf, const char *end) const { - buf = Head::deserialize(pkt, buf, end); - return ser.deserialize(pkt.*ptr, buf, end); + buf = Head::deserialize(com, buf, end); + return ser.deserialize(com.*ptr, buf, end); } -template -unsigned Protocol::PacketTypeDef

::class_id = 0; - template Protocol::PacketTypeDef

::PacketTypeDef(unsigned i): PacketDefBase(i), - compound(new CompoundDef >(NullSerializer

())) -{ - if(!class_id) - class_id = next_class_id++; -} - -template -Protocol::PacketTypeDef

::~PacketTypeDef() -{ - delete compound; -} + serializer(std::make_unique>()) +{ } template template void Protocol::PacketTypeDef

::set_serializer(const S &ser) { - delete compound; - compound = new CompoundDef(ser); + serializer = std::make_unique(ser); } template char *Protocol::PacketTypeDef

::serialize(const P &pkt, char *buf, char *end) const { - return compound->serialize(pkt, buf, end); + return serializer->serialize(pkt, buf, end); } template const char *Protocol::PacketTypeDef

::deserialize(P &pkt, const char *buf, const char *end) const { - return compound->deserialize(pkt, buf, end); + return serializer->deserialize(pkt, buf, end); +} + +template +const char *Protocol::PacketTypeDef

::deserialize(Variant &var_pkt, const char *buf, const char *end) const +{ + P pkt; + const char *ptr = serializer->deserialize(pkt, buf, end); + var_pkt = std::move(pkt); + return ptr; } template @@ -515,11 +464,11 @@ Protocol::PacketDefBuilder::PacketDefBuilder(const Protocol &p, PacketType template template -Protocol::PacketDefBuilder::Type> Protocol::PacketDefBuilder::operator()(T P::*ptr) +Protocol::PacketDefBuilder> Protocol::PacketDefBuilder::fields(T P::*ptr) { - typename S::template Next::Type next_ser(serializer, ptr, protocol); + typename S::template Next next_ser(serializer, ptr, protocol); pktdef.set_serializer(next_ser); - return PacketDefBuilder::Type>(protocol, pktdef, next_ser); + return PacketDefBuilder>(protocol, pktdef, next_ser); } } // namespace Net diff --git a/source/net/protocol_impl.h b/source/net/protocol_impl.h new file mode 100644 index 0000000..4eb2861 --- /dev/null +++ b/source/net/protocol_impl.h @@ -0,0 +1,6 @@ +#ifndef MSP_NET_PROTOCOL_IMPL_H_ +#define MSP_NET_PROTOCOL_IMPL_H_ + +#warning "This header is deprected and should not be used" + +#endif diff --git a/source/net/receiver.cpp b/source/net/receiver.cpp new file mode 100644 index 0000000..46c12f3 --- /dev/null +++ b/source/net/receiver.cpp @@ -0,0 +1,16 @@ +#include "receiver.h" + +namespace Msp { +namespace Net { + +void DynamicDispatcher::receive(unsigned packet_id, const Variant &packet) +{ + auto i = lower_bound_member(targets, packet_id, &Target::packet_id); + if(i==targets.end() || i->packet_id!=packet_id) + throw key_error(packet_id); + + i->func(*i->receiver, packet); +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/receiver.h b/source/net/receiver.h index 19e69e0..a0a9f29 100644 --- a/source/net/receiver.h +++ b/source/net/receiver.h @@ -1,26 +1,81 @@ #ifndef MSP_NET_RECEIVER_H_ #define MSP_NET_RECEIVER_H_ +#include +#include +#include +#include +#include "mspnet_api.h" + namespace Msp { namespace Net { -class ReceiverBase +class MSPNET_API ReceiverBase { protected: - ReceiverBase() { } + ReceiverBase() = default; public: - virtual ~ReceiverBase() { } + virtual ~ReceiverBase() = default; }; + template class PacketReceiver: public virtual ReceiverBase { protected: - PacketReceiver() { } + PacketReceiver() = default; public: virtual void receive(const P &) = 0; }; + +class MSPNET_API DynamicReceiver: public ReceiverBase +{ +protected: + DynamicReceiver() = default; +public: + virtual void receive(unsigned, const Variant &) = 0; +}; + + +class MSPNET_API DynamicDispatcher: public DynamicReceiver +{ +private: + using DispatchFunc = void(ReceiverBase &, const Variant &); + + struct Target + { + unsigned packet_id; + ReceiverBase *receiver; + DispatchFunc *func; + + Target(unsigned i, ReceiverBase &r, DispatchFunc *f): packet_id(i), receiver(&r), func(f) { } + }; + + std::vector targets; + +public: + template + void add_receiver(unsigned, PacketReceiver

&); + + void receive(unsigned, const Variant &) override; +}; + + +template +void DynamicDispatcher::add_receiver(unsigned packet_id, PacketReceiver

&r) +{ + auto i = lower_bound_member(targets, packet_id, &Target::packet_id); + if(i!=targets.end() && i->packet_id==packet_id) + throw key_error(packet_id); + + auto dispatch = [](ReceiverBase &receiver, const Variant &packet){ + dynamic_cast &>(receiver).receive(packet.value

()); + }; + + targets.emplace(i, packet_id, r, +dispatch); +} + } // namespace Net } // namespace Msp diff --git a/source/net/resolve.cpp b/source/net/resolve.cpp index 10e8c65..9ec5395 100644 --- a/source/net/resolve.cpp +++ b/source/net/resolve.cpp @@ -1,9 +1,9 @@ -#include "platform_api.h" +#include "resolve.h" #include #include +#include "platform_api.h" #include "sockaddr_private.h" #include "socket.h" -#include "resolve.h" using namespace std; @@ -40,16 +40,16 @@ namespace Net { SockAddr *resolve(const string &host, const string &serv, Family family) { - const char *chost = (host.empty() ? 0 : host.c_str()); - const char *cserv = (serv.empty() ? 0 : serv.c_str()); + const char *chost = (host.empty() ? nullptr : host.c_str()); + const char *cserv = (serv.empty() ? nullptr : serv.c_str()); int flags = 0; if(host=="*") { flags = AI_PASSIVE; - chost = 0; + chost = nullptr; } - addrinfo hints = { flags, family_to_sys(family), 0, 0, 0, 0, 0, 0 }; + addrinfo hints = { flags, family_to_sys(family), 0, 0, 0, nullptr, nullptr, nullptr }; addrinfo *res; int err = getaddrinfo(chost, cserv, &hints, &res); @@ -81,9 +81,7 @@ SockAddr *resolve(const string &str, Family family) } -Resolver::Resolver(): - event_disp(0), - next_tag(1) +Resolver::Resolver() { thread.get_notify_pipe().signal_data_available.connect(sigc::mem_fun(this, &Resolver::task_done)); } @@ -104,7 +102,7 @@ unsigned Resolver::resolve(const string &host, const string &serv, Family family task.host = host; task.serv = serv; task.family = family; - thread.add_task(task); + thread.add_task(move(task)); return task.tag; } @@ -132,24 +130,23 @@ void Resolver::task_done() if(task->addr) signal_address_resolved.emit(task->tag, *task->addr); else if(task->error) + { + if(signal_resolve_failed.empty()) + { + unique_ptr err = move(task->error); + thread.pop_complete_task(); + throw *err; + } signal_resolve_failed.emit(task->tag, *task->error); + } thread.pop_complete_task(); } } -Resolver::Task::Task(): - tag(0), - family(UNSPEC), - addr(0), - error(0) -{ } - - Resolver::WorkerThread::WorkerThread(): Thread("Resolver"), - sem(1), - done(false) + sem(1) { launch(); } @@ -161,11 +158,11 @@ Resolver::WorkerThread::~WorkerThread() join(); } -void Resolver::WorkerThread::add_task(const Task &t) +void Resolver::WorkerThread::add_task(Task &&t) { MutexLock lock(queue_mutex); bool was_starved = (queue.empty() || queue.back().is_complete()); - queue.push_back(t); + queue.push_back(move(t)); if(was_starved) sem.signal(); } @@ -176,18 +173,14 @@ Resolver::Task *Resolver::WorkerThread::get_complete_task() if(!queue.empty() && queue.front().is_complete()) return &queue.front(); else - return 0; + return nullptr; } void Resolver::WorkerThread::pop_complete_task() { MutexLock lock(queue_mutex); if(!queue.empty() && queue.front().is_complete()) - { - delete queue.front().addr; - delete queue.front().error; queue.pop_front(); - } } void Resolver::WorkerThread::main() @@ -199,10 +192,10 @@ void Resolver::WorkerThread::main() sem.wait(); wait = false; - Task *task = 0; + Task *task = nullptr; { MutexLock lock(queue_mutex); - for(list::iterator i=queue.begin(); (!task && i!=queue.end()); ++i) + for(auto i=queue.begin(); (!task && i!=queue.end()); ++i) if(!i->is_complete()) task = &*i; } @@ -211,16 +204,16 @@ void Resolver::WorkerThread::main() { try { - SockAddr *addr = Net::resolve(task->host, task->serv, task->family); + unique_ptr addr(Net::resolve(task->host, task->serv, task->family)); { MutexLock lock(queue_mutex); - task->addr = addr; + task->addr = move(addr); } } catch(const runtime_error &e) { MutexLock lock(queue_mutex); - task->error = new runtime_error(e); + task->error = make_unique(e); } notify_pipe.put(1); } diff --git a/source/net/resolve.h b/source/net/resolve.h index 21737ff..3071b77 100644 --- a/source/net/resolve.h +++ b/source/net/resolve.h @@ -1,48 +1,47 @@ #ifndef MSP_NET_RESOLVE_H_ #define MSP_NET_RESOLVE_H_ +#include +#include #include #include #include #include #include #include -#include "constants.h" +#include "mspnet_api.h" +#include "sockaddr.h" namespace Msp { namespace Net { -class SockAddr; - /** Resolves host and service names into a socket address. If host is empty, the loopback address will be used. If host is "*", the wildcard address will be used. If service is empty, a socket address with a null service will be returned. With the IP families, these are not very useful. */ -SockAddr *resolve(const std::string &, const std::string &, Family = UNSPEC); +MSPNET_API SockAddr *resolve(const std::string &, const std::string &, Family = UNSPEC); /** And overload of resolve() that takes host and service as a single string, separated by a colon. If the host part contains colons, such as is the case with a numeric IPv6 address, it must be enclosed in brackets. */ -SockAddr *resolve(const std::string &, Family = UNSPEC); +MSPNET_API SockAddr *resolve(const std::string &, Family = UNSPEC); /** An asynchronous name resolver. Blocking calls are performed in a thread and completion is notified with one of the two signals. */ -class Resolver +class MSPNET_API Resolver { private: struct Task { - unsigned tag; + unsigned tag = 0; std::string host; std::string serv; - Family family; - SockAddr *addr; - std::runtime_error *error; - - Task(); + Family family = UNSPEC; + std::unique_ptr addr; + std::unique_ptr error; bool is_complete() const { return addr || error; } }; @@ -50,24 +49,24 @@ private: class WorkerThread: public Thread { private: - std::list queue; + std::deque queue; Mutex queue_mutex; Semaphore sem; IO::Pipe notify_pipe; - bool done; + bool done = false; public: WorkerThread(); ~WorkerThread(); - void add_task(const Task &); + void add_task(Task &&); Task *get_complete_task(); void pop_complete_task(); IO::Pipe &get_notify_pipe() { return notify_pipe; } private: - virtual void main(); + void main() override; }; public: @@ -75,9 +74,9 @@ public: sigc::signal signal_resolve_failed; private: - IO::EventDispatcher *event_disp; + IO::EventDispatcher *event_disp = nullptr; WorkerThread thread; - unsigned next_tag; + unsigned next_tag = 1; public: Resolver(); diff --git a/source/net/serversocket.cpp b/source/net/serversocket.cpp index 6028483..66876c5 100644 --- a/source/net/serversocket.cpp +++ b/source/net/serversocket.cpp @@ -1,4 +1,5 @@ #include "serversocket.h" +#include using namespace std; @@ -9,14 +10,14 @@ ServerSocket::ServerSocket(Family af, int type, int proto): Socket(af, type, proto) { } -unsigned ServerSocket::do_write(const char *, unsigned) +size_t ServerSocket::do_write(const char *, size_t) { - throw logic_error("can't write to ServerSocket"); + throw unsupported("ServerSocket::write"); } -unsigned ServerSocket::do_read(char *, unsigned) +size_t ServerSocket::do_read(char *, size_t) { - throw logic_error("can't read from ServerSocket"); + throw unsupported("ServerSocket::read"); } } // namespace Net diff --git a/source/net/serversocket.h b/source/net/serversocket.h index 10375f0..8c082b3 100644 --- a/source/net/serversocket.h +++ b/source/net/serversocket.h @@ -1,6 +1,7 @@ #ifndef MSP_NET_SERVERSOCKET_H_ #define MSP_NET_SERVERSOCKET_H_ +#include "mspnet_api.h" #include "socket.h" namespace Msp { @@ -12,7 +13,7 @@ class ClientSocket; ServerSockets are used to receive incoming connections. They cannot be used for sending and receiving data. */ -class ServerSocket: public Socket +class MSPNET_API ServerSocket: public Socket { protected: ServerSocket(Family, int, int); @@ -22,8 +23,8 @@ public: virtual ClientSocket *accept() = 0; protected: - virtual unsigned do_write(const char *, unsigned); - virtual unsigned do_read(char *, unsigned); + std::size_t do_write(const char *, std::size_t) override; + std::size_t do_read(char *, std::size_t) override; }; } // namespace Net diff --git a/source/net/sockaddr.cpp b/source/net/sockaddr.cpp index db436fe..a3d3eed 100644 --- a/source/net/sockaddr.cpp +++ b/source/net/sockaddr.cpp @@ -1,3 +1,4 @@ +#include "sockaddr.h" #include #include "platform_api.h" #include "inet.h" @@ -25,11 +26,35 @@ SockAddr *SockAddr::new_from_sys(const SysAddr &sa) } } -SockAddr::SysAddr::SysAddr(): - size(sizeof(sockaddr_storage)) +SockAddr::SysAddr::SysAddr() { addr.ss_family = AF_UNSPEC; } + +int family_to_sys(Family f) +{ + switch(f) + { + case UNSPEC: return AF_UNSPEC; + case INET: return AF_INET; + case INET6: return AF_INET6; + case UNIX: return AF_UNIX; + default: throw invalid_argument("family_to_sys"); + } +} + +Family family_from_sys(int f) +{ + switch(f) + { + case AF_UNSPEC: return UNSPEC; + case AF_INET: return INET; + case AF_INET6: return INET6; + case AF_UNIX: return UNIX; + default: throw invalid_argument("family_from_sys"); + } +} + } // namespace Net } // namespace Msp diff --git a/source/net/sockaddr.h b/source/net/sockaddr.h index aad5e29..931d4f4 100644 --- a/source/net/sockaddr.h +++ b/source/net/sockaddr.h @@ -2,20 +2,33 @@ #define MSP_NET_SOCKADDR_H_ #include -#include "constants.h" +#include "mspnet_api.h" namespace Msp { namespace Net { -class SockAddr +enum Family +{ + UNSPEC, + INET, + INET6, + UNIX +}; + + +class MSPNET_API SockAddr { public: struct SysAddr; protected: - SockAddr() { } + SockAddr() = default; + SockAddr(const SockAddr &) = default; + SockAddr(SockAddr &&) = default; + SockAddr &operator=(const SockAddr &) = default; + SockAddr &operator=(SockAddr &&) = default; public: - virtual ~SockAddr() { } + virtual ~SockAddr() = default; virtual SockAddr *copy() const = 0; diff --git a/source/net/sockaddr_private.h b/source/net/sockaddr_private.h index 909ef92..2a3d14b 100644 --- a/source/net/sockaddr_private.h +++ b/source/net/sockaddr_private.h @@ -14,11 +14,15 @@ namespace Net { struct SockAddr::SysAddr { struct sockaddr_storage addr; - socklen_t size; + socklen_t size = sizeof(sockaddr_storage); SysAddr(); }; + +int family_to_sys(Family); +Family family_from_sys(int); + } // namespace Net } // namespace Msp diff --git a/source/net/socket.cpp b/source/net/socket.cpp index e8e5eaf..2a1a71c 100644 --- a/source/net/socket.cpp +++ b/source/net/socket.cpp @@ -1,16 +1,17 @@ #include "platform_api.h" +#include "socket.h" #include #include #include "sockaddr_private.h" -#include "socket.h" #include "socket_private.h" +using namespace std; + namespace Msp { namespace Net { Socket::Socket(const Private &p): - priv(new Private), - local_addr(0) + priv(make_unique()) { mode = IO::M_RDWR; @@ -18,18 +19,23 @@ Socket::Socket(const Private &p): SockAddr::SysAddr sa; getsockname(priv->handle, reinterpret_cast(&sa.addr), &sa.size); - local_addr = SockAddr::new_from_sys(sa); + local_addr.reset(SockAddr::new_from_sys(sa)); platform_init(); } Socket::Socket(Family af, int type, int proto): - priv(new Private), - local_addr(0) + priv(make_unique()) { mode = IO::M_RDWR; +#ifdef __linux__ + type |= SOCK_CLOEXEC; +#endif priv->handle = socket(family_to_sys(af), type, proto); +#ifndef __linux__ + set_inherit(false); +#endif platform_init(); } @@ -37,20 +43,26 @@ Socket::Socket(Family af, int type, int proto): Socket::~Socket() { platform_cleanup(); - - delete local_addr; - delete priv; } void Socket::set_block(bool b) { - mode = (mode&~IO::M_NONBLOCK); - if(b) - mode = (mode|IO::M_NONBLOCK); - + IO::adjust_mode(mode, IO::M_NONBLOCK, !b); priv->set_block(b); } +void Socket::set_inherit(bool i) +{ + IO::adjust_mode(mode, IO::M_INHERIT, i); + priv->set_inherit(i); +} + +const IO::Handle &Socket::get_handle(IO::Mode) +{ + // TODO could this be implemented somehow? + throw unsupported("Socket::get_handle"); +} + const IO::Handle &Socket::get_event_handle() { return priv->event; @@ -64,13 +76,12 @@ void Socket::bind(const SockAddr &addr) if(err==-1) throw system_error("bind"); - delete local_addr; - local_addr = addr.copy(); + local_addr.reset(addr.copy()); } const SockAddr &Socket::get_local_address() const { - if(local_addr==0) + if(!local_addr) throw bad_socket_state("not bound"); return *local_addr; } diff --git a/source/net/socket.h b/source/net/socket.h index 1f60f9b..b89c175 100644 --- a/source/net/socket.h +++ b/source/net/socket.h @@ -1,23 +1,24 @@ #ifndef MSP_NET_SOCKET_H_ #define MSP_NET_SOCKET_H_ +#include +#include #include #include -#include "constants.h" +#include "mspnet_api.h" #include "sockaddr.h" namespace Msp { namespace Net { -class bad_socket_state: public std::logic_error +class MSPNET_API bad_socket_state: public invalid_state { public: - bad_socket_state(const std::string &w): std::logic_error(w) { } - virtual ~bad_socket_state() throw() { } + bad_socket_state(const std::string &w): invalid_state(w) { } }; -class Socket: public IO::EventObject +class MSPNET_API Socket: public IO::EventObject { protected: enum SocketEvent @@ -30,8 +31,8 @@ protected: struct Private; - Private *priv; - SockAddr *local_addr; + std::unique_ptr priv; + std::unique_ptr local_addr; Socket(const Private &); Socket(Family, int, int); @@ -41,13 +42,16 @@ private: public: ~Socket(); - virtual void set_block(bool); - virtual const IO::Handle &get_event_handle(); + void set_block(bool) override; + void set_inherit(bool) override; + const IO::Handle &get_handle(IO::Mode) override; + const IO::Handle &get_event_handle() override; /** Associates the socket with a local address. There must be no existing users of the address. */ void bind(const SockAddr &); + bool is_bound() const { return static_cast(local_addr); } const SockAddr &get_local_address() const; void set_timeout(const Time::TimeDelta &); diff --git a/source/net/socket_private.h b/source/net/socket_private.h index 742877b..bb9eff6 100644 --- a/source/net/socket_private.h +++ b/source/net/socket_private.h @@ -22,11 +22,12 @@ struct Socket::Private IO::Handle event; void set_block(bool); + void set_inherit(bool); int set_option(int, int, const void *, socklen_t); int get_option(int, int, void *, socklen_t *); }; -unsigned check_sys_error(int, const char *); +std::size_t check_sys_error(std::make_signed::type, const char *); bool check_sys_connect_error(int); } // namespace Net diff --git a/source/net/streamserversocket.cpp b/source/net/streamserversocket.cpp index 817145b..8e655fc 100644 --- a/source/net/streamserversocket.cpp +++ b/source/net/streamserversocket.cpp @@ -1,11 +1,11 @@ #include "platform_api.h" +#include "streamserversocket.h" #include #include #include #include #include "sockaddr_private.h" #include "socket_private.h" -#include "streamserversocket.h" #include "streamsocket.h" using namespace std; @@ -14,12 +14,14 @@ namespace Msp { namespace Net { StreamServerSocket::StreamServerSocket(Family af, int proto): - ServerSocket(af, SOCK_STREAM, proto), - listening(false) + ServerSocket(af, SOCK_STREAM, proto) { } void StreamServerSocket::listen(const SockAddr &addr, unsigned backlog) { + if(listening) + throw bad_socket_state("already listening"); + bind(addr); int err = ::listen(priv->handle, backlog); diff --git a/source/net/streamserversocket.h b/source/net/streamserversocket.h index aa4868c..bbe1c91 100644 --- a/source/net/streamserversocket.h +++ b/source/net/streamserversocket.h @@ -1,22 +1,24 @@ #ifndef MSP_NET_STREAMSERVERSOCKET_H_ #define MSP_NET_STREAMSERVERSOCKET_H_ +#include "mspnet_api.h" #include "serversocket.h" #include "streamsocket.h" namespace Msp { namespace Net { -class StreamServerSocket: public ServerSocket +class MSPNET_API StreamServerSocket: public ServerSocket { private: - bool listening; + bool listening = false; public: StreamServerSocket(Family, int = 0); - virtual void listen(const SockAddr &, unsigned = 4); - virtual StreamSocket *accept(); + void listen(const SockAddr &, unsigned = 4) override; + bool is_listening() const { return listening; } + StreamSocket *accept() override; }; } // namespace Net diff --git a/source/net/streamsocket.cpp b/source/net/streamsocket.cpp index 11887d3..9e2d79a 100644 --- a/source/net/streamsocket.cpp +++ b/source/net/streamsocket.cpp @@ -1,11 +1,11 @@ #include "platform_api.h" +#include "streamsocket.h" #include #include #include #include #include "sockaddr_private.h" #include "socket_private.h" -#include "streamsocket.h" namespace Msp { namespace Net { @@ -34,13 +34,11 @@ bool StreamSocket::connect(const SockAddr &addr) set_socket_events(S_CONNECT); } - delete peer_addr; - peer_addr = addr.copy(); + peer_addr.reset(addr.copy()); - delete local_addr; SockAddr::SysAddr lsa; getsockname(priv->handle, reinterpret_cast(&lsa.addr), &lsa.size); - local_addr = SockAddr::new_from_sys(lsa); + local_addr.reset(SockAddr::new_from_sys(lsa)); if(finished) { @@ -99,10 +97,7 @@ void StreamSocket::on_event(IO::PollEvent ev) signal_connect_finished.emit(0); if(err!=0) - { - delete peer_addr; - peer_addr = 0; - } + peer_addr.reset(); set_socket_events((err==0) ? S_INPUT : S_NONE); } diff --git a/source/net/streamsocket.h b/source/net/streamsocket.h index 8b39e91..84b347d 100644 --- a/source/net/streamsocket.h +++ b/source/net/streamsocket.h @@ -2,11 +2,12 @@ #define MSP_NET_STREAMSOCKET_H_ #include "clientsocket.h" +#include "mspnet_api.h" namespace Msp { namespace Net { -class StreamSocket: public ClientSocket +class MSPNET_API StreamSocket: public ClientSocket { friend class StreamServerSocket; @@ -23,12 +24,12 @@ public: If the socket is non-blocking, this function may return before the connection is fully established. The caller must then use either the poll_connect function or an EventDispatcher to finish the process. */ - virtual bool connect(const SockAddr &); + bool connect(const SockAddr &) override; - virtual bool poll_connect(const Time::TimeDelta &); + bool poll_connect(const Time::TimeDelta &) override; private: - void on_event(IO::PollEvent); + void on_event(IO::PollEvent) override; }; } // namespace Net diff --git a/source/net/unix.cpp b/source/net/unix.cpp index fe00020..1da822b 100644 --- a/source/net/unix.cpp +++ b/source/net/unix.cpp @@ -5,11 +5,6 @@ using namespace std; namespace Msp { namespace Net { -UnixAddr::UnixAddr(): - abstract(false) -{ -} - string UnixAddr::str() const { string result = "unix:"; diff --git a/source/net/unix.h b/source/net/unix.h index 821e915..86f1ddd 100644 --- a/source/net/unix.h +++ b/source/net/unix.h @@ -1,28 +1,29 @@ #ifndef MSP_NET_UNIX_H_ #define MSP_NET_UNIX_H_ +#include "mspnet_api.h" #include "sockaddr.h" namespace Msp { namespace Net { -class UnixAddr: public SockAddr +class MSPNET_API UnixAddr: public SockAddr { private: std::string path; - bool abstract; + bool abstract = false; public: - UnixAddr(); + UnixAddr() = default; UnixAddr(const SysAddr &); UnixAddr(const std::string &, bool = false); - virtual UnixAddr *copy() const { return new UnixAddr(*this); } + UnixAddr *copy() const override { return new UnixAddr(*this); } - virtual SysAddr to_sys() const; + SysAddr to_sys() const override; - virtual Family get_family() const { return UNIX; } - virtual std::string str() const; + Family get_family() const override { return UNIX; } + std::string str() const override; }; } // namespace Net diff --git a/source/net/unix/socket.cpp b/source/net/unix/socket.cpp index 02ef7c3..2c1e2ad 100644 --- a/source/net/unix/socket.cpp +++ b/source/net/unix/socket.cpp @@ -1,20 +1,23 @@ +#include "platform_api.h" +#include "socket.h" #include #include #include -#include "platform_api.h" #include #include #include #include "sockaddr_private.h" -#include "socket.h" #include "socket_private.h" +using namespace std; + namespace Msp { namespace Net { void Socket::platform_init() { *priv->event = priv->handle; + set_inherit(false); } void Socket::platform_cleanup() @@ -37,7 +40,13 @@ void Socket::set_platform_events(unsigned) void Socket::Private::set_block(bool b) { int flags = fcntl(handle, F_GETFL); - fcntl(handle, F_SETFL, (flags&O_NONBLOCK)|(b?0:O_NONBLOCK)); + fcntl(handle, F_SETFL, (flags&~O_NONBLOCK)|(b?0:O_NONBLOCK)); +} + +void Socket::Private::set_inherit(bool i) +{ + int flags = fcntl(handle, F_GETFD); + fcntl(handle, F_SETFD, (flags&~O_CLOEXEC)|(i?0:O_CLOEXEC)); } int Socket::Private::set_option(int level, int optname, const void *optval, socklen_t optlen) @@ -51,11 +60,11 @@ int Socket::Private::get_option(int level, int optname, void *optval, socklen_t } -unsigned check_sys_error(int ret, const char *func) +size_t check_sys_error(make_signed::type ret, const char *func) { if(ret<0) { - if(errno==EAGAIN) + if(errno==EAGAIN || errno==EWOULDBLOCK) return 0; else throw system_error(func); diff --git a/source/net/unix/unix.cpp b/source/net/unix/unix.cpp index 6bb3436..648d53a 100644 --- a/source/net/unix/unix.cpp +++ b/source/net/unix/unix.cpp @@ -1,16 +1,15 @@ +#include "platform_api.h" +#include "unix.h" #include #include -#include "platform_api.h" #include "sockaddr_private.h" -#include "unix.h" using namespace std; namespace Msp { namespace Net { -UnixAddr::UnixAddr(const SysAddr &sa): - abstract(false) +UnixAddr::UnixAddr(const SysAddr &sa) { const sockaddr_un &sau = reinterpret_cast(sa.addr); if(static_cast(sa.size)>sizeof(sa_family_t)) @@ -25,7 +24,7 @@ UnixAddr::UnixAddr(const string &p, bool a): abstract(a) { if(sizeof(sa_family_t)+path.size()+1>sizeof(sockaddr_storage)) - throw invalid_argument("UnixAddr"); + throw invalid_argument("UnixAddr::UnixAddr"); } SockAddr::SysAddr UnixAddr::to_sys() const diff --git a/source/net/windows/socket.cpp b/source/net/windows/socket.cpp index f4fab44..eadc22a 100644 --- a/source/net/windows/socket.cpp +++ b/source/net/windows/socket.cpp @@ -1,11 +1,13 @@ -#include #include "platform_api.h" +#include "socket.h" +#include #include #include #include "sockaddr_private.h" -#include "socket.h" #include "socket_private.h" +using namespace std; + namespace { class WinSockHelper @@ -25,7 +27,7 @@ public: } }; -WinSockHelper wsh; +unique_ptr wsh; } @@ -35,6 +37,8 @@ namespace Net { void Socket::platform_init() { + if(!wsh) + wsh = make_unique(); *priv->event = CreateEvent(0, false, false, 0); } @@ -70,6 +74,10 @@ void Socket::Private::set_block(bool b) ioctlsocket(handle, FIONBIO, &flag); } +void Socket::Private::set_inherit(bool) +{ +} + int Socket::Private::set_option(int level, int optname, const void *optval, socklen_t optlen) { return setsockopt(handle, level, optname, reinterpret_cast(optval), optlen); @@ -81,7 +89,7 @@ int Socket::Private::get_option(int level, int optname, void *optval, socklen_t } -unsigned check_sys_error(int ret, const char *func) +size_t check_sys_error(make_signed::type ret, const char *func) { if(ret<0) { diff --git a/source/net/windows/unix.cpp b/source/net/windows/unix.cpp index 2aa8857..b38a258 100644 --- a/source/net/windows/unix.cpp +++ b/source/net/windows/unix.cpp @@ -1,17 +1,16 @@ -#include #include "platform_api.h" -#include "sockaddr_private.h" #include "unix.h" +#include +#include "sockaddr_private.h" using namespace std; namespace Msp { namespace Net { -UnixAddr::UnixAddr(const SysAddr &): - abstract(false) +UnixAddr::UnixAddr(const SysAddr &) { - throw logic_error("AF_UNIX not supported"); + throw unsupported("AF_UNIX"); } UnixAddr::UnixAddr(const string &p, bool a): @@ -22,7 +21,7 @@ UnixAddr::UnixAddr(const string &p, bool a): SockAddr::SysAddr UnixAddr::to_sys() const { - throw logic_error("AF_UNIX not supported"); + throw unsupported("AF_UNIX"); } } // namespace Net diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..ee4c926 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +/test diff --git a/tests/Build b/tests/Build new file mode 100644 index 0000000..971a59d --- /dev/null +++ b/tests/Build @@ -0,0 +1,11 @@ +package "mspnet-tests" +{ + require "mspcore"; + require "mspnet"; + require "msptest"; + + program "test" + { + source "."; + }; +}; diff --git a/tests/http_header.cpp b/tests/http_header.cpp new file mode 100644 index 0000000..e123714 --- /dev/null +++ b/tests/http_header.cpp @@ -0,0 +1,99 @@ +#include +#include + +using namespace std; +using namespace Msp; + +class HttpHeaderTests: public Test::RegisteredTest +{ +public: + HttpHeaderTests(); + + static const char *get_name() { return "Http::Header"; } + +private: + void single_value(); + void list_of_values(); + void key_value_list(); + void value_with_attributes(); + void list_with_attributes(); + void quoted_values(); +}; + +HttpHeaderTests::HttpHeaderTests() +{ + add(&HttpHeaderTests::single_value, "Single value"); + add(&HttpHeaderTests::list_of_values, "List"); + add(&HttpHeaderTests::key_value_list, "Key-value list"); + add(&HttpHeaderTests::value_with_attributes, "Value with attributes"); + add(&HttpHeaderTests::list_with_attributes, "List with attributes"); + add(&HttpHeaderTests::quoted_values, "Quoted values"); +} + +void HttpHeaderTests::single_value() +{ + Http::Header header("test", "1;2 , 3 ", Http::Header::SINGLE_VALUE); + EXPECT_EQUAL(header.values.size(), 1); + EXPECT_EQUAL(header.values.front().value, "1;2 , 3"); + EXPECT(header.values.front().parameters.empty()); +} + +void HttpHeaderTests::list_of_values() +{ + Http::Header header("test", "1;2 , 3 ", Http::Header::LIST); + EXPECT_EQUAL(header.values.size(), 3); + EXPECT_EQUAL(header.values[0].value, "1"); + EXPECT_EQUAL(header.values[1].value, "2"); + EXPECT_EQUAL(header.values[2].value, "3"); +} + +void HttpHeaderTests::key_value_list() +{ + Http::Header header("test", "a=1; b = 2 ,c=3 ", Http::Header::KEY_VALUE_LIST); + EXPECT_EQUAL(header.values.size(), 1); + EXPECT_EQUAL(header.values.front().parameters.count("a"), 1); + EXPECT_EQUAL(header.values.front().parameters["a"], "1"); + EXPECT_EQUAL(header.values.front().parameters.count("b"), 1); + EXPECT_EQUAL(header.values.front().parameters["b"], "2"); + EXPECT_EQUAL(header.values.front().parameters.count("c"), 1); + EXPECT_EQUAL(header.values.front().parameters["c"], "3"); +} + +void HttpHeaderTests::value_with_attributes() +{ + Http::Header header("test", "X;a=1, 2 ; b=3 ", Http::Header::VALUE_WITH_ATTRIBUTES); + EXPECT_EQUAL(header.values.size(), 1); + EXPECT_EQUAL(header.values.front().value, "X"); + EXPECT_EQUAL(header.values.front().parameters.count("a"), 1); + EXPECT_EQUAL(header.values.front().parameters["a"], "1, 2"); + EXPECT_EQUAL(header.values.front().parameters.count("b"), 1); + EXPECT_EQUAL(header.values.front().parameters["b"], "3"); +} + +void HttpHeaderTests::list_with_attributes() +{ + Http::Header header("test", "X;a= 1;b= 2 ,Y ; c=3 ", Http::Header::LIST_WITH_ATTRIBUTES); + EXPECT_EQUAL(header.values.size(), 2); + EXPECT_EQUAL(header.values[0].value, "X"); + EXPECT_EQUAL(header.values[0].parameters.count("a"), 1); + EXPECT_EQUAL(header.values[0].parameters["a"], "1"); + EXPECT_EQUAL(header.values[0].parameters.count("b"), 1); + EXPECT_EQUAL(header.values[0].parameters["b"], "2"); + EXPECT_EQUAL(header.values[1].value, "Y"); + EXPECT_EQUAL(header.values[1].parameters.count("c"), 1); + EXPECT_EQUAL(header.values[1].parameters["c"], "3"); +} + +void HttpHeaderTests::quoted_values() +{ + Http::Header header("test", "X;a= 1;b=\" ;2, \",Y ; c=\"3 \" ", Http::Header::LIST_WITH_ATTRIBUTES); + EXPECT_EQUAL(header.values.size(), 2); + EXPECT_EQUAL(header.values[0].value, "X"); + EXPECT_EQUAL(header.values[0].parameters.count("a"), 1); + EXPECT_EQUAL(header.values[0].parameters["a"], "1"); + EXPECT_EQUAL(header.values[0].parameters.count("b"), 1); + EXPECT_EQUAL(header.values[0].parameters["b"], " ;2, "); + EXPECT_EQUAL(header.values[1].value, "Y"); + EXPECT_EQUAL(header.values[1].parameters.count("c"), 1); + EXPECT_EQUAL(header.values[1].parameters["c"], "3 "); +} diff --git a/tests/protocol.cpp b/tests/protocol.cpp new file mode 100644 index 0000000..bca1161 --- /dev/null +++ b/tests/protocol.cpp @@ -0,0 +1,179 @@ +#include +#include + +using namespace std; +using namespace Msp; + +class ProtocolTests: public Test::RegisteredTest +{ +public: + ProtocolTests(); + + static const char *get_name() { return "Protocol"; } + +private: + void hash_match(); + void buffer_overflow(); + void truncated_packet(); + void stub_header(); + + template + void transmit(const P &, P &, size_t); + + void transmit_int(); + void transmit_string(); + void transmit_array(); + void transmit_composite(); +}; + + +class Protocol: public Msp::Net::Protocol +{ +public: + Protocol(); +}; + +struct Packet1 +{ + uint32_t value; +}; + +struct Packet2 +{ + std::string value; +}; + +struct Packet3 +{ + std::vector values; +}; + +struct Packet4 +{ + Packet2 sub1; + std::vector sub2; +}; + +Protocol::Protocol() +{ + add(&Packet1::value); + add(&Packet2::value); + add(&Packet3::values); + add(&Packet4::sub1, &Packet4::sub2); +} + +template +class Receiver: public Net::PacketReceiver +{ +private: + T &storage; + +public: + Receiver(T &s): storage(s) { } + + void receive(const T &p) override { storage = p; } +}; + + +ProtocolTests::ProtocolTests() +{ + add(&ProtocolTests::hash_match, "Hash match"); + add(&ProtocolTests::buffer_overflow, "Serialization buffer overflow").expect_throw(); + add(&ProtocolTests::truncated_packet, "Truncated packet").expect_throw(); + add(&ProtocolTests::stub_header, "Stub header"); + add(&ProtocolTests::transmit_int, "Integer transmission"); + add(&ProtocolTests::transmit_string, "String transmission"); + add(&ProtocolTests::transmit_array, "Array transmission"); + add(&ProtocolTests::transmit_composite, "Composite transmission"); +} + +void ProtocolTests::hash_match() +{ + Protocol proto1; + Protocol proto2; + EXPECT_EQUAL(proto1.get_hash(), proto2.get_hash()); +} + +void ProtocolTests::buffer_overflow() +{ + Protocol proto; + Packet1 pkt = { 42 }; + char buffer[7]; + proto.serialize(pkt, buffer, sizeof(buffer)); +} + +void ProtocolTests::truncated_packet() +{ + Protocol proto; + Packet1 pkt = { 42 }; + char buffer[16]; + size_t len = proto.serialize(pkt, buffer, sizeof(buffer)); + Receiver recv(pkt); + proto.dispatch(recv, buffer, len-1); +} + +void ProtocolTests::stub_header() +{ + Protocol proto; + char buffer[3] = { 4, 0, 1 }; + size_t len = proto.get_packet_size(buffer, sizeof(buffer)); + EXPECT_EQUAL(len, 0); +} + +template +void ProtocolTests::transmit(const P &pkt, P &rpkt, size_t expected_length) +{ + Protocol proto; + char buffer[128]; + size_t len = proto.serialize(pkt, buffer, sizeof(buffer)); + EXPECT_EQUAL(len, expected_length); + + size_t rlen = proto.get_packet_size(buffer, sizeof(buffer)); + EXPECT_EQUAL(rlen, len); + + Receiver

recv(rpkt); + size_t dlen = proto.dispatch(recv, buffer, sizeof(buffer)); + EXPECT_EQUAL(dlen, len); +} + +void ProtocolTests::transmit_int() +{ + Packet1 pkt = { 42 }; + Packet1 rpkt; + transmit(pkt, rpkt, 8); + EXPECT_EQUAL(rpkt.value, 42); +} + +void ProtocolTests::transmit_string() +{ + Packet2 pkt = { "Hello" }; + Packet2 rpkt; + transmit(pkt, rpkt, 11); + EXPECT_EQUAL(rpkt.value, "Hello"); +} + +void ProtocolTests::transmit_array() +{ + Packet3 pkt = {{ 2, 3, 5, 7, 11 }}; + Packet3 rpkt; + transmit(pkt, rpkt, 26); + EXPECT_EQUAL(rpkt.values.size(), 5); + for(size_t i=0; i