1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000 1.2 +++ b/testing/mochitest/ssltunnel/ssltunnel.cpp Wed Dec 31 06:09:35 2014 +0100 1.3 @@ -0,0 +1,1394 @@ 1.4 +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ 1.5 +/* This Source Code Form is subject to the terms of the Mozilla Public 1.6 + * License, v. 2.0. If a copy of the MPL was not distributed with this 1.7 + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 1.8 + 1.9 +/* 1.10 + * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to 1.11 + * be plagued with the usual problems endemic to C (buffer overflows 1.12 + * and the like). We don't especially care here (but would accept 1.13 + * patches!) because this is only intended for use in our test 1.14 + * harnesses in controlled situations where input is guaranteed not to 1.15 + * be malicious. 1.16 + */ 1.17 + 1.18 +#include "ScopedNSSTypes.h" 1.19 +#include <assert.h> 1.20 +#include <stdio.h> 1.21 +#include <string> 1.22 +#include <vector> 1.23 +#include <algorithm> 1.24 +#include <stdarg.h> 1.25 +#include "prinit.h" 1.26 +#include "prerror.h" 1.27 +#include "prenv.h" 1.28 +#include "prnetdb.h" 1.29 +#include "prtpool.h" 1.30 +#include "nsAlgorithm.h" 1.31 +#include "nss.h" 1.32 +#include "key.h" 1.33 +#include "ssl.h" 1.34 +#include "plhash.h" 1.35 + 1.36 +using namespace mozilla; 1.37 +using namespace mozilla::psm; 1.38 +using std::string; 1.39 +using std::vector; 1.40 + 1.41 +#define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c) & 7))) 1.42 +#define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c) & 7))) 1.43 +#define DELIM_TABLE_SIZE 32 1.44 + 1.45 +// You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n 1.46 +// is 0 through 3. The default is 1, INFO level logging. 1.47 +enum LogLevel { 1.48 + LEVEL_DEBUG = 0, 1.49 + LEVEL_INFO = 1, 1.50 + LEVEL_ERROR = 2, 1.51 + LEVEL_SILENT = 3 1.52 +} gLogLevel, gLastLogLevel; 1.53 + 1.54 +#define _LOG_OUTPUT(level, func, params) \ 1.55 +PR_BEGIN_MACRO \ 1.56 + if (level >= gLogLevel) { \ 1.57 + gLastLogLevel = level; \ 1.58 + func params;\ 1.59 + } \ 1.60 +PR_END_MACRO 1.61 + 1.62 +// The most verbose output 1.63 +#define LOG_DEBUG(params) \ 1.64 + _LOG_OUTPUT(LEVEL_DEBUG, printf, params) 1.65 + 1.66 +// Top level informative messages 1.67 +#define LOG_INFO(params) \ 1.68 + _LOG_OUTPUT(LEVEL_INFO, printf, params) 1.69 + 1.70 +// Serious errors that must be logged always until completely gag 1.71 +#define LOG_ERROR(params) \ 1.72 + _LOG_OUTPUT(LEVEL_ERROR, eprintf, params) 1.73 + 1.74 +// Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message 1.75 +// will be put to the stdout instead of stderr to keep continuity with other 1.76 +// LOG_DEBUG message output 1.77 +#define LOG_ERRORD(params) \ 1.78 +PR_BEGIN_MACRO \ 1.79 + if (gLogLevel == LEVEL_DEBUG) \ 1.80 + _LOG_OUTPUT(LEVEL_ERROR, printf, params); \ 1.81 + else \ 1.82 + _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \ 1.83 +PR_END_MACRO 1.84 + 1.85 +// If there is any output written between LOG_BEGIN_BLOCK() and 1.86 +// LOG_END_BLOCK() then a new line will be put to the proper output (out/err) 1.87 +#define LOG_BEGIN_BLOCK() \ 1.88 + gLastLogLevel = LEVEL_SILENT; 1.89 + 1.90 +#define LOG_END_BLOCK() \ 1.91 +PR_BEGIN_MACRO \ 1.92 + if (gLastLogLevel == LEVEL_ERROR) \ 1.93 + LOG_ERROR(("\n")); \ 1.94 + if (gLastLogLevel < LEVEL_ERROR) \ 1.95 + _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \ 1.96 +PR_END_MACRO 1.97 + 1.98 +int eprintf(const char* str, ...) 1.99 +{ 1.100 + va_list ap; 1.101 + va_start(ap, str); 1.102 + int result = vfprintf(stderr, str, ap); 1.103 + va_end(ap); 1.104 + return result; 1.105 +} 1.106 + 1.107 +// Copied from nsCRT 1.108 +char* strtok2(char* string, const char* delims, char* *newStr) 1.109 +{ 1.110 + PR_ASSERT(string); 1.111 + 1.112 + char delimTable[DELIM_TABLE_SIZE]; 1.113 + uint32_t i; 1.114 + char* result; 1.115 + char* str = string; 1.116 + 1.117 + for (i = 0; i < DELIM_TABLE_SIZE; i++) 1.118 + delimTable[i] = '\0'; 1.119 + 1.120 + for (i = 0; delims[i]; i++) { 1.121 + SET_DELIM(delimTable, static_cast<uint8_t>(delims[i])); 1.122 + } 1.123 + 1.124 + // skip to beginning 1.125 + while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) { 1.126 + str++; 1.127 + } 1.128 + result = str; 1.129 + 1.130 + // fix up the end of the token 1.131 + while (*str) { 1.132 + if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) { 1.133 + *str++ = '\0'; 1.134 + break; 1.135 + } 1.136 + str++; 1.137 + } 1.138 + *newStr = str; 1.139 + 1.140 + return str == result ? nullptr : result; 1.141 +} 1.142 + 1.143 + 1.144 + 1.145 +enum client_auth_option { 1.146 + caNone = 0, 1.147 + caRequire = 1, 1.148 + caRequest = 2 1.149 +}; 1.150 + 1.151 +// Structs for passing data into jobs on the thread pool 1.152 +typedef struct { 1.153 + int32_t listen_port; 1.154 + string cert_nickname; 1.155 + PLHashTable* host_cert_table; 1.156 + PLHashTable* host_clientauth_table; 1.157 + PLHashTable* host_redir_table; 1.158 +} server_info_t; 1.159 + 1.160 +typedef struct { 1.161 + PRFileDesc* client_sock; 1.162 + PRNetAddr client_addr; 1.163 + server_info_t* server_info; 1.164 + // the original host in the Host: header for this connection is 1.165 + // stored here, for proxied connections 1.166 + string original_host; 1.167 + // true if no SSL should be used for this connection 1.168 + bool http_proxy_only; 1.169 + // true if this connection is for a WebSocket 1.170 + bool iswebsocket; 1.171 +} connection_info_t; 1.172 + 1.173 +typedef struct { 1.174 + string fullHost; 1.175 + bool matched; 1.176 +} server_match_t; 1.177 + 1.178 +const int32_t BUF_SIZE = 16384; 1.179 +const int32_t BUF_MARGIN = 1024; 1.180 +const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN; 1.181 + 1.182 +struct relayBuffer 1.183 +{ 1.184 + char *buffer, *bufferhead, *buffertail, *bufferend; 1.185 + 1.186 + relayBuffer() 1.187 + { 1.188 + // Leave 1024 bytes more for request line manipulations 1.189 + bufferhead = buffertail = buffer = new char[BUF_TOTAL]; 1.190 + bufferend = buffer + BUF_SIZE; 1.191 + } 1.192 + 1.193 + ~relayBuffer() 1.194 + { 1.195 + delete [] buffer; 1.196 + } 1.197 + 1.198 + void compact() { 1.199 + if (buffertail == bufferhead) 1.200 + buffertail = bufferhead = buffer; 1.201 + } 1.202 + 1.203 + bool empty() { return bufferhead == buffertail; } 1.204 + size_t areafree() { return bufferend - buffertail; } 1.205 + size_t margin() { return areafree() + BUF_MARGIN; } 1.206 + size_t present() { return buffertail - bufferhead; } 1.207 +}; 1.208 + 1.209 +// These numbers are multiplied by the number of listening ports (actual 1.210 +// servers running). According the thread pool implementation there is no 1.211 +// need to limit the number of threads initially, threads are allocated 1.212 +// dynamically and stored in a linked list. Initial number of 2 is chosen 1.213 +// to allocate a thread for socket accept and preallocate one for the first 1.214 +// connection that is with high probability expected to come. 1.215 +const uint32_t INITIAL_THREADS = 2; 1.216 +const uint32_t MAX_THREADS = 100; 1.217 +const uint32_t DEFAULT_STACKSIZE = (512 * 1024); 1.218 + 1.219 +// global data 1.220 +string nssconfigdir; 1.221 +vector<server_info_t> servers; 1.222 +PRNetAddr remote_addr; 1.223 +PRNetAddr websocket_server; 1.224 +PRThreadPool* threads = nullptr; 1.225 +PRLock* shutdown_lock = nullptr; 1.226 +PRCondVar* shutdown_condvar = nullptr; 1.227 +// Not really used, unless something fails to start 1.228 +bool shutdown_server = false; 1.229 +bool do_http_proxy = false; 1.230 +bool any_host_spec_config = false; 1.231 + 1.232 +int ClientAuthValueComparator(const void *v1, const void *v2) 1.233 +{ 1.234 + int a = *static_cast<const client_auth_option*>(v1) - 1.235 + *static_cast<const client_auth_option*>(v2); 1.236 + if (a == 0) 1.237 + return 0; 1.238 + if (a > 0) 1.239 + return 1; 1.240 + else // (a < 0) 1.241 + return -1; 1.242 +} 1.243 + 1.244 +static int match_hostname(PLHashEntry *he, int index, void* arg) 1.245 +{ 1.246 + server_match_t *match = (server_match_t*)arg; 1.247 + if (match->fullHost.find((char*)he->key) != string::npos) 1.248 + match->matched = true; 1.249 + return HT_ENUMERATE_NEXT; 1.250 +} 1.251 + 1.252 +/* 1.253 + * Signal the main thread that the application should shut down. 1.254 + */ 1.255 +void SignalShutdown() 1.256 +{ 1.257 + PR_Lock(shutdown_lock); 1.258 + PR_NotifyCondVar(shutdown_condvar); 1.259 + PR_Unlock(shutdown_lock); 1.260 +} 1.261 + 1.262 +bool ReadConnectRequest(server_info_t* server_info, 1.263 + relayBuffer& buffer, int32_t* result, string& certificate, 1.264 + client_auth_option* clientauth, string& host, string& location) 1.265 +{ 1.266 + if (buffer.present() < 4) { 1.267 + LOG_DEBUG((" !! only %d bytes present in the buffer", (int)buffer.present())); 1.268 + return false; 1.269 + } 1.270 + if (strncmp(buffer.buffertail-4, "\r\n\r\n", 4)) { 1.271 + LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x", 1.272 + *(buffer.buffertail-4), 1.273 + *(buffer.buffertail-3), 1.274 + *(buffer.buffertail-2), 1.275 + *(buffer.buffertail-1))); 1.276 + return false; 1.277 + } 1.278 + 1.279 + LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n", (int)buffer.present(), buffer.bufferhead)); 1.280 + 1.281 + *result = 400; 1.282 + 1.283 + char* token; 1.284 + char* _caret; 1.285 + token = strtok2(buffer.bufferhead, " ", &_caret); 1.286 + if (!token) { 1.287 + LOG_ERRORD((" no space found")); 1.288 + return true; 1.289 + } 1.290 + if (strcmp(token, "CONNECT")) { 1.291 + LOG_ERRORD((" not CONNECT request but %s", token)); 1.292 + return true; 1.293 + } 1.294 + 1.295 + token = strtok2(_caret, " ", &_caret); 1.296 + void* c = PL_HashTableLookup(server_info->host_cert_table, token); 1.297 + if (c) 1.298 + certificate = static_cast<char*>(c); 1.299 + 1.300 + host = "https://"; 1.301 + host += token; 1.302 + 1.303 + c = PL_HashTableLookup(server_info->host_clientauth_table, token); 1.304 + if (c) 1.305 + *clientauth = *static_cast<client_auth_option*>(c); 1.306 + else 1.307 + *clientauth = caNone; 1.308 + 1.309 + void *redir = PL_HashTableLookup(server_info->host_redir_table, token); 1.310 + if (redir) 1.311 + location = static_cast<char*>(redir); 1.312 + 1.313 + token = strtok2(_caret, "/", &_caret); 1.314 + if (strcmp(token, "HTTP")) { 1.315 + LOG_ERRORD((" not tailed with HTTP but with %s", token)); 1.316 + return true; 1.317 + } 1.318 + 1.319 + *result = (redir) ? 302 : 200; 1.320 + return true; 1.321 +} 1.322 + 1.323 +bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, string &certificate, client_auth_option clientAuth) 1.324 +{ 1.325 + const char* certnick = certificate.empty() ? 1.326 + si->cert_nickname.c_str() : certificate.c_str(); 1.327 + 1.328 + ScopedCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr)); 1.329 + if (!cert) { 1.330 + LOG_ERROR(("Failed to find cert %s\n", certnick)); 1.331 + return false; 1.332 + } 1.333 + 1.334 + ScopedSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert, nullptr)); 1.335 + if (!privKey) { 1.336 + LOG_ERROR(("Failed to find private key\n")); 1.337 + return false; 1.338 + } 1.339 + 1.340 + PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket); 1.341 + if (!ssl_socket) { 1.342 + LOG_ERROR(("Error importing SSL socket\n")); 1.343 + return false; 1.344 + } 1.345 + 1.346 + SSLKEAType certKEA = NSS_FindCertKEAType(cert); 1.347 + if (SSL_ConfigSecureServer(ssl_socket, cert, privKey, certKEA) 1.348 + != SECSuccess) { 1.349 + LOG_ERROR(("Error configuring SSL server socket\n")); 1.350 + return false; 1.351 + } 1.352 + 1.353 + SSL_OptionSet(ssl_socket, SSL_SECURITY, true); 1.354 + SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false); 1.355 + SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true); 1.356 + 1.357 + if (clientAuth != caNone) 1.358 + { 1.359 + SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true); 1.360 + SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire); 1.361 + } 1.362 + 1.363 + SSL_ResetHandshake(ssl_socket, true); 1.364 + 1.365 + return true; 1.366 +} 1.367 + 1.368 +/** 1.369 + * This function examines the buffer for a Sec-WebSocket-Location: field, 1.370 + * and if it's present, it replaces the hostname in that field with the 1.371 + * value in the server's original_host field. This function works 1.372 + * in the reverse direction as AdjustWebSocketHost(), replacing the real 1.373 + * hostname of a response with the potentially fake hostname that is expected 1.374 + * by the browser (e.g., mochi.test). 1.375 + * 1.376 + * @return true if the header was adjusted successfully, or not found, false 1.377 + * if the header is present but the url is not, which should indicate 1.378 + * that more data needs to be read from the socket 1.379 + */ 1.380 +bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t *ci) 1.381 +{ 1.382 + assert(buffer.margin()); 1.383 + buffer.buffertail[1] = '\0'; 1.384 + 1.385 + char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:"); 1.386 + if (!wsloc) 1.387 + return true; 1.388 + // advance pointer to the start of the hostname 1.389 + wsloc = strstr(wsloc, "ws://"); 1.390 + if (!wsloc) 1.391 + return false; 1.392 + wsloc += 5; 1.393 + // find the end of the hostname 1.394 + char* wslocend = strchr(wsloc + 1, '/'); 1.395 + if (!wslocend) 1.396 + return false; 1.397 + char *crlf = strstr(wsloc, "\r\n"); 1.398 + if (!crlf) 1.399 + return false; 1.400 + if (ci->original_host.empty()) 1.401 + return true; 1.402 + 1.403 + int diff = ci->original_host.length() - (wslocend-wsloc); 1.404 + if (diff > 0) 1.405 + assert(size_t(diff) <= buffer.margin()); 1.406 + memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff); 1.407 + buffer.buffertail += diff; 1.408 + 1.409 + memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length()); 1.410 + return true; 1.411 +} 1.412 + 1.413 +/** 1.414 + * This function examines the buffer for a Host: field, and if it's present, 1.415 + * it replaces the hostname in that field with the hostname in the server's 1.416 + * remote_addr field. This is needed because proxy requests may be coming 1.417 + * from mochitest with fake hosts, like mochi.test, and these need to be 1.418 + * replaced with the host that the destination server is actually running 1.419 + * on. 1.420 + */ 1.421 +bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t *ci) 1.422 +{ 1.423 + const char HEADER_UPGRADE[] = "Upgrade:"; 1.424 + const char HEADER_HOST[] = "Host:"; 1.425 + 1.426 + PRNetAddr inet_addr = (websocket_server.inet.port ? websocket_server : 1.427 + remote_addr); 1.428 + 1.429 + assert(buffer.margin()); 1.430 + 1.431 + // Cannot use strnchr so add a null char at the end. There is always some 1.432 + // space left because we preserve a margin. 1.433 + buffer.buffertail[1] = '\0'; 1.434 + 1.435 + // Verify this is a WebSocket header. 1.436 + char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE); 1.437 + if (!h1) 1.438 + return false; 1.439 + h1 += strlen(HEADER_UPGRADE); 1.440 + h1 += strspn(h1, " \t"); 1.441 + char* h2 = strstr(h1, "WebSocket\r\n"); 1.442 + if (!h2) h2 = strstr(h1, "websocket\r\n"); 1.443 + if (!h2) h2 = strstr(h1, "Websocket\r\n"); 1.444 + if (!h2) 1.445 + return false; 1.446 + 1.447 + char* host = strstr(buffer.bufferhead, HEADER_HOST); 1.448 + if (!host) 1.449 + return false; 1.450 + // advance pointer to beginning of hostname 1.451 + host += strlen(HEADER_HOST); 1.452 + host += strspn(host, " \t"); 1.453 + 1.454 + char* endhost = strstr(host, "\r\n"); 1.455 + if (!endhost) 1.456 + return false; 1.457 + 1.458 + // Save the original host, so we can use it later on responses from the 1.459 + // server. 1.460 + ci->original_host.assign(host, endhost-host); 1.461 + 1.462 + char newhost[40]; 1.463 + PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost)); 1.464 + assert(strlen(newhost) < sizeof(newhost) - 7); 1.465 + sprintf(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port)); 1.466 + 1.467 + int diff = strlen(newhost) - (endhost-host); 1.468 + if (diff > 0) 1.469 + assert(size_t(diff) <= buffer.margin()); 1.470 + memmove(endhost + diff, endhost, buffer.buffertail - host - diff); 1.471 + buffer.buffertail += diff; 1.472 + 1.473 + memcpy(host, newhost, strlen(newhost)); 1.474 + return true; 1.475 +} 1.476 + 1.477 +/** 1.478 + * This function prefixes Request-URI path with a full scheme-host-port 1.479 + * string. 1.480 + */ 1.481 +bool AdjustRequestURI(relayBuffer& buffer, string *host) 1.482 +{ 1.483 + assert(buffer.margin()); 1.484 + 1.485 + // Cannot use strnchr so add a null char at the end. There is always some space left 1.486 + // because we preserve a margin. 1.487 + buffer.buffertail[1] = '\0'; 1.488 + LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead)); 1.489 + 1.490 + char *token, *path; 1.491 + path = strchr(buffer.bufferhead, ' ') + 1; 1.492 + if (!path) 1.493 + return false; 1.494 + 1.495 + // If the path doesn't start with a slash don't change it, it is probably '*' or a full 1.496 + // path already. Return true, we are done with this request adjustment. 1.497 + if (*path != '/') 1.498 + return true; 1.499 + 1.500 + token = strchr(path, ' ') + 1; 1.501 + if (!token) 1.502 + return false; 1.503 + 1.504 + if (strncmp(token, "HTTP/", 5)) 1.505 + return false; 1.506 + 1.507 + size_t hostlength = host->length(); 1.508 + assert(hostlength <= buffer.margin()); 1.509 + 1.510 + memmove(path + hostlength, path, buffer.buffertail - path); 1.511 + memcpy(path, host->c_str(), hostlength); 1.512 + buffer.buffertail += hostlength; 1.513 + 1.514 + return true; 1.515 +} 1.516 + 1.517 +bool ConnectSocket(PRFileDesc *fd, const PRNetAddr *addr, PRIntervalTime timeout) 1.518 +{ 1.519 + PRStatus stat = PR_Connect(fd, addr, timeout); 1.520 + if (stat != PR_SUCCESS) 1.521 + return false; 1.522 + 1.523 + PRSocketOptionData option; 1.524 + option.option = PR_SockOpt_Nonblocking; 1.525 + option.value.non_blocking = true; 1.526 + PR_SetSocketOption(fd, &option); 1.527 + 1.528 + return true; 1.529 +} 1.530 + 1.531 +/* 1.532 + * Handle an incoming client connection. The server thread has already 1.533 + * accepted the connection, so we just need to connect to the remote 1.534 + * port and then proxy data back and forth. 1.535 + * The data parameter is a connection_info_t*, and must be deleted 1.536 + * by this function. 1.537 + */ 1.538 +void HandleConnection(void* data) 1.539 +{ 1.540 + connection_info_t* ci = static_cast<connection_info_t*>(data); 1.541 + PRIntervalTime connect_timeout = PR_SecondsToInterval(30); 1.542 + 1.543 + ScopedPRFileDesc other_sock(PR_NewTCPSocket()); 1.544 + bool client_done = false; 1.545 + bool client_error = false; 1.546 + bool connect_accepted = !do_http_proxy; 1.547 + bool ssl_updated = !do_http_proxy; 1.548 + bool expect_request_start = do_http_proxy; 1.549 + string certificateToUse; 1.550 + string locationHeader; 1.551 + client_auth_option clientAuth; 1.552 + string fullHost; 1.553 + 1.554 + LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n", 1.555 + static_cast<void*>(data), 1.556 + static_cast<void*>(ci->client_sock), 1.557 + static_cast<void*>(other_sock))); 1.558 + if (other_sock) 1.559 + { 1.560 + int32_t numberOfSockets = 1; 1.561 + 1.562 + relayBuffer buffers[2]; 1.563 + 1.564 + if (!do_http_proxy) 1.565 + { 1.566 + if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, caNone)) 1.567 + client_error = true; 1.568 + else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout)) 1.569 + client_error = true; 1.570 + else 1.571 + numberOfSockets = 2; 1.572 + } 1.573 + 1.574 + PRPollDesc sockets[2] = 1.575 + { 1.576 + {ci->client_sock, PR_POLL_READ, 0}, 1.577 + {other_sock, PR_POLL_READ, 0} 1.578 + }; 1.579 + bool socketErrorState[2] = {false, false}; 1.580 + 1.581 + while (!((client_error||client_done) && buffers[0].empty() && buffers[1].empty())) 1.582 + { 1.583 + sockets[0].in_flags |= PR_POLL_EXCEPT; 1.584 + sockets[1].in_flags |= PR_POLL_EXCEPT; 1.585 + LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n", 1.586 + static_cast<void*>(data), 1.587 + sockets[0].in_flags & PR_POLL_READ ? 'R' : '-', 1.588 + sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-', 1.589 + sockets[1].in_flags & PR_POLL_READ ? 'R' : '-', 1.590 + sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-')); 1.591 + int32_t pollStatus = PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000)); 1.592 + if (pollStatus < 0) 1.593 + { 1.594 + LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n", 1.595 + static_cast<void*>(data), pollStatus)); 1.596 + client_error = true; 1.597 + break; 1.598 + } 1.599 + 1.600 + if (pollStatus == 0) 1.601 + { 1.602 + // timeout 1.603 + LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n", 1.604 + static_cast<void*>(data))); 1.605 + continue; 1.606 + } 1.607 + 1.608 + for (int32_t s = 0; s < numberOfSockets; ++s) 1.609 + { 1.610 + int32_t s2 = s == 1 ? 0 : 1; 1.611 + int16_t out_flags = sockets[s].out_flags; 1.612 + int16_t &in_flags = sockets[s].in_flags; 1.613 + int16_t &in_flags2 = sockets[s2].in_flags; 1.614 + sockets[s].out_flags = 0; 1.615 + 1.616 + LOG_BEGIN_BLOCK(); 1.617 + LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d", 1.618 + static_cast<void*>(data), 1.619 + s == 0 ? 'c' : 's', 1.620 + s, 1.621 + static_cast<void*>(sockets[s].fd), 1.622 + out_flags)); 1.623 + if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP)) 1.624 + { 1.625 + LOG_DEBUG((" :exception\n")); 1.626 + client_error = true; 1.627 + socketErrorState[s] = true; 1.628 + // We got a fatal error state on the socket. Clear the output buffer 1.629 + // for this socket to break the main loop, we will never more be able 1.630 + // to send those data anyway. 1.631 + buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; 1.632 + continue; 1.633 + } // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling 1.634 + 1.635 + if (out_flags & PR_POLL_READ && !buffers[s].areafree()) 1.636 + { 1.637 + LOG_DEBUG((" no place in read buffer but got read flag, dropping it now!")); 1.638 + in_flags &= ~PR_POLL_READ; 1.639 + } 1.640 + 1.641 + if (out_flags & PR_POLL_READ && buffers[s].areafree()) 1.642 + { 1.643 + LOG_DEBUG((" :reading")); 1.644 + int32_t bytesRead = PR_Recv(sockets[s].fd, buffers[s].buffertail, 1.645 + buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT); 1.646 + 1.647 + if (bytesRead == 0) 1.648 + { 1.649 + LOG_DEBUG((" socket gracefully closed")); 1.650 + client_done = true; 1.651 + in_flags &= ~PR_POLL_READ; 1.652 + } 1.653 + else if (bytesRead < 0) 1.654 + { 1.655 + if (PR_GetError() != PR_WOULD_BLOCK_ERROR) 1.656 + { 1.657 + LOG_DEBUG((" error=%d", PR_GetError())); 1.658 + // We are in error state, indicate that the connection was 1.659 + // not closed gracefully 1.660 + client_error = true; 1.661 + socketErrorState[s] = true; 1.662 + // Wipe out our send buffer, we cannot send it anyway. 1.663 + buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; 1.664 + } 1.665 + else 1.666 + LOG_DEBUG((" would block")); 1.667 + } 1.668 + else 1.669 + { 1.670 + // If the other socket is in error state (unable to send/receive) 1.671 + // throw this data away and continue loop 1.672 + if (socketErrorState[s2]) 1.673 + { 1.674 + LOG_DEBUG((" have read but other socket is in error state\n")); 1.675 + continue; 1.676 + } 1.677 + 1.678 + buffers[s].buffertail += bytesRead; 1.679 + LOG_DEBUG((", read %d bytes", bytesRead)); 1.680 + 1.681 + // We have to accept and handle the initial CONNECT request here 1.682 + int32_t response; 1.683 + if (!connect_accepted && ReadConnectRequest(ci->server_info, buffers[s], 1.684 + &response, certificateToUse, &clientAuth, fullHost, locationHeader)) 1.685 + { 1.686 + // Mark this as a proxy-only connection (no SSL) if the CONNECT 1.687 + // request didn't come for port 443 or from any of the server's 1.688 + // cert or clientauth hostnames. 1.689 + if (fullHost.find(":443") == string::npos) 1.690 + { 1.691 + server_match_t match; 1.692 + match.fullHost = fullHost; 1.693 + match.matched = false; 1.694 + PL_HashTableEnumerateEntries(ci->server_info->host_cert_table, 1.695 + match_hostname, 1.696 + &match); 1.697 + PL_HashTableEnumerateEntries(ci->server_info->host_clientauth_table, 1.698 + match_hostname, 1.699 + &match); 1.700 + ci->http_proxy_only = !match.matched; 1.701 + } 1.702 + else 1.703 + { 1.704 + ci->http_proxy_only = false; 1.705 + } 1.706 + 1.707 + // Clean the request as it would be read 1.708 + buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer; 1.709 + in_flags |= PR_POLL_WRITE; 1.710 + connect_accepted = true; 1.711 + 1.712 + // Store response to the oposite buffer 1.713 + if (response == 200) 1.714 + { 1.715 + LOG_DEBUG((" accepted CONNECT request, connected to the server, sending OK to the client\n")); 1.716 + strcpy(buffers[s2].buffer, "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n"); 1.717 + } 1.718 + else if (response == 302) 1.719 + { 1.720 + LOG_DEBUG((" accepted CONNECT request with redirection, " 1.721 + "sending location and 302 to the client\n")); 1.722 + client_done = true; 1.723 + sprintf(buffers[s2].buffer, 1.724 + "HTTP/1.1 302 Moved\r\n" 1.725 + "Location: https://%s/\r\n" 1.726 + "Connection: close\r\n\r\n", 1.727 + locationHeader.c_str()); 1.728 + } 1.729 + else 1.730 + { 1.731 + LOG_ERRORD((" could not read the connect request, closing connection with %d", response)); 1.732 + client_done = true; 1.733 + sprintf(buffers[s2].buffer, "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", response); 1.734 + 1.735 + break; 1.736 + } 1.737 + 1.738 + buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer); 1.739 + 1.740 + // Send the response to the client socket 1.741 + break; 1.742 + } // end of CONNECT handling 1.743 + 1.744 + if (!buffers[s].areafree()) 1.745 + { 1.746 + // Do not poll for read when the buffer is full 1.747 + LOG_DEBUG((" no place in our read buffer, stop reading")); 1.748 + in_flags &= ~PR_POLL_READ; 1.749 + } 1.750 + 1.751 + if (ssl_updated) 1.752 + { 1.753 + if (s == 0 && expect_request_start) 1.754 + { 1.755 + if (!strstr(buffers[s].bufferhead, "\r\n\r\n")) 1.756 + { 1.757 + // We haven't received the complete header yet, so wait. 1.758 + continue; 1.759 + } 1.760 + else 1.761 + { 1.762 + ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci); 1.763 + expect_request_start = !(ci->iswebsocket || 1.764 + AdjustRequestURI(buffers[s], &fullHost)); 1.765 + PRNetAddr* addr = &remote_addr; 1.766 + if (ci->iswebsocket && websocket_server.inet.port) 1.767 + addr = &websocket_server; 1.768 + if (!ConnectSocket(other_sock, addr, connect_timeout)) 1.769 + { 1.770 + LOG_ERRORD((" could not open connection to the real server\n")); 1.771 + client_error = true; 1.772 + break; 1.773 + } 1.774 + LOG_DEBUG(("\n connected to remote server\n")); 1.775 + numberOfSockets = 2; 1.776 + } 1.777 + } 1.778 + else if (s == 1 && ci->iswebsocket) 1.779 + { 1.780 + if (!AdjustWebSocketLocation(buffers[s], ci)) 1.781 + continue; 1.782 + } 1.783 + 1.784 + in_flags2 |= PR_POLL_WRITE; 1.785 + LOG_DEBUG((" telling the other socket to write")); 1.786 + } 1.787 + else 1.788 + LOG_DEBUG((" we have something for the other socket to write, but ssl has not been administered on it")); 1.789 + } 1.790 + } // PR_POLL_READ handling 1.791 + 1.792 + if (out_flags & PR_POLL_WRITE) 1.793 + { 1.794 + LOG_DEBUG((" :writing")); 1.795 + int32_t bytesWrite = PR_Send(sockets[s].fd, buffers[s2].bufferhead, 1.796 + buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT); 1.797 + 1.798 + if (bytesWrite < 0) 1.799 + { 1.800 + if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { 1.801 + LOG_DEBUG((" error=%d", PR_GetError())); 1.802 + client_error = true; 1.803 + socketErrorState[s] = true; 1.804 + // We got a fatal error while writting the buffer. Clear it to break 1.805 + // the main loop, we will never more be able to send it. 1.806 + buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; 1.807 + } 1.808 + else 1.809 + LOG_DEBUG((" would block")); 1.810 + } 1.811 + else 1.812 + { 1.813 + LOG_DEBUG((", written %d bytes", bytesWrite)); 1.814 + buffers[s2].buffertail[1] = '\0'; 1.815 + LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead)); 1.816 + 1.817 + buffers[s2].bufferhead += bytesWrite; 1.818 + if (buffers[s2].present()) 1.819 + { 1.820 + LOG_DEBUG((" still have to write %d bytes", (int)buffers[s2].present())); 1.821 + in_flags |= PR_POLL_WRITE; 1.822 + } 1.823 + else 1.824 + { 1.825 + if (!ssl_updated) 1.826 + { 1.827 + LOG_DEBUG((" proxy response sent to the client")); 1.828 + // Proxy response has just been writen, update to ssl 1.829 + ssl_updated = true; 1.830 + if (ci->http_proxy_only) 1.831 + { 1.832 + LOG_DEBUG((" not updating to SSL based on http_proxy_only for this socket")); 1.833 + } 1.834 + else if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, 1.835 + certificateToUse, clientAuth)) 1.836 + { 1.837 + LOG_ERRORD((" failed to config server socket\n")); 1.838 + client_error = true; 1.839 + break; 1.840 + } 1.841 + else 1.842 + { 1.843 + LOG_DEBUG((" client socket updated to SSL")); 1.844 + } 1.845 + } // sslUpdate 1.846 + 1.847 + LOG_DEBUG((" dropping our write flag and setting other socket read flag")); 1.848 + in_flags &= ~PR_POLL_WRITE; 1.849 + in_flags2 |= PR_POLL_READ; 1.850 + buffers[s2].compact(); 1.851 + } 1.852 + } 1.853 + } // PR_POLL_WRITE handling 1.854 + LOG_END_BLOCK(); // end the log 1.855 + } // for... 1.856 + } // while, poll 1.857 + } 1.858 + else 1.859 + client_error = true; 1.860 + 1.861 + LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n", 1.862 + static_cast<void*>(data), 1.863 + static_cast<void*>(ci->client_sock), 1.864 + static_cast<void*>(other_sock))); 1.865 + if (!client_error) 1.866 + PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND); 1.867 + PR_Close(ci->client_sock); 1.868 + 1.869 + delete ci; 1.870 +} 1.871 + 1.872 +/* 1.873 + * Start listening for SSL connections on a specified port, handing 1.874 + * them off to client threads after accepting the connection. 1.875 + * The data parameter is a server_info_t*, owned by the calling 1.876 + * function. 1.877 + */ 1.878 +void StartServer(void* data) 1.879 +{ 1.880 + server_info_t* si = static_cast<server_info_t*>(data); 1.881 + 1.882 + //TODO: select ciphers? 1.883 + ScopedPRFileDesc listen_socket(PR_NewTCPSocket()); 1.884 + if (!listen_socket) { 1.885 + LOG_ERROR(("failed to create socket\n")); 1.886 + SignalShutdown(); 1.887 + return; 1.888 + } 1.889 + 1.890 + // In case the socket is still open in the TIME_WAIT state from a previous 1.891 + // instance of ssltunnel we ask to reuse the port. 1.892 + PRSocketOptionData socket_option; 1.893 + socket_option.option = PR_SockOpt_Reuseaddr; 1.894 + socket_option.value.reuse_addr = true; 1.895 + PR_SetSocketOption(listen_socket, &socket_option); 1.896 + 1.897 + PRNetAddr server_addr; 1.898 + PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr); 1.899 + if (PR_Bind(listen_socket, &server_addr) != PR_SUCCESS) { 1.900 + LOG_ERROR(("failed to bind socket\n")); 1.901 + SignalShutdown(); 1.902 + return; 1.903 + } 1.904 + 1.905 + if (PR_Listen(listen_socket, 1) != PR_SUCCESS) { 1.906 + LOG_ERROR(("failed to listen on socket\n")); 1.907 + SignalShutdown(); 1.908 + return; 1.909 + } 1.910 + 1.911 + LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port, 1.912 + si->cert_nickname.c_str())); 1.913 + 1.914 + while (!shutdown_server) { 1.915 + connection_info_t* ci = new connection_info_t(); 1.916 + ci->server_info = si; 1.917 + ci->http_proxy_only = do_http_proxy; 1.918 + // block waiting for connections 1.919 + ci->client_sock = PR_Accept(listen_socket, &ci->client_addr, 1.920 + PR_INTERVAL_NO_TIMEOUT); 1.921 + 1.922 + PRSocketOptionData option; 1.923 + option.option = PR_SockOpt_Nonblocking; 1.924 + option.value.non_blocking = true; 1.925 + PR_SetSocketOption(ci->client_sock, &option); 1.926 + 1.927 + if (ci->client_sock) 1.928 + // Not actually using this PRJob*... 1.929 + //PRJob* job = 1.930 + PR_QueueJob(threads, HandleConnection, ci, true); 1.931 + else 1.932 + delete ci; 1.933 + } 1.934 +} 1.935 + 1.936 +// bogus password func, just don't use passwords. :-P 1.937 +char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg) 1.938 +{ 1.939 + if (retry) 1.940 + return nullptr; 1.941 + 1.942 + return PL_strdup(""); 1.943 +} 1.944 + 1.945 +server_info_t* findServerInfo(int portnumber) 1.946 +{ 1.947 + for (vector<server_info_t>::iterator it = servers.begin(); 1.948 + it != servers.end(); it++) 1.949 + { 1.950 + if (it->listen_port == portnumber) 1.951 + return &(*it); 1.952 + } 1.953 + 1.954 + return nullptr; 1.955 +} 1.956 + 1.957 +int processConfigLine(char* configLine) 1.958 +{ 1.959 + if (*configLine == 0 || *configLine == '#') 1.960 + return 0; 1.961 + 1.962 + char* _caret; 1.963 + char* keyword = strtok2(configLine, ":", &_caret); 1.964 + 1.965 + // Configure usage of http/ssl tunneling proxy behavior 1.966 + if (!strcmp(keyword, "httpproxy")) 1.967 + { 1.968 + char* value = strtok2(_caret, ":", &_caret); 1.969 + if (!strcmp(value, "1")) 1.970 + do_http_proxy = true; 1.971 + 1.972 + return 0; 1.973 + } 1.974 + 1.975 + if (!strcmp(keyword, "websocketserver")) 1.976 + { 1.977 + char* ipstring = strtok2(_caret, ":", &_caret); 1.978 + if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) { 1.979 + LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring)); 1.980 + return 1; 1.981 + } 1.982 + char* remoteport = strtok2(_caret, ":", &_caret); 1.983 + int port = atoi(remoteport); 1.984 + if (port <= 0) { 1.985 + LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport)); 1.986 + return 1; 1.987 + } 1.988 + websocket_server.inet.port = PR_htons(port); 1.989 + return 0; 1.990 + } 1.991 + 1.992 + // Configure the forward address of the target server 1.993 + if (!strcmp(keyword, "forward")) 1.994 + { 1.995 + char* ipstring = strtok2(_caret, ":", &_caret); 1.996 + if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) { 1.997 + LOG_ERROR(("Invalid remote IP address: %s\n", ipstring)); 1.998 + return 1; 1.999 + } 1.1000 + char* serverportstring = strtok2(_caret, ":", &_caret); 1.1001 + int port = atoi(serverportstring); 1.1002 + if (port <= 0) { 1.1003 + LOG_ERROR(("Invalid remote port: %s\n", serverportstring)); 1.1004 + return 1; 1.1005 + } 1.1006 + remote_addr.inet.port = PR_htons(port); 1.1007 + 1.1008 + return 0; 1.1009 + } 1.1010 + 1.1011 + // Configure all listen sockets and port+certificate bindings 1.1012 + if (!strcmp(keyword, "listen")) 1.1013 + { 1.1014 + char* hostname = strtok2(_caret, ":", &_caret); 1.1015 + char* hostportstring = nullptr; 1.1016 + if (strcmp(hostname, "*")) 1.1017 + { 1.1018 + any_host_spec_config = true; 1.1019 + hostportstring = strtok2(_caret, ":", &_caret); 1.1020 + } 1.1021 + 1.1022 + char* serverportstring = strtok2(_caret, ":", &_caret); 1.1023 + char* certnick = strtok2(_caret, ":", &_caret); 1.1024 + 1.1025 + int port = atoi(serverportstring); 1.1026 + if (port <= 0) { 1.1027 + LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); 1.1028 + return 1; 1.1029 + } 1.1030 + 1.1031 + if (server_info_t* existingServer = findServerInfo(port)) 1.1032 + { 1.1033 + char *certnick_copy = new char[strlen(certnick)+1]; 1.1034 + char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; 1.1035 + 1.1036 + strcpy(hostname_copy, hostname); 1.1037 + strcat(hostname_copy, ":"); 1.1038 + strcat(hostname_copy, hostportstring); 1.1039 + strcpy(certnick_copy, certnick); 1.1040 + 1.1041 + PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, hostname_copy, certnick_copy); 1.1042 + if (!entry) { 1.1043 + LOG_ERROR(("Out of memory")); 1.1044 + return 1; 1.1045 + } 1.1046 + } 1.1047 + else 1.1048 + { 1.1049 + server_info_t server; 1.1050 + server.cert_nickname = certnick; 1.1051 + server.listen_port = port; 1.1052 + server.host_cert_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1.1053 + PL_CompareStrings, nullptr, nullptr); 1.1054 + if (!server.host_cert_table) 1.1055 + { 1.1056 + LOG_ERROR(("Internal, could not create hash table\n")); 1.1057 + return 1; 1.1058 + } 1.1059 + server.host_clientauth_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1.1060 + ClientAuthValueComparator, nullptr, nullptr); 1.1061 + if (!server.host_clientauth_table) 1.1062 + { 1.1063 + LOG_ERROR(("Internal, could not create hash table\n")); 1.1064 + return 1; 1.1065 + } 1.1066 + server.host_redir_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, 1.1067 + PL_CompareStrings, nullptr, nullptr); 1.1068 + if (!server.host_redir_table) 1.1069 + { 1.1070 + LOG_ERROR(("Internal, could not create hash table\n")); 1.1071 + return 1; 1.1072 + } 1.1073 + servers.push_back(server); 1.1074 + } 1.1075 + 1.1076 + return 0; 1.1077 + } 1.1078 + 1.1079 + if (!strcmp(keyword, "clientauth")) 1.1080 + { 1.1081 + char* hostname = strtok2(_caret, ":", &_caret); 1.1082 + char* hostportstring = strtok2(_caret, ":", &_caret); 1.1083 + char* serverportstring = strtok2(_caret, ":", &_caret); 1.1084 + 1.1085 + int port = atoi(serverportstring); 1.1086 + if (port <= 0) { 1.1087 + LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); 1.1088 + return 1; 1.1089 + } 1.1090 + 1.1091 + if (server_info_t* existingServer = findServerInfo(port)) 1.1092 + { 1.1093 + char* authoptionstring = strtok2(_caret, ":", &_caret); 1.1094 + client_auth_option* authoption = new client_auth_option; 1.1095 + if (!authoption) { 1.1096 + LOG_ERROR(("Out of memory")); 1.1097 + return 1; 1.1098 + } 1.1099 + 1.1100 + if (!strcmp(authoptionstring, "require")) 1.1101 + *authoption = caRequire; 1.1102 + else if (!strcmp(authoptionstring, "request")) 1.1103 + *authoption = caRequest; 1.1104 + else if (!strcmp(authoptionstring, "none")) 1.1105 + *authoption = caNone; 1.1106 + else 1.1107 + { 1.1108 + LOG_ERROR(("Incorrect client auth option modifier for host '%s'", hostname)); 1.1109 + return 1; 1.1110 + } 1.1111 + 1.1112 + any_host_spec_config = true; 1.1113 + 1.1114 + char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; 1.1115 + if (!hostname_copy) { 1.1116 + LOG_ERROR(("Out of memory")); 1.1117 + return 1; 1.1118 + } 1.1119 + 1.1120 + strcpy(hostname_copy, hostname); 1.1121 + strcat(hostname_copy, ":"); 1.1122 + strcat(hostname_copy, hostportstring); 1.1123 + 1.1124 + PLHashEntry* entry = PL_HashTableAdd(existingServer->host_clientauth_table, hostname_copy, authoption); 1.1125 + if (!entry) { 1.1126 + LOG_ERROR(("Out of memory")); 1.1127 + return 1; 1.1128 + } 1.1129 + } 1.1130 + else 1.1131 + { 1.1132 + LOG_ERROR(("Server on port %d for client authentication option is not defined, use 'listen' option first", port)); 1.1133 + return 1; 1.1134 + } 1.1135 + 1.1136 + return 0; 1.1137 + } 1.1138 + 1.1139 + if (!strcmp(keyword, "redirhost")) 1.1140 + { 1.1141 + char* hostname = strtok2(_caret, ":", &_caret); 1.1142 + char* hostportstring = strtok2(_caret, ":", &_caret); 1.1143 + char* serverportstring = strtok2(_caret, ":", &_caret); 1.1144 + 1.1145 + int port = atoi(serverportstring); 1.1146 + if (port <= 0) { 1.1147 + LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); 1.1148 + return 1; 1.1149 + } 1.1150 + 1.1151 + if (server_info_t* existingServer = findServerInfo(port)) 1.1152 + { 1.1153 + char* redirhoststring = strtok2(_caret, ":", &_caret); 1.1154 + 1.1155 + any_host_spec_config = true; 1.1156 + 1.1157 + char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; 1.1158 + if (!hostname_copy) { 1.1159 + LOG_ERROR(("Out of memory")); 1.1160 + return 1; 1.1161 + } 1.1162 + 1.1163 + strcpy(hostname_copy, hostname); 1.1164 + strcat(hostname_copy, ":"); 1.1165 + strcat(hostname_copy, hostportstring); 1.1166 + 1.1167 + char *redir_copy = new char[strlen(redirhoststring)+1]; 1.1168 + strcpy(redir_copy, redirhoststring); 1.1169 + PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table, hostname_copy, redir_copy); 1.1170 + if (!entry) { 1.1171 + LOG_ERROR(("Out of memory")); 1.1172 + return 1; 1.1173 + } 1.1174 + } 1.1175 + else 1.1176 + { 1.1177 + LOG_ERROR(("Server on port %d for redirhost option is not defined, use 'listen' option first", port)); 1.1178 + return 1; 1.1179 + } 1.1180 + 1.1181 + return 0; 1.1182 + } 1.1183 + 1.1184 + // Configure the NSS certificate database directory 1.1185 + if (!strcmp(keyword, "certdbdir")) 1.1186 + { 1.1187 + nssconfigdir = strtok2(_caret, "\n", &_caret); 1.1188 + return 0; 1.1189 + } 1.1190 + 1.1191 + LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword)); 1.1192 + return 1; 1.1193 +} 1.1194 + 1.1195 +int parseConfigFile(const char* filePath) 1.1196 +{ 1.1197 + FILE* f = fopen(filePath, "r"); 1.1198 + if (!f) 1.1199 + return 1; 1.1200 + 1.1201 + char buffer[1024], *b = buffer; 1.1202 + while (!feof(f)) 1.1203 + { 1.1204 + char c; 1.1205 + fscanf(f, "%c", &c); 1.1206 + switch (c) 1.1207 + { 1.1208 + case '\n': 1.1209 + *b++ = 0; 1.1210 + if (processConfigLine(buffer)) 1.1211 + return 1; 1.1212 + b = buffer; 1.1213 + case '\r': 1.1214 + continue; 1.1215 + default: 1.1216 + *b++ = c; 1.1217 + } 1.1218 + } 1.1219 + 1.1220 + fclose(f); 1.1221 + 1.1222 + // Check mandatory items 1.1223 + if (nssconfigdir.empty()) 1.1224 + { 1.1225 + LOG_ERROR(("Error: missing path to NSS certification database\n,use certdbdir:<path> in the config file\n")); 1.1226 + return 1; 1.1227 + } 1.1228 + 1.1229 + if (any_host_spec_config && !do_http_proxy) 1.1230 + { 1.1231 + LOG_ERROR(("Warning: any host-specific configurations are ignored, add httpproxy:1 to allow them\n")); 1.1232 + } 1.1233 + 1.1234 + return 0; 1.1235 +} 1.1236 + 1.1237 +int freeHostCertHashItems(PLHashEntry *he, int i, void *arg) 1.1238 +{ 1.1239 + delete [] (char*)he->key; 1.1240 + delete [] (char*)he->value; 1.1241 + return HT_ENUMERATE_REMOVE; 1.1242 +} 1.1243 + 1.1244 +int freeHostRedirHashItems(PLHashEntry *he, int i, void *arg) 1.1245 +{ 1.1246 + delete [] (char*)he->key; 1.1247 + delete [] (char*)he->value; 1.1248 + return HT_ENUMERATE_REMOVE; 1.1249 +} 1.1250 + 1.1251 +int freeClientAuthHashItems(PLHashEntry *he, int i, void *arg) 1.1252 +{ 1.1253 + delete [] (char*)he->key; 1.1254 + delete (client_auth_option*)he->value; 1.1255 + return HT_ENUMERATE_REMOVE; 1.1256 +} 1.1257 + 1.1258 +int main(int argc, char** argv) 1.1259 +{ 1.1260 + const char* configFilePath; 1.1261 + 1.1262 + const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL"); 1.1263 + gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO; 1.1264 + 1.1265 + if (argc == 1) 1.1266 + configFilePath = "ssltunnel.cfg"; 1.1267 + else 1.1268 + configFilePath = argv[1]; 1.1269 + 1.1270 + memset(&websocket_server, 0, sizeof(PRNetAddr)); 1.1271 + 1.1272 + if (parseConfigFile(configFilePath)) { 1.1273 + LOG_ERROR(("Error: config file \"%s\" missing or formating incorrect\n" 1.1274 + "Specify path to the config file as parameter to ssltunnel or \n" 1.1275 + "create ssltunnel.cfg in the working directory.\n\n" 1.1276 + "Example format of the config file:\n\n" 1.1277 + " # Enable http/ssl tunneling proxy-like behavior.\n" 1.1278 + " # If not specified ssltunnel simply does direct forward.\n" 1.1279 + " httpproxy:1\n\n" 1.1280 + " # Specify path to the certification database used.\n" 1.1281 + " certdbdir:/path/to/certdb\n\n" 1.1282 + " # Forward/proxy all requests in raw to 127.0.0.1:8888.\n" 1.1283 + " forward:127.0.0.1:8888\n\n" 1.1284 + " # Accept connections on port 4443 or 5678 resp. and authenticate\n" 1.1285 + " # to any host ('*') using the 'server cert' or 'server cert 2' resp.\n" 1.1286 + " listen:*:4443:server cert\n" 1.1287 + " listen:*:5678:server cert 2\n\n" 1.1288 + " # Accept connections on port 4443 and authenticate using\n" 1.1289 + " # 'a different cert' when target host is 'my.host.name:443'.\n" 1.1290 + " # This only works in httpproxy mode and has higher priority\n" 1.1291 + " # than the previous option.\n" 1.1292 + " listen:my.host.name:443:4443:a different cert\n\n" 1.1293 + " # To make a specific host require or just request a client certificate\n" 1.1294 + " # to authenticate use the following options. This can only be used\n" 1.1295 + " # in httpproxy mode and only after the 'listen' option has been\n" 1.1296 + " # specified. You also have to specify the tunnel listen port.\n" 1.1297 + " clientauth:requesting-client-cert.host.com:443:4443:request\n" 1.1298 + " clientauth:requiring-client-cert.host.com:443:4443:require\n" 1.1299 + " # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n" 1.1300 + " # instead of the server specified in the 'forward' option.\n" 1.1301 + " websocketserver:127.0.0.1:9999\n", 1.1302 + configFilePath)); 1.1303 + return 1; 1.1304 + } 1.1305 + 1.1306 + // create a thread pool to handle connections 1.1307 + threads = PR_CreateThreadPool(INITIAL_THREADS * servers.size(), 1.1308 + MAX_THREADS * servers.size(), 1.1309 + DEFAULT_STACKSIZE); 1.1310 + if (!threads) { 1.1311 + LOG_ERROR(("Failed to create thread pool\n")); 1.1312 + return 1; 1.1313 + } 1.1314 + 1.1315 + shutdown_lock = PR_NewLock(); 1.1316 + if (!shutdown_lock) { 1.1317 + LOG_ERROR(("Failed to create lock\n")); 1.1318 + PR_ShutdownThreadPool(threads); 1.1319 + return 1; 1.1320 + } 1.1321 + shutdown_condvar = PR_NewCondVar(shutdown_lock); 1.1322 + if (!shutdown_condvar) { 1.1323 + LOG_ERROR(("Failed to create condvar\n")); 1.1324 + PR_ShutdownThreadPool(threads); 1.1325 + PR_DestroyLock(shutdown_lock); 1.1326 + return 1; 1.1327 + } 1.1328 + 1.1329 + PK11_SetPasswordFunc(password_func); 1.1330 + 1.1331 + // Initialize NSS 1.1332 + if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) { 1.1333 + int32_t errorlen = PR_GetErrorTextLength(); 1.1334 + char* err = new char[errorlen+1]; 1.1335 + PR_GetErrorText(err); 1.1336 + LOG_ERROR(("Failed to init NSS: %s", err)); 1.1337 + delete[] err; 1.1338 + PR_ShutdownThreadPool(threads); 1.1339 + PR_DestroyCondVar(shutdown_condvar); 1.1340 + PR_DestroyLock(shutdown_lock); 1.1341 + return 1; 1.1342 + } 1.1343 + 1.1344 + if (NSS_SetDomesticPolicy() != SECSuccess) { 1.1345 + LOG_ERROR(("NSS_SetDomesticPolicy failed\n")); 1.1346 + PR_ShutdownThreadPool(threads); 1.1347 + PR_DestroyCondVar(shutdown_condvar); 1.1348 + PR_DestroyLock(shutdown_lock); 1.1349 + NSS_Shutdown(); 1.1350 + return 1; 1.1351 + } 1.1352 + 1.1353 + // these values should make NSS use the defaults 1.1354 + if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { 1.1355 + LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n")); 1.1356 + PR_ShutdownThreadPool(threads); 1.1357 + PR_DestroyCondVar(shutdown_condvar); 1.1358 + PR_DestroyLock(shutdown_lock); 1.1359 + NSS_Shutdown(); 1.1360 + return 1; 1.1361 + } 1.1362 + 1.1363 + for (vector<server_info_t>::iterator it = servers.begin(); 1.1364 + it != servers.end(); it++) { 1.1365 + // Not actually using this PRJob*... 1.1366 + // PRJob* server_job = 1.1367 + PR_QueueJob(threads, StartServer, &(*it), true); 1.1368 + } 1.1369 + // now wait for someone to tell us to quit 1.1370 + PR_Lock(shutdown_lock); 1.1371 + PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT); 1.1372 + PR_Unlock(shutdown_lock); 1.1373 + shutdown_server = true; 1.1374 + LOG_INFO(("Shutting down...\n")); 1.1375 + // cleanup 1.1376 + PR_ShutdownThreadPool(threads); 1.1377 + PR_JoinThreadPool(threads); 1.1378 + PR_DestroyCondVar(shutdown_condvar); 1.1379 + PR_DestroyLock(shutdown_lock); 1.1380 + if (NSS_Shutdown() == SECFailure) { 1.1381 + LOG_DEBUG(("Leaked NSS objects!\n")); 1.1382 + } 1.1383 + 1.1384 + for (vector<server_info_t>::iterator it = servers.begin(); 1.1385 + it != servers.end(); it++) 1.1386 + { 1.1387 + PL_HashTableEnumerateEntries(it->host_cert_table, freeHostCertHashItems, nullptr); 1.1388 + PL_HashTableEnumerateEntries(it->host_clientauth_table, freeClientAuthHashItems, nullptr); 1.1389 + PL_HashTableEnumerateEntries(it->host_redir_table, freeHostRedirHashItems, nullptr); 1.1390 + PL_HashTableDestroy(it->host_cert_table); 1.1391 + PL_HashTableDestroy(it->host_clientauth_table); 1.1392 + PL_HashTableDestroy(it->host_redir_table); 1.1393 + } 1.1394 + 1.1395 + PR_Cleanup(); 1.1396 + return 0; 1.1397 +}