DefaultExtensionMessageFactoryTest.cc 6.41 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
#include "DefaultExtensionMessageFactory.h"

#include <iostream>

#include <cppunit/extensions/HelperMacros.h>

#include "Peer.h"
#include "MockPeerStorage.h"
#include "bittorrent_helper.h"
#include "HandshakeExtensionMessage.h"
#include "UTPexExtensionMessage.h"
#include "Exception.h"
#include "FileEntry.h"
#include "ExtensionMessageRegistry.h"
#include "DownloadContext.h"
#include "MockBtMessageDispatcher.h"
#include "MockBtMessageFactory.h"
#include "DownloadContext.h"
#include "BtHandshakeMessage.h"
#include "UTMetadataRequestExtensionMessage.h"
#include "UTMetadataDataExtensionMessage.h"
#include "UTMetadataRejectExtensionMessage.h"
#include "BtRuntime.h"
#include "PieceStorage.h"
#include "RequestGroup.h"
#include "Option.h"

namespace aria2 {

class DefaultExtensionMessageFactoryTest:public CppUnit::TestFixture {

  CPPUNIT_TEST_SUITE(DefaultExtensionMessageFactoryTest);
  CPPUNIT_TEST(testCreateMessage_unknown);
  CPPUNIT_TEST(testCreateMessage_Handshake);
  CPPUNIT_TEST(testCreateMessage_UTPex);
  CPPUNIT_TEST(testCreateMessage_UTMetadataRequest);
  CPPUNIT_TEST(testCreateMessage_UTMetadataData);
  CPPUNIT_TEST(testCreateMessage_UTMetadataReject);
  CPPUNIT_TEST_SUITE_END();
private:
41 42 43 44 45 46 47 48
  std::unique_ptr<MockPeerStorage> peerStorage_;
  std::shared_ptr<Peer> peer_;
  std::unique_ptr<DefaultExtensionMessageFactory> factory_;
  std::unique_ptr<ExtensionMessageRegistry> registry_;
  std::unique_ptr<MockBtMessageDispatcher> dispatcher_;
  std::unique_ptr<MockBtMessageFactory> messageFactory_;
  std::shared_ptr<DownloadContext> dctx_;
  std::unique_ptr<RequestGroup> requestGroup_;
49 50 51
public:
  void setUp()
  {
52
    peerStorage_ = make_unique<MockPeerStorage>();
53

54
    peer_ = std::make_shared<Peer>("192.168.0.1", 6969);
55
    peer_->allocateSessionResource(1_k, 1_m);
56
    peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 1);
57

58 59 60 61 62 63
    registry_ = make_unique<ExtensionMessageRegistry>();
    dispatcher_ = make_unique<MockBtMessageDispatcher>();
    messageFactory_ = make_unique<MockBtMessageFactory>();
    dctx_ = std::make_shared<DownloadContext>();
    auto option = std::make_shared<Option>();
    requestGroup_ = make_unique<RequestGroup>(GroupId::create(), option);
64 65
    requestGroup_->setDownloadContext(dctx_);

66 67
    factory_ = make_unique<DefaultExtensionMessageFactory>();
    factory_->setPeerStorage(peerStorage_.get());
68
    factory_->setPeer(peer_);
69
    factory_->setExtensionMessageRegistry(registry_.get());
70 71
    factory_->setBtMessageDispatcher(dispatcher_.get());
    factory_->setBtMessageFactory(messageFactory_.get());
72
    factory_->setDownloadContext(dctx_.get());
73 74
  }

75
  std::string getExtensionMessageID(int key)
76
  {
77
    unsigned char id[1] = { registry_->getExtensionMessageID(key) };
78 79 80 81
    return std::string(&id[0], &id[1]);
  }

  template<typename T>
82
  std::shared_ptr<T> createMessage(const std::string& data)
83
  {
84 85 86 87
    auto m = factory_->createMessage
      (reinterpret_cast<const unsigned char*>(data.c_str()), data.size());
    return std::dynamic_pointer_cast<T>(std::shared_ptr<T>
      {static_cast<T*>(m.release())});
88 89 90 91 92 93 94 95 96 97 98 99 100 101
  }

  void testCreateMessage_unknown();
  void testCreateMessage_Handshake();
  void testCreateMessage_UTPex();
  void testCreateMessage_UTMetadataRequest();
  void testCreateMessage_UTMetadataData();
  void testCreateMessage_UTMetadataReject();
};

CPPUNIT_TEST_SUITE_REGISTRATION(DefaultExtensionMessageFactoryTest);

void DefaultExtensionMessageFactoryTest::testCreateMessage_unknown()
{
102
  peer_->setExtension(ExtensionMessageRegistry::UT_PEX, 255);
103

104
  unsigned char id[1] = { 255 };
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

  std::string data = std::string(&id[0], &id[1]);
  try {
    // this test fails because localhost doesn't have extension id = 255.
    factory_->createMessage
      (reinterpret_cast<const unsigned char*>(data.c_str()), data.size());
    CPPUNIT_FAIL("exception must be thrown.");
  } catch(Exception& e) {
    std::cerr << e.stackTrace() << std::endl;
  }
}

void DefaultExtensionMessageFactoryTest::testCreateMessage_Handshake()
{
  char id[1] = { 0 };

  std::string data = std::string(&id[0], &id[1])+"d1:v5:aria2e";
122
  auto m = createMessage<HandshakeExtensionMessage>(data);
123 124 125 126 127 128 129 130 131 132 133 134 135 136
  CPPUNIT_ASSERT_EQUAL(std::string("aria2"), m->getClientVersion());
}

void DefaultExtensionMessageFactoryTest::testCreateMessage_UTPex()
{
  unsigned char c1[COMPACT_LEN_IPV6];
  unsigned char c2[COMPACT_LEN_IPV6];
  unsigned char c3[COMPACT_LEN_IPV6];
  unsigned char c4[COMPACT_LEN_IPV6];
  bittorrent::packcompact(c1, "192.168.0.1", 6881);
  bittorrent::packcompact(c2, "10.1.1.2", 9999);
  bittorrent::packcompact(c3, "192.168.0.2", 6882);
  bittorrent::packcompact(c4, "10.1.1.3",10000);

137 138 139 140
  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_PEX, 1);

  std::string data = getExtensionMessageID(ExtensionMessageRegistry::UT_PEX)
    +"d5:added12:"+
141 142 143 144 145
    std::string(&c1[0], &c1[6])+std::string(&c2[0], &c2[6])+
    "7:added.f2:207:dropped12:"+
    std::string(&c3[0], &c3[6])+std::string(&c4[0], &c4[6])+
    "e";

146
  auto m = createMessage<UTPexExtensionMessage>(data);
147 148
  CPPUNIT_ASSERT_EQUAL(registry_->getExtensionMessageID
                       (ExtensionMessageRegistry::UT_PEX),
149 150 151 152 153
                       m->getExtensionMessageID());
}

void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataRequest()
{
154 155 156 157
  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);

  std::string data = getExtensionMessageID
    (ExtensionMessageRegistry::UT_METADATA)+
158
    "d8:msg_typei0e5:piecei1ee";
159
  auto m = createMessage<UTMetadataRequestExtensionMessage>(data);
160 161 162 163 164
  CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex());
}

void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataData()
{
165 166 167 168
  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);

  std::string data = getExtensionMessageID
    (ExtensionMessageRegistry::UT_METADATA)+
169
    "d8:msg_typei1e5:piecei1e10:total_sizei300ee0000000000";
170
  auto m = createMessage<UTMetadataDataExtensionMessage>(data);
171 172 173 174 175 176 177
  CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex());
  CPPUNIT_ASSERT_EQUAL((size_t)300, m->getTotalSize());
  CPPUNIT_ASSERT_EQUAL(std::string(10, '0'), m->getData());
}

void DefaultExtensionMessageFactoryTest::testCreateMessage_UTMetadataReject()
{
178 179 180 181
  registry_->setExtensionMessageID(ExtensionMessageRegistry::UT_METADATA, 1);

  std::string data = getExtensionMessageID
    (ExtensionMessageRegistry::UT_METADATA)+
182
    "d8:msg_typei2e5:piecei1ee";
183
  auto m = createMessage<UTMetadataRejectExtensionMessage>(data);
184 185 186 187
  CPPUNIT_ASSERT_EQUAL((size_t)1, m->getIndex());
}

} // namespace aria2