DefaultBtMessageFactory.cc 14 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
/* <!-- copyright */
/*
 * aria2 - The high speed download utility
 *
 * Copyright (C) 2006 Tatsuhiro Tsujikawa
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 *
 * In addition, as a special exception, the copyright holders give
 * permission to link the code of portions of this program with the
 * OpenSSL library under certain conditions as described in each
 * individual source file, and distribute linked combinations
 * including the two.
 * You must obey the GNU General Public License in all respects
 * for all of the code used other than OpenSSL.  If you modify
 * file(s) with this exception, you may extend this exception to your
 * version of the file(s), but you are not obligated to do so.  If you
 * do not wish to do so, delete this exception statement from your
 * version.  If you delete this exception statement from all source
 * files in the program, then also delete it here.
 */
/* copyright --> */
#include "DefaultBtMessageFactory.h"
#include "DlAbortEx.h"
#include "bittorrent_helper.h"
#include "BtKeepAliveMessage.h"
#include "BtChokeMessage.h"
#include "BtUnchokeMessage.h"
#include "BtInterestedMessage.h"
#include "BtNotInterestedMessage.h"
#include "BtHaveMessage.h"
#include "BtBitfieldMessage.h"
#include "BtBitfieldMessageValidator.h"
#include "RangeBtMessageValidator.h"
#include "IndexBtMessageValidator.h"
#include "BtRequestMessage.h"
#include "BtCancelMessage.h"
#include "BtPieceMessage.h"
#include "BtPieceMessageValidator.h"
#include "BtPortMessage.h"
#include "BtHaveAllMessage.h"
#include "BtHaveNoneMessage.h"
#include "BtRejectMessage.h"
#include "BtSuggestPieceMessage.h"
#include "BtAllowedFastMessage.h"
#include "BtHandshakeMessage.h"
#include "BtHandshakeMessageValidator.h"
#include "BtExtendedMessage.h"
#include "ExtensionMessage.h"
#include "Peer.h"
#include "Piece.h"
#include "DownloadContext.h"
#include "PieceStorage.h"
#include "PeerStorage.h"
#include "fmt.h"
#include "ExtensionMessageFactory.h"
#include "bittorrent_helper.h"

namespace aria2 {

73
DefaultBtMessageFactory::DefaultBtMessageFactory()
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    : cuid_{0},
      downloadContext_{nullptr},
      pieceStorage_{nullptr},
      peerStorage_{nullptr},
      dhtEnabled_(false),
      dispatcher_{nullptr},
      requestFactory_{nullptr},
      peerConnection_{nullptr},
      extensionMessageFactory_{nullptr},
      localNode_{nullptr},
      routingTable_{nullptr},
      taskQueue_{nullptr},
      taskFactory_{nullptr},
      metadataGetMode_(false)
{
}
90

91
std::unique_ptr<BtMessage>
92 93
DefaultBtMessageFactory::createBtMessage(const unsigned char* data,
                                         size_t dataLength)
94
{
95
  auto msg = std::unique_ptr<AbstractBtMessage>{};
96
  if (dataLength == 0) {
97
    // keep-alive
98
    msg = make_unique<BtKeepAliveMessage>();
99 100
  }
  else {
101
    uint8_t id = bittorrent::getId(data);
102
    switch (id) {
103
    case BtChokeMessage::ID:
104
      msg = BtChokeMessage::create(data, dataLength);
105 106
      break;
    case BtUnchokeMessage::ID:
107
      msg = BtUnchokeMessage::create(data, dataLength);
108
      break;
109 110 111 112
    case BtInterestedMessage::ID: {
      auto m = BtInterestedMessage::create(data, dataLength);
      m->setPeerStorage(peerStorage_);
      msg = std::move(m);
113
      break;
114 115 116 117 118
    }
    case BtNotInterestedMessage::ID: {
      auto m = BtNotInterestedMessage::create(data, dataLength);
      m->setPeerStorage(peerStorage_);
      msg = std::move(m);
119
      break;
120
    }
121
    case BtHaveMessage::ID:
122
      msg = BtHaveMessage::create(data, dataLength);
123 124 125 126
      if (!metadataGetMode_) {
        msg->setBtMessageValidator(make_unique<IndexBtMessageValidator>(
            static_cast<BtHaveMessage*>(msg.get()),
            downloadContext_->getNumPieces()));
127 128 129
      }
      break;
    case BtBitfieldMessage::ID:
130
      msg = BtBitfieldMessage::create(data, dataLength);
131 132 133 134
      if (!metadataGetMode_) {
        msg->setBtMessageValidator(make_unique<BtBitfieldMessageValidator>(
            static_cast<BtBitfieldMessage*>(msg.get()),
            downloadContext_->getNumPieces()));
135 136 137
      }
      break;
    case BtRequestMessage::ID: {
138
      auto m = BtRequestMessage::create(data, dataLength);
139 140 141
      if (!metadataGetMode_) {
        m->setBtMessageValidator(make_unique<RangeBtMessageValidator>(
            static_cast<BtRequestMessage*>(m.get()),
142
            downloadContext_->getNumPieces(),
143
            pieceStorage_->getPieceLength(m->getIndex())));
144
      }
145
      msg = std::move(m);
146 147
      break;
    }
148 149
    case BtPieceMessage::ID: {
      auto m = BtPieceMessage::create(data, dataLength);
150 151 152
      if (!metadataGetMode_) {
        m->setBtMessageValidator(make_unique<BtPieceMessageValidator>(
            static_cast<BtPieceMessage*>(m.get()),
153
            downloadContext_->getNumPieces(),
154
            pieceStorage_->getPieceLength(m->getIndex())));
155
      }
156 157 158
      m->setDownloadContext(downloadContext_);
      m->setPeerStorage(peerStorage_);
      msg = std::move(m);
159 160
      break;
    }
161 162
    case BtCancelMessage::ID: {
      auto m = BtCancelMessage::create(data, dataLength);
163 164 165
      if (!metadataGetMode_) {
        m->setBtMessageValidator(make_unique<RangeBtMessageValidator>(
            static_cast<BtCancelMessage*>(m.get()),
166
            downloadContext_->getNumPieces(),
167
            pieceStorage_->getPieceLength(m->getIndex())));
168
      }
169 170 171 172 173 174 175 176 177 178 179 180 181 182
      msg = std::move(m);
      break;
    }
    case BtPortMessage::ID: {
      auto m = BtPortMessage::create(data, dataLength);
      m->setLocalNode(localNode_);
      m->setRoutingTable(routingTable_);
      m->setTaskQueue(taskQueue_);
      m->setTaskFactory(taskFactory_);
      msg = std::move(m);
      break;
    }
    case BtSuggestPieceMessage::ID: {
      auto m = BtSuggestPieceMessage::create(data, dataLength);
183 184 185 186
      if (!metadataGetMode_) {
        m->setBtMessageValidator(make_unique<IndexBtMessageValidator>(
            static_cast<BtSuggestPieceMessage*>(m.get()),
            downloadContext_->getNumPieces()));
187 188
      }
      msg = std::move(m);
189 190 191
      break;
    }
    case BtHaveAllMessage::ID:
192
      msg = BtHaveAllMessage::create(data, dataLength);
193 194
      break;
    case BtHaveNoneMessage::ID:
195
      msg = BtHaveNoneMessage::create(data, dataLength);
196 197
      break;
    case BtRejectMessage::ID: {
198
      auto m = BtRejectMessage::create(data, dataLength);
199 200 201
      if (!metadataGetMode_) {
        m->setBtMessageValidator(make_unique<RangeBtMessageValidator>(
            static_cast<BtRejectMessage*>(m.get()),
202
            downloadContext_->getNumPieces(),
203
            pieceStorage_->getPieceLength(m->getIndex())));
204
      }
205
      msg = std::move(m);
206 207 208
      break;
    }
    case BtAllowedFastMessage::ID: {
209
      auto m = BtAllowedFastMessage::create(data, dataLength);
210 211 212 213
      if (!metadataGetMode_) {
        m->setBtMessageValidator(make_unique<IndexBtMessageValidator>(
            static_cast<BtAllowedFastMessage*>(m.get()),
            downloadContext_->getNumPieces()));
214
      }
215
      msg = std::move(m);
216 217 218
      break;
    }
    case BtExtendedMessage::ID: {
219 220 221 222 223
      if (peer_->isExtendedMessagingEnabled()) {
        msg = BtExtendedMessage::create(extensionMessageFactory_, peer_, data,
                                        dataLength);
      }
      else {
224 225 226 227 228 229 230 231 232
        throw DL_ABORT_EX("Received extended message from peer during"
                          " a session with extended messaging disabled.");
      }
      break;
    }
    default:
      throw DL_ABORT_EX(fmt("Invalid message ID. id=%u", id));
    }
  }
233
  setCommonProperty(msg.get());
234
  return std::move(msg);
235 236
}

237 238
void DefaultBtMessageFactory::setCommonProperty(AbstractBtMessage* msg)
{
239 240 241 242 243 244 245
  msg->setCuid(cuid_);
  msg->setPeer(peer_);
  msg->setPieceStorage(pieceStorage_);
  msg->setBtMessageDispatcher(dispatcher_);
  msg->setBtRequestFactory(requestFactory_);
  msg->setBtMessageFactory(this);
  msg->setPeerConnection(peerConnection_);
246
  if (metadataGetMode_) {
247 248 249 250
    msg->enableMetadataGetMode();
  }
}

251
std::unique_ptr<BtHandshakeMessage>
252 253
DefaultBtMessageFactory::createHandshakeMessage(const unsigned char* data,
                                                size_t dataLength)
254
{
255
  auto msg = BtHandshakeMessage::create(data, dataLength);
256 257
  msg->setBtMessageValidator(make_unique<BtHandshakeMessageValidator>(
      msg.get(), bittorrent::getInfoHash(downloadContext_)));
258
  setCommonProperty(msg.get());
259 260 261
  return msg;
}

262
std::unique_ptr<BtHandshakeMessage>
263 264 265
DefaultBtMessageFactory::createHandshakeMessage(const unsigned char* infoHash,
                                                const unsigned char* peerId)
{
266
  auto msg = make_unique<BtHandshakeMessage>(infoHash, peerId);
267
  msg->setDHTEnabled(dhtEnabled_);
268
  setCommonProperty(msg.get());
269 270 271
  return msg;
}

272 273
std::unique_ptr<BtRequestMessage> DefaultBtMessageFactory::createRequestMessage(
    const std::shared_ptr<Piece>& piece, size_t blockIndex)
274
{
275 276 277
  auto msg = make_unique<BtRequestMessage>(
      piece->getIndex(), blockIndex * piece->getBlockLength(),
      piece->getBlockLength(blockIndex), blockIndex);
278 279
  setCommonProperty(msg.get());
  return msg;
280 281
}

282
std::unique_ptr<BtCancelMessage>
283 284
DefaultBtMessageFactory::createCancelMessage(size_t index, int32_t begin,
                                             int32_t length)
285
{
286 287 288
  auto msg = make_unique<BtCancelMessage>(index, begin, length);
  setCommonProperty(msg.get());
  return msg;
289 290
}

291
std::unique_ptr<BtPieceMessage>
292 293
DefaultBtMessageFactory::createPieceMessage(size_t index, int32_t begin,
                                            int32_t length)
294
{
295
  auto msg = make_unique<BtPieceMessage>(index, begin, length);
296
  msg->setDownloadContext(downloadContext_);
297 298
  setCommonProperty(msg.get());
  return msg;
299 300
}

301
std::unique_ptr<BtHaveMessage>
302 303
DefaultBtMessageFactory::createHaveMessage(size_t index)
{
304 305 306
  auto msg = make_unique<BtHaveMessage>(index);
  setCommonProperty(msg.get());
  return msg;
307 308
}

309
std::unique_ptr<BtChokeMessage> DefaultBtMessageFactory::createChokeMessage()
310
{
311 312 313
  auto msg = make_unique<BtChokeMessage>();
  setCommonProperty(msg.get());
  return msg;
314 315
}

316
std::unique_ptr<BtUnchokeMessage>
317 318
DefaultBtMessageFactory::createUnchokeMessage()
{
319 320 321
  auto msg = make_unique<BtUnchokeMessage>();
  setCommonProperty(msg.get());
  return msg;
322
}
323

324
std::unique_ptr<BtInterestedMessage>
325 326
DefaultBtMessageFactory::createInterestedMessage()
{
327 328 329
  auto msg = make_unique<BtInterestedMessage>();
  setCommonProperty(msg.get());
  return msg;
330 331
}

332
std::unique_ptr<BtNotInterestedMessage>
333 334
DefaultBtMessageFactory::createNotInterestedMessage()
{
335 336 337
  auto msg = make_unique<BtNotInterestedMessage>();
  setCommonProperty(msg.get());
  return msg;
338 339
}

340
std::unique_ptr<BtBitfieldMessage>
341 342
DefaultBtMessageFactory::createBitfieldMessage()
{
343 344
  auto msg = make_unique<BtBitfieldMessage>(pieceStorage_->getBitfield(),
                                            pieceStorage_->getBitfieldLength());
345 346
  setCommonProperty(msg.get());
  return msg;
347 348
}

349
std::unique_ptr<BtKeepAliveMessage>
350 351
DefaultBtMessageFactory::createKeepAliveMessage()
{
352 353 354
  auto msg = make_unique<BtKeepAliveMessage>();
  setCommonProperty(msg.get());
  return msg;
355
}
356

357
std::unique_ptr<BtHaveAllMessage>
358 359
DefaultBtMessageFactory::createHaveAllMessage()
{
360 361 362
  auto msg = make_unique<BtHaveAllMessage>();
  setCommonProperty(msg.get());
  return msg;
363 364
}

365
std::unique_ptr<BtHaveNoneMessage>
366 367
DefaultBtMessageFactory::createHaveNoneMessage()
{
368 369 370
  auto msg = make_unique<BtHaveNoneMessage>();
  setCommonProperty(msg.get());
  return msg;
371 372
}

373
std::unique_ptr<BtRejectMessage>
374 375
DefaultBtMessageFactory::createRejectMessage(size_t index, int32_t begin,
                                             int32_t length)
376
{
377 378 379
  auto msg = make_unique<BtRejectMessage>(index, begin, length);
  setCommonProperty(msg.get());
  return msg;
380 381
}

382
std::unique_ptr<BtAllowedFastMessage>
383 384
DefaultBtMessageFactory::createAllowedFastMessage(size_t index)
{
385 386 387
  auto msg = make_unique<BtAllowedFastMessage>(index);
  setCommonProperty(msg.get());
  return msg;
388 389
}

390
std::unique_ptr<BtPortMessage>
391 392
DefaultBtMessageFactory::createPortMessage(uint16_t port)
{
393 394 395
  auto msg = make_unique<BtPortMessage>(port);
  setCommonProperty(msg.get());
  return msg;
396 397
}

398
std::unique_ptr<BtExtendedMessage>
399 400
DefaultBtMessageFactory::createBtExtendedMessage(
    std::unique_ptr<ExtensionMessage> exmsg)
401
{
402 403 404
  auto msg = make_unique<BtExtendedMessage>(std::move(exmsg));
  setCommonProperty(msg.get());
  return msg;
405 406 407 408 409 410 411 412 413 414 415 416
}

void DefaultBtMessageFactory::setTaskQueue(DHTTaskQueue* taskQueue)
{
  taskQueue_ = taskQueue;
}

void DefaultBtMessageFactory::setTaskFactory(DHTTaskFactory* taskFactory)
{
  taskFactory_ = taskFactory;
}

417
void DefaultBtMessageFactory::setPeer(const std::shared_ptr<Peer>& peer)
418 419 420 421
{
  peer_ = peer;
}

422 423
void DefaultBtMessageFactory::setDownloadContext(
    DownloadContext* downloadContext)
424 425 426 427
{
  downloadContext_ = downloadContext;
}

428
void DefaultBtMessageFactory::setPieceStorage(PieceStorage* pieceStorage)
429 430 431 432
{
  pieceStorage_ = pieceStorage;
}

433
void DefaultBtMessageFactory::setPeerStorage(PeerStorage* peerStorage)
434 435 436 437
{
  peerStorage_ = peerStorage;
}

438 439
void DefaultBtMessageFactory::setBtMessageDispatcher(
    BtMessageDispatcher* dispatcher)
440 441 442 443
{
  dispatcher_ = dispatcher;
}

444 445
void DefaultBtMessageFactory::setExtensionMessageFactory(
    ExtensionMessageFactory* factory)
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
{
  extensionMessageFactory_ = factory;
}

void DefaultBtMessageFactory::setLocalNode(DHTNode* localNode)
{
  localNode_ = localNode;
}

void DefaultBtMessageFactory::setRoutingTable(DHTRoutingTable* routingTable)
{
  routingTable_ = routingTable;
}

void DefaultBtMessageFactory::setBtRequestFactory(BtRequestFactory* factory)
{
  requestFactory_ = factory;
}

void DefaultBtMessageFactory::setPeerConnection(PeerConnection* connection)
{
  peerConnection_ = connection;
}

} // namespace aria2