michael@0: /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ michael@0: /* This Source Code Form is subject to the terms of the Mozilla Public michael@0: * License, v. 2.0. If a copy of the MPL was not distributed with this michael@0: * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ michael@0: michael@0: /* michael@0: * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to michael@0: * be plagued with the usual problems endemic to C (buffer overflows michael@0: * and the like). We don't especially care here (but would accept michael@0: * patches!) because this is only intended for use in our test michael@0: * harnesses in controlled situations where input is guaranteed not to michael@0: * be malicious. michael@0: */ michael@0: michael@0: #include "ScopedNSSTypes.h" michael@0: #include michael@0: #include michael@0: #include michael@0: #include michael@0: #include michael@0: #include michael@0: #include "prinit.h" michael@0: #include "prerror.h" michael@0: #include "prenv.h" michael@0: #include "prnetdb.h" michael@0: #include "prtpool.h" michael@0: #include "nsAlgorithm.h" michael@0: #include "nss.h" michael@0: #include "key.h" michael@0: #include "ssl.h" michael@0: #include "plhash.h" michael@0: michael@0: using namespace mozilla; michael@0: using namespace mozilla::psm; michael@0: using std::string; michael@0: using std::vector; michael@0: michael@0: #define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c) & 7))) michael@0: #define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c) & 7))) michael@0: #define DELIM_TABLE_SIZE 32 michael@0: michael@0: // You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n michael@0: // is 0 through 3. The default is 1, INFO level logging. michael@0: enum LogLevel { michael@0: LEVEL_DEBUG = 0, michael@0: LEVEL_INFO = 1, michael@0: LEVEL_ERROR = 2, michael@0: LEVEL_SILENT = 3 michael@0: } gLogLevel, gLastLogLevel; michael@0: michael@0: #define _LOG_OUTPUT(level, func, params) \ michael@0: PR_BEGIN_MACRO \ michael@0: if (level >= gLogLevel) { \ michael@0: gLastLogLevel = level; \ michael@0: func params;\ michael@0: } \ michael@0: PR_END_MACRO michael@0: michael@0: // The most verbose output michael@0: #define LOG_DEBUG(params) \ michael@0: _LOG_OUTPUT(LEVEL_DEBUG, printf, params) michael@0: michael@0: // Top level informative messages michael@0: #define LOG_INFO(params) \ michael@0: _LOG_OUTPUT(LEVEL_INFO, printf, params) michael@0: michael@0: // Serious errors that must be logged always until completely gag michael@0: #define LOG_ERROR(params) \ michael@0: _LOG_OUTPUT(LEVEL_ERROR, eprintf, params) michael@0: michael@0: // Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message michael@0: // will be put to the stdout instead of stderr to keep continuity with other michael@0: // LOG_DEBUG message output michael@0: #define LOG_ERRORD(params) \ michael@0: PR_BEGIN_MACRO \ michael@0: if (gLogLevel == LEVEL_DEBUG) \ michael@0: _LOG_OUTPUT(LEVEL_ERROR, printf, params); \ michael@0: else \ michael@0: _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \ michael@0: PR_END_MACRO michael@0: michael@0: // If there is any output written between LOG_BEGIN_BLOCK() and michael@0: // LOG_END_BLOCK() then a new line will be put to the proper output (out/err) michael@0: #define LOG_BEGIN_BLOCK() \ michael@0: gLastLogLevel = LEVEL_SILENT; michael@0: michael@0: #define LOG_END_BLOCK() \ michael@0: PR_BEGIN_MACRO \ michael@0: if (gLastLogLevel == LEVEL_ERROR) \ michael@0: LOG_ERROR(("\n")); \ michael@0: if (gLastLogLevel < LEVEL_ERROR) \ michael@0: _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \ michael@0: PR_END_MACRO michael@0: michael@0: int eprintf(const char* str, ...) michael@0: { michael@0: va_list ap; michael@0: va_start(ap, str); michael@0: int result = vfprintf(stderr, str, ap); michael@0: va_end(ap); michael@0: return result; michael@0: } michael@0: michael@0: // Copied from nsCRT michael@0: char* strtok2(char* string, const char* delims, char* *newStr) michael@0: { michael@0: PR_ASSERT(string); michael@0: michael@0: char delimTable[DELIM_TABLE_SIZE]; michael@0: uint32_t i; michael@0: char* result; michael@0: char* str = string; michael@0: michael@0: for (i = 0; i < DELIM_TABLE_SIZE; i++) michael@0: delimTable[i] = '\0'; michael@0: michael@0: for (i = 0; delims[i]; i++) { michael@0: SET_DELIM(delimTable, static_cast(delims[i])); michael@0: } michael@0: michael@0: // skip to beginning michael@0: while (*str && IS_DELIM(delimTable, static_cast(*str))) { michael@0: str++; michael@0: } michael@0: result = str; michael@0: michael@0: // fix up the end of the token michael@0: while (*str) { michael@0: if (IS_DELIM(delimTable, static_cast(*str))) { michael@0: *str++ = '\0'; michael@0: break; michael@0: } michael@0: str++; michael@0: } michael@0: *newStr = str; michael@0: michael@0: return str == result ? nullptr : result; michael@0: } michael@0: michael@0: michael@0: michael@0: enum client_auth_option { michael@0: caNone = 0, michael@0: caRequire = 1, michael@0: caRequest = 2 michael@0: }; michael@0: michael@0: // Structs for passing data into jobs on the thread pool michael@0: typedef struct { michael@0: int32_t listen_port; michael@0: string cert_nickname; michael@0: PLHashTable* host_cert_table; michael@0: PLHashTable* host_clientauth_table; michael@0: PLHashTable* host_redir_table; michael@0: } server_info_t; michael@0: michael@0: typedef struct { michael@0: PRFileDesc* client_sock; michael@0: PRNetAddr client_addr; michael@0: server_info_t* server_info; michael@0: // the original host in the Host: header for this connection is michael@0: // stored here, for proxied connections michael@0: string original_host; michael@0: // true if no SSL should be used for this connection michael@0: bool http_proxy_only; michael@0: // true if this connection is for a WebSocket michael@0: bool iswebsocket; michael@0: } connection_info_t; michael@0: michael@0: typedef struct { michael@0: string fullHost; michael@0: bool matched; michael@0: } server_match_t; michael@0: michael@0: const int32_t BUF_SIZE = 16384; michael@0: const int32_t BUF_MARGIN = 1024; michael@0: const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN; michael@0: michael@0: struct relayBuffer michael@0: { michael@0: char *buffer, *bufferhead, *buffertail, *bufferend; michael@0: michael@0: relayBuffer() michael@0: { michael@0: // Leave 1024 bytes more for request line manipulations michael@0: bufferhead = buffertail = buffer = new char[BUF_TOTAL]; michael@0: bufferend = buffer + BUF_SIZE; michael@0: } michael@0: michael@0: ~relayBuffer() michael@0: { michael@0: delete [] buffer; michael@0: } michael@0: michael@0: void compact() { michael@0: if (buffertail == bufferhead) michael@0: buffertail = bufferhead = buffer; michael@0: } michael@0: michael@0: bool empty() { return bufferhead == buffertail; } michael@0: size_t areafree() { return bufferend - buffertail; } michael@0: size_t margin() { return areafree() + BUF_MARGIN; } michael@0: size_t present() { return buffertail - bufferhead; } michael@0: }; michael@0: michael@0: // These numbers are multiplied by the number of listening ports (actual michael@0: // servers running). According the thread pool implementation there is no michael@0: // need to limit the number of threads initially, threads are allocated michael@0: // dynamically and stored in a linked list. Initial number of 2 is chosen michael@0: // to allocate a thread for socket accept and preallocate one for the first michael@0: // connection that is with high probability expected to come. michael@0: const uint32_t INITIAL_THREADS = 2; michael@0: const uint32_t MAX_THREADS = 100; michael@0: const uint32_t DEFAULT_STACKSIZE = (512 * 1024); michael@0: michael@0: // global data michael@0: string nssconfigdir; michael@0: vector servers; michael@0: PRNetAddr remote_addr; michael@0: PRNetAddr websocket_server; michael@0: PRThreadPool* threads = nullptr; michael@0: PRLock* shutdown_lock = nullptr; michael@0: PRCondVar* shutdown_condvar = nullptr; michael@0: // Not really used, unless something fails to start michael@0: bool shutdown_server = false; michael@0: bool do_http_proxy = false; michael@0: bool any_host_spec_config = false; michael@0: michael@0: int ClientAuthValueComparator(const void *v1, const void *v2) michael@0: { michael@0: int a = *static_cast(v1) - michael@0: *static_cast(v2); michael@0: if (a == 0) michael@0: return 0; michael@0: if (a > 0) michael@0: return 1; michael@0: else // (a < 0) michael@0: return -1; michael@0: } michael@0: michael@0: static int match_hostname(PLHashEntry *he, int index, void* arg) michael@0: { michael@0: server_match_t *match = (server_match_t*)arg; michael@0: if (match->fullHost.find((char*)he->key) != string::npos) michael@0: match->matched = true; michael@0: return HT_ENUMERATE_NEXT; michael@0: } michael@0: michael@0: /* michael@0: * Signal the main thread that the application should shut down. michael@0: */ michael@0: void SignalShutdown() michael@0: { michael@0: PR_Lock(shutdown_lock); michael@0: PR_NotifyCondVar(shutdown_condvar); michael@0: PR_Unlock(shutdown_lock); michael@0: } michael@0: michael@0: bool ReadConnectRequest(server_info_t* server_info, michael@0: relayBuffer& buffer, int32_t* result, string& certificate, michael@0: client_auth_option* clientauth, string& host, string& location) michael@0: { michael@0: if (buffer.present() < 4) { michael@0: LOG_DEBUG((" !! only %d bytes present in the buffer", (int)buffer.present())); michael@0: return false; michael@0: } michael@0: if (strncmp(buffer.buffertail-4, "\r\n\r\n", 4)) { michael@0: LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x", michael@0: *(buffer.buffertail-4), michael@0: *(buffer.buffertail-3), michael@0: *(buffer.buffertail-2), michael@0: *(buffer.buffertail-1))); michael@0: return false; michael@0: } michael@0: michael@0: LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n", (int)buffer.present(), buffer.bufferhead)); michael@0: michael@0: *result = 400; michael@0: michael@0: char* token; michael@0: char* _caret; michael@0: token = strtok2(buffer.bufferhead, " ", &_caret); michael@0: if (!token) { michael@0: LOG_ERRORD((" no space found")); michael@0: return true; michael@0: } michael@0: if (strcmp(token, "CONNECT")) { michael@0: LOG_ERRORD((" not CONNECT request but %s", token)); michael@0: return true; michael@0: } michael@0: michael@0: token = strtok2(_caret, " ", &_caret); michael@0: void* c = PL_HashTableLookup(server_info->host_cert_table, token); michael@0: if (c) michael@0: certificate = static_cast(c); michael@0: michael@0: host = "https://"; michael@0: host += token; michael@0: michael@0: c = PL_HashTableLookup(server_info->host_clientauth_table, token); michael@0: if (c) michael@0: *clientauth = *static_cast(c); michael@0: else michael@0: *clientauth = caNone; michael@0: michael@0: void *redir = PL_HashTableLookup(server_info->host_redir_table, token); michael@0: if (redir) michael@0: location = static_cast(redir); michael@0: michael@0: token = strtok2(_caret, "/", &_caret); michael@0: if (strcmp(token, "HTTP")) { michael@0: LOG_ERRORD((" not tailed with HTTP but with %s", token)); michael@0: return true; michael@0: } michael@0: michael@0: *result = (redir) ? 302 : 200; michael@0: return true; michael@0: } michael@0: michael@0: bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, string &certificate, client_auth_option clientAuth) michael@0: { michael@0: const char* certnick = certificate.empty() ? michael@0: si->cert_nickname.c_str() : certificate.c_str(); michael@0: michael@0: ScopedCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr)); michael@0: if (!cert) { michael@0: LOG_ERROR(("Failed to find cert %s\n", certnick)); michael@0: return false; michael@0: } michael@0: michael@0: ScopedSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert, nullptr)); michael@0: if (!privKey) { michael@0: LOG_ERROR(("Failed to find private key\n")); michael@0: return false; michael@0: } michael@0: michael@0: PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket); michael@0: if (!ssl_socket) { michael@0: LOG_ERROR(("Error importing SSL socket\n")); michael@0: return false; michael@0: } michael@0: michael@0: SSLKEAType certKEA = NSS_FindCertKEAType(cert); michael@0: if (SSL_ConfigSecureServer(ssl_socket, cert, privKey, certKEA) michael@0: != SECSuccess) { michael@0: LOG_ERROR(("Error configuring SSL server socket\n")); michael@0: return false; michael@0: } michael@0: michael@0: SSL_OptionSet(ssl_socket, SSL_SECURITY, true); michael@0: SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false); michael@0: SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true); michael@0: michael@0: if (clientAuth != caNone) michael@0: { michael@0: SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true); michael@0: SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire); michael@0: } michael@0: michael@0: SSL_ResetHandshake(ssl_socket, true); michael@0: michael@0: return true; michael@0: } michael@0: michael@0: /** michael@0: * This function examines the buffer for a Sec-WebSocket-Location: field, michael@0: * and if it's present, it replaces the hostname in that field with the michael@0: * value in the server's original_host field. This function works michael@0: * in the reverse direction as AdjustWebSocketHost(), replacing the real michael@0: * hostname of a response with the potentially fake hostname that is expected michael@0: * by the browser (e.g., mochi.test). michael@0: * michael@0: * @return true if the header was adjusted successfully, or not found, false michael@0: * if the header is present but the url is not, which should indicate michael@0: * that more data needs to be read from the socket michael@0: */ michael@0: bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t *ci) michael@0: { michael@0: assert(buffer.margin()); michael@0: buffer.buffertail[1] = '\0'; michael@0: michael@0: char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:"); michael@0: if (!wsloc) michael@0: return true; michael@0: // advance pointer to the start of the hostname michael@0: wsloc = strstr(wsloc, "ws://"); michael@0: if (!wsloc) michael@0: return false; michael@0: wsloc += 5; michael@0: // find the end of the hostname michael@0: char* wslocend = strchr(wsloc + 1, '/'); michael@0: if (!wslocend) michael@0: return false; michael@0: char *crlf = strstr(wsloc, "\r\n"); michael@0: if (!crlf) michael@0: return false; michael@0: if (ci->original_host.empty()) michael@0: return true; michael@0: michael@0: int diff = ci->original_host.length() - (wslocend-wsloc); michael@0: if (diff > 0) michael@0: assert(size_t(diff) <= buffer.margin()); michael@0: memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff); michael@0: buffer.buffertail += diff; michael@0: michael@0: memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length()); michael@0: return true; michael@0: } michael@0: michael@0: /** michael@0: * This function examines the buffer for a Host: field, and if it's present, michael@0: * it replaces the hostname in that field with the hostname in the server's michael@0: * remote_addr field. This is needed because proxy requests may be coming michael@0: * from mochitest with fake hosts, like mochi.test, and these need to be michael@0: * replaced with the host that the destination server is actually running michael@0: * on. michael@0: */ michael@0: bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t *ci) michael@0: { michael@0: const char HEADER_UPGRADE[] = "Upgrade:"; michael@0: const char HEADER_HOST[] = "Host:"; michael@0: michael@0: PRNetAddr inet_addr = (websocket_server.inet.port ? websocket_server : michael@0: remote_addr); michael@0: michael@0: assert(buffer.margin()); michael@0: michael@0: // Cannot use strnchr so add a null char at the end. There is always some michael@0: // space left because we preserve a margin. michael@0: buffer.buffertail[1] = '\0'; michael@0: michael@0: // Verify this is a WebSocket header. michael@0: char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE); michael@0: if (!h1) michael@0: return false; michael@0: h1 += strlen(HEADER_UPGRADE); michael@0: h1 += strspn(h1, " \t"); michael@0: char* h2 = strstr(h1, "WebSocket\r\n"); michael@0: if (!h2) h2 = strstr(h1, "websocket\r\n"); michael@0: if (!h2) h2 = strstr(h1, "Websocket\r\n"); michael@0: if (!h2) michael@0: return false; michael@0: michael@0: char* host = strstr(buffer.bufferhead, HEADER_HOST); michael@0: if (!host) michael@0: return false; michael@0: // advance pointer to beginning of hostname michael@0: host += strlen(HEADER_HOST); michael@0: host += strspn(host, " \t"); michael@0: michael@0: char* endhost = strstr(host, "\r\n"); michael@0: if (!endhost) michael@0: return false; michael@0: michael@0: // Save the original host, so we can use it later on responses from the michael@0: // server. michael@0: ci->original_host.assign(host, endhost-host); michael@0: michael@0: char newhost[40]; michael@0: PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost)); michael@0: assert(strlen(newhost) < sizeof(newhost) - 7); michael@0: sprintf(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port)); michael@0: michael@0: int diff = strlen(newhost) - (endhost-host); michael@0: if (diff > 0) michael@0: assert(size_t(diff) <= buffer.margin()); michael@0: memmove(endhost + diff, endhost, buffer.buffertail - host - diff); michael@0: buffer.buffertail += diff; michael@0: michael@0: memcpy(host, newhost, strlen(newhost)); michael@0: return true; michael@0: } michael@0: michael@0: /** michael@0: * This function prefixes Request-URI path with a full scheme-host-port michael@0: * string. michael@0: */ michael@0: bool AdjustRequestURI(relayBuffer& buffer, string *host) michael@0: { michael@0: assert(buffer.margin()); michael@0: michael@0: // Cannot use strnchr so add a null char at the end. There is always some space left michael@0: // because we preserve a margin. michael@0: buffer.buffertail[1] = '\0'; michael@0: LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead)); michael@0: michael@0: char *token, *path; michael@0: path = strchr(buffer.bufferhead, ' ') + 1; michael@0: if (!path) michael@0: return false; michael@0: michael@0: // If the path doesn't start with a slash don't change it, it is probably '*' or a full michael@0: // path already. Return true, we are done with this request adjustment. michael@0: if (*path != '/') michael@0: return true; michael@0: michael@0: token = strchr(path, ' ') + 1; michael@0: if (!token) michael@0: return false; michael@0: michael@0: if (strncmp(token, "HTTP/", 5)) michael@0: return false; michael@0: michael@0: size_t hostlength = host->length(); michael@0: assert(hostlength <= buffer.margin()); michael@0: michael@0: memmove(path + hostlength, path, buffer.buffertail - path); michael@0: memcpy(path, host->c_str(), hostlength); michael@0: buffer.buffertail += hostlength; michael@0: michael@0: return true; michael@0: } michael@0: michael@0: bool ConnectSocket(PRFileDesc *fd, const PRNetAddr *addr, PRIntervalTime timeout) michael@0: { michael@0: PRStatus stat = PR_Connect(fd, addr, timeout); michael@0: if (stat != PR_SUCCESS) michael@0: return false; michael@0: michael@0: PRSocketOptionData option; michael@0: option.option = PR_SockOpt_Nonblocking; michael@0: option.value.non_blocking = true; michael@0: PR_SetSocketOption(fd, &option); michael@0: michael@0: return true; michael@0: } michael@0: michael@0: /* michael@0: * Handle an incoming client connection. The server thread has already michael@0: * accepted the connection, so we just need to connect to the remote michael@0: * port and then proxy data back and forth. michael@0: * The data parameter is a connection_info_t*, and must be deleted michael@0: * by this function. michael@0: */ michael@0: void HandleConnection(void* data) michael@0: { michael@0: connection_info_t* ci = static_cast(data); michael@0: PRIntervalTime connect_timeout = PR_SecondsToInterval(30); michael@0: michael@0: ScopedPRFileDesc other_sock(PR_NewTCPSocket()); michael@0: bool client_done = false; michael@0: bool client_error = false; michael@0: bool connect_accepted = !do_http_proxy; michael@0: bool ssl_updated = !do_http_proxy; michael@0: bool expect_request_start = do_http_proxy; michael@0: string certificateToUse; michael@0: string locationHeader; michael@0: client_auth_option clientAuth; michael@0: string fullHost; michael@0: michael@0: LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n", michael@0: static_cast(data), michael@0: static_cast(ci->client_sock), michael@0: static_cast(other_sock))); michael@0: if (other_sock) michael@0: { michael@0: int32_t numberOfSockets = 1; michael@0: michael@0: relayBuffer buffers[2]; michael@0: michael@0: if (!do_http_proxy) michael@0: { michael@0: if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, caNone)) michael@0: client_error = true; michael@0: else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout)) michael@0: client_error = true; michael@0: else michael@0: numberOfSockets = 2; michael@0: } michael@0: michael@0: PRPollDesc sockets[2] = michael@0: { michael@0: {ci->client_sock, PR_POLL_READ, 0}, michael@0: {other_sock, PR_POLL_READ, 0} michael@0: }; michael@0: bool socketErrorState[2] = {false, false}; michael@0: michael@0: while (!((client_error||client_done) && buffers[0].empty() && buffers[1].empty())) michael@0: { michael@0: sockets[0].in_flags |= PR_POLL_EXCEPT; michael@0: sockets[1].in_flags |= PR_POLL_EXCEPT; michael@0: LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n", michael@0: static_cast(data), michael@0: sockets[0].in_flags & PR_POLL_READ ? 'R' : '-', michael@0: sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-', michael@0: sockets[1].in_flags & PR_POLL_READ ? 'R' : '-', michael@0: sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-')); michael@0: int32_t pollStatus = PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000)); michael@0: if (pollStatus < 0) michael@0: { michael@0: LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n", michael@0: static_cast(data), pollStatus)); michael@0: client_error = true; michael@0: break; michael@0: } michael@0: michael@0: if (pollStatus == 0) michael@0: { michael@0: // timeout michael@0: LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n", michael@0: static_cast(data))); michael@0: continue; michael@0: } michael@0: michael@0: for (int32_t s = 0; s < numberOfSockets; ++s) michael@0: { michael@0: int32_t s2 = s == 1 ? 0 : 1; michael@0: int16_t out_flags = sockets[s].out_flags; michael@0: int16_t &in_flags = sockets[s].in_flags; michael@0: int16_t &in_flags2 = sockets[s2].in_flags; michael@0: sockets[s].out_flags = 0; michael@0: michael@0: LOG_BEGIN_BLOCK(); michael@0: LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d", michael@0: static_cast(data), michael@0: s == 0 ? 'c' : 's', michael@0: s, michael@0: static_cast(sockets[s].fd), michael@0: out_flags)); michael@0: if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP)) michael@0: { michael@0: LOG_DEBUG((" :exception\n")); michael@0: client_error = true; michael@0: socketErrorState[s] = true; michael@0: // We got a fatal error state on the socket. Clear the output buffer michael@0: // for this socket to break the main loop, we will never more be able michael@0: // to send those data anyway. michael@0: buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; michael@0: continue; michael@0: } // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling michael@0: michael@0: if (out_flags & PR_POLL_READ && !buffers[s].areafree()) michael@0: { michael@0: LOG_DEBUG((" no place in read buffer but got read flag, dropping it now!")); michael@0: in_flags &= ~PR_POLL_READ; michael@0: } michael@0: michael@0: if (out_flags & PR_POLL_READ && buffers[s].areafree()) michael@0: { michael@0: LOG_DEBUG((" :reading")); michael@0: int32_t bytesRead = PR_Recv(sockets[s].fd, buffers[s].buffertail, michael@0: buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT); michael@0: michael@0: if (bytesRead == 0) michael@0: { michael@0: LOG_DEBUG((" socket gracefully closed")); michael@0: client_done = true; michael@0: in_flags &= ~PR_POLL_READ; michael@0: } michael@0: else if (bytesRead < 0) michael@0: { michael@0: if (PR_GetError() != PR_WOULD_BLOCK_ERROR) michael@0: { michael@0: LOG_DEBUG((" error=%d", PR_GetError())); michael@0: // We are in error state, indicate that the connection was michael@0: // not closed gracefully michael@0: client_error = true; michael@0: socketErrorState[s] = true; michael@0: // Wipe out our send buffer, we cannot send it anyway. michael@0: buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; michael@0: } michael@0: else michael@0: LOG_DEBUG((" would block")); michael@0: } michael@0: else michael@0: { michael@0: // If the other socket is in error state (unable to send/receive) michael@0: // throw this data away and continue loop michael@0: if (socketErrorState[s2]) michael@0: { michael@0: LOG_DEBUG((" have read but other socket is in error state\n")); michael@0: continue; michael@0: } michael@0: michael@0: buffers[s].buffertail += bytesRead; michael@0: LOG_DEBUG((", read %d bytes", bytesRead)); michael@0: michael@0: // We have to accept and handle the initial CONNECT request here michael@0: int32_t response; michael@0: if (!connect_accepted && ReadConnectRequest(ci->server_info, buffers[s], michael@0: &response, certificateToUse, &clientAuth, fullHost, locationHeader)) michael@0: { michael@0: // Mark this as a proxy-only connection (no SSL) if the CONNECT michael@0: // request didn't come for port 443 or from any of the server's michael@0: // cert or clientauth hostnames. michael@0: if (fullHost.find(":443") == string::npos) michael@0: { michael@0: server_match_t match; michael@0: match.fullHost = fullHost; michael@0: match.matched = false; michael@0: PL_HashTableEnumerateEntries(ci->server_info->host_cert_table, michael@0: match_hostname, michael@0: &match); michael@0: PL_HashTableEnumerateEntries(ci->server_info->host_clientauth_table, michael@0: match_hostname, michael@0: &match); michael@0: ci->http_proxy_only = !match.matched; michael@0: } michael@0: else michael@0: { michael@0: ci->http_proxy_only = false; michael@0: } michael@0: michael@0: // Clean the request as it would be read michael@0: buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer; michael@0: in_flags |= PR_POLL_WRITE; michael@0: connect_accepted = true; michael@0: michael@0: // Store response to the oposite buffer michael@0: if (response == 200) michael@0: { michael@0: LOG_DEBUG((" accepted CONNECT request, connected to the server, sending OK to the client\n")); michael@0: strcpy(buffers[s2].buffer, "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n"); michael@0: } michael@0: else if (response == 302) michael@0: { michael@0: LOG_DEBUG((" accepted CONNECT request with redirection, " michael@0: "sending location and 302 to the client\n")); michael@0: client_done = true; michael@0: sprintf(buffers[s2].buffer, michael@0: "HTTP/1.1 302 Moved\r\n" michael@0: "Location: https://%s/\r\n" michael@0: "Connection: close\r\n\r\n", michael@0: locationHeader.c_str()); michael@0: } michael@0: else michael@0: { michael@0: LOG_ERRORD((" could not read the connect request, closing connection with %d", response)); michael@0: client_done = true; michael@0: sprintf(buffers[s2].buffer, "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", response); michael@0: michael@0: break; michael@0: } michael@0: michael@0: buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer); michael@0: michael@0: // Send the response to the client socket michael@0: break; michael@0: } // end of CONNECT handling michael@0: michael@0: if (!buffers[s].areafree()) michael@0: { michael@0: // Do not poll for read when the buffer is full michael@0: LOG_DEBUG((" no place in our read buffer, stop reading")); michael@0: in_flags &= ~PR_POLL_READ; michael@0: } michael@0: michael@0: if (ssl_updated) michael@0: { michael@0: if (s == 0 && expect_request_start) michael@0: { michael@0: if (!strstr(buffers[s].bufferhead, "\r\n\r\n")) michael@0: { michael@0: // We haven't received the complete header yet, so wait. michael@0: continue; michael@0: } michael@0: else michael@0: { michael@0: ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci); michael@0: expect_request_start = !(ci->iswebsocket || michael@0: AdjustRequestURI(buffers[s], &fullHost)); michael@0: PRNetAddr* addr = &remote_addr; michael@0: if (ci->iswebsocket && websocket_server.inet.port) michael@0: addr = &websocket_server; michael@0: if (!ConnectSocket(other_sock, addr, connect_timeout)) michael@0: { michael@0: LOG_ERRORD((" could not open connection to the real server\n")); michael@0: client_error = true; michael@0: break; michael@0: } michael@0: LOG_DEBUG(("\n connected to remote server\n")); michael@0: numberOfSockets = 2; michael@0: } michael@0: } michael@0: else if (s == 1 && ci->iswebsocket) michael@0: { michael@0: if (!AdjustWebSocketLocation(buffers[s], ci)) michael@0: continue; michael@0: } michael@0: michael@0: in_flags2 |= PR_POLL_WRITE; michael@0: LOG_DEBUG((" telling the other socket to write")); michael@0: } michael@0: else michael@0: LOG_DEBUG((" we have something for the other socket to write, but ssl has not been administered on it")); michael@0: } michael@0: } // PR_POLL_READ handling michael@0: michael@0: if (out_flags & PR_POLL_WRITE) michael@0: { michael@0: LOG_DEBUG((" :writing")); michael@0: int32_t bytesWrite = PR_Send(sockets[s].fd, buffers[s2].bufferhead, michael@0: buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT); michael@0: michael@0: if (bytesWrite < 0) michael@0: { michael@0: if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { michael@0: LOG_DEBUG((" error=%d", PR_GetError())); michael@0: client_error = true; michael@0: socketErrorState[s] = true; michael@0: // We got a fatal error while writting the buffer. Clear it to break michael@0: // the main loop, we will never more be able to send it. michael@0: buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; michael@0: } michael@0: else michael@0: LOG_DEBUG((" would block")); michael@0: } michael@0: else michael@0: { michael@0: LOG_DEBUG((", written %d bytes", bytesWrite)); michael@0: buffers[s2].buffertail[1] = '\0'; michael@0: LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead)); michael@0: michael@0: buffers[s2].bufferhead += bytesWrite; michael@0: if (buffers[s2].present()) michael@0: { michael@0: LOG_DEBUG((" still have to write %d bytes", (int)buffers[s2].present())); michael@0: in_flags |= PR_POLL_WRITE; michael@0: } michael@0: else michael@0: { michael@0: if (!ssl_updated) michael@0: { michael@0: LOG_DEBUG((" proxy response sent to the client")); michael@0: // Proxy response has just been writen, update to ssl michael@0: ssl_updated = true; michael@0: if (ci->http_proxy_only) michael@0: { michael@0: LOG_DEBUG((" not updating to SSL based on http_proxy_only for this socket")); michael@0: } michael@0: else if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, michael@0: certificateToUse, clientAuth)) michael@0: { michael@0: LOG_ERRORD((" failed to config server socket\n")); michael@0: client_error = true; michael@0: break; michael@0: } michael@0: else michael@0: { michael@0: LOG_DEBUG((" client socket updated to SSL")); michael@0: } michael@0: } // sslUpdate michael@0: michael@0: LOG_DEBUG((" dropping our write flag and setting other socket read flag")); michael@0: in_flags &= ~PR_POLL_WRITE; michael@0: in_flags2 |= PR_POLL_READ; michael@0: buffers[s2].compact(); michael@0: } michael@0: } michael@0: } // PR_POLL_WRITE handling michael@0: LOG_END_BLOCK(); // end the log michael@0: } // for... michael@0: } // while, poll michael@0: } michael@0: else michael@0: client_error = true; michael@0: michael@0: LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n", michael@0: static_cast(data), michael@0: static_cast(ci->client_sock), michael@0: static_cast(other_sock))); michael@0: if (!client_error) michael@0: PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND); michael@0: PR_Close(ci->client_sock); michael@0: michael@0: delete ci; michael@0: } michael@0: michael@0: /* michael@0: * Start listening for SSL connections on a specified port, handing michael@0: * them off to client threads after accepting the connection. michael@0: * The data parameter is a server_info_t*, owned by the calling michael@0: * function. michael@0: */ michael@0: void StartServer(void* data) michael@0: { michael@0: server_info_t* si = static_cast(data); michael@0: michael@0: //TODO: select ciphers? michael@0: ScopedPRFileDesc listen_socket(PR_NewTCPSocket()); michael@0: if (!listen_socket) { michael@0: LOG_ERROR(("failed to create socket\n")); michael@0: SignalShutdown(); michael@0: return; michael@0: } michael@0: michael@0: // In case the socket is still open in the TIME_WAIT state from a previous michael@0: // instance of ssltunnel we ask to reuse the port. michael@0: PRSocketOptionData socket_option; michael@0: socket_option.option = PR_SockOpt_Reuseaddr; michael@0: socket_option.value.reuse_addr = true; michael@0: PR_SetSocketOption(listen_socket, &socket_option); michael@0: michael@0: PRNetAddr server_addr; michael@0: PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr); michael@0: if (PR_Bind(listen_socket, &server_addr) != PR_SUCCESS) { michael@0: LOG_ERROR(("failed to bind socket\n")); michael@0: SignalShutdown(); michael@0: return; michael@0: } michael@0: michael@0: if (PR_Listen(listen_socket, 1) != PR_SUCCESS) { michael@0: LOG_ERROR(("failed to listen on socket\n")); michael@0: SignalShutdown(); michael@0: return; michael@0: } michael@0: michael@0: LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port, michael@0: si->cert_nickname.c_str())); michael@0: michael@0: while (!shutdown_server) { michael@0: connection_info_t* ci = new connection_info_t(); michael@0: ci->server_info = si; michael@0: ci->http_proxy_only = do_http_proxy; michael@0: // block waiting for connections michael@0: ci->client_sock = PR_Accept(listen_socket, &ci->client_addr, michael@0: PR_INTERVAL_NO_TIMEOUT); michael@0: michael@0: PRSocketOptionData option; michael@0: option.option = PR_SockOpt_Nonblocking; michael@0: option.value.non_blocking = true; michael@0: PR_SetSocketOption(ci->client_sock, &option); michael@0: michael@0: if (ci->client_sock) michael@0: // Not actually using this PRJob*... michael@0: //PRJob* job = michael@0: PR_QueueJob(threads, HandleConnection, ci, true); michael@0: else michael@0: delete ci; michael@0: } michael@0: } michael@0: michael@0: // bogus password func, just don't use passwords. :-P michael@0: char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg) michael@0: { michael@0: if (retry) michael@0: return nullptr; michael@0: michael@0: return PL_strdup(""); michael@0: } michael@0: michael@0: server_info_t* findServerInfo(int portnumber) michael@0: { michael@0: for (vector::iterator it = servers.begin(); michael@0: it != servers.end(); it++) michael@0: { michael@0: if (it->listen_port == portnumber) michael@0: return &(*it); michael@0: } michael@0: michael@0: return nullptr; michael@0: } michael@0: michael@0: int processConfigLine(char* configLine) michael@0: { michael@0: if (*configLine == 0 || *configLine == '#') michael@0: return 0; michael@0: michael@0: char* _caret; michael@0: char* keyword = strtok2(configLine, ":", &_caret); michael@0: michael@0: // Configure usage of http/ssl tunneling proxy behavior michael@0: if (!strcmp(keyword, "httpproxy")) michael@0: { michael@0: char* value = strtok2(_caret, ":", &_caret); michael@0: if (!strcmp(value, "1")) michael@0: do_http_proxy = true; michael@0: michael@0: return 0; michael@0: } michael@0: michael@0: if (!strcmp(keyword, "websocketserver")) michael@0: { michael@0: char* ipstring = strtok2(_caret, ":", &_caret); michael@0: if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) { michael@0: LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring)); michael@0: return 1; michael@0: } michael@0: char* remoteport = strtok2(_caret, ":", &_caret); michael@0: int port = atoi(remoteport); michael@0: if (port <= 0) { michael@0: LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport)); michael@0: return 1; michael@0: } michael@0: websocket_server.inet.port = PR_htons(port); michael@0: return 0; michael@0: } michael@0: michael@0: // Configure the forward address of the target server michael@0: if (!strcmp(keyword, "forward")) michael@0: { michael@0: char* ipstring = strtok2(_caret, ":", &_caret); michael@0: if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) { michael@0: LOG_ERROR(("Invalid remote IP address: %s\n", ipstring)); michael@0: return 1; michael@0: } michael@0: char* serverportstring = strtok2(_caret, ":", &_caret); michael@0: int port = atoi(serverportstring); michael@0: if (port <= 0) { michael@0: LOG_ERROR(("Invalid remote port: %s\n", serverportstring)); michael@0: return 1; michael@0: } michael@0: remote_addr.inet.port = PR_htons(port); michael@0: michael@0: return 0; michael@0: } michael@0: michael@0: // Configure all listen sockets and port+certificate bindings michael@0: if (!strcmp(keyword, "listen")) michael@0: { michael@0: char* hostname = strtok2(_caret, ":", &_caret); michael@0: char* hostportstring = nullptr; michael@0: if (strcmp(hostname, "*")) michael@0: { michael@0: any_host_spec_config = true; michael@0: hostportstring = strtok2(_caret, ":", &_caret); michael@0: } michael@0: michael@0: char* serverportstring = strtok2(_caret, ":", &_caret); michael@0: char* certnick = strtok2(_caret, ":", &_caret); michael@0: michael@0: int port = atoi(serverportstring); michael@0: if (port <= 0) { michael@0: LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); michael@0: return 1; michael@0: } michael@0: michael@0: if (server_info_t* existingServer = findServerInfo(port)) michael@0: { michael@0: char *certnick_copy = new char[strlen(certnick)+1]; michael@0: char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; michael@0: michael@0: strcpy(hostname_copy, hostname); michael@0: strcat(hostname_copy, ":"); michael@0: strcat(hostname_copy, hostportstring); michael@0: strcpy(certnick_copy, certnick); michael@0: michael@0: PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, hostname_copy, certnick_copy); michael@0: if (!entry) { michael@0: LOG_ERROR(("Out of memory")); michael@0: return 1; michael@0: } michael@0: } michael@0: else michael@0: { michael@0: server_info_t server; michael@0: server.cert_nickname = certnick; michael@0: server.listen_port = port; michael@0: server.host_cert_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, michael@0: PL_CompareStrings, nullptr, nullptr); michael@0: if (!server.host_cert_table) michael@0: { michael@0: LOG_ERROR(("Internal, could not create hash table\n")); michael@0: return 1; michael@0: } michael@0: server.host_clientauth_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, michael@0: ClientAuthValueComparator, nullptr, nullptr); michael@0: if (!server.host_clientauth_table) michael@0: { michael@0: LOG_ERROR(("Internal, could not create hash table\n")); michael@0: return 1; michael@0: } michael@0: server.host_redir_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, michael@0: PL_CompareStrings, nullptr, nullptr); michael@0: if (!server.host_redir_table) michael@0: { michael@0: LOG_ERROR(("Internal, could not create hash table\n")); michael@0: return 1; michael@0: } michael@0: servers.push_back(server); michael@0: } michael@0: michael@0: return 0; michael@0: } michael@0: michael@0: if (!strcmp(keyword, "clientauth")) michael@0: { michael@0: char* hostname = strtok2(_caret, ":", &_caret); michael@0: char* hostportstring = strtok2(_caret, ":", &_caret); michael@0: char* serverportstring = strtok2(_caret, ":", &_caret); michael@0: michael@0: int port = atoi(serverportstring); michael@0: if (port <= 0) { michael@0: LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); michael@0: return 1; michael@0: } michael@0: michael@0: if (server_info_t* existingServer = findServerInfo(port)) michael@0: { michael@0: char* authoptionstring = strtok2(_caret, ":", &_caret); michael@0: client_auth_option* authoption = new client_auth_option; michael@0: if (!authoption) { michael@0: LOG_ERROR(("Out of memory")); michael@0: return 1; michael@0: } michael@0: michael@0: if (!strcmp(authoptionstring, "require")) michael@0: *authoption = caRequire; michael@0: else if (!strcmp(authoptionstring, "request")) michael@0: *authoption = caRequest; michael@0: else if (!strcmp(authoptionstring, "none")) michael@0: *authoption = caNone; michael@0: else michael@0: { michael@0: LOG_ERROR(("Incorrect client auth option modifier for host '%s'", hostname)); michael@0: return 1; michael@0: } michael@0: michael@0: any_host_spec_config = true; michael@0: michael@0: char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; michael@0: if (!hostname_copy) { michael@0: LOG_ERROR(("Out of memory")); michael@0: return 1; michael@0: } michael@0: michael@0: strcpy(hostname_copy, hostname); michael@0: strcat(hostname_copy, ":"); michael@0: strcat(hostname_copy, hostportstring); michael@0: michael@0: PLHashEntry* entry = PL_HashTableAdd(existingServer->host_clientauth_table, hostname_copy, authoption); michael@0: if (!entry) { michael@0: LOG_ERROR(("Out of memory")); michael@0: return 1; michael@0: } michael@0: } michael@0: else michael@0: { michael@0: LOG_ERROR(("Server on port %d for client authentication option is not defined, use 'listen' option first", port)); michael@0: return 1; michael@0: } michael@0: michael@0: return 0; michael@0: } michael@0: michael@0: if (!strcmp(keyword, "redirhost")) michael@0: { michael@0: char* hostname = strtok2(_caret, ":", &_caret); michael@0: char* hostportstring = strtok2(_caret, ":", &_caret); michael@0: char* serverportstring = strtok2(_caret, ":", &_caret); michael@0: michael@0: int port = atoi(serverportstring); michael@0: if (port <= 0) { michael@0: LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); michael@0: return 1; michael@0: } michael@0: michael@0: if (server_info_t* existingServer = findServerInfo(port)) michael@0: { michael@0: char* redirhoststring = strtok2(_caret, ":", &_caret); michael@0: michael@0: any_host_spec_config = true; michael@0: michael@0: char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; michael@0: if (!hostname_copy) { michael@0: LOG_ERROR(("Out of memory")); michael@0: return 1; michael@0: } michael@0: michael@0: strcpy(hostname_copy, hostname); michael@0: strcat(hostname_copy, ":"); michael@0: strcat(hostname_copy, hostportstring); michael@0: michael@0: char *redir_copy = new char[strlen(redirhoststring)+1]; michael@0: strcpy(redir_copy, redirhoststring); michael@0: PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table, hostname_copy, redir_copy); michael@0: if (!entry) { michael@0: LOG_ERROR(("Out of memory")); michael@0: return 1; michael@0: } michael@0: } michael@0: else michael@0: { michael@0: LOG_ERROR(("Server on port %d for redirhost option is not defined, use 'listen' option first", port)); michael@0: return 1; michael@0: } michael@0: michael@0: return 0; michael@0: } michael@0: michael@0: // Configure the NSS certificate database directory michael@0: if (!strcmp(keyword, "certdbdir")) michael@0: { michael@0: nssconfigdir = strtok2(_caret, "\n", &_caret); michael@0: return 0; michael@0: } michael@0: michael@0: LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword)); michael@0: return 1; michael@0: } michael@0: michael@0: int parseConfigFile(const char* filePath) michael@0: { michael@0: FILE* f = fopen(filePath, "r"); michael@0: if (!f) michael@0: return 1; michael@0: michael@0: char buffer[1024], *b = buffer; michael@0: while (!feof(f)) michael@0: { michael@0: char c; michael@0: fscanf(f, "%c", &c); michael@0: switch (c) michael@0: { michael@0: case '\n': michael@0: *b++ = 0; michael@0: if (processConfigLine(buffer)) michael@0: return 1; michael@0: b = buffer; michael@0: case '\r': michael@0: continue; michael@0: default: michael@0: *b++ = c; michael@0: } michael@0: } michael@0: michael@0: fclose(f); michael@0: michael@0: // Check mandatory items michael@0: if (nssconfigdir.empty()) michael@0: { michael@0: LOG_ERROR(("Error: missing path to NSS certification database\n,use certdbdir: in the config file\n")); michael@0: return 1; michael@0: } michael@0: michael@0: if (any_host_spec_config && !do_http_proxy) michael@0: { michael@0: LOG_ERROR(("Warning: any host-specific configurations are ignored, add httpproxy:1 to allow them\n")); michael@0: } michael@0: michael@0: return 0; michael@0: } michael@0: michael@0: int freeHostCertHashItems(PLHashEntry *he, int i, void *arg) michael@0: { michael@0: delete [] (char*)he->key; michael@0: delete [] (char*)he->value; michael@0: return HT_ENUMERATE_REMOVE; michael@0: } michael@0: michael@0: int freeHostRedirHashItems(PLHashEntry *he, int i, void *arg) michael@0: { michael@0: delete [] (char*)he->key; michael@0: delete [] (char*)he->value; michael@0: return HT_ENUMERATE_REMOVE; michael@0: } michael@0: michael@0: int freeClientAuthHashItems(PLHashEntry *he, int i, void *arg) michael@0: { michael@0: delete [] (char*)he->key; michael@0: delete (client_auth_option*)he->value; michael@0: return HT_ENUMERATE_REMOVE; michael@0: } michael@0: michael@0: int main(int argc, char** argv) michael@0: { michael@0: const char* configFilePath; michael@0: michael@0: const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL"); michael@0: gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO; michael@0: michael@0: if (argc == 1) michael@0: configFilePath = "ssltunnel.cfg"; michael@0: else michael@0: configFilePath = argv[1]; michael@0: michael@0: memset(&websocket_server, 0, sizeof(PRNetAddr)); michael@0: michael@0: if (parseConfigFile(configFilePath)) { michael@0: LOG_ERROR(("Error: config file \"%s\" missing or formating incorrect\n" michael@0: "Specify path to the config file as parameter to ssltunnel or \n" michael@0: "create ssltunnel.cfg in the working directory.\n\n" michael@0: "Example format of the config file:\n\n" michael@0: " # Enable http/ssl tunneling proxy-like behavior.\n" michael@0: " # If not specified ssltunnel simply does direct forward.\n" michael@0: " httpproxy:1\n\n" michael@0: " # Specify path to the certification database used.\n" michael@0: " certdbdir:/path/to/certdb\n\n" michael@0: " # Forward/proxy all requests in raw to 127.0.0.1:8888.\n" michael@0: " forward:127.0.0.1:8888\n\n" michael@0: " # Accept connections on port 4443 or 5678 resp. and authenticate\n" michael@0: " # to any host ('*') using the 'server cert' or 'server cert 2' resp.\n" michael@0: " listen:*:4443:server cert\n" michael@0: " listen:*:5678:server cert 2\n\n" michael@0: " # Accept connections on port 4443 and authenticate using\n" michael@0: " # 'a different cert' when target host is 'my.host.name:443'.\n" michael@0: " # This only works in httpproxy mode and has higher priority\n" michael@0: " # than the previous option.\n" michael@0: " listen:my.host.name:443:4443:a different cert\n\n" michael@0: " # To make a specific host require or just request a client certificate\n" michael@0: " # to authenticate use the following options. This can only be used\n" michael@0: " # in httpproxy mode and only after the 'listen' option has been\n" michael@0: " # specified. You also have to specify the tunnel listen port.\n" michael@0: " clientauth:requesting-client-cert.host.com:443:4443:request\n" michael@0: " clientauth:requiring-client-cert.host.com:443:4443:require\n" michael@0: " # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n" michael@0: " # instead of the server specified in the 'forward' option.\n" michael@0: " websocketserver:127.0.0.1:9999\n", michael@0: configFilePath)); michael@0: return 1; michael@0: } michael@0: michael@0: // create a thread pool to handle connections michael@0: threads = PR_CreateThreadPool(INITIAL_THREADS * servers.size(), michael@0: MAX_THREADS * servers.size(), michael@0: DEFAULT_STACKSIZE); michael@0: if (!threads) { michael@0: LOG_ERROR(("Failed to create thread pool\n")); michael@0: return 1; michael@0: } michael@0: michael@0: shutdown_lock = PR_NewLock(); michael@0: if (!shutdown_lock) { michael@0: LOG_ERROR(("Failed to create lock\n")); michael@0: PR_ShutdownThreadPool(threads); michael@0: return 1; michael@0: } michael@0: shutdown_condvar = PR_NewCondVar(shutdown_lock); michael@0: if (!shutdown_condvar) { michael@0: LOG_ERROR(("Failed to create condvar\n")); michael@0: PR_ShutdownThreadPool(threads); michael@0: PR_DestroyLock(shutdown_lock); michael@0: return 1; michael@0: } michael@0: michael@0: PK11_SetPasswordFunc(password_func); michael@0: michael@0: // Initialize NSS michael@0: if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) { michael@0: int32_t errorlen = PR_GetErrorTextLength(); michael@0: char* err = new char[errorlen+1]; michael@0: PR_GetErrorText(err); michael@0: LOG_ERROR(("Failed to init NSS: %s", err)); michael@0: delete[] err; michael@0: PR_ShutdownThreadPool(threads); michael@0: PR_DestroyCondVar(shutdown_condvar); michael@0: PR_DestroyLock(shutdown_lock); michael@0: return 1; michael@0: } michael@0: michael@0: if (NSS_SetDomesticPolicy() != SECSuccess) { michael@0: LOG_ERROR(("NSS_SetDomesticPolicy failed\n")); michael@0: PR_ShutdownThreadPool(threads); michael@0: PR_DestroyCondVar(shutdown_condvar); michael@0: PR_DestroyLock(shutdown_lock); michael@0: NSS_Shutdown(); michael@0: return 1; michael@0: } michael@0: michael@0: // these values should make NSS use the defaults michael@0: if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { michael@0: LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n")); michael@0: PR_ShutdownThreadPool(threads); michael@0: PR_DestroyCondVar(shutdown_condvar); michael@0: PR_DestroyLock(shutdown_lock); michael@0: NSS_Shutdown(); michael@0: return 1; michael@0: } michael@0: michael@0: for (vector::iterator it = servers.begin(); michael@0: it != servers.end(); it++) { michael@0: // Not actually using this PRJob*... michael@0: // PRJob* server_job = michael@0: PR_QueueJob(threads, StartServer, &(*it), true); michael@0: } michael@0: // now wait for someone to tell us to quit michael@0: PR_Lock(shutdown_lock); michael@0: PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT); michael@0: PR_Unlock(shutdown_lock); michael@0: shutdown_server = true; michael@0: LOG_INFO(("Shutting down...\n")); michael@0: // cleanup michael@0: PR_ShutdownThreadPool(threads); michael@0: PR_JoinThreadPool(threads); michael@0: PR_DestroyCondVar(shutdown_condvar); michael@0: PR_DestroyLock(shutdown_lock); michael@0: if (NSS_Shutdown() == SECFailure) { michael@0: LOG_DEBUG(("Leaked NSS objects!\n")); michael@0: } michael@0: michael@0: for (vector::iterator it = servers.begin(); michael@0: it != servers.end(); it++) michael@0: { michael@0: PL_HashTableEnumerateEntries(it->host_cert_table, freeHostCertHashItems, nullptr); michael@0: PL_HashTableEnumerateEntries(it->host_clientauth_table, freeClientAuthHashItems, nullptr); michael@0: PL_HashTableEnumerateEntries(it->host_redir_table, freeHostRedirHashItems, nullptr); michael@0: PL_HashTableDestroy(it->host_cert_table); michael@0: PL_HashTableDestroy(it->host_clientauth_table); michael@0: PL_HashTableDestroy(it->host_redir_table); michael@0: } michael@0: michael@0: PR_Cleanup(); michael@0: return 0; michael@0: }