SelectEventPoll.cc 9.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
/* <!-- copyright */
/*
 * aria2 - The high speed download utility
 *
 * Copyright (C) 2009 Tatsuhiro Tsujikawa
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 *
 * In addition, as a special exception, the copyright holders give
 * permission to link the code of portions of this program with the
 * OpenSSL library under certain conditions as described in each
 * individual source file, and distribute linked combinations
 * including the two.
 * You must obey the GNU General Public License in all respects
 * for all of the code used other than OpenSSL.  If you modify
 * file(s) with this exception, you may extend this exception to your
 * version of the file(s), but you are not obligated to do so.  If you
 * do not wish to do so, delete this exception statement from your
 * version.  If you delete this exception statement from all source
 * files in the program, then also delete it here.
 */
/* copyright --> */
#include "SelectEventPoll.h"

#ifdef __MINGW32__
38
#include <cassert>
39 40 41 42 43 44 45 46 47 48 49 50 51 52
#endif // __MINGW32__
#include <cstring>
#include <algorithm>
#include <numeric>

#include "Command.h"
#include "LogFactory.h"
#include "Logger.h"
#include "a2functional.h"
#include "fmt.h"
#include "util.h"

namespace aria2 {

53 54 55 56
SelectEventPoll::CommandEvent::CommandEvent(Command* command, int events)
    : command_(command), events_(events)
{
}
57 58 59

void SelectEventPoll::CommandEvent::processEvents(int events)
{
60 61
  if ((events_ & events) ||
      ((EventPoll::EVENT_ERROR | EventPoll::EVENT_HUP) & events)) {
62 63
    command_->setStatusActive();
  }
64
  if (EventPoll::EVENT_READ & events) {
65 66
    command_->readEventReceived();
  }
67
  if (EventPoll::EVENT_WRITE & events) {
68 69
    command_->writeEventReceived();
  }
70
  if (EventPoll::EVENT_ERROR & events) {
71 72
    command_->errorEventReceived();
  }
73
  if (EventPoll::EVENT_HUP & events) {
74 75 76 77
    command_->hupEventReceived();
  }
}

78
SelectEventPoll::SocketEntry::SocketEntry(sock_t socket) : socket_(socket) {}
79

80
void SelectEventPoll::SocketEntry::addCommandEvent(Command* command, int events)
81 82
{
  CommandEvent cev(command, events);
83
  auto i = std::find(commandEvents_.begin(), commandEvents_.end(), cev);
84
  if (i == commandEvents_.end()) {
85
    commandEvents_.push_back(cev);
86 87
  }
  else {
88 89 90
    (*i).addEvents(events);
  }
}
91 92
void SelectEventPoll::SocketEntry::removeCommandEvent(Command* command,
                                                      int events)
93 94
{
  CommandEvent cev(command, events);
95
  auto i = std::find(commandEvents_.begin(), commandEvents_.end(), cev);
96
  if (i == commandEvents_.end()) {
97
    // not found
98 99
  }
  else {
100
    (*i).removeEvents(events);
101
    if ((*i).eventsEmpty()) {
102 103 104 105 106 107
      commandEvents_.erase(i);
    }
  }
}
void SelectEventPoll::SocketEntry::processEvents(int events)
{
108
  using namespace std::placeholders;
109
  std::for_each(commandEvents_.begin(), commandEvents_.end(),
110
                std::bind(&CommandEvent::processEvents, _1, events));
111 112 113 114
}

int accumulateEvent(int events, const SelectEventPoll::CommandEvent& event)
{
115
  return events | event.getEvents();
116 117 118 119
}

int SelectEventPoll::SocketEntry::getEvents()
{
120 121
  return std::accumulate(commandEvents_.begin(), commandEvents_.end(), 0,
                         accumulateEvent);
122 123 124 125
}

#ifdef ENABLE_ASYNC_DNS

126 127 128 129 130
SelectEventPoll::AsyncNameResolverEntry::AsyncNameResolverEntry(
    const std::shared_ptr<AsyncNameResolver>& nameResolver, Command* command)
    : nameResolver_(nameResolver), command_(command)
{
}
131

132 133
int SelectEventPoll::AsyncNameResolverEntry::getFds(fd_set* rfdsPtr,
                                                    fd_set* wfdsPtr)
134 135 136 137
{
  return nameResolver_->getFds(rfdsPtr, wfdsPtr);
}

138 139
void SelectEventPoll::AsyncNameResolverEntry::process(fd_set* rfdsPtr,
                                                      fd_set* wfdsPtr)
140 141
{
  nameResolver_->process(rfdsPtr, wfdsPtr);
142
  switch (nameResolver_->getStatus()) {
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
  case AsyncNameResolver::STATUS_SUCCESS:
  case AsyncNameResolver::STATUS_ERROR:
    command_->setStatusActive();
    break;
  default:
    break;
  }
}

#endif // ENABLE_ASYNC_DNS

SelectEventPoll::SelectEventPoll()
{
#ifdef __MINGW32__
  dummySocket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
  assert(dummySocket_ != (sock_t)-1);
#endif // __MINGW32__
  updateFdSet();
}

SelectEventPoll::~SelectEventPoll()
{
#ifdef __MINGW32__
  ::closesocket(dummySocket_);
#endif // __MINGW32__
}

void SelectEventPoll::poll(const struct timeval& tv)
{
  fd_set rfds;
  fd_set wfds;
174

175 176
  memcpy(&rfds, &rfdset_, sizeof(fd_set));
  memcpy(&wfds, &wfdset_, sizeof(fd_set));
177

178 179
#ifdef __MINGW32__
  fd_set efds;
180
  memcpy(&efds, &wfdset_, sizeof(fd_set));
181
#endif // __MINGW32__
182

183 184
#ifdef ENABLE_ASYNC_DNS

185 186 187
  for (auto& i : nameResolverEntries_) {
    auto& entry = i.second;
    int fd = entry.getFds(&rfds, &wfds);
188
    // TODO force error if fd == 0
189
    if (fdmax_ < fd) {
190 191 192 193 194 195 196 197 198
      fdmax_ = fd;
    }
  }

#endif // ENABLE_ASYNC_DNS
  int retval;
  do {
    struct timeval ttv = tv;
#ifdef __MINGW32__
199 200
    // winsock will report non-blocking connect() errors in efds,
    // unlike posix, which will mark such sockets as writable.
201
    retval = select(fdmax_ + 1, &rfds, &wfds, &efds, &ttv);
202
#else  // !__MINGW32__
203
    retval = select(fdmax_ + 1, &rfds, &wfds, nullptr, &ttv);
204
#endif // !__MINGW32__
205 206 207 208
  } while (retval == -1 && errno == EINTR);
  if (retval > 0) {
    for (auto& i : socketEntries_) {
      auto& e = i.second;
209
      int events = 0;
210
      if (FD_ISSET(e.getSocket(), &rfds)) {
211 212
        events |= EventPoll::EVENT_READ;
      }
213
      if (FD_ISSET(e.getSocket(), &wfds)) {
214 215
        events |= EventPoll::EVENT_WRITE;
      }
216 217 218 219 220
#ifdef __MINGW32__
      if (FD_ISSET(e.getSocket(), &efds)) {
        events |= EventPoll::EVENT_ERROR;
      }
#endif // __MINGW32__
221
      e.processEvents(events);
222
    }
223 224
  }
  else if (retval == -1) {
225
    int errNum = errno;
226 227
    A2_LOG_INFO(fmt("select error: %s, fdmax: %d",
                    util::safeStrerror(errNum).c_str(), fdmax_));
228 229 230
  }
#ifdef ENABLE_ASYNC_DNS

231 232
  for (auto& i : nameResolverEntries_) {
    i.second.process(&rfds, &wfds);
233 234 235 236 237 238 239 240 241
  }

#endif // ENABLE_ASYNC_DNS
}

#ifdef __MINGW32__
namespace {
void checkFdCountMingw(const fd_set& fdset)
{
242
  if (fdset.fd_count >= FD_SETSIZE) {
243 244 245 246 247 248 249 250 251
    A2_LOG_WARN("The number of file descriptor exceeded FD_SETSIZE. "
                "Download may slow down or fail.");
  }
}
} // namespace
#endif // __MINGW32__

void SelectEventPoll::updateFdSet()
{
252 253
  FD_ZERO(&rfdset_);
  FD_ZERO(&wfdset_);
254
#ifdef __MINGW32__
255 256
  FD_SET(dummySocket_, &rfdset_);
  FD_SET(dummySocket_, &wfdset_);
257
  fdmax_ = dummySocket_;
258
#else  // !__MINGW32__
259 260
  fdmax_ = 0;
#endif // !__MINGW32__
261

262 263 264
  for (auto& i : socketEntries_) {
    auto& e = i.second;
    sock_t fd = e.getSocket();
265
#ifndef __MINGW32__
266
    if (fd < 0 || FD_SETSIZE <= fd) {
267 268 269 270 271
      A2_LOG_WARN("Detected file descriptor >= FD_SETSIZE or < 0. "
                  "Download may slow down or fail.");
      continue;
    }
#endif // !__MINGW32__
272 273
    int events = e.getEvents();
    if (events & EventPoll::EVENT_READ) {
274 275 276 277 278
#ifdef __MINGW32__
      checkFdCountMingw(rfdset_);
#endif // __MINGW32__
      FD_SET(fd, &rfdset_);
    }
279
    if (events & EventPoll::EVENT_WRITE) {
280 281 282 283 284
#ifdef __MINGW32__
      checkFdCountMingw(wfdset_);
#endif // __MINGW32__
      FD_SET(fd, &wfdset_);
    }
285
    if (fdmax_ < fd) {
286 287 288 289 290 291 292 293
      fdmax_ = fd;
    }
  }
}

bool SelectEventPoll::addEvents(sock_t socket, Command* command,
                                EventPoll::EventType events)
{
294 295 296 297 298 299 300
  auto i = socketEntries_.lower_bound(socket);
  if (i != std::end(socketEntries_) && (*i).first == socket) {
    (*i).second.addCommandEvent(command, events);
  }
  else {
    i = socketEntries_.insert(i, std::make_pair(socket, SocketEntry(socket)));
    (*i).second.addCommandEvent(command, events);
301 302 303 304 305 306 307 308
  }
  updateFdSet();
  return true;
}

bool SelectEventPoll::deleteEvents(sock_t socket, Command* command,
                                   EventPoll::EventType events)
{
309 310
  auto i = socketEntries_.find(socket);
  if (i == std::end(socketEntries_)) {
311 312
    A2_LOG_DEBUG(fmt("Socket %d is not found in SocketEntries.", socket));
    return false;
313
  }
314 315 316 317 318 319 320 321

  auto& socketEntry = (*i).second;
  socketEntry.removeCommandEvent(command, events);
  if (socketEntry.eventEmpty()) {
    socketEntries_.erase(i);
  }
  updateFdSet();
  return true;
322 323 324
}

#ifdef ENABLE_ASYNC_DNS
325 326
bool SelectEventPoll::addNameResolver(
    const std::shared_ptr<AsyncNameResolver>& resolver, Command* command)
327
{
328 329 330
  auto key = std::make_pair(resolver.get(), command);
  auto itr = nameResolverEntries_.lower_bound(key);
  if (itr != std::end(nameResolverEntries_) && (*itr).first == key) {
331 332
    return false;
  }
333 334 335 336 337

  nameResolverEntries_.insert(
      itr, std::make_pair(key, AsyncNameResolverEntry(resolver, command)));

  return true;
338 339
}

340 341
bool SelectEventPoll::deleteNameResolver(
    const std::shared_ptr<AsyncNameResolver>& resolver, Command* command)
342
{
343 344
  auto key = std::make_pair(resolver.get(), command);
  return nameResolverEntries_.erase(key) == 1;
345 346 347 348
}
#endif // ENABLE_ASYNC_DNS

} // namespace aria2