+#include <msp/net/protocol.h>
+#include <msp/test/test.h>
+
+using namespace std;
+using namespace Msp;
+
+class ProtocolTests: public Test::RegisteredTest<ProtocolTests>
+{
+public:
+ ProtocolTests();
+
+ static const char *get_name() { return "Protocol"; }
+
+private:
+ void hash_match();
+ void buffer_overflow();
+ void truncated_packet();
+ void stub_header();
+
+ template<typename P>
+ void transmit(const P &, P &, size_t);
+
+ void transmit_int();
+ void transmit_string();
+ void transmit_array();
+ void transmit_composite();
+};
+
+
+class Protocol: public Msp::Net::Protocol
+{
+public:
+ Protocol();
+};
+
+struct Packet1
+{
+ uint32_t value;
+};
+
+struct Packet2
+{
+ std::string value;
+};
+
+struct Packet3
+{
+ std::vector<uint32_t> values;
+};
+
+struct Packet4
+{
+ Packet2 sub1;
+ std::vector<Packet3> sub2;
+};
+
+Protocol::Protocol()
+{
+ add<Packet1>()(&Packet1::value);
+ add<Packet2>()(&Packet2::value);
+ add<Packet3>()(&Packet3::values);
+ add<Packet4>()(&Packet4::sub1)(&Packet4::sub2);
+}
+
+template<typename T>
+class Receiver: public Net::PacketReceiver<T>
+{
+private:
+ T &storage;
+
+public:
+ Receiver(T &s): storage(s) { }
+
+ void receive(const T &p) override { storage = p; }
+};
+
+
+ProtocolTests::ProtocolTests()
+{
+ add(&ProtocolTests::hash_match, "Hash match");
+ add(&ProtocolTests::buffer_overflow, "Serialization buffer overflow").expect_throw<Net::buffer_error>();
+ add(&ProtocolTests::truncated_packet, "Truncated packet").expect_throw<Net::bad_packet>();
+ add(&ProtocolTests::stub_header, "Stub header");
+ add(&ProtocolTests::transmit_int, "Integer transmission");
+ add(&ProtocolTests::transmit_string, "String transmission");
+ add(&ProtocolTests::transmit_array, "Array transmission");
+ add(&ProtocolTests::transmit_composite, "Composite transmission");
+}
+
+void ProtocolTests::hash_match()
+{
+ Protocol proto1;
+ Protocol proto2;
+ EXPECT_EQUAL(proto1.get_hash(), proto2.get_hash());
+}
+
+void ProtocolTests::buffer_overflow()
+{
+ Protocol proto;
+ Packet1 pkt = { 42 };
+ char buffer[7];
+ proto.serialize(pkt, buffer, sizeof(buffer));
+}
+
+void ProtocolTests::truncated_packet()
+{
+ Protocol proto;
+ Packet1 pkt = { 42 };
+ char buffer[16];
+ size_t len = proto.serialize(pkt, buffer, sizeof(buffer));
+ Receiver<Packet1> recv(pkt);
+ proto.dispatch(recv, buffer, len-1);
+}
+
+void ProtocolTests::stub_header()
+{
+ Protocol proto;
+ char buffer[3] = { 4, 0, 1 };
+ size_t len = proto.get_packet_size(buffer, sizeof(buffer));
+ EXPECT_EQUAL(len, 0);
+}
+
+template<typename P>
+void ProtocolTests::transmit(const P &pkt, P &rpkt, size_t expected_length)
+{
+ Protocol proto;
+ char buffer[128];
+ size_t len = proto.serialize(pkt, buffer, sizeof(buffer));
+ EXPECT_EQUAL(len, expected_length);
+
+ size_t rlen = proto.get_packet_size(buffer, sizeof(buffer));
+ EXPECT_EQUAL(rlen, len);
+
+ Receiver<P> recv(rpkt);
+ size_t dlen = proto.dispatch(recv, buffer, sizeof(buffer));
+ EXPECT_EQUAL(dlen, len);
+}
+
+void ProtocolTests::transmit_int()
+{
+ Packet1 pkt = { 42 };
+ Packet1 rpkt;
+ transmit(pkt, rpkt, 8);
+ EXPECT_EQUAL(rpkt.value, 42);
+}
+
+void ProtocolTests::transmit_string()
+{
+ Packet2 pkt = { "Hello" };
+ Packet2 rpkt;
+ transmit(pkt, rpkt, 11);
+ EXPECT_EQUAL(rpkt.value, "Hello");
+}
+
+void ProtocolTests::transmit_array()
+{
+ Packet3 pkt = {{ 2, 3, 5, 7, 11 }};
+ Packet3 rpkt;
+ transmit(pkt, rpkt, 26);
+ EXPECT_EQUAL(rpkt.values.size(), 5);
+ for(size_t i=0; i<pkt.values.size(); ++i)
+ EXPECT_EQUAL(rpkt.values[i], pkt.values[i]);
+}
+
+void ProtocolTests::transmit_composite()
+{
+ Packet4 pkt = { "Don't panic", { }};
+ pkt.sub2.emplace_back(Packet3{{ 2, 3, 5, 7, 11 }});
+ pkt.sub2.emplace_back(Packet3{{ 20, 10, 5, 16, 8, 4, 2, 1 }});
+ Packet4 rpkt;
+ transmit(pkt, rpkt, 75);
+ EXPECT_EQUAL(rpkt.sub1.value, "Don't panic");
+ EXPECT_EQUAL(rpkt.sub2.size(), 2);
+ EXPECT_EQUAL(rpkt.sub2[0].values.size(), 5);
+ EXPECT_EQUAL(rpkt.sub2[1].values.size(), 8);
+ for(size_t i=0; i<pkt.sub2.size(); ++i)
+ for(size_t j=0; j<pkt.sub2[i].values.size(); ++j)
+ EXPECT_EQUAL(rpkt.sub2[i].values[j], pkt.sub2[i].values[j]);
+}