testing/mochitest/ssltunnel/ssltunnel.cpp

changeset 0
6474c204b198
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/testing/mochitest/ssltunnel/ssltunnel.cpp	Wed Dec 31 06:09:35 2014 +0100
     1.3 @@ -0,0 +1,1394 @@
     1.4 +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
     1.5 +/* This Source Code Form is subject to the terms of the Mozilla Public
     1.6 + * License, v. 2.0. If a copy of the MPL was not distributed with this
     1.7 + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
     1.8 +
     1.9 +/*
    1.10 + * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS.  It is highly likely to
    1.11 + *          be plagued with the usual problems endemic to C (buffer overflows
    1.12 + *          and the like).  We don't especially care here (but would accept
    1.13 + *          patches!) because this is only intended for use in our test
    1.14 + *          harnesses in controlled situations where input is guaranteed not to
    1.15 + *          be malicious.
    1.16 + */
    1.17 +
    1.18 +#include "ScopedNSSTypes.h"
    1.19 +#include <assert.h>
    1.20 +#include <stdio.h>
    1.21 +#include <string>
    1.22 +#include <vector>
    1.23 +#include <algorithm>
    1.24 +#include <stdarg.h>
    1.25 +#include "prinit.h"
    1.26 +#include "prerror.h"
    1.27 +#include "prenv.h"
    1.28 +#include "prnetdb.h"
    1.29 +#include "prtpool.h"
    1.30 +#include "nsAlgorithm.h"
    1.31 +#include "nss.h"
    1.32 +#include "key.h"
    1.33 +#include "ssl.h"
    1.34 +#include "plhash.h"
    1.35 +
    1.36 +using namespace mozilla;
    1.37 +using namespace mozilla::psm;
    1.38 +using std::string;
    1.39 +using std::vector;
    1.40 +
    1.41 +#define IS_DELIM(m, c)          ((m)[(c) >> 3] & (1 << ((c) & 7)))
    1.42 +#define SET_DELIM(m, c)         ((m)[(c) >> 3] |= (1 << ((c) & 7)))
    1.43 +#define DELIM_TABLE_SIZE        32
    1.44 +
    1.45 +// You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n
    1.46 +// is 0 through 3.  The default is 1, INFO level logging.
    1.47 +enum LogLevel {
    1.48 +  LEVEL_DEBUG = 0,
    1.49 +  LEVEL_INFO = 1,
    1.50 +  LEVEL_ERROR = 2,
    1.51 +  LEVEL_SILENT = 3
    1.52 +} gLogLevel, gLastLogLevel;
    1.53 +
    1.54 +#define _LOG_OUTPUT(level, func, params) \
    1.55 +PR_BEGIN_MACRO \
    1.56 +  if (level >= gLogLevel) { \
    1.57 +    gLastLogLevel = level; \
    1.58 +    func params;\
    1.59 +  } \
    1.60 +PR_END_MACRO
    1.61 +
    1.62 +// The most verbose output
    1.63 +#define LOG_DEBUG(params) \
    1.64 +  _LOG_OUTPUT(LEVEL_DEBUG, printf, params)
    1.65 +
    1.66 +// Top level informative messages
    1.67 +#define LOG_INFO(params) \
    1.68 +  _LOG_OUTPUT(LEVEL_INFO, printf, params)
    1.69 +
    1.70 +// Serious errors that must be logged always until completely gag
    1.71 +#define LOG_ERROR(params) \
    1.72 +  _LOG_OUTPUT(LEVEL_ERROR, eprintf, params)
    1.73 +
    1.74 +// Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message 
    1.75 +// will be put to the stdout instead of stderr to keep continuity with other 
    1.76 +// LOG_DEBUG message output
    1.77 +#define LOG_ERRORD(params) \
    1.78 +PR_BEGIN_MACRO \
    1.79 +  if (gLogLevel == LEVEL_DEBUG) \
    1.80 +    _LOG_OUTPUT(LEVEL_ERROR, printf, params); \
    1.81 +  else \
    1.82 +    _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \
    1.83 +PR_END_MACRO
    1.84 +
    1.85 +// If there is any output written between LOG_BEGIN_BLOCK() and
    1.86 +// LOG_END_BLOCK() then a new line will be put to the proper output (out/err)
    1.87 +#define LOG_BEGIN_BLOCK() \
    1.88 +  gLastLogLevel = LEVEL_SILENT;
    1.89 +
    1.90 +#define LOG_END_BLOCK() \
    1.91 +PR_BEGIN_MACRO \
    1.92 +  if (gLastLogLevel == LEVEL_ERROR) \
    1.93 +    LOG_ERROR(("\n")); \
    1.94 +  if (gLastLogLevel < LEVEL_ERROR) \
    1.95 +    _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \
    1.96 +PR_END_MACRO
    1.97 +
    1.98 +int eprintf(const char* str, ...)
    1.99 +{
   1.100 +  va_list ap;
   1.101 +  va_start(ap, str);
   1.102 +  int result = vfprintf(stderr, str, ap);
   1.103 +  va_end(ap);
   1.104 +  return result;
   1.105 +}
   1.106 +
   1.107 +// Copied from nsCRT
   1.108 +char* strtok2(char* string, const char* delims, char* *newStr)
   1.109 +{
   1.110 +  PR_ASSERT(string);
   1.111 +  
   1.112 +  char delimTable[DELIM_TABLE_SIZE];
   1.113 +  uint32_t i;
   1.114 +  char* result;
   1.115 +  char* str = string;
   1.116 +  
   1.117 +  for (i = 0; i < DELIM_TABLE_SIZE; i++)
   1.118 +    delimTable[i] = '\0';
   1.119 +  
   1.120 +  for (i = 0; delims[i]; i++) {
   1.121 +    SET_DELIM(delimTable, static_cast<uint8_t>(delims[i]));
   1.122 +  }
   1.123 +  
   1.124 +  // skip to beginning
   1.125 +  while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
   1.126 +    str++;
   1.127 +  }
   1.128 +  result = str;
   1.129 +  
   1.130 +  // fix up the end of the token
   1.131 +  while (*str) {
   1.132 +    if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
   1.133 +      *str++ = '\0';
   1.134 +      break;
   1.135 +    }
   1.136 +    str++;
   1.137 +  }
   1.138 +  *newStr = str;
   1.139 +  
   1.140 +  return str == result ? nullptr : result;
   1.141 +}
   1.142 +
   1.143 +
   1.144 +
   1.145 +enum client_auth_option {
   1.146 +  caNone = 0,
   1.147 +  caRequire = 1,
   1.148 +  caRequest = 2
   1.149 +};
   1.150 +
   1.151 +// Structs for passing data into jobs on the thread pool
   1.152 +typedef struct {
   1.153 +  int32_t listen_port;
   1.154 +  string cert_nickname;
   1.155 +  PLHashTable* host_cert_table;
   1.156 +  PLHashTable* host_clientauth_table;
   1.157 +  PLHashTable* host_redir_table;
   1.158 +} server_info_t;
   1.159 +
   1.160 +typedef struct {
   1.161 +  PRFileDesc* client_sock;
   1.162 +  PRNetAddr client_addr;
   1.163 +  server_info_t* server_info;
   1.164 +  // the original host in the Host: header for this connection is
   1.165 +  // stored here, for proxied connections
   1.166 +  string original_host;
   1.167 +  // true if no SSL should be used for this connection
   1.168 +  bool http_proxy_only;
   1.169 +  // true if this connection is for a WebSocket
   1.170 +  bool iswebsocket;
   1.171 +} connection_info_t;
   1.172 +
   1.173 +typedef struct {
   1.174 +  string fullHost;
   1.175 +  bool matched;
   1.176 +} server_match_t;
   1.177 +
   1.178 +const int32_t BUF_SIZE = 16384;
   1.179 +const int32_t BUF_MARGIN = 1024;
   1.180 +const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN;
   1.181 +
   1.182 +struct relayBuffer
   1.183 +{
   1.184 +  char *buffer, *bufferhead, *buffertail, *bufferend;
   1.185 +
   1.186 +  relayBuffer()
   1.187 +  {
   1.188 +    // Leave 1024 bytes more for request line manipulations
   1.189 +    bufferhead = buffertail = buffer = new char[BUF_TOTAL];
   1.190 +    bufferend = buffer + BUF_SIZE;
   1.191 +  }
   1.192 +
   1.193 +  ~relayBuffer()
   1.194 +  {
   1.195 +    delete [] buffer;
   1.196 +  }
   1.197 +
   1.198 +  void compact() {
   1.199 +    if (buffertail == bufferhead)
   1.200 +      buffertail = bufferhead = buffer;
   1.201 +  }
   1.202 +
   1.203 +  bool empty() { return bufferhead == buffertail; }
   1.204 +  size_t areafree() { return bufferend - buffertail; }
   1.205 +  size_t margin() { return areafree() + BUF_MARGIN; }
   1.206 +  size_t present() { return buffertail - bufferhead; }
   1.207 +};
   1.208 +
   1.209 +// These numbers are multiplied by the number of listening ports (actual
   1.210 +// servers running).  According the thread pool implementation there is no
   1.211 +// need to limit the number of threads initially, threads are allocated
   1.212 +// dynamically and stored in a linked list.  Initial number of 2 is chosen
   1.213 +// to allocate a thread for socket accept and preallocate one for the first
   1.214 +// connection that is with high probability expected to come.
   1.215 +const uint32_t INITIAL_THREADS = 2;
   1.216 +const uint32_t MAX_THREADS = 100;
   1.217 +const uint32_t DEFAULT_STACKSIZE = (512 * 1024);
   1.218 +
   1.219 +// global data
   1.220 +string nssconfigdir;
   1.221 +vector<server_info_t> servers;
   1.222 +PRNetAddr remote_addr;
   1.223 +PRNetAddr websocket_server;
   1.224 +PRThreadPool* threads = nullptr;
   1.225 +PRLock* shutdown_lock = nullptr;
   1.226 +PRCondVar* shutdown_condvar = nullptr;
   1.227 +// Not really used, unless something fails to start
   1.228 +bool shutdown_server = false;
   1.229 +bool do_http_proxy = false;
   1.230 +bool any_host_spec_config = false;
   1.231 +
   1.232 +int ClientAuthValueComparator(const void *v1, const void *v2)
   1.233 +{
   1.234 +  int a = *static_cast<const client_auth_option*>(v1) -
   1.235 +          *static_cast<const client_auth_option*>(v2);
   1.236 +  if (a == 0)
   1.237 +    return 0;
   1.238 +  if (a > 0)
   1.239 +    return 1;
   1.240 +  else // (a < 0)
   1.241 +    return -1;
   1.242 +}
   1.243 +
   1.244 +static int match_hostname(PLHashEntry *he, int index, void* arg)
   1.245 +{
   1.246 +  server_match_t *match = (server_match_t*)arg;
   1.247 +  if (match->fullHost.find((char*)he->key) != string::npos)
   1.248 +    match->matched = true;
   1.249 +  return HT_ENUMERATE_NEXT;
   1.250 +}
   1.251 +
   1.252 +/*
   1.253 + * Signal the main thread that the application should shut down.
   1.254 + */
   1.255 +void SignalShutdown()
   1.256 +{
   1.257 +  PR_Lock(shutdown_lock);
   1.258 +  PR_NotifyCondVar(shutdown_condvar);
   1.259 +  PR_Unlock(shutdown_lock);
   1.260 +}
   1.261 +
   1.262 +bool ReadConnectRequest(server_info_t* server_info, 
   1.263 +    relayBuffer& buffer, int32_t* result, string& certificate,
   1.264 +    client_auth_option* clientauth, string& host, string& location)
   1.265 +{
   1.266 +  if (buffer.present() < 4) {
   1.267 +    LOG_DEBUG((" !! only %d bytes present in the buffer", (int)buffer.present()));
   1.268 +    return false;
   1.269 +  }
   1.270 +  if (strncmp(buffer.buffertail-4, "\r\n\r\n", 4)) {
   1.271 +    LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x", 
   1.272 +               *(buffer.buffertail-4),
   1.273 +               *(buffer.buffertail-3),
   1.274 +               *(buffer.buffertail-2),
   1.275 +               *(buffer.buffertail-1)));
   1.276 +    return false;
   1.277 +  }
   1.278 +
   1.279 +  LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n", (int)buffer.present(), buffer.bufferhead));
   1.280 +
   1.281 +  *result = 400;
   1.282 +
   1.283 +  char* token;
   1.284 +  char* _caret;
   1.285 +  token = strtok2(buffer.bufferhead, " ", &_caret);
   1.286 +  if (!token) {
   1.287 +    LOG_ERRORD((" no space found"));
   1.288 +    return true;
   1.289 +  }
   1.290 +  if (strcmp(token, "CONNECT")) {
   1.291 +    LOG_ERRORD((" not CONNECT request but %s", token));
   1.292 +    return true;
   1.293 +  }
   1.294 +
   1.295 +  token = strtok2(_caret, " ", &_caret);
   1.296 +  void* c = PL_HashTableLookup(server_info->host_cert_table, token);
   1.297 +  if (c)
   1.298 +    certificate = static_cast<char*>(c);
   1.299 +
   1.300 +  host = "https://";
   1.301 +  host += token;
   1.302 +
   1.303 +  c = PL_HashTableLookup(server_info->host_clientauth_table, token);
   1.304 +  if (c)
   1.305 +    *clientauth = *static_cast<client_auth_option*>(c);
   1.306 +  else
   1.307 +    *clientauth = caNone;
   1.308 +
   1.309 +  void *redir = PL_HashTableLookup(server_info->host_redir_table, token);
   1.310 +  if (redir)
   1.311 +    location = static_cast<char*>(redir);
   1.312 +
   1.313 +  token = strtok2(_caret, "/", &_caret);
   1.314 +  if (strcmp(token, "HTTP")) {  
   1.315 +    LOG_ERRORD((" not tailed with HTTP but with %s", token));
   1.316 +    return true;
   1.317 +  }
   1.318 +
   1.319 +  *result = (redir) ? 302 : 200;
   1.320 +  return true;
   1.321 +}
   1.322 +
   1.323 +bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, string &certificate, client_auth_option clientAuth)
   1.324 +{
   1.325 +  const char* certnick = certificate.empty() ?
   1.326 +      si->cert_nickname.c_str() : certificate.c_str();
   1.327 +
   1.328 +  ScopedCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr));
   1.329 +  if (!cert) {
   1.330 +    LOG_ERROR(("Failed to find cert %s\n", certnick));
   1.331 +    return false;
   1.332 +  }
   1.333 +
   1.334 +  ScopedSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert, nullptr));
   1.335 +  if (!privKey) {
   1.336 +    LOG_ERROR(("Failed to find private key\n"));
   1.337 +    return false;
   1.338 +  }
   1.339 +
   1.340 +  PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket);
   1.341 +  if (!ssl_socket) {
   1.342 +    LOG_ERROR(("Error importing SSL socket\n"));
   1.343 +    return false;
   1.344 +  }
   1.345 +
   1.346 +  SSLKEAType certKEA = NSS_FindCertKEAType(cert);
   1.347 +  if (SSL_ConfigSecureServer(ssl_socket, cert, privKey, certKEA)
   1.348 +      != SECSuccess) {
   1.349 +    LOG_ERROR(("Error configuring SSL server socket\n"));
   1.350 +    return false;
   1.351 +  }
   1.352 +
   1.353 +  SSL_OptionSet(ssl_socket, SSL_SECURITY, true);
   1.354 +  SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false);
   1.355 +  SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true);
   1.356 +
   1.357 +  if (clientAuth != caNone)
   1.358 +  {
   1.359 +    SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true);
   1.360 +    SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire);
   1.361 +  }
   1.362 +
   1.363 +  SSL_ResetHandshake(ssl_socket, true);
   1.364 +
   1.365 +  return true;
   1.366 +}
   1.367 +
   1.368 +/**
   1.369 + * This function examines the buffer for a Sec-WebSocket-Location: field, 
   1.370 + * and if it's present, it replaces the hostname in that field with the
   1.371 + * value in the server's original_host field.  This function works
   1.372 + * in the reverse direction as AdjustWebSocketHost(), replacing the real
   1.373 + * hostname of a response with the potentially fake hostname that is expected
   1.374 + * by the browser (e.g., mochi.test).
   1.375 + *
   1.376 + * @return true if the header was adjusted successfully, or not found, false
   1.377 + * if the header is present but the url is not, which should indicate
   1.378 + * that more data needs to be read from the socket
   1.379 + */
   1.380 +bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t *ci)
   1.381 +{
   1.382 +  assert(buffer.margin());
   1.383 +  buffer.buffertail[1] = '\0';
   1.384 +
   1.385 +  char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:");
   1.386 +  if (!wsloc)
   1.387 +    return true;
   1.388 +  // advance pointer to the start of the hostname
   1.389 +  wsloc = strstr(wsloc, "ws://");
   1.390 +  if (!wsloc)
   1.391 +    return false;
   1.392 +  wsloc += 5;
   1.393 +  // find the end of the hostname
   1.394 +  char* wslocend = strchr(wsloc + 1, '/');
   1.395 +  if (!wslocend)
   1.396 +    return false;
   1.397 +  char *crlf = strstr(wsloc, "\r\n");
   1.398 +  if (!crlf)
   1.399 +    return false;
   1.400 +  if (ci->original_host.empty())
   1.401 +    return true;
   1.402 +
   1.403 +  int diff = ci->original_host.length() - (wslocend-wsloc);
   1.404 +  if (diff > 0)
   1.405 +    assert(size_t(diff) <= buffer.margin());
   1.406 +  memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff);
   1.407 +  buffer.buffertail += diff;
   1.408 +
   1.409 +  memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length());
   1.410 +  return true;
   1.411 +}
   1.412 +
   1.413 +/**
   1.414 + * This function examines the buffer for a Host: field, and if it's present,
   1.415 + * it replaces the hostname in that field with the hostname in the server's
   1.416 + * remote_addr field.  This is needed because proxy requests may be coming
   1.417 + * from mochitest with fake hosts, like mochi.test, and these need to be
   1.418 + * replaced with the host that the destination server is actually running
   1.419 + * on.
   1.420 + */
   1.421 +bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t *ci)
   1.422 +{
   1.423 +  const char HEADER_UPGRADE[] = "Upgrade:";
   1.424 +  const char HEADER_HOST[] = "Host:";
   1.425 +
   1.426 +  PRNetAddr inet_addr = (websocket_server.inet.port ? websocket_server :
   1.427 +    remote_addr);
   1.428 +
   1.429 +  assert(buffer.margin());
   1.430 +
   1.431 +  // Cannot use strnchr so add a null char at the end. There is always some
   1.432 +  // space left because we preserve a margin.
   1.433 +  buffer.buffertail[1] = '\0';
   1.434 +
   1.435 +  // Verify this is a WebSocket header.
   1.436 +  char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE);
   1.437 +  if (!h1)
   1.438 +    return false;
   1.439 +  h1 += strlen(HEADER_UPGRADE);
   1.440 +  h1 += strspn(h1, " \t");
   1.441 +  char* h2 = strstr(h1, "WebSocket\r\n");
   1.442 +  if (!h2) h2 = strstr(h1, "websocket\r\n");
   1.443 +  if (!h2) h2 = strstr(h1, "Websocket\r\n");
   1.444 +  if (!h2)
   1.445 +    return false;
   1.446 +
   1.447 +  char* host = strstr(buffer.bufferhead, HEADER_HOST);
   1.448 +  if (!host)
   1.449 +    return false;
   1.450 +  // advance pointer to beginning of hostname
   1.451 +  host += strlen(HEADER_HOST);
   1.452 +  host += strspn(host, " \t");
   1.453 +
   1.454 +  char* endhost = strstr(host, "\r\n");
   1.455 +  if (!endhost)
   1.456 +    return false;
   1.457 +
   1.458 +  // Save the original host, so we can use it later on responses from the
   1.459 +  // server.
   1.460 +  ci->original_host.assign(host, endhost-host);
   1.461 +
   1.462 +  char newhost[40];
   1.463 +  PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost));
   1.464 +  assert(strlen(newhost) < sizeof(newhost) - 7);
   1.465 +  sprintf(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port));
   1.466 +
   1.467 +  int diff = strlen(newhost) - (endhost-host);
   1.468 +  if (diff > 0)
   1.469 +    assert(size_t(diff) <= buffer.margin());
   1.470 +  memmove(endhost + diff, endhost, buffer.buffertail - host - diff);
   1.471 +  buffer.buffertail += diff;
   1.472 +
   1.473 +  memcpy(host, newhost, strlen(newhost));
   1.474 +  return true;
   1.475 +}
   1.476 +
   1.477 +/**
   1.478 + * This function prefixes Request-URI path with a full scheme-host-port
   1.479 + * string.
   1.480 + */
   1.481 +bool AdjustRequestURI(relayBuffer& buffer, string *host)
   1.482 +{
   1.483 +  assert(buffer.margin());
   1.484 +
   1.485 +  // Cannot use strnchr so add a null char at the end. There is always some space left
   1.486 +  // because we preserve a margin.
   1.487 +  buffer.buffertail[1] = '\0';
   1.488 +  LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead));
   1.489 +
   1.490 +  char *token, *path;
   1.491 +  path = strchr(buffer.bufferhead, ' ') + 1;
   1.492 +  if (!path)
   1.493 +    return false;
   1.494 +
   1.495 +  // If the path doesn't start with a slash don't change it, it is probably '*' or a full
   1.496 +  // path already. Return true, we are done with this request adjustment.
   1.497 +  if (*path != '/')
   1.498 +    return true;
   1.499 +
   1.500 +  token = strchr(path, ' ') + 1;
   1.501 +  if (!token)
   1.502 +    return false;
   1.503 +
   1.504 +  if (strncmp(token, "HTTP/", 5))
   1.505 +    return false;
   1.506 +
   1.507 +  size_t hostlength = host->length();
   1.508 +  assert(hostlength <= buffer.margin());
   1.509 +
   1.510 +  memmove(path + hostlength, path, buffer.buffertail - path);
   1.511 +  memcpy(path, host->c_str(), hostlength);
   1.512 +  buffer.buffertail += hostlength;
   1.513 +
   1.514 +  return true;
   1.515 +}
   1.516 +
   1.517 +bool ConnectSocket(PRFileDesc *fd, const PRNetAddr *addr, PRIntervalTime timeout)
   1.518 +{
   1.519 +  PRStatus stat = PR_Connect(fd, addr, timeout);
   1.520 +  if (stat != PR_SUCCESS)
   1.521 +    return false;
   1.522 +
   1.523 +  PRSocketOptionData option;
   1.524 +  option.option = PR_SockOpt_Nonblocking;
   1.525 +  option.value.non_blocking = true;
   1.526 +  PR_SetSocketOption(fd, &option);
   1.527 +
   1.528 +  return true;
   1.529 +}
   1.530 +
   1.531 +/*
   1.532 + * Handle an incoming client connection. The server thread has already
   1.533 + * accepted the connection, so we just need to connect to the remote
   1.534 + * port and then proxy data back and forth.
   1.535 + * The data parameter is a connection_info_t*, and must be deleted
   1.536 + * by this function.
   1.537 + */
   1.538 +void HandleConnection(void* data)
   1.539 +{
   1.540 +  connection_info_t* ci = static_cast<connection_info_t*>(data);
   1.541 +  PRIntervalTime connect_timeout = PR_SecondsToInterval(30);
   1.542 +
   1.543 +  ScopedPRFileDesc other_sock(PR_NewTCPSocket());
   1.544 +  bool client_done = false;
   1.545 +  bool client_error = false;
   1.546 +  bool connect_accepted = !do_http_proxy;
   1.547 +  bool ssl_updated = !do_http_proxy;
   1.548 +  bool expect_request_start = do_http_proxy;
   1.549 +  string certificateToUse;
   1.550 +  string locationHeader;
   1.551 +  client_auth_option clientAuth;
   1.552 +  string fullHost;
   1.553 +
   1.554 +  LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n",
   1.555 +         static_cast<void*>(data),
   1.556 +         static_cast<void*>(ci->client_sock),
   1.557 +         static_cast<void*>(other_sock)));
   1.558 +  if (other_sock) 
   1.559 +  {
   1.560 +    int32_t numberOfSockets = 1;
   1.561 +
   1.562 +    relayBuffer buffers[2];
   1.563 +
   1.564 +    if (!do_http_proxy)
   1.565 +    {
   1.566 +      if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, caNone))
   1.567 +        client_error = true;
   1.568 +      else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
   1.569 +        client_error = true;
   1.570 +      else
   1.571 +        numberOfSockets = 2;
   1.572 +    }
   1.573 +
   1.574 +    PRPollDesc sockets[2] = 
   1.575 +    { 
   1.576 +      {ci->client_sock, PR_POLL_READ, 0},
   1.577 +      {other_sock, PR_POLL_READ, 0}
   1.578 +    };
   1.579 +    bool socketErrorState[2] = {false, false};
   1.580 +
   1.581 +    while (!((client_error||client_done) && buffers[0].empty() && buffers[1].empty()))
   1.582 +    {
   1.583 +      sockets[0].in_flags |= PR_POLL_EXCEPT;
   1.584 +      sockets[1].in_flags |= PR_POLL_EXCEPT;
   1.585 +      LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n",
   1.586 +                 static_cast<void*>(data),
   1.587 +                 sockets[0].in_flags & PR_POLL_READ  ? 'R' : '-',
   1.588 +                 sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-',
   1.589 +                 sockets[1].in_flags & PR_POLL_READ  ? 'R' : '-',
   1.590 +                 sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-'));
   1.591 +      int32_t pollStatus = PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000));
   1.592 +      if (pollStatus < 0)
   1.593 +      {
   1.594 +        LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n",
   1.595 +                   static_cast<void*>(data), pollStatus));
   1.596 +        client_error = true;
   1.597 +        break;
   1.598 +      }
   1.599 +
   1.600 +      if (pollStatus == 0)
   1.601 +      {
   1.602 +        // timeout
   1.603 +        LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n",
   1.604 +                   static_cast<void*>(data)));
   1.605 +        continue;
   1.606 +      }
   1.607 +
   1.608 +      for (int32_t s = 0; s < numberOfSockets; ++s)
   1.609 +      {
   1.610 +        int32_t s2 = s == 1 ? 0 : 1;
   1.611 +        int16_t out_flags = sockets[s].out_flags;
   1.612 +        int16_t &in_flags = sockets[s].in_flags;
   1.613 +        int16_t &in_flags2 = sockets[s2].in_flags;
   1.614 +        sockets[s].out_flags = 0;
   1.615 +
   1.616 +        LOG_BEGIN_BLOCK();
   1.617 +        LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d",
   1.618 +                   static_cast<void*>(data),
   1.619 +                   s == 0 ? 'c' : 's',
   1.620 +                   s,
   1.621 +                   static_cast<void*>(sockets[s].fd),
   1.622 +                   out_flags));
   1.623 +        if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP))
   1.624 +        {
   1.625 +          LOG_DEBUG((" :exception\n"));
   1.626 +          client_error = true;
   1.627 +          socketErrorState[s] = true;
   1.628 +          // We got a fatal error state on the socket. Clear the output buffer
   1.629 +          // for this socket to break the main loop, we will never more be able
   1.630 +          // to send those data anyway.
   1.631 +          buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
   1.632 +          continue;
   1.633 +        } // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling
   1.634 +
   1.635 +        if (out_flags & PR_POLL_READ && !buffers[s].areafree())
   1.636 +        {
   1.637 +           LOG_DEBUG((" no place in read buffer but got read flag, dropping it now!"));
   1.638 +           in_flags &= ~PR_POLL_READ;
   1.639 +        }
   1.640 +
   1.641 +        if (out_flags & PR_POLL_READ && buffers[s].areafree())
   1.642 +        {
   1.643 +          LOG_DEBUG((" :reading"));
   1.644 +          int32_t bytesRead = PR_Recv(sockets[s].fd, buffers[s].buffertail, 
   1.645 +              buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT);
   1.646 +
   1.647 +          if (bytesRead == 0)
   1.648 +          {
   1.649 +            LOG_DEBUG((" socket gracefully closed"));
   1.650 +            client_done = true;
   1.651 +            in_flags &= ~PR_POLL_READ;
   1.652 +          }
   1.653 +          else if (bytesRead < 0)
   1.654 +          {
   1.655 +            if (PR_GetError() != PR_WOULD_BLOCK_ERROR)
   1.656 +            {
   1.657 +              LOG_DEBUG((" error=%d", PR_GetError()));
   1.658 +              // We are in error state, indicate that the connection was 
   1.659 +              // not closed gracefully
   1.660 +              client_error = true;
   1.661 +              socketErrorState[s] = true;
   1.662 +              // Wipe out our send buffer, we cannot send it anyway.
   1.663 +              buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
   1.664 +            }
   1.665 +            else
   1.666 +              LOG_DEBUG((" would block"));
   1.667 +          }
   1.668 +          else
   1.669 +          {
   1.670 +            // If the other socket is in error state (unable to send/receive)
   1.671 +            // throw this data away and continue loop
   1.672 +            if (socketErrorState[s2])
   1.673 +            {
   1.674 +              LOG_DEBUG((" have read but other socket is in error state\n"));
   1.675 +              continue;
   1.676 +            }
   1.677 +
   1.678 +            buffers[s].buffertail += bytesRead;
   1.679 +            LOG_DEBUG((", read %d bytes", bytesRead));
   1.680 +
   1.681 +            // We have to accept and handle the initial CONNECT request here
   1.682 +            int32_t response;
   1.683 +            if (!connect_accepted && ReadConnectRequest(ci->server_info, buffers[s],
   1.684 +                &response, certificateToUse, &clientAuth, fullHost, locationHeader))
   1.685 +            {
   1.686 +              // Mark this as a proxy-only connection (no SSL) if the CONNECT
   1.687 +              // request didn't come for port 443 or from any of the server's
   1.688 +              // cert or clientauth hostnames.
   1.689 +              if (fullHost.find(":443") == string::npos)
   1.690 +              {
   1.691 +                server_match_t match;
   1.692 +                match.fullHost = fullHost;
   1.693 +                match.matched = false;
   1.694 +                PL_HashTableEnumerateEntries(ci->server_info->host_cert_table, 
   1.695 +                                             match_hostname, 
   1.696 +                                             &match);
   1.697 +                PL_HashTableEnumerateEntries(ci->server_info->host_clientauth_table, 
   1.698 +                                             match_hostname, 
   1.699 +                                             &match);
   1.700 +                ci->http_proxy_only = !match.matched;
   1.701 +              }
   1.702 +              else
   1.703 +              {
   1.704 +                ci->http_proxy_only = false;
   1.705 +              }
   1.706 +
   1.707 +              // Clean the request as it would be read
   1.708 +              buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer;
   1.709 +              in_flags |= PR_POLL_WRITE;
   1.710 +              connect_accepted = true;
   1.711 +
   1.712 +              // Store response to the oposite buffer
   1.713 +              if (response == 200)
   1.714 +              {
   1.715 +                  LOG_DEBUG((" accepted CONNECT request, connected to the server, sending OK to the client\n"));
   1.716 +                  strcpy(buffers[s2].buffer, "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n");
   1.717 +              }
   1.718 +              else if (response == 302)
   1.719 +              {
   1.720 +                  LOG_DEBUG((" accepted CONNECT request with redirection, "
   1.721 +                             "sending location and 302 to the client\n"));
   1.722 +                  client_done = true;
   1.723 +                  sprintf(buffers[s2].buffer, 
   1.724 +                          "HTTP/1.1 302 Moved\r\n"
   1.725 +                          "Location: https://%s/\r\n"
   1.726 +                          "Connection: close\r\n\r\n",
   1.727 +                          locationHeader.c_str());
   1.728 +              }
   1.729 +              else
   1.730 +              {
   1.731 +                LOG_ERRORD((" could not read the connect request, closing connection with %d", response));
   1.732 +                client_done = true;
   1.733 +                sprintf(buffers[s2].buffer, "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", response);
   1.734 +
   1.735 +                break;
   1.736 +              }
   1.737 +
   1.738 +              buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer);
   1.739 +
   1.740 +              // Send the response to the client socket
   1.741 +              break;
   1.742 +            } // end of CONNECT handling
   1.743 +
   1.744 +            if (!buffers[s].areafree())
   1.745 +            {
   1.746 +              // Do not poll for read when the buffer is full
   1.747 +              LOG_DEBUG((" no place in our read buffer, stop reading"));
   1.748 +              in_flags &= ~PR_POLL_READ;
   1.749 +            }
   1.750 +
   1.751 +            if (ssl_updated)
   1.752 +            {
   1.753 +              if (s == 0 && expect_request_start) 
   1.754 +              {
   1.755 +                if (!strstr(buffers[s].bufferhead, "\r\n\r\n"))
   1.756 +                {
   1.757 +                  // We haven't received the complete header yet, so wait.
   1.758 +                  continue;
   1.759 +                }
   1.760 +                else
   1.761 +                {
   1.762 +                  ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci);
   1.763 +                  expect_request_start = !(ci->iswebsocket || 
   1.764 +                                           AdjustRequestURI(buffers[s], &fullHost));
   1.765 +                  PRNetAddr* addr = &remote_addr;
   1.766 +                  if (ci->iswebsocket && websocket_server.inet.port)
   1.767 +                    addr = &websocket_server;
   1.768 +                  if (!ConnectSocket(other_sock, addr, connect_timeout))
   1.769 +                  {
   1.770 +                    LOG_ERRORD((" could not open connection to the real server\n"));
   1.771 +                    client_error = true;
   1.772 +                    break;
   1.773 +                  }
   1.774 +                  LOG_DEBUG(("\n connected to remote server\n"));
   1.775 +                  numberOfSockets = 2;
   1.776 +                }
   1.777 +              }
   1.778 +              else if (s == 1 && ci->iswebsocket)
   1.779 +              {
   1.780 +                if (!AdjustWebSocketLocation(buffers[s], ci))
   1.781 +                  continue;
   1.782 +              }
   1.783 +
   1.784 +              in_flags2 |= PR_POLL_WRITE;
   1.785 +              LOG_DEBUG((" telling the other socket to write"));
   1.786 +            }
   1.787 +            else
   1.788 +              LOG_DEBUG((" we have something for the other socket to write, but ssl has not been administered on it"));
   1.789 +          }
   1.790 +        } // PR_POLL_READ handling
   1.791 +
   1.792 +        if (out_flags & PR_POLL_WRITE)
   1.793 +        {
   1.794 +          LOG_DEBUG((" :writing"));
   1.795 +          int32_t bytesWrite = PR_Send(sockets[s].fd, buffers[s2].bufferhead, 
   1.796 +              buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT);
   1.797 +
   1.798 +          if (bytesWrite < 0)
   1.799 +          {
   1.800 +            if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
   1.801 +              LOG_DEBUG((" error=%d", PR_GetError()));
   1.802 +              client_error = true;
   1.803 +              socketErrorState[s] = true;
   1.804 +              // We got a fatal error while writting the buffer. Clear it to break
   1.805 +              // the main loop, we will never more be able to send it.
   1.806 +              buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
   1.807 +            }
   1.808 +            else
   1.809 +              LOG_DEBUG((" would block"));
   1.810 +          }
   1.811 +          else
   1.812 +          {
   1.813 +            LOG_DEBUG((", written %d bytes", bytesWrite));
   1.814 +            buffers[s2].buffertail[1] = '\0';
   1.815 +            LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead));
   1.816 +            
   1.817 +            buffers[s2].bufferhead += bytesWrite;
   1.818 +            if (buffers[s2].present())
   1.819 +            {
   1.820 +              LOG_DEBUG((" still have to write %d bytes", (int)buffers[s2].present()));
   1.821 +              in_flags |= PR_POLL_WRITE;
   1.822 +            }              
   1.823 +            else
   1.824 +            {
   1.825 +              if (!ssl_updated)
   1.826 +              {
   1.827 +                LOG_DEBUG((" proxy response sent to the client"));
   1.828 +                // Proxy response has just been writen, update to ssl
   1.829 +                ssl_updated = true;
   1.830 +                if (ci->http_proxy_only)
   1.831 +                {
   1.832 +                  LOG_DEBUG((" not updating to SSL based on http_proxy_only for this socket"));
   1.833 +                }
   1.834 +                else if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, 
   1.835 +                                                   certificateToUse, clientAuth))
   1.836 +                {
   1.837 +                  LOG_ERRORD((" failed to config server socket\n"));
   1.838 +                  client_error = true;
   1.839 +                  break;
   1.840 +                }
   1.841 +                else
   1.842 +                {
   1.843 +                  LOG_DEBUG((" client socket updated to SSL"));
   1.844 +                }
   1.845 +              } // sslUpdate
   1.846 +
   1.847 +              LOG_DEBUG((" dropping our write flag and setting other socket read flag"));
   1.848 +              in_flags &= ~PR_POLL_WRITE;
   1.849 +              in_flags2 |= PR_POLL_READ;
   1.850 +              buffers[s2].compact();
   1.851 +            }
   1.852 +          }
   1.853 +        } // PR_POLL_WRITE handling
   1.854 +        LOG_END_BLOCK(); // end the log
   1.855 +      } // for...
   1.856 +    } // while, poll
   1.857 +  }
   1.858 +  else
   1.859 +    client_error = true;
   1.860 +
   1.861 +  LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n",
   1.862 +             static_cast<void*>(data),
   1.863 +             static_cast<void*>(ci->client_sock),
   1.864 +             static_cast<void*>(other_sock)));
   1.865 +  if (!client_error)
   1.866 +    PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND);
   1.867 +  PR_Close(ci->client_sock);
   1.868 +
   1.869 +  delete ci;
   1.870 +}
   1.871 +
   1.872 +/*
   1.873 + * Start listening for SSL connections on a specified port, handing
   1.874 + * them off to client threads after accepting the connection.
   1.875 + * The data parameter is a server_info_t*, owned by the calling
   1.876 + * function.
   1.877 + */
   1.878 +void StartServer(void* data)
   1.879 +{
   1.880 +  server_info_t* si = static_cast<server_info_t*>(data);
   1.881 +
   1.882 +  //TODO: select ciphers?
   1.883 +  ScopedPRFileDesc listen_socket(PR_NewTCPSocket());
   1.884 +  if (!listen_socket) {
   1.885 +    LOG_ERROR(("failed to create socket\n"));
   1.886 +    SignalShutdown();
   1.887 +    return;
   1.888 +  }
   1.889 +
   1.890 +  // In case the socket is still open in the TIME_WAIT state from a previous
   1.891 +  // instance of ssltunnel we ask to reuse the port.
   1.892 +  PRSocketOptionData socket_option;
   1.893 +  socket_option.option = PR_SockOpt_Reuseaddr;
   1.894 +  socket_option.value.reuse_addr = true;
   1.895 +  PR_SetSocketOption(listen_socket, &socket_option);
   1.896 +
   1.897 +  PRNetAddr server_addr;
   1.898 +  PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr);
   1.899 +  if (PR_Bind(listen_socket, &server_addr) != PR_SUCCESS) {
   1.900 +    LOG_ERROR(("failed to bind socket\n"));
   1.901 +    SignalShutdown();
   1.902 +    return;
   1.903 +  }
   1.904 +
   1.905 +  if (PR_Listen(listen_socket, 1) != PR_SUCCESS) {
   1.906 +    LOG_ERROR(("failed to listen on socket\n"));
   1.907 +    SignalShutdown();
   1.908 +    return;
   1.909 +  }
   1.910 +
   1.911 +  LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port,
   1.912 +         si->cert_nickname.c_str()));
   1.913 +
   1.914 +  while (!shutdown_server) {
   1.915 +    connection_info_t* ci = new connection_info_t();
   1.916 +    ci->server_info = si;
   1.917 +    ci->http_proxy_only = do_http_proxy;
   1.918 +    // block waiting for connections
   1.919 +    ci->client_sock = PR_Accept(listen_socket, &ci->client_addr,
   1.920 +                                PR_INTERVAL_NO_TIMEOUT);
   1.921 +    
   1.922 +    PRSocketOptionData option;
   1.923 +    option.option = PR_SockOpt_Nonblocking;
   1.924 +    option.value.non_blocking = true;
   1.925 +    PR_SetSocketOption(ci->client_sock, &option);
   1.926 +
   1.927 +    if (ci->client_sock)
   1.928 +      // Not actually using this PRJob*...
   1.929 +      //PRJob* job =
   1.930 +      PR_QueueJob(threads, HandleConnection, ci, true);
   1.931 +    else
   1.932 +      delete ci;
   1.933 +  }
   1.934 +}
   1.935 +
   1.936 +// bogus password func, just don't use passwords. :-P
   1.937 +char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg)
   1.938 +{
   1.939 +  if (retry)
   1.940 +    return nullptr;
   1.941 +
   1.942 +  return PL_strdup("");
   1.943 +}
   1.944 +
   1.945 +server_info_t* findServerInfo(int portnumber)
   1.946 +{
   1.947 +  for (vector<server_info_t>::iterator it = servers.begin();
   1.948 +       it != servers.end(); it++) 
   1.949 +  {
   1.950 +    if (it->listen_port == portnumber)
   1.951 +      return &(*it);
   1.952 +  }
   1.953 +
   1.954 +  return nullptr;
   1.955 +}
   1.956 +
   1.957 +int processConfigLine(char* configLine)
   1.958 +{
   1.959 +  if (*configLine == 0 || *configLine == '#')
   1.960 +    return 0;
   1.961 +
   1.962 +  char* _caret;
   1.963 +  char* keyword = strtok2(configLine, ":", &_caret);
   1.964 +
   1.965 +  // Configure usage of http/ssl tunneling proxy behavior
   1.966 +  if (!strcmp(keyword, "httpproxy"))
   1.967 +  {
   1.968 +    char* value = strtok2(_caret, ":", &_caret);
   1.969 +    if (!strcmp(value, "1"))
   1.970 +      do_http_proxy = true;
   1.971 +
   1.972 +    return 0;
   1.973 +  }
   1.974 +
   1.975 +  if (!strcmp(keyword, "websocketserver"))
   1.976 +  {
   1.977 +    char* ipstring = strtok2(_caret, ":", &_caret);
   1.978 +    if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) {
   1.979 +      LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring));
   1.980 +      return 1;
   1.981 +    }
   1.982 +    char* remoteport = strtok2(_caret, ":", &_caret);
   1.983 +    int port = atoi(remoteport);
   1.984 +    if (port <= 0) {
   1.985 +      LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport));
   1.986 +      return 1;
   1.987 +    }
   1.988 +    websocket_server.inet.port = PR_htons(port);
   1.989 +    return 0;
   1.990 +  }
   1.991 +
   1.992 +  // Configure the forward address of the target server
   1.993 +  if (!strcmp(keyword, "forward"))
   1.994 +  {
   1.995 +    char* ipstring = strtok2(_caret, ":", &_caret);
   1.996 +    if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) {
   1.997 +      LOG_ERROR(("Invalid remote IP address: %s\n", ipstring));
   1.998 +      return 1;
   1.999 +    }
  1.1000 +    char* serverportstring = strtok2(_caret, ":", &_caret);
  1.1001 +    int port = atoi(serverportstring);
  1.1002 +    if (port <= 0) {
  1.1003 +      LOG_ERROR(("Invalid remote port: %s\n", serverportstring));
  1.1004 +      return 1;
  1.1005 +    }
  1.1006 +    remote_addr.inet.port = PR_htons(port);
  1.1007 +
  1.1008 +    return 0;
  1.1009 +  }
  1.1010 +
  1.1011 +  // Configure all listen sockets and port+certificate bindings
  1.1012 +  if (!strcmp(keyword, "listen"))
  1.1013 +  {
  1.1014 +    char* hostname = strtok2(_caret, ":", &_caret);
  1.1015 +    char* hostportstring = nullptr;
  1.1016 +    if (strcmp(hostname, "*"))
  1.1017 +    {
  1.1018 +      any_host_spec_config = true;
  1.1019 +      hostportstring = strtok2(_caret, ":", &_caret);
  1.1020 +    }
  1.1021 +
  1.1022 +    char* serverportstring = strtok2(_caret, ":", &_caret);
  1.1023 +    char* certnick = strtok2(_caret, ":", &_caret);
  1.1024 +
  1.1025 +    int port = atoi(serverportstring);
  1.1026 +    if (port <= 0) {
  1.1027 +      LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
  1.1028 +      return 1;
  1.1029 +    }
  1.1030 +
  1.1031 +    if (server_info_t* existingServer = findServerInfo(port))
  1.1032 +    {
  1.1033 +      char *certnick_copy = new char[strlen(certnick)+1];
  1.1034 +      char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
  1.1035 +
  1.1036 +      strcpy(hostname_copy, hostname);
  1.1037 +      strcat(hostname_copy, ":");
  1.1038 +      strcat(hostname_copy, hostportstring);
  1.1039 +      strcpy(certnick_copy, certnick);
  1.1040 +
  1.1041 +      PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, hostname_copy, certnick_copy);
  1.1042 +      if (!entry) {
  1.1043 +        LOG_ERROR(("Out of memory"));
  1.1044 +        return 1;
  1.1045 +      }
  1.1046 +    }
  1.1047 +    else
  1.1048 +    {
  1.1049 +      server_info_t server;
  1.1050 +      server.cert_nickname = certnick;
  1.1051 +      server.listen_port = port;
  1.1052 +      server.host_cert_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
  1.1053 +                                               PL_CompareStrings, nullptr, nullptr);
  1.1054 +      if (!server.host_cert_table)
  1.1055 +      {
  1.1056 +        LOG_ERROR(("Internal, could not create hash table\n"));
  1.1057 +        return 1;
  1.1058 +      }
  1.1059 +      server.host_clientauth_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
  1.1060 +                                                     ClientAuthValueComparator, nullptr, nullptr);
  1.1061 +      if (!server.host_clientauth_table)
  1.1062 +      {
  1.1063 +        LOG_ERROR(("Internal, could not create hash table\n"));
  1.1064 +        return 1;
  1.1065 +      }
  1.1066 +      server.host_redir_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
  1.1067 +                                                PL_CompareStrings, nullptr, nullptr);
  1.1068 +      if (!server.host_redir_table)
  1.1069 +      {
  1.1070 +        LOG_ERROR(("Internal, could not create hash table\n"));
  1.1071 +        return 1;
  1.1072 +      }
  1.1073 +      servers.push_back(server);
  1.1074 +    }
  1.1075 +
  1.1076 +    return 0;
  1.1077 +  }
  1.1078 +  
  1.1079 +  if (!strcmp(keyword, "clientauth"))
  1.1080 +  {
  1.1081 +    char* hostname = strtok2(_caret, ":", &_caret);
  1.1082 +    char* hostportstring = strtok2(_caret, ":", &_caret);
  1.1083 +    char* serverportstring = strtok2(_caret, ":", &_caret);
  1.1084 +
  1.1085 +    int port = atoi(serverportstring);
  1.1086 +    if (port <= 0) {
  1.1087 +      LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
  1.1088 +      return 1;
  1.1089 +    }
  1.1090 +
  1.1091 +    if (server_info_t* existingServer = findServerInfo(port))
  1.1092 +    {
  1.1093 +      char* authoptionstring = strtok2(_caret, ":", &_caret);
  1.1094 +      client_auth_option* authoption = new client_auth_option;
  1.1095 +      if (!authoption) {
  1.1096 +        LOG_ERROR(("Out of memory"));
  1.1097 +        return 1;
  1.1098 +      }
  1.1099 +
  1.1100 +      if (!strcmp(authoptionstring, "require"))
  1.1101 +        *authoption = caRequire;
  1.1102 +      else if (!strcmp(authoptionstring, "request"))
  1.1103 +        *authoption = caRequest;
  1.1104 +      else if (!strcmp(authoptionstring, "none"))
  1.1105 +        *authoption = caNone;
  1.1106 +      else
  1.1107 +      {
  1.1108 +        LOG_ERROR(("Incorrect client auth option modifier for host '%s'", hostname));
  1.1109 +        return 1;
  1.1110 +      }
  1.1111 +
  1.1112 +      any_host_spec_config = true;
  1.1113 +
  1.1114 +      char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
  1.1115 +      if (!hostname_copy) {
  1.1116 +        LOG_ERROR(("Out of memory"));
  1.1117 +        return 1;
  1.1118 +      }
  1.1119 +
  1.1120 +      strcpy(hostname_copy, hostname);
  1.1121 +      strcat(hostname_copy, ":");
  1.1122 +      strcat(hostname_copy, hostportstring);
  1.1123 +
  1.1124 +      PLHashEntry* entry = PL_HashTableAdd(existingServer->host_clientauth_table, hostname_copy, authoption);
  1.1125 +      if (!entry) {
  1.1126 +        LOG_ERROR(("Out of memory"));
  1.1127 +        return 1;
  1.1128 +      }
  1.1129 +    }
  1.1130 +    else
  1.1131 +    {
  1.1132 +      LOG_ERROR(("Server on port %d for client authentication option is not defined, use 'listen' option first", port));
  1.1133 +      return 1;
  1.1134 +    }
  1.1135 +
  1.1136 +    return 0;
  1.1137 +  }
  1.1138 +
  1.1139 +  if (!strcmp(keyword, "redirhost"))
  1.1140 +  {
  1.1141 +    char* hostname = strtok2(_caret, ":", &_caret);
  1.1142 +    char* hostportstring = strtok2(_caret, ":", &_caret);
  1.1143 +    char* serverportstring = strtok2(_caret, ":", &_caret);
  1.1144 +
  1.1145 +    int port = atoi(serverportstring);
  1.1146 +    if (port <= 0) {
  1.1147 +      LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
  1.1148 +      return 1;
  1.1149 +    }
  1.1150 +
  1.1151 +    if (server_info_t* existingServer = findServerInfo(port))
  1.1152 +    {
  1.1153 +      char* redirhoststring = strtok2(_caret, ":", &_caret);
  1.1154 +
  1.1155 +      any_host_spec_config = true;
  1.1156 +
  1.1157 +      char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
  1.1158 +      if (!hostname_copy) {
  1.1159 +        LOG_ERROR(("Out of memory"));
  1.1160 +        return 1;
  1.1161 +      }
  1.1162 +
  1.1163 +      strcpy(hostname_copy, hostname);
  1.1164 +      strcat(hostname_copy, ":");
  1.1165 +      strcat(hostname_copy, hostportstring);
  1.1166 +
  1.1167 +      char *redir_copy = new char[strlen(redirhoststring)+1];
  1.1168 +      strcpy(redir_copy, redirhoststring);
  1.1169 +      PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table, hostname_copy, redir_copy);
  1.1170 +      if (!entry) {
  1.1171 +        LOG_ERROR(("Out of memory"));
  1.1172 +        return 1;
  1.1173 +      }
  1.1174 +    }
  1.1175 +    else
  1.1176 +    {
  1.1177 +      LOG_ERROR(("Server on port %d for redirhost option is not defined, use 'listen' option first", port));
  1.1178 +      return 1;
  1.1179 +    }
  1.1180 +
  1.1181 +    return 0;
  1.1182 +  }
  1.1183 +
  1.1184 +  // Configure the NSS certificate database directory
  1.1185 +  if (!strcmp(keyword, "certdbdir"))
  1.1186 +  {
  1.1187 +    nssconfigdir = strtok2(_caret, "\n", &_caret);
  1.1188 +    return 0;
  1.1189 +  }
  1.1190 +
  1.1191 +  LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword));
  1.1192 +  return 1;
  1.1193 +}
  1.1194 +
  1.1195 +int parseConfigFile(const char* filePath)
  1.1196 +{
  1.1197 +  FILE* f = fopen(filePath, "r");
  1.1198 +  if (!f)
  1.1199 +    return 1;
  1.1200 +
  1.1201 +  char buffer[1024], *b = buffer;
  1.1202 +  while (!feof(f))
  1.1203 +  {
  1.1204 +    char c;
  1.1205 +    fscanf(f, "%c", &c);
  1.1206 +    switch (c)
  1.1207 +    {
  1.1208 +    case '\n':
  1.1209 +      *b++ = 0;
  1.1210 +      if (processConfigLine(buffer))
  1.1211 +        return 1;
  1.1212 +      b = buffer;
  1.1213 +    case '\r':
  1.1214 +      continue;
  1.1215 +    default:
  1.1216 +      *b++ = c;
  1.1217 +    }
  1.1218 +  }
  1.1219 +
  1.1220 +  fclose(f);
  1.1221 +
  1.1222 +  // Check mandatory items
  1.1223 +  if (nssconfigdir.empty())
  1.1224 +  {
  1.1225 +    LOG_ERROR(("Error: missing path to NSS certification database\n,use certdbdir:<path> in the config file\n"));
  1.1226 +    return 1;
  1.1227 +  }
  1.1228 +
  1.1229 +  if (any_host_spec_config && !do_http_proxy)
  1.1230 +  {
  1.1231 +    LOG_ERROR(("Warning: any host-specific configurations are ignored, add httpproxy:1 to allow them\n"));
  1.1232 +  }
  1.1233 +
  1.1234 +  return 0;
  1.1235 +}
  1.1236 +
  1.1237 +int freeHostCertHashItems(PLHashEntry *he, int i, void *arg)
  1.1238 +{
  1.1239 +  delete [] (char*)he->key;
  1.1240 +  delete [] (char*)he->value;
  1.1241 +  return HT_ENUMERATE_REMOVE;
  1.1242 +}
  1.1243 +
  1.1244 +int freeHostRedirHashItems(PLHashEntry *he, int i, void *arg)
  1.1245 +{
  1.1246 +  delete [] (char*)he->key;
  1.1247 +  delete [] (char*)he->value;
  1.1248 +  return HT_ENUMERATE_REMOVE;
  1.1249 +}
  1.1250 +
  1.1251 +int freeClientAuthHashItems(PLHashEntry *he, int i, void *arg)
  1.1252 +{
  1.1253 +  delete [] (char*)he->key;
  1.1254 +  delete (client_auth_option*)he->value;
  1.1255 +  return HT_ENUMERATE_REMOVE;
  1.1256 +}
  1.1257 +
  1.1258 +int main(int argc, char** argv)
  1.1259 +{
  1.1260 +  const char* configFilePath;
  1.1261 +  
  1.1262 +  const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL");
  1.1263 +  gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO;
  1.1264 +  
  1.1265 +  if (argc == 1)
  1.1266 +    configFilePath = "ssltunnel.cfg";
  1.1267 +  else
  1.1268 +    configFilePath = argv[1];
  1.1269 +
  1.1270 +  memset(&websocket_server, 0, sizeof(PRNetAddr));
  1.1271 +
  1.1272 +  if (parseConfigFile(configFilePath)) {
  1.1273 +    LOG_ERROR(("Error: config file \"%s\" missing or formating incorrect\n"
  1.1274 +      "Specify path to the config file as parameter to ssltunnel or \n"
  1.1275 +      "create ssltunnel.cfg in the working directory.\n\n"
  1.1276 +      "Example format of the config file:\n\n"
  1.1277 +      "       # Enable http/ssl tunneling proxy-like behavior.\n"
  1.1278 +      "       # If not specified ssltunnel simply does direct forward.\n"
  1.1279 +      "       httpproxy:1\n\n"
  1.1280 +      "       # Specify path to the certification database used.\n"
  1.1281 +      "       certdbdir:/path/to/certdb\n\n"
  1.1282 +      "       # Forward/proxy all requests in raw to 127.0.0.1:8888.\n"
  1.1283 +      "       forward:127.0.0.1:8888\n\n"
  1.1284 +      "       # Accept connections on port 4443 or 5678 resp. and authenticate\n"
  1.1285 +      "       # to any host ('*') using the 'server cert' or 'server cert 2' resp.\n"
  1.1286 +      "       listen:*:4443:server cert\n"
  1.1287 +      "       listen:*:5678:server cert 2\n\n"
  1.1288 +      "       # Accept connections on port 4443 and authenticate using\n"
  1.1289 +      "       # 'a different cert' when target host is 'my.host.name:443'.\n"
  1.1290 +      "       # This only works in httpproxy mode and has higher priority\n"
  1.1291 +      "       # than the previous option.\n"
  1.1292 +      "       listen:my.host.name:443:4443:a different cert\n\n"
  1.1293 +      "       # To make a specific host require or just request a client certificate\n"
  1.1294 +      "       # to authenticate use the following options. This can only be used\n"
  1.1295 +      "       # in httpproxy mode and only after the 'listen' option has been\n"
  1.1296 +      "       # specified. You also have to specify the tunnel listen port.\n"
  1.1297 +      "       clientauth:requesting-client-cert.host.com:443:4443:request\n"
  1.1298 +      "       clientauth:requiring-client-cert.host.com:443:4443:require\n"
  1.1299 +      "       # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n"
  1.1300 +      "       # instead of the server specified in the 'forward' option.\n"
  1.1301 +      "       websocketserver:127.0.0.1:9999\n",
  1.1302 +      configFilePath));
  1.1303 +    return 1;
  1.1304 +  }
  1.1305 +
  1.1306 +  // create a thread pool to handle connections
  1.1307 +  threads = PR_CreateThreadPool(INITIAL_THREADS * servers.size(),
  1.1308 +                                MAX_THREADS * servers.size(),
  1.1309 +                                DEFAULT_STACKSIZE);
  1.1310 +  if (!threads) {
  1.1311 +    LOG_ERROR(("Failed to create thread pool\n"));
  1.1312 +    return 1;
  1.1313 +  }
  1.1314 +
  1.1315 +  shutdown_lock = PR_NewLock();
  1.1316 +  if (!shutdown_lock) {
  1.1317 +    LOG_ERROR(("Failed to create lock\n"));
  1.1318 +    PR_ShutdownThreadPool(threads);
  1.1319 +    return 1;
  1.1320 +  }
  1.1321 +  shutdown_condvar = PR_NewCondVar(shutdown_lock);
  1.1322 +  if (!shutdown_condvar) {
  1.1323 +    LOG_ERROR(("Failed to create condvar\n"));
  1.1324 +    PR_ShutdownThreadPool(threads);
  1.1325 +    PR_DestroyLock(shutdown_lock);
  1.1326 +    return 1;
  1.1327 +  }
  1.1328 +
  1.1329 +  PK11_SetPasswordFunc(password_func);
  1.1330 +
  1.1331 +  // Initialize NSS
  1.1332 +  if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) {
  1.1333 +    int32_t errorlen = PR_GetErrorTextLength();
  1.1334 +    char* err = new char[errorlen+1];
  1.1335 +    PR_GetErrorText(err);
  1.1336 +    LOG_ERROR(("Failed to init NSS: %s", err));
  1.1337 +    delete[] err;
  1.1338 +    PR_ShutdownThreadPool(threads);
  1.1339 +    PR_DestroyCondVar(shutdown_condvar);
  1.1340 +    PR_DestroyLock(shutdown_lock);
  1.1341 +    return 1;
  1.1342 +  }
  1.1343 +
  1.1344 +  if (NSS_SetDomesticPolicy() != SECSuccess) {
  1.1345 +    LOG_ERROR(("NSS_SetDomesticPolicy failed\n"));
  1.1346 +    PR_ShutdownThreadPool(threads);
  1.1347 +    PR_DestroyCondVar(shutdown_condvar);
  1.1348 +    PR_DestroyLock(shutdown_lock);
  1.1349 +    NSS_Shutdown();
  1.1350 +    return 1;
  1.1351 +  }
  1.1352 +
  1.1353 +  // these values should make NSS use the defaults
  1.1354 +  if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
  1.1355 +    LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n"));
  1.1356 +    PR_ShutdownThreadPool(threads);
  1.1357 +    PR_DestroyCondVar(shutdown_condvar);
  1.1358 +    PR_DestroyLock(shutdown_lock);
  1.1359 +    NSS_Shutdown();
  1.1360 +    return 1;
  1.1361 +  }
  1.1362 +
  1.1363 +  for (vector<server_info_t>::iterator it = servers.begin();
  1.1364 +       it != servers.end(); it++) {
  1.1365 +    // Not actually using this PRJob*...
  1.1366 +    // PRJob* server_job =
  1.1367 +    PR_QueueJob(threads, StartServer, &(*it), true);
  1.1368 +  }
  1.1369 +  // now wait for someone to tell us to quit
  1.1370 +  PR_Lock(shutdown_lock);
  1.1371 +  PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT);
  1.1372 +  PR_Unlock(shutdown_lock);
  1.1373 +  shutdown_server = true;
  1.1374 +  LOG_INFO(("Shutting down...\n"));
  1.1375 +  // cleanup
  1.1376 +  PR_ShutdownThreadPool(threads);
  1.1377 +  PR_JoinThreadPool(threads);
  1.1378 +  PR_DestroyCondVar(shutdown_condvar);
  1.1379 +  PR_DestroyLock(shutdown_lock);
  1.1380 +  if (NSS_Shutdown() == SECFailure) {
  1.1381 +    LOG_DEBUG(("Leaked NSS objects!\n"));
  1.1382 +  }
  1.1383 +  
  1.1384 +  for (vector<server_info_t>::iterator it = servers.begin();
  1.1385 +       it != servers.end(); it++) 
  1.1386 +  {
  1.1387 +    PL_HashTableEnumerateEntries(it->host_cert_table, freeHostCertHashItems, nullptr);
  1.1388 +    PL_HashTableEnumerateEntries(it->host_clientauth_table, freeClientAuthHashItems, nullptr);
  1.1389 +    PL_HashTableEnumerateEntries(it->host_redir_table, freeHostRedirHashItems, nullptr);
  1.1390 +    PL_HashTableDestroy(it->host_cert_table);
  1.1391 +    PL_HashTableDestroy(it->host_clientauth_table);
  1.1392 +    PL_HashTableDestroy(it->host_redir_table);
  1.1393 +  }
  1.1394 +
  1.1395 +  PR_Cleanup();
  1.1396 +  return 0;
  1.1397 +}

mercurial