From: Mikko Rasa Date: Wed, 10 Aug 2011 18:14:42 +0000 (+0300) Subject: Merge branch 'http-master' X-Git-Url: http://git.tdb.fi/?p=libs%2Fnet.git;a=commitdiff_plain;h=debe1004676d5431e571d9c4361072661dcc88c4;hp=cf8d2e48581eeb8f1b83e8c48321a0bc2ffa6d83 Merge branch 'http-master' Conflicts: .gitignore Build --- diff --git a/.gitignore b/.gitignore index e6bccfe..bee94e5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,10 @@ /.options.* /.profile /debug -/libmsphttp.a -/libmsphttp.so -/msphttp.pc +/libmspnet.a +/libmspnet.so +/mspnet.pc +/netcat +/pc-32-windows +/release /temp diff --git a/Build b/Build index f602417..1ed9db8 100644 --- a/Build +++ b/Build @@ -1,15 +1,39 @@ -package "msphttp" +package "mspnet" { - description "HTTP client and server library"; - version "0.1"; + require "mspcore"; + if "arch=win32" + { + build_info + { + library "ws2_32"; + }; + }; + + headers "msp/net" + { + source "source/net"; + install true; + }; - require "mspnet"; - require "mspstrings"; + headers "msp/http" + { + source "source/http"; + install true; + }; - library "msphttp" + library "mspnet" { - source "source"; + source "source/net"; + source "source/http"; install true; - install_headers "msp/http"; + }; + + program "netcat" + { + source "examples/netcat.cpp"; + build_info + { + library "mspnet"; + }; }; }; diff --git a/examples/netcat.cpp b/examples/netcat.cpp new file mode 100644 index 0000000..9bbad87 --- /dev/null +++ b/examples/netcat.cpp @@ -0,0 +1,98 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace Msp; + +class NetCat: public RegisteredApplication +{ +private: + bool ipv6; + bool listen; + Net::StreamServerSocket *server_sock; + Net::StreamSocket *sock; + IO::EventDispatcher event_disp; + +public: + NetCat(int, char **); + +private: + virtual void tick(); + void net_data_available(); + void console_data_available(); +}; + +NetCat::NetCat(int argc, char **argv): + ipv6(false), + listen(false), + server_sock(0), + sock(0) +{ + GetOpt getopt; + getopt.add_option('6', "ipv6", ipv6, GetOpt::NO_ARG); + getopt.add_option('l', "listen", listen, GetOpt::NO_ARG); + getopt(argc, argv); + + const vector &args = getopt.get_args(); + if(args.empty()) + throw usage_error("host argument missing"); + + RefPtr addr = Net::resolve(args.front(), (ipv6 ? Net::INET6 : Net::INET)); + if(!listen) + { + sock = new Net::StreamSocket(addr->get_family()); + sock->connect(*addr); + event_disp.add(*sock); + sock->signal_data_available.connect(sigc::mem_fun(this, &NetCat::net_data_available)); + } + else + { + server_sock = new Net::StreamServerSocket(addr->get_family()); + server_sock->listen(*addr); + event_disp.add(*server_sock); + server_sock->signal_data_available.connect(sigc::mem_fun(this, &NetCat::net_data_available)); + } + + event_disp.add(IO::cin); + IO::cin.signal_data_available.connect(sigc::mem_fun(this, &NetCat::console_data_available)); +} + +void NetCat::tick() +{ + event_disp.tick(); + if(server_sock && sock) + { + delete server_sock; + server_sock = 0; + } +} + +void NetCat::net_data_available() +{ + if(server_sock) + { + sock = server_sock->accept(); + event_disp.add(*sock); + sock->signal_data_available.connect(sigc::mem_fun(this, &NetCat::net_data_available)); + } + else + { + char buf[1024]; + unsigned len = sock->read(buf, sizeof(buf)); + IO::cout.write(buf, len); + } +} + +void NetCat::console_data_available() +{ + char buf[1024]; + unsigned len = IO::cin.read(buf, sizeof(buf)); + if(sock) + sock->write(buf, len); +} diff --git a/source/net/clientsocket.cpp b/source/net/clientsocket.cpp new file mode 100644 index 0000000..80303de --- /dev/null +++ b/source/net/clientsocket.cpp @@ -0,0 +1,89 @@ +#ifdef WIN32 +#include +#else +#include +#include +#endif +#include +#include "clientsocket.h" +#include "socket_private.h" + +namespace Msp { +namespace Net { + +ClientSocket::ClientSocket(Family af, int type, int proto): + Socket(af, type, proto), + connecting(false), + connected(false), + peer_addr(0) +{ } + +ClientSocket::ClientSocket(const Private &p, const SockAddr &paddr): + Socket(p), + connecting(false), + connected(true), + peer_addr(paddr.copy()) +{ } + +ClientSocket::~ClientSocket() +{ + signal_flush_required.emit(); + + delete peer_addr; +} + +const SockAddr &ClientSocket::get_peer_address() const +{ + if(peer_addr==0) + throw bad_socket_state("not connected"); + return *peer_addr; +} + +unsigned ClientSocket::do_write(const char *buf, unsigned size) +{ + if(!connected) + throw bad_socket_state("not connected"); + + if(size==0) + return 0; + + int ret = ::send(priv->handle, buf, size, 0); + if(ret<0) + { + if(errno==EAGAIN) + return 0; + else + throw system_error("send"); + } + + return ret; +} + +unsigned ClientSocket::do_read(char *buf, unsigned size) +{ + if(!connected) + throw bad_socket_state("not connected"); + + if(size==0) + return 0; + + int ret = ::recv(priv->handle, buf, size, 0); + if(ret<0) + { + if(errno==EAGAIN) + return 0; + else + throw system_error("recv"); + } + else if(ret==0 && !eof_flag) + { + eof_flag = true; + signal_end_of_file.emit(); + set_events(IO::P_NONE); + } + + return ret; +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/clientsocket.h b/source/net/clientsocket.h new file mode 100644 index 0000000..db684f4 --- /dev/null +++ b/source/net/clientsocket.h @@ -0,0 +1,50 @@ +#ifndef MSP_NET_CLIENTSOCKET_H_ +#define MSP_NET_CLIENTSOCKET_H_ + +#include "socket.h" + +namespace Msp { +namespace Net { + +/** +ClientSockets are used for sending and receiving data over the network. +*/ +class ClientSocket: public Socket +{ +public: + /** Emitted when the socket finishes connecting. */ + sigc::signal signal_connect_finished; + +protected: + bool connecting; + bool connected; + SockAddr *peer_addr; + + ClientSocket(const Private &, const SockAddr &); + ClientSocket(Family, int, int); +public: + virtual ~ClientSocket(); + + /** Connects to a remote address. Exact semantics depend on the socket + type. Returns true if the connection was established, false if it's in + progress. */ + virtual bool connect(const SockAddr &) = 0; + + /** Checks the status of a connection being established. Returns true if + the connection was established successfully, false if it's still in + progress. If an error occurred, an exception is thrown. */ + virtual bool poll_connect(const Time::TimeDelta &) = 0; + + bool is_connecting() const { return connecting; } + bool is_connected() const { return connected; } + + const SockAddr &get_peer_address() const; +protected: + virtual unsigned do_write(const char *, unsigned); + virtual unsigned do_read(char *, unsigned); +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/communicator.cpp b/source/net/communicator.cpp new file mode 100644 index 0000000..c9b277d --- /dev/null +++ b/source/net/communicator.cpp @@ -0,0 +1,158 @@ +#include +#include "communicator.h" + +namespace { + +using namespace Msp::Net; + +struct Handshake +{ + unsigned hash; +}; + + +class HandshakeProtocol: public Protocol +{ +public: + HandshakeProtocol(); +}; + +HandshakeProtocol::HandshakeProtocol(): + Protocol(0x7F00) +{ + add()(&Handshake::hash); +} + + +class HandshakeReceiver: public PacketReceiver +{ +private: + unsigned hash; + +public: + HandshakeReceiver(); + unsigned get_hash() const { return hash; } + virtual void receive(const Handshake &); +}; + +HandshakeReceiver::HandshakeReceiver(): + hash(0) +{ } + +void HandshakeReceiver::receive(const Handshake &shake) +{ + hash = shake.hash; +} + +} + + +namespace Msp { +namespace Net { + +Communicator::Communicator(StreamSocket &s, const Protocol &p, ReceiverBase &r): + socket(s), + protocol(p), + receiver(r), + handshake_status(0), + buf_size(1024), + in_buf(new char[buf_size]), + in_begin(in_buf), + in_end(in_buf), + out_buf(new char[buf_size]), + good(true) +{ + socket.signal_data_available.connect(sigc::mem_fun(this, &Communicator::data_available)); +} + +Communicator::~Communicator() +{ + delete[] in_buf; + delete[] out_buf; +} + +void Communicator::initiate_handshake() +{ + if(handshake_status!=0) + throw sequence_error("handshaking already done"); + + send_handshake(); + handshake_status = 1; +} + +void Communicator::data_available() +{ + if(!good) + return; + + in_end += socket.read(in_end, in_buf+buf_size-in_end); + try + { + bool more = true; + while(more) + { + if(handshake_status==2) + { + more = receive_packet(protocol, receiver); + } + else + { + HandshakeProtocol hsproto; + HandshakeReceiver hsrecv; + if((more = receive_packet(hsproto, hsrecv))) + { + if(hsrecv.get_hash()==protocol.get_hash()) + { + if(handshake_status==0) + send_handshake(); + handshake_status = 2; + signal_handshake_done.emit(); + } + else + good = false; + } + } + } + } + catch(...) + { + good = false; + throw; + } +} + +bool Communicator::receive_packet(const Protocol &proto, ReceiverBase &recv) +{ + int psz = proto.get_packet_size(in_begin, in_end-in_begin); + if(psz && psz<=in_end-in_begin) + { + char *pkt = in_begin; + in_begin += psz; + proto.disassemble(recv, pkt, psz); + return true; + } + else + { + if(in_end==in_buf+buf_size) + { + unsigned used = in_end-in_begin; + memmove(in_buf, in_begin, used); + in_begin = in_buf; + in_end = in_begin+used; + } + return false; + } +} + +void Communicator::send_handshake() +{ + Handshake shake; + shake.hash = protocol.get_hash(); + + HandshakeProtocol hsproto; + unsigned size = hsproto.assemble(shake, out_buf, buf_size); + socket.write(out_buf, size); +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/communicator.h b/source/net/communicator.h new file mode 100644 index 0000000..8530db1 --- /dev/null +++ b/source/net/communicator.h @@ -0,0 +1,62 @@ +#ifndef MSP_NET_COMMUNICATOR_H_ +#define MSP_NET_COMMUNICATOR_H_ + +#include "protocol.h" +#include "streamsocket.h" + +namespace Msp { +namespace Net { + +class sequence_error: public std::logic_error +{ +public: + sequence_error(const std::string &w): std::logic_error(w) { } + virtual ~sequence_error() throw() { } +}; + + +class Communicator +{ +public: + sigc::signal signal_handshake_done; + +private: + StreamSocket &socket; + const Protocol &protocol; + ReceiverBase &receiver; + int handshake_status; + unsigned buf_size; + char *in_buf; + char *in_begin; + char *in_end; + char *out_buf; + bool good; + +public: + Communicator(StreamSocket &, const Protocol &, ReceiverBase &); + ~Communicator(); + + void initiate_handshake(); + bool is_handshake_done() const { return handshake_status==2; } + + template + void send(const P &pkt) + { + if(!good) + throw sequence_error("connection aborted"); + if(handshake_status!=2) + throw sequence_error("handshaking not done"); + unsigned size = protocol.assemble(pkt, out_buf, buf_size); + socket.write(out_buf, size); + } + +private: + void data_available(); + bool receive_packet(const Protocol &, ReceiverBase &); + void send_handshake(); +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/constants.cpp b/source/net/constants.cpp new file mode 100644 index 0000000..e57897d --- /dev/null +++ b/source/net/constants.cpp @@ -0,0 +1,39 @@ +#include +#ifdef WIN32 +#include +#else +#include +#endif +#include "constants.h" + +using namespace std; + +namespace Msp { +namespace Net { + +int family_to_sys(Family f) +{ + switch(f) + { + case UNSPEC: return AF_UNSPEC; + case INET: return AF_INET; + case INET6: return AF_INET6; + case UNIX: return AF_UNIX; + default: throw invalid_argument("family_to_sys"); + } +} + +Family family_from_sys(int f) +{ + switch(f) + { + case AF_UNSPEC: return UNSPEC; + case AF_INET: return INET; + case AF_INET6: return INET6; + case AF_UNIX: return UNIX; + default: throw invalid_argument("family_from_sys"); + } +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/constants.h b/source/net/constants.h new file mode 100644 index 0000000..0d94661 --- /dev/null +++ b/source/net/constants.h @@ -0,0 +1,21 @@ +#ifndef MSP_NET_CONSTANTS_H_ +#define MSP_NET_CONSTANTS_H_ + +namespace Msp { +namespace Net { + +enum Family +{ + UNSPEC, + INET, + INET6, + UNIX +}; + +int family_to_sys(Family); +Family family_from_sys(int); + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/datagramsocket.cpp b/source/net/datagramsocket.cpp new file mode 100644 index 0000000..d24ca4e --- /dev/null +++ b/source/net/datagramsocket.cpp @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include "datagramsocket.h" +#include "sockaddr_private.h" +#include "socket_private.h" + +namespace Msp { +namespace Net { + +DatagramSocket::DatagramSocket(Family f, int p): + ClientSocket(f, SOCK_DGRAM, p) +{ +#ifdef WIN32 + WSAEventSelect(priv->handle, *priv->event, FD_READ|FD_CLOSE); +#endif + set_events(IO::P_INPUT); +} + +bool DatagramSocket::connect(const SockAddr &addr) +{ + SockAddr::SysAddr sa = addr.to_sys(); + + int err = ::connect(priv->handle, reinterpret_cast(&sa.addr), sa.size); + if(err==-1) + { +#ifdef WIN32 + throw system_error("connect", WSAGetLastError()); +#else + throw system_error("connect"); +#endif + } + + delete peer_addr; + peer_addr = addr.copy(); + + delete local_addr; + SockAddr::SysAddr lsa; + getsockname(priv->handle, reinterpret_cast(&lsa.addr), &lsa.size); + local_addr = SockAddr::new_from_sys(lsa); + + connected = true; + + return true; +} + +unsigned DatagramSocket::sendto(const char *buf, unsigned size, const SockAddr &addr) +{ + if(size==0) + return 0; + + SockAddr::SysAddr sa = addr.to_sys(); + + int ret = ::sendto(priv->handle, buf, size, 0, reinterpret_cast(&sa.addr), sa.size); + if(ret<0) + { + if(errno==EAGAIN) + return 0; + else + { +#ifdef WIN32 + throw system_error("sendto", WSAGetLastError()); +#else + throw system_error("sendto"); +#endif + } + } + + return ret; +} + +unsigned DatagramSocket::recvfrom(char *buf, unsigned size, SockAddr *&from_addr) +{ + if(size==0) + return 0; + + SockAddr::SysAddr sa; + int ret = ::recvfrom(priv->handle, buf, size, 0, reinterpret_cast(&sa.addr), &sa.size); + if(ret<0) + { + if(errno==EAGAIN) + return 0; + else + { +#ifdef WIN32 + throw system_error("recvfrom", WSAGetLastError()); +#else + throw system_error("recvfrom"); +#endif + } + } + + from_addr = SockAddr::new_from_sys(sa); + + return ret; +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/datagramsocket.h b/source/net/datagramsocket.h new file mode 100644 index 0000000..23ca296 --- /dev/null +++ b/source/net/datagramsocket.h @@ -0,0 +1,24 @@ +#ifndef MSP_NET_DATAGRAMSOCKET_H_ +#define MSP_NET_DATAGRAMSOCKET_H_ + +#include "clientsocket.h" + +namespace Msp { +namespace Net { + +class DatagramSocket: public ClientSocket +{ +public: + DatagramSocket(Family, int = 0); + + virtual bool connect(const SockAddr &); + virtual bool poll_connect(const Time::TimeDelta &) { return false; } + + unsigned sendto(const char *, unsigned, const SockAddr &); + unsigned recvfrom(char *, unsigned, SockAddr *&); +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/inet.cpp b/source/net/inet.cpp new file mode 100644 index 0000000..29e24aa --- /dev/null +++ b/source/net/inet.cpp @@ -0,0 +1,52 @@ +#ifdef WIN32 +#include +#else +#include +#endif +#include +#include "inet.h" +#include "sockaddr_private.h" + +using namespace std; + +namespace Msp { +namespace Net { + +InetAddr::InetAddr(): + port(0) +{ + fill(addr, addr+4, 0); +} + +InetAddr::InetAddr(const SysAddr &sa) +{ + const sockaddr_in &sai = reinterpret_cast(sa.addr); + addr[0] = sai.sin_addr.s_addr>>24; + addr[1] = sai.sin_addr.s_addr>>16; + addr[2] = sai.sin_addr.s_addr>>8; + addr[3] = sai.sin_addr.s_addr; + port = ntohs(sai.sin_port); +} + +SockAddr::SysAddr InetAddr::to_sys() const +{ + SysAddr sa; + sa.size = sizeof(sockaddr_in); + sockaddr_in &sai = reinterpret_cast(sa.addr); + sai.sin_family = AF_INET; + sai.sin_addr.s_addr = (addr[0]<<24) | (addr[1]<<16) | (addr[2]<<8) | (addr[3]); + sai.sin_port = htons(port); + + return sa; +} + +string InetAddr::str() const +{ + string result = format("%d.%d.%d.%d", addr[0], addr[1], addr[2], addr[3]); + if(port) + result += format(":%d", port); + return result; +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/inet.h b/source/net/inet.h new file mode 100644 index 0000000..639e64c --- /dev/null +++ b/source/net/inet.h @@ -0,0 +1,34 @@ +#ifndef MSP_NET_INET_H_ +#define MSP_NET_INET_H_ + +#include "sockaddr.h" + +namespace Msp { +namespace Net { + +/** +Address class for IPv4 sockets. +*/ +class InetAddr: public SockAddr +{ +private: + unsigned char addr[4]; + unsigned port; + +public: + InetAddr(); + InetAddr(const SysAddr &); + + virtual InetAddr *copy() const { return new InetAddr(*this); } + + virtual SysAddr to_sys() const; + + virtual Family get_family() const { return INET; } + unsigned get_port() const { return port; } + virtual std::string str() const; +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/inet6.cpp b/source/net/inet6.cpp new file mode 100644 index 0000000..79da1d2 --- /dev/null +++ b/source/net/inet6.cpp @@ -0,0 +1,61 @@ +#ifdef WIN32 +#include +#include +#else +#include +#endif +#include +#include "inet6.h" +#include "sockaddr_private.h" + +using namespace std; + +namespace Msp { +namespace Net { + +Inet6Addr::Inet6Addr(): + port(0) +{ + fill(addr, addr+16, 0); +} + +Inet6Addr::Inet6Addr(const SysAddr &sa) +{ + const sockaddr_in6 &sai6 = reinterpret_cast(sa.addr); + std::copy(sai6.sin6_addr.s6_addr, sai6.sin6_addr.s6_addr+16, addr); + port = htons(sai6.sin6_port); +} + +SockAddr::SysAddr Inet6Addr::to_sys() const +{ + SysAddr sa; + sa.size = sizeof(sockaddr_in6); + sockaddr_in6 &sai6 = reinterpret_cast(sa.addr); + sai6.sin6_family = AF_INET6; + std::copy(addr, addr+16, sai6.sin6_addr.s6_addr); + sai6.sin6_port = htons(port); + sai6.sin6_flowinfo = 0; + sai6.sin6_scope_id = 0; + + return sa; +} + +string Inet6Addr::str() const +{ + string result = "["; + for(unsigned i=0; i<16; i+=2) + { + unsigned short part = (addr[i]<<8) | addr[i+1]; + if(i>0) + result += ':'; + result += format("%x", part); + } + result += ']'; + if(port) + result += format(":%d", port); + + return result; +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/inet6.h b/source/net/inet6.h new file mode 100644 index 0000000..cc33e76 --- /dev/null +++ b/source/net/inet6.h @@ -0,0 +1,31 @@ +#ifndef MSP_NET_INET6_H_ +#define NSP_NET_INET6_H_ + +#include "sockaddr.h" + +namespace Msp { +namespace Net { + +class Inet6Addr: public SockAddr +{ +private: + unsigned char addr[16]; + unsigned port; + +public: + Inet6Addr(); + Inet6Addr(const SysAddr &); + + virtual Inet6Addr *copy() const { return new Inet6Addr(*this); } + + virtual SysAddr to_sys() const; + + virtual Family get_family() const { return INET6; } + unsigned get_port() const { return port; } + virtual std::string str() const; +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/protocol.cpp b/source/net/protocol.cpp new file mode 100644 index 0000000..d56ac77 --- /dev/null +++ b/source/net/protocol.cpp @@ -0,0 +1,232 @@ +#include +#include +#include +#include +#include "protocol.h" + +using namespace std; + +namespace { + +using Msp::Net::buffer_error; + +template +class Assembler +{ +public: + static char *assemble(const T &v, char *, char *); + static const char *disassemble(T &, const char *, const char *); +}; + +template +class Assembler > +{ +public: + static char *assemble(const vector &v, char *, char *); + static const char *disassemble(vector &, const char *, const char *); +}; + +template +char *Assembler::assemble(const T &v, char *data, char *end) +{ + // XXX Assumes little-endian + const char *ptr = reinterpret_cast(&v)+sizeof(T); + for(unsigned i=0; i +char *Assembler::assemble(const string &v, char *data, char *end) +{ + data = Assembler::assemble(v.size(), data, end); + if(end-data(v.size())) + throw buffer_error("overflow"); + memcpy(data, v.data(), v.size()); + return data+v.size(); +} + +template +char *Assembler >::assemble(const vector &v, char *data, char *end) +{ + data = Assembler::assemble(v.size(), data, end); + for(typename vector::const_iterator i=v.begin(); i!=v.end(); ++i) + data = Assembler::assemble(*i, data, end); + return data; +} + +template +const char *Assembler::disassemble(T &v, const char *data, const char *end) +{ + char *ptr = reinterpret_cast(&v)+sizeof(T); + for(unsigned i=0; i +const char *Assembler::disassemble(string &v, const char *data, const char *end) +{ + unsigned short size; + data = Assembler::disassemble(size, data, end); + if(end-data +const char *Assembler >::disassemble(vector &v, const char *data, const char *end) +{ + /* We assume that the vector is in pristine state - this holds because the + only code path leading here is from PacketDef

::disassemble, which creates + a new packet. */ + unsigned short size; + data = Assembler::disassemble(size, data, end); + for(unsigned i=0; i::disassemble(u, data, end); + v.push_back(u); + } + return data; +} + +} + +namespace Msp { +namespace Net { + +Protocol::Protocol(unsigned npi): + next_packet_id(npi) +{ } + +Protocol::~Protocol() +{ + for(map::iterator i=packet_class_defs.begin(); i!=packet_class_defs.end(); ++i) + delete i->second; +} + +void Protocol::add_packet(PacketDefBase &pdef) +{ + PacketDefBase *&ptr = packet_class_defs[pdef.get_class_id()]; + if(ptr) + delete ptr; + ptr = &pdef; + packet_id_defs[pdef.get_id()] = &pdef; +} + +const Protocol::PacketDefBase &Protocol::get_packet_by_class(unsigned id) const +{ + return *get_item(packet_class_defs, id); +} + +const Protocol::PacketDefBase &Protocol::get_packet_by_id(unsigned id) const +{ + return *get_item(packet_id_defs, id); +} + +unsigned Protocol::disassemble(ReceiverBase &rcv, const char *data, unsigned size) const +{ + const unsigned char *udata = reinterpret_cast(data); + unsigned id = (udata[0]<<8)+udata[1]; + unsigned psz = (udata[2]<<8)+udata[3]; + if(psz>size) + throw bad_packet("truncated"); + const PacketDefBase &pdef = get_packet_by_id(id); + const char *ptr = pdef.disassemble(rcv, data+4, data+psz); + return ptr-data; +} + +unsigned Protocol::get_packet_size(const char *data, unsigned size) const +{ + if(size<4) + return 0; + const unsigned char *udata = reinterpret_cast(data); + return (udata[2]<<8)+udata[3]; +} + +unsigned Protocol::get_hash() const +{ + // TODO + return 123; +} + +void Protocol::assemble_header(char *buf, unsigned id, unsigned size) +{ + buf[0] = (id>>8)&0xFF; + buf[1] = id&0xFF; + buf[2] = (size>>8)&0xFF; + buf[3] = size&0xFF; +} + +template +char *Protocol::assemble_field(const T &v, char *d, char *e) +{ return Assembler::assemble(v, d, e); } + +template char *Protocol::assemble_field<>(const char &v, char *d, char *e); +template char *Protocol::assemble_field<>(const signed char &v, char *d, char *e); +template char *Protocol::assemble_field<>(const unsigned char &v, char *d, char *e); +template char *Protocol::assemble_field<>(const short &v, char *d, char *e); +template char *Protocol::assemble_field<>(const unsigned short &v, char *d, char *e); +template char *Protocol::assemble_field<>(const int &v, char *d, char *e); +template char *Protocol::assemble_field<>(const unsigned &v, char *d, char *e); +template char *Protocol::assemble_field<>(const long &v, char *d, char *e); +template char *Protocol::assemble_field<>(const unsigned long &v, char *d, char *e); +template char *Protocol::assemble_field<>(const float &v, char *d, char *e); +template char *Protocol::assemble_field<>(const double &v, char *d, char *e); +template char *Protocol::assemble_field<>(const string &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); +template char *Protocol::assemble_field<>(const vector &v, char *d, char *e); + +template +const char *Protocol::disassemble_field(T &v, const char *d, const char *e) +{ return Assembler::disassemble(v, d, e); } + +template const char *Protocol::disassemble_field<>(char &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(signed char &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(unsigned char &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(short &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(unsigned short &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(int &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(unsigned &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(long &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(unsigned long &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(float &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(double &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(string &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); +template const char *Protocol::disassemble_field<>(vector &v, const char *d, const char *e); + +unsigned Protocol::PacketDefBase::next_class_id = 1; + +} // namespace Net +} // namespace Msp diff --git a/source/net/protocol.h b/source/net/protocol.h new file mode 100644 index 0000000..75b3514 --- /dev/null +++ b/source/net/protocol.h @@ -0,0 +1,176 @@ +#ifndef MSP_NET_PROTOCOL_H_ +#define MSP_NET_PROTOCOL_H_ + +#include +#include +#include +#include "receiver.h" + +namespace Msp { +namespace Net { + +class bad_packet: public std::runtime_error +{ +public: + bad_packet(const std::string &w): std::runtime_error(w) { } + virtual ~bad_packet() throw() { } +}; + + +class buffer_error: public std::runtime_error +{ +public: + buffer_error(const std::string &w): std::runtime_error(w) { } + virtual ~buffer_error() throw() { } +}; + + +class Protocol +{ +private: + class PacketDefBase + { + protected: + unsigned id; + + PacketDefBase(unsigned i): id(i) { } + public: + virtual ~PacketDefBase() { } + virtual unsigned get_class_id() const = 0; + unsigned get_id() const { return id; } + virtual const char *disassemble(ReceiverBase &, const char *, const char *) const = 0; + + static unsigned next_class_id; + }; + + template + class FieldBase + { + protected: + FieldBase() { } + public: + virtual ~FieldBase() { } + virtual char *assemble(const P &, char *, char *) const = 0; + virtual const char *disassemble(P &, const char *, const char *) const = 0; + }; + + template + class Field: public FieldBase

+ { + private: + T P::*ptr; + + public: + Field(T P::*p): ptr(p) { } + + virtual char *assemble(const P &p, char *d, char *e) const + { return assemble_field(p.*ptr, d, e); } + + virtual const char *disassemble(P &p, const char *d, const char *e) const + { return disassemble_field(p.*ptr, d, e); } + }; + +protected: + template + class PacketDef: public PacketDefBase + { + private: + std::vector *> fields; + + public: + PacketDef(unsigned i): PacketDefBase(i) + { if(!class_id) class_id = next_class_id++; } + + ~PacketDef() + { + for(typename std::vector *>::const_iterator i=fields.begin(); i!=fields.end(); ++i) + delete *i; + } + + virtual unsigned get_class_id() const { return class_id; } + + template + PacketDef &operator()(T P::*p) + { fields.push_back(new Field(p)); return *this; } + + char *assemble(const P &p, char *d, char *e) const + { + for(typename std::vector *>::const_iterator i=fields.begin(); i!=fields.end(); ++i) + d = (*i)->assemble(p, d, e); + return d; + } + + const char *disassemble(ReceiverBase &r, const char *d, const char *e) const + { + PacketReceiver

*prcv = dynamic_cast *>(&r); + if(!prcv) + throw bad_packet("unsupported"); + P pkt; + for(typename std::vector *>::const_iterator i=fields.begin(); i!=fields.end(); ++i) + d = (*i)->disassemble(pkt, d, e); + prcv->receive(pkt); + return d; + } + + static unsigned class_id; + }; + + typedef std::map PacketMap; + + unsigned next_packet_id; + PacketMap packet_class_defs; + PacketMap packet_id_defs; + + Protocol(unsigned = 1); +public: + ~Protocol(); + +private: + void add_packet(PacketDefBase &); + +protected: + template + PacketDef

&add() + { + PacketDef

*pdef = new PacketDef

(next_packet_id++); + add_packet(*pdef); + return *pdef; + } + + const PacketDefBase &get_packet_by_class(unsigned) const; + const PacketDefBase &get_packet_by_id(unsigned) const; + +public: + template + unsigned assemble(const P &pkt, char *buf, unsigned size) const + { + unsigned id = PacketDef

::class_id; + const PacketDef

&pdef = static_cast &>(get_packet_by_class(id)); + char *ptr = pdef.assemble(pkt, buf+4, buf+size); + assemble_header(buf, pdef.get_id(), (size = ptr-buf)); + return size; + } + + unsigned disassemble(ReceiverBase &, const char *, unsigned) const; + + unsigned get_packet_size(const char *, unsigned) const; + + unsigned get_hash() const; + +private: + static void assemble_header(char *, unsigned, unsigned); + + template + static char *assemble_field(const T &, char *, char *); + + template + static const char *disassemble_field(T &, const char *, const char *); +}; + +template +unsigned Protocol::PacketDef

::class_id = 0; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/receiver.h b/source/net/receiver.h new file mode 100644 index 0000000..19e69e0 --- /dev/null +++ b/source/net/receiver.h @@ -0,0 +1,27 @@ +#ifndef MSP_NET_RECEIVER_H_ +#define MSP_NET_RECEIVER_H_ + +namespace Msp { +namespace Net { + +class ReceiverBase +{ +protected: + ReceiverBase() { } +public: + virtual ~ReceiverBase() { } +}; + +template +class PacketReceiver: public virtual ReceiverBase +{ +protected: + PacketReceiver() { } +public: + virtual void receive(const P &) = 0; +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/resolve.cpp b/source/net/resolve.cpp new file mode 100644 index 0000000..e8aad66 --- /dev/null +++ b/source/net/resolve.cpp @@ -0,0 +1,90 @@ +#ifdef WIN32 +#define _WIN32_WINNT 0x0501 +#include +#else +#include +#endif +#include +#include +#include "sockaddr_private.h" +#include "socket.h" +#include "resolve.h" + +using namespace std; + +namespace Msp { +namespace Net { + +SockAddr *resolve(const string &host, const string &serv, Family family) +{ + const char *chost = (host.empty() ? 0 : host.c_str()); + const char *cserv = (serv.empty() ? 0 : serv.c_str()); + unsigned flags = 0; + if(host=="*") + { + flags = AI_PASSIVE; + chost = 0; + } + + addrinfo hints = { flags, family_to_sys(family), 0, 0, 0, 0, 0, 0 }; + addrinfo *res; + + int err = getaddrinfo(chost, cserv, &hints, &res); + if(err==0) + { + SockAddr::SysAddr sa; + sa.size = res->ai_addrlen; + const char *sptr = reinterpret_cast(res->ai_addr); + char *dptr = reinterpret_cast(&sa.addr); + copy(sptr, sptr+res->ai_addrlen, dptr); + SockAddr *addr = SockAddr::new_from_sys(sa); + freeaddrinfo(res); + return addr; + } + else +#ifdef WIN32 + throw system_error("getaddrinfo", WSAGetLastError()); +#else + throw system_error("getaddrinfo", gai_strerror(err)); +#endif +} + +SockAddr *resolve(const string &str, Family family) +{ + string host, serv; + if(str[0]=='[') + { + unsigned bracket = str.find(']'); + host = str.substr(1, bracket-1); + unsigned colon = str.find(':', bracket); + if(colon!=string::npos) + serv = str.substr(colon+1); + } + else + { + unsigned colon = str.find(':'); + if(colon!=string::npos) + { + host = str.substr(0, colon); + serv = str.substr(colon+1); + } + else + host = str; + } + + return resolve(host, serv, family); +} + + /*sockaddr sa; + unsigned size = fill_sockaddr(sa); + char hst[128]; + char srv[128]; + int err = getnameinfo(&sa, size, hst, 128, srv, 128, 0); + if(err==0) + { + host = hst; + serv = srv; + }*/ + +} // namespace Net +} // namespace Msp diff --git a/source/net/resolve.h b/source/net/resolve.h new file mode 100644 index 0000000..256455f --- /dev/null +++ b/source/net/resolve.h @@ -0,0 +1,26 @@ +#ifndef MSP_NET_RESOLVE_H_ +#define MSP_NET_RESOLVE_H_ + +#include +#include "constants.h" + +namespace Msp { +namespace Net { + +class SockAddr; + +/** Resolves host and service names into a socket address. If host is empty, +the loopback address will be used. If host is "*", the wildcard address will +be used. If service is empty, a socket address with a null service will be +returned. With the IP families, these are not very useful. */ +SockAddr *resolve(const std::string &, const std::string &, Family = UNSPEC); + +/** And overload of resolve() that takes host and service as a single string, +separated by a colon. If the host part contains colons, such as is the case +with a numeric IPv6 address, it must be enclosed in brackets. */ +SockAddr *resolve(const std::string &, Family = UNSPEC); + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/serversocket.cpp b/source/net/serversocket.cpp new file mode 100644 index 0000000..6028483 --- /dev/null +++ b/source/net/serversocket.cpp @@ -0,0 +1,23 @@ +#include "serversocket.h" + +using namespace std; + +namespace Msp { +namespace Net { + +ServerSocket::ServerSocket(Family af, int type, int proto): + Socket(af, type, proto) +{ } + +unsigned ServerSocket::do_write(const char *, unsigned) +{ + throw logic_error("can't write to ServerSocket"); +} + +unsigned ServerSocket::do_read(char *, unsigned) +{ + throw logic_error("can't read from ServerSocket"); +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/serversocket.h b/source/net/serversocket.h new file mode 100644 index 0000000..10375f0 --- /dev/null +++ b/source/net/serversocket.h @@ -0,0 +1,32 @@ +#ifndef MSP_NET_SERVERSOCKET_H_ +#define MSP_NET_SERVERSOCKET_H_ + +#include "socket.h" + +namespace Msp { +namespace Net { + +class ClientSocket; + +/** +ServerSockets are used to receive incoming connections. They cannot be used +for sending and receiving data. +*/ +class ServerSocket: public Socket +{ +protected: + ServerSocket(Family, int, int); + +public: + virtual void listen(const SockAddr &, unsigned = 4) = 0; + + virtual ClientSocket *accept() = 0; +protected: + virtual unsigned do_write(const char *, unsigned); + virtual unsigned do_read(char *, unsigned); +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/sockaddr.cpp b/source/net/sockaddr.cpp new file mode 100644 index 0000000..9b41f81 --- /dev/null +++ b/source/net/sockaddr.cpp @@ -0,0 +1,31 @@ +#include +#include "inet.h" +#include "inet6.h" +#include "sockaddr_private.h" + +using namespace std; + +namespace Msp { +namespace Net { + +SockAddr *SockAddr::new_from_sys(const SysAddr &sa) +{ + switch(sa.addr.ss_family) + { + case AF_INET: + return new InetAddr(sa); + case AF_INET6: + return new Inet6Addr(sa); + default: + throw invalid_argument("SockAddr::create"); + } +} + +SockAddr::SysAddr::SysAddr(): + size(sizeof(sockaddr_storage)) +{ + addr.ss_family = AF_UNSPEC; +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/sockaddr.h b/source/net/sockaddr.h new file mode 100644 index 0000000..aad5e29 --- /dev/null +++ b/source/net/sockaddr.h @@ -0,0 +1,32 @@ +#ifndef MSP_NET_SOCKADDR_H_ +#define MSP_NET_SOCKADDR_H_ + +#include +#include "constants.h" + +namespace Msp { +namespace Net { + +class SockAddr +{ +public: + struct SysAddr; + +protected: + SockAddr() { } +public: + virtual ~SockAddr() { } + + virtual SockAddr *copy() const = 0; + + static SockAddr *new_from_sys(const SysAddr &); + virtual SysAddr to_sys() const = 0; + + virtual Family get_family() const = 0; + virtual std::string str() const = 0; +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/sockaddr_private.h b/source/net/sockaddr_private.h new file mode 100644 index 0000000..16af063 --- /dev/null +++ b/source/net/sockaddr_private.h @@ -0,0 +1,29 @@ +#ifndef MSP_NET_SOCKADDR_PRIVATE_H_ +#define MSP_NET_SOCKADDR_PRIVATE_H_ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "sockaddr.h" + +namespace Msp { +namespace Net { + +struct SockAddr::SysAddr +{ + struct sockaddr_storage addr; +#ifdef WIN32 + int size; +#else + socklen_t size; +#endif + + SysAddr(); +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/socket.cpp b/source/net/socket.cpp new file mode 100644 index 0000000..5af2b4f --- /dev/null +++ b/source/net/socket.cpp @@ -0,0 +1,159 @@ +#ifndef WIN32 +#include +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include "sockaddr_private.h" +#include "socket.h" +#include "socket_private.h" + +namespace { + +#ifdef WIN32 +class WinSockHelper +{ +public: + WinSockHelper() + { + WSADATA wsa_data; + int err = WSAStartup(0x0002, &wsa_data); + if(err) + std::cerr<<"Failed to initialize WinSock: "<handle = p.handle; + + SockAddr::SysAddr sa; + getsockname(priv->handle, reinterpret_cast(&sa.addr), &sa.size); + local_addr = SockAddr::new_from_sys(sa); + +#ifdef WIN32 + *priv->event = CreateEvent(0, false, false, 0); +#else + *priv->event = priv->handle; +#endif +} + +Socket::Socket(Family af, int type, int proto): + priv(new Private), + local_addr(0) +{ + priv->handle = socket(family_to_sys(af), type, proto); + +#ifdef WIN32 + *priv->event = CreateEvent(0, false, false, 0); +#else + *priv->event = priv->handle; +#endif +} + +Socket::~Socket() +{ +#ifdef WIN32 + closesocket(priv->handle); + CloseHandle(*priv->event); +#else + ::close(priv->handle); +#endif + + delete local_addr; + delete priv; +} + +void Socket::set_block(bool b) +{ + mode = (mode&~IO::M_NONBLOCK); + if(b) + mode = (mode|IO::M_NONBLOCK); + +#ifdef WIN32 + u_long flag = !b; + ioctlsocket(priv->handle, FIONBIO, &flag); +#else + int flags = fcntl(priv->handle, F_GETFL); + fcntl(priv->handle, F_SETFL, (flags&O_NONBLOCK)|(b?0:O_NONBLOCK)); +#endif +} + +const IO::Handle &Socket::get_event_handle() +{ + return priv->event; +} + + +void Socket::bind(const SockAddr &addr) +{ + SockAddr::SysAddr sa = addr.to_sys(); + + int err = ::bind(priv->handle, reinterpret_cast(&sa.addr), sa.size); + if(err==-1) + throw system_error("bind"); + + delete local_addr; + local_addr = addr.copy(); +} + +const SockAddr &Socket::get_local_address() const +{ + if(local_addr==0) + throw bad_socket_state("not bound"); + return *local_addr; +} + +void Socket::set_timeout(const Time::TimeDelta &timeout) +{ +#ifndef WIN32 + timeval tv = Time::rawtime_to_timeval(timeout.raw()); + set_option(SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(timeval)); + set_option(SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(timeval)); +#else + DWORD msecs = static_cast(timeout/Time::msec); + set_option(SOL_SOCKET, SO_RCVTIMEO, &msecs, sizeof(DWORD)); + set_option(SOL_SOCKET, SO_SNDTIMEO, &msecs, sizeof(DWORD)); +#endif +} + +int Socket::set_option(int level, int optname, const void *optval, socklen_t optlen) +{ +#ifdef WIN32 + return setsockopt(priv->handle, level, optname, reinterpret_cast(optval), optlen); +#else + return setsockopt(priv->handle, level, optname, optval, optlen); +#endif +} + +int Socket::get_option(int level, int optname, void *optval, socklen_t *optlen) const +{ +#ifdef WIN32 + return getsockopt(priv->handle, level, optname, reinterpret_cast(optval), optlen); +#else + return getsockopt(priv->handle, level, optname, optval, optlen); +#endif +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/socket.h b/source/net/socket.h new file mode 100644 index 0000000..f975884 --- /dev/null +++ b/source/net/socket.h @@ -0,0 +1,56 @@ +#ifndef MSP_NET_SOCKET_H_ +#define MSP_NET_SOCKET_H_ + +#include +#include +#include "constants.h" +#include "sockaddr.h" + +namespace Msp { +namespace Net { + +#ifdef WIN32 +typedef int socklen_t; +#endif + + +class bad_socket_state: public std::logic_error +{ +public: + bad_socket_state(const std::string &w): std::logic_error(w) { } + virtual ~bad_socket_state() throw() { } +}; + + +class Socket: public IO::EventObject +{ +protected: + struct Private; + + Private *priv; + SockAddr *local_addr; + + Socket(const Private &); + Socket(Family, int, int); +public: + ~Socket(); + + virtual void set_block(bool); + virtual const IO::Handle &get_event_handle(); + + /** Associates the socket with a local address. There must be no existing + users of the address. */ + void bind(const SockAddr &); + + const SockAddr &get_local_address() const; + + void set_timeout(const Time::TimeDelta &); +protected: + int set_option(int, int, const void *, socklen_t); + int get_option(int, int, void *, socklen_t *) const; +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/socket_private.h b/source/net/socket_private.h new file mode 100644 index 0000000..83cc369 --- /dev/null +++ b/source/net/socket_private.h @@ -0,0 +1,26 @@ +#ifndef MSP_NET_SOCKET_PRIVATE_H_ +#define MSP_NET_SOCKET_PRIVATE_H_ + +#include +#include "socket.h" + +namespace Msp { +namespace Net { + +struct Socket::Private +{ +#ifdef WIN32 + SOCKET handle; +#else + int handle; +#endif + + /* On POSIX platforms this is the same as the handle. This might seem + strange but it allows the same syntax on both POSIX and Windows. */ + IO::Handle event; +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/streamserversocket.cpp b/source/net/streamserversocket.cpp new file mode 100644 index 0000000..bbf91bc --- /dev/null +++ b/source/net/streamserversocket.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include +#include "sockaddr_private.h" +#include "socket_private.h" +#include "streamserversocket.h" +#include "streamsocket.h" + +using namespace std; + +namespace Msp { +namespace Net { + +StreamServerSocket::StreamServerSocket(Family af, int proto): + ServerSocket(af, SOCK_STREAM, proto), + listening(false) +{ } + +void StreamServerSocket::listen(const SockAddr &addr, unsigned backlog) +{ + bind(addr); + + int err = ::listen(priv->handle, backlog); + if(err==-1) + throw system_error("listen"); + +#ifdef WIN32 + WSAEventSelect(priv->handle, *priv->event, FD_ACCEPT); +#endif + set_events(IO::P_INPUT); + + listening = true; +} + +StreamSocket *StreamServerSocket::accept() +{ + if(!listening) + throw bad_socket_state("not listening"); + + SockAddr::SysAddr sa; + Private new_p; + new_p.handle = ::accept(priv->handle, reinterpret_cast(&sa.addr), &sa.size); + + RefPtr paddr = SockAddr::new_from_sys(sa); + return new StreamSocket(new_p, *paddr); +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/streamserversocket.h b/source/net/streamserversocket.h new file mode 100644 index 0000000..aa4868c --- /dev/null +++ b/source/net/streamserversocket.h @@ -0,0 +1,25 @@ +#ifndef MSP_NET_STREAMSERVERSOCKET_H_ +#define MSP_NET_STREAMSERVERSOCKET_H_ + +#include "serversocket.h" +#include "streamsocket.h" + +namespace Msp { +namespace Net { + +class StreamServerSocket: public ServerSocket +{ +private: + bool listening; + +public: + StreamServerSocket(Family, int = 0); + + virtual void listen(const SockAddr &, unsigned = 4); + virtual StreamSocket *accept(); +}; + +} // namespace Net +} // namespace Msp + +#endif diff --git a/source/net/streamsocket.cpp b/source/net/streamsocket.cpp new file mode 100644 index 0000000..8a26245 --- /dev/null +++ b/source/net/streamsocket.cpp @@ -0,0 +1,148 @@ +#ifndef WIN32 +#include +#endif +#include +#include +#include +#include +#include +#include "sockaddr_private.h" +#include "socket_private.h" +#include "streamsocket.h" + +namespace Msp { +namespace Net { + +StreamSocket::StreamSocket(const Private &p, const SockAddr &paddr): + ClientSocket(p, paddr) +{ +#ifdef WIN32 + WSAEventSelect(priv->handle, *priv->event, FD_READ|FD_CLOSE); +#endif + set_events(IO::P_INPUT); +} + +StreamSocket::StreamSocket(Family af, int proto): + ClientSocket(af, SOCK_STREAM, proto) +{ } + +bool StreamSocket::connect(const SockAddr &addr) +{ + if(connected) + throw bad_socket_state("already connected"); + + SockAddr::SysAddr sa = addr.to_sys(); + + int err = ::connect(priv->handle, reinterpret_cast(&sa.addr), sa.size); +#ifdef WIN32 + if(err==SOCKET_ERROR) + { + int err_code = WSAGetLastError(); + if(err_code==WSAEWOULDBLOCK) + { + connecting = true; + WSAEventSelect(priv->handle, *priv->event, FD_CONNECT); + set_events(IO::P_OUTPUT); + } + else + throw system_error("connect", err_code); + } +#else + if(err==-1) + { + if(errno==EINPROGRESS) + { + connecting = true; + set_events(IO::P_OUTPUT); + } + else + throw system_error("connect"); + } +#endif + + delete peer_addr; + peer_addr = addr.copy(); + + delete local_addr; + SockAddr::SysAddr lsa; + getsockname(priv->handle, reinterpret_cast(&lsa.addr), &lsa.size); + local_addr = SockAddr::new_from_sys(lsa); + + if(err==0) + { + connected = true; + set_events(IO::P_INPUT); + signal_connect_finished.emit(0); + } + + return connected; +} + +bool StreamSocket::poll_connect(const Time::TimeDelta &timeout) +{ + if(!connecting) + return false; + + IO::PollEvent res = poll(*this, IO::P_OUTPUT, timeout); + if(res&IO::P_OUTPUT) + { + connecting = false; + + int err; + socklen_t len = sizeof(int); + get_option(SOL_SOCKET, SO_ERROR, &err, &len); + + if(err!=0) + { + set_events(IO::P_NONE); +#ifdef WIN32 + throw system_error("connect", WSAGetLastError()); +#else + throw system_error("connect"); +#endif + } + +#ifdef WIN32 + WSAEventSelect(priv->handle, *priv->event, FD_READ|FD_CLOSE); +#endif + set_events(IO::P_INPUT); + + connected = true; + } + + return connected; +} + +void StreamSocket::on_event(IO::PollEvent ev) +{ + if((ev&(IO::P_OUTPUT|IO::P_ERROR)) && connecting) + { + int err; + socklen_t len = sizeof(err); + get_option(SOL_SOCKET, SO_ERROR, &err, &len); + + connecting = false; + connected = (err==0); + if(err) + { + system_error exc("connect", err); + signal_connect_finished.emit(&exc); + } + else + signal_connect_finished.emit(0); + + if(err!=0) + { + delete peer_addr; + peer_addr = 0; + } + +#ifdef WIN32 + WSAEventSelect(priv->handle, *priv->event, FD_READ|FD_CLOSE); +#endif + set_events((err==0) ? IO::P_INPUT : IO::P_NONE); + } +} + +} // namespace Net +} // namespace Msp diff --git a/source/net/streamsocket.h b/source/net/streamsocket.h new file mode 100644 index 0000000..8b39e91 --- /dev/null +++ b/source/net/streamsocket.h @@ -0,0 +1,37 @@ +#ifndef MSP_NET_STREAMSOCKET_H_ +#define MSP_NET_STREAMSOCKET_H_ + +#include "clientsocket.h" + +namespace Msp { +namespace Net { + +class StreamSocket: public ClientSocket +{ + friend class StreamServerSocket; + +private: + /// Used by StreamListenSocket to construct a new socket from accept. + StreamSocket(const Private &, const SockAddr &); +public: + StreamSocket(Family, int = 0); + + /** Connects to a remote address. StreamSockets must be connected before + data can be sent and received. Returns 0 if the connection was successfully + established, 1 if it's in progress. + + If the socket is non-blocking, this function may return before the + connection is fully established. The caller must then use either the + poll_connect function or an EventDispatcher to finish the process. */ + virtual bool connect(const SockAddr &); + + virtual bool poll_connect(const Time::TimeDelta &); + +private: + void on_event(IO::PollEvent); +}; + +} // namespace Net +} // namespace Msp + +#endif