|
1 /* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ |
|
2 /* This Source Code Form is subject to the terms of the Mozilla Public |
|
3 * License, v. 2.0. If a copy of the MPL was not distributed with this |
|
4 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ |
|
5 |
|
6 /* |
|
7 * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to |
|
8 * be plagued with the usual problems endemic to C (buffer overflows |
|
9 * and the like). We don't especially care here (but would accept |
|
10 * patches!) because this is only intended for use in our test |
|
11 * harnesses in controlled situations where input is guaranteed not to |
|
12 * be malicious. |
|
13 */ |
|
14 |
|
15 #include "ScopedNSSTypes.h" |
|
16 #include <assert.h> |
|
17 #include <stdio.h> |
|
18 #include <string> |
|
19 #include <vector> |
|
20 #include <algorithm> |
|
21 #include <stdarg.h> |
|
22 #include "prinit.h" |
|
23 #include "prerror.h" |
|
24 #include "prenv.h" |
|
25 #include "prnetdb.h" |
|
26 #include "prtpool.h" |
|
27 #include "nsAlgorithm.h" |
|
28 #include "nss.h" |
|
29 #include "key.h" |
|
30 #include "ssl.h" |
|
31 #include "plhash.h" |
|
32 |
|
33 using namespace mozilla; |
|
34 using namespace mozilla::psm; |
|
35 using std::string; |
|
36 using std::vector; |
|
37 |
|
38 #define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c) & 7))) |
|
39 #define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c) & 7))) |
|
40 #define DELIM_TABLE_SIZE 32 |
|
41 |
|
42 // You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n |
|
43 // is 0 through 3. The default is 1, INFO level logging. |
|
44 enum LogLevel { |
|
45 LEVEL_DEBUG = 0, |
|
46 LEVEL_INFO = 1, |
|
47 LEVEL_ERROR = 2, |
|
48 LEVEL_SILENT = 3 |
|
49 } gLogLevel, gLastLogLevel; |
|
50 |
|
51 #define _LOG_OUTPUT(level, func, params) \ |
|
52 PR_BEGIN_MACRO \ |
|
53 if (level >= gLogLevel) { \ |
|
54 gLastLogLevel = level; \ |
|
55 func params;\ |
|
56 } \ |
|
57 PR_END_MACRO |
|
58 |
|
59 // The most verbose output |
|
60 #define LOG_DEBUG(params) \ |
|
61 _LOG_OUTPUT(LEVEL_DEBUG, printf, params) |
|
62 |
|
63 // Top level informative messages |
|
64 #define LOG_INFO(params) \ |
|
65 _LOG_OUTPUT(LEVEL_INFO, printf, params) |
|
66 |
|
67 // Serious errors that must be logged always until completely gag |
|
68 #define LOG_ERROR(params) \ |
|
69 _LOG_OUTPUT(LEVEL_ERROR, eprintf, params) |
|
70 |
|
71 // Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message |
|
72 // will be put to the stdout instead of stderr to keep continuity with other |
|
73 // LOG_DEBUG message output |
|
74 #define LOG_ERRORD(params) \ |
|
75 PR_BEGIN_MACRO \ |
|
76 if (gLogLevel == LEVEL_DEBUG) \ |
|
77 _LOG_OUTPUT(LEVEL_ERROR, printf, params); \ |
|
78 else \ |
|
79 _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \ |
|
80 PR_END_MACRO |
|
81 |
|
82 // If there is any output written between LOG_BEGIN_BLOCK() and |
|
83 // LOG_END_BLOCK() then a new line will be put to the proper output (out/err) |
|
84 #define LOG_BEGIN_BLOCK() \ |
|
85 gLastLogLevel = LEVEL_SILENT; |
|
86 |
|
87 #define LOG_END_BLOCK() \ |
|
88 PR_BEGIN_MACRO \ |
|
89 if (gLastLogLevel == LEVEL_ERROR) \ |
|
90 LOG_ERROR(("\n")); \ |
|
91 if (gLastLogLevel < LEVEL_ERROR) \ |
|
92 _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \ |
|
93 PR_END_MACRO |
|
94 |
|
95 int eprintf(const char* str, ...) |
|
96 { |
|
97 va_list ap; |
|
98 va_start(ap, str); |
|
99 int result = vfprintf(stderr, str, ap); |
|
100 va_end(ap); |
|
101 return result; |
|
102 } |
|
103 |
|
104 // Copied from nsCRT |
|
105 char* strtok2(char* string, const char* delims, char* *newStr) |
|
106 { |
|
107 PR_ASSERT(string); |
|
108 |
|
109 char delimTable[DELIM_TABLE_SIZE]; |
|
110 uint32_t i; |
|
111 char* result; |
|
112 char* str = string; |
|
113 |
|
114 for (i = 0; i < DELIM_TABLE_SIZE; i++) |
|
115 delimTable[i] = '\0'; |
|
116 |
|
117 for (i = 0; delims[i]; i++) { |
|
118 SET_DELIM(delimTable, static_cast<uint8_t>(delims[i])); |
|
119 } |
|
120 |
|
121 // skip to beginning |
|
122 while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) { |
|
123 str++; |
|
124 } |
|
125 result = str; |
|
126 |
|
127 // fix up the end of the token |
|
128 while (*str) { |
|
129 if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) { |
|
130 *str++ = '\0'; |
|
131 break; |
|
132 } |
|
133 str++; |
|
134 } |
|
135 *newStr = str; |
|
136 |
|
137 return str == result ? nullptr : result; |
|
138 } |
|
139 |
|
140 |
|
141 |
|
142 enum client_auth_option { |
|
143 caNone = 0, |
|
144 caRequire = 1, |
|
145 caRequest = 2 |
|
146 }; |
|
147 |
|
148 // Structs for passing data into jobs on the thread pool |
|
149 typedef struct { |
|
150 int32_t listen_port; |
|
151 string cert_nickname; |
|
152 PLHashTable* host_cert_table; |
|
153 PLHashTable* host_clientauth_table; |
|
154 PLHashTable* host_redir_table; |
|
155 } server_info_t; |
|
156 |
|
157 typedef struct { |
|
158 PRFileDesc* client_sock; |
|
159 PRNetAddr client_addr; |
|
160 server_info_t* server_info; |
|
161 // the original host in the Host: header for this connection is |
|
162 // stored here, for proxied connections |
|
163 string original_host; |
|
164 // true if no SSL should be used for this connection |
|
165 bool http_proxy_only; |
|
166 // true if this connection is for a WebSocket |
|
167 bool iswebsocket; |
|
168 } connection_info_t; |
|
169 |
|
170 typedef struct { |
|
171 string fullHost; |
|
172 bool matched; |
|
173 } server_match_t; |
|
174 |
|
175 const int32_t BUF_SIZE = 16384; |
|
176 const int32_t BUF_MARGIN = 1024; |
|
177 const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN; |
|
178 |
|
179 struct relayBuffer |
|
180 { |
|
181 char *buffer, *bufferhead, *buffertail, *bufferend; |
|
182 |
|
183 relayBuffer() |
|
184 { |
|
185 // Leave 1024 bytes more for request line manipulations |
|
186 bufferhead = buffertail = buffer = new char[BUF_TOTAL]; |
|
187 bufferend = buffer + BUF_SIZE; |
|
188 } |
|
189 |
|
190 ~relayBuffer() |
|
191 { |
|
192 delete [] buffer; |
|
193 } |
|
194 |
|
195 void compact() { |
|
196 if (buffertail == bufferhead) |
|
197 buffertail = bufferhead = buffer; |
|
198 } |
|
199 |
|
200 bool empty() { return bufferhead == buffertail; } |
|
201 size_t areafree() { return bufferend - buffertail; } |
|
202 size_t margin() { return areafree() + BUF_MARGIN; } |
|
203 size_t present() { return buffertail - bufferhead; } |
|
204 }; |
|
205 |
|
206 // These numbers are multiplied by the number of listening ports (actual |
|
207 // servers running). According the thread pool implementation there is no |
|
208 // need to limit the number of threads initially, threads are allocated |
|
209 // dynamically and stored in a linked list. Initial number of 2 is chosen |
|
210 // to allocate a thread for socket accept and preallocate one for the first |
|
211 // connection that is with high probability expected to come. |
|
212 const uint32_t INITIAL_THREADS = 2; |
|
213 const uint32_t MAX_THREADS = 100; |
|
214 const uint32_t DEFAULT_STACKSIZE = (512 * 1024); |
|
215 |
|
216 // global data |
|
217 string nssconfigdir; |
|
218 vector<server_info_t> servers; |
|
219 PRNetAddr remote_addr; |
|
220 PRNetAddr websocket_server; |
|
221 PRThreadPool* threads = nullptr; |
|
222 PRLock* shutdown_lock = nullptr; |
|
223 PRCondVar* shutdown_condvar = nullptr; |
|
224 // Not really used, unless something fails to start |
|
225 bool shutdown_server = false; |
|
226 bool do_http_proxy = false; |
|
227 bool any_host_spec_config = false; |
|
228 |
|
229 int ClientAuthValueComparator(const void *v1, const void *v2) |
|
230 { |
|
231 int a = *static_cast<const client_auth_option*>(v1) - |
|
232 *static_cast<const client_auth_option*>(v2); |
|
233 if (a == 0) |
|
234 return 0; |
|
235 if (a > 0) |
|
236 return 1; |
|
237 else // (a < 0) |
|
238 return -1; |
|
239 } |
|
240 |
|
241 static int match_hostname(PLHashEntry *he, int index, void* arg) |
|
242 { |
|
243 server_match_t *match = (server_match_t*)arg; |
|
244 if (match->fullHost.find((char*)he->key) != string::npos) |
|
245 match->matched = true; |
|
246 return HT_ENUMERATE_NEXT; |
|
247 } |
|
248 |
|
249 /* |
|
250 * Signal the main thread that the application should shut down. |
|
251 */ |
|
252 void SignalShutdown() |
|
253 { |
|
254 PR_Lock(shutdown_lock); |
|
255 PR_NotifyCondVar(shutdown_condvar); |
|
256 PR_Unlock(shutdown_lock); |
|
257 } |
|
258 |
|
259 bool ReadConnectRequest(server_info_t* server_info, |
|
260 relayBuffer& buffer, int32_t* result, string& certificate, |
|
261 client_auth_option* clientauth, string& host, string& location) |
|
262 { |
|
263 if (buffer.present() < 4) { |
|
264 LOG_DEBUG((" !! only %d bytes present in the buffer", (int)buffer.present())); |
|
265 return false; |
|
266 } |
|
267 if (strncmp(buffer.buffertail-4, "\r\n\r\n", 4)) { |
|
268 LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x", |
|
269 *(buffer.buffertail-4), |
|
270 *(buffer.buffertail-3), |
|
271 *(buffer.buffertail-2), |
|
272 *(buffer.buffertail-1))); |
|
273 return false; |
|
274 } |
|
275 |
|
276 LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n", (int)buffer.present(), buffer.bufferhead)); |
|
277 |
|
278 *result = 400; |
|
279 |
|
280 char* token; |
|
281 char* _caret; |
|
282 token = strtok2(buffer.bufferhead, " ", &_caret); |
|
283 if (!token) { |
|
284 LOG_ERRORD((" no space found")); |
|
285 return true; |
|
286 } |
|
287 if (strcmp(token, "CONNECT")) { |
|
288 LOG_ERRORD((" not CONNECT request but %s", token)); |
|
289 return true; |
|
290 } |
|
291 |
|
292 token = strtok2(_caret, " ", &_caret); |
|
293 void* c = PL_HashTableLookup(server_info->host_cert_table, token); |
|
294 if (c) |
|
295 certificate = static_cast<char*>(c); |
|
296 |
|
297 host = "https://"; |
|
298 host += token; |
|
299 |
|
300 c = PL_HashTableLookup(server_info->host_clientauth_table, token); |
|
301 if (c) |
|
302 *clientauth = *static_cast<client_auth_option*>(c); |
|
303 else |
|
304 *clientauth = caNone; |
|
305 |
|
306 void *redir = PL_HashTableLookup(server_info->host_redir_table, token); |
|
307 if (redir) |
|
308 location = static_cast<char*>(redir); |
|
309 |
|
310 token = strtok2(_caret, "/", &_caret); |
|
311 if (strcmp(token, "HTTP")) { |
|
312 LOG_ERRORD((" not tailed with HTTP but with %s", token)); |
|
313 return true; |
|
314 } |
|
315 |
|
316 *result = (redir) ? 302 : 200; |
|
317 return true; |
|
318 } |
|
319 |
|
320 bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, string &certificate, client_auth_option clientAuth) |
|
321 { |
|
322 const char* certnick = certificate.empty() ? |
|
323 si->cert_nickname.c_str() : certificate.c_str(); |
|
324 |
|
325 ScopedCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr)); |
|
326 if (!cert) { |
|
327 LOG_ERROR(("Failed to find cert %s\n", certnick)); |
|
328 return false; |
|
329 } |
|
330 |
|
331 ScopedSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert, nullptr)); |
|
332 if (!privKey) { |
|
333 LOG_ERROR(("Failed to find private key\n")); |
|
334 return false; |
|
335 } |
|
336 |
|
337 PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket); |
|
338 if (!ssl_socket) { |
|
339 LOG_ERROR(("Error importing SSL socket\n")); |
|
340 return false; |
|
341 } |
|
342 |
|
343 SSLKEAType certKEA = NSS_FindCertKEAType(cert); |
|
344 if (SSL_ConfigSecureServer(ssl_socket, cert, privKey, certKEA) |
|
345 != SECSuccess) { |
|
346 LOG_ERROR(("Error configuring SSL server socket\n")); |
|
347 return false; |
|
348 } |
|
349 |
|
350 SSL_OptionSet(ssl_socket, SSL_SECURITY, true); |
|
351 SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false); |
|
352 SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true); |
|
353 |
|
354 if (clientAuth != caNone) |
|
355 { |
|
356 SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true); |
|
357 SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire); |
|
358 } |
|
359 |
|
360 SSL_ResetHandshake(ssl_socket, true); |
|
361 |
|
362 return true; |
|
363 } |
|
364 |
|
365 /** |
|
366 * This function examines the buffer for a Sec-WebSocket-Location: field, |
|
367 * and if it's present, it replaces the hostname in that field with the |
|
368 * value in the server's original_host field. This function works |
|
369 * in the reverse direction as AdjustWebSocketHost(), replacing the real |
|
370 * hostname of a response with the potentially fake hostname that is expected |
|
371 * by the browser (e.g., mochi.test). |
|
372 * |
|
373 * @return true if the header was adjusted successfully, or not found, false |
|
374 * if the header is present but the url is not, which should indicate |
|
375 * that more data needs to be read from the socket |
|
376 */ |
|
377 bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t *ci) |
|
378 { |
|
379 assert(buffer.margin()); |
|
380 buffer.buffertail[1] = '\0'; |
|
381 |
|
382 char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:"); |
|
383 if (!wsloc) |
|
384 return true; |
|
385 // advance pointer to the start of the hostname |
|
386 wsloc = strstr(wsloc, "ws://"); |
|
387 if (!wsloc) |
|
388 return false; |
|
389 wsloc += 5; |
|
390 // find the end of the hostname |
|
391 char* wslocend = strchr(wsloc + 1, '/'); |
|
392 if (!wslocend) |
|
393 return false; |
|
394 char *crlf = strstr(wsloc, "\r\n"); |
|
395 if (!crlf) |
|
396 return false; |
|
397 if (ci->original_host.empty()) |
|
398 return true; |
|
399 |
|
400 int diff = ci->original_host.length() - (wslocend-wsloc); |
|
401 if (diff > 0) |
|
402 assert(size_t(diff) <= buffer.margin()); |
|
403 memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff); |
|
404 buffer.buffertail += diff; |
|
405 |
|
406 memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length()); |
|
407 return true; |
|
408 } |
|
409 |
|
410 /** |
|
411 * This function examines the buffer for a Host: field, and if it's present, |
|
412 * it replaces the hostname in that field with the hostname in the server's |
|
413 * remote_addr field. This is needed because proxy requests may be coming |
|
414 * from mochitest with fake hosts, like mochi.test, and these need to be |
|
415 * replaced with the host that the destination server is actually running |
|
416 * on. |
|
417 */ |
|
418 bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t *ci) |
|
419 { |
|
420 const char HEADER_UPGRADE[] = "Upgrade:"; |
|
421 const char HEADER_HOST[] = "Host:"; |
|
422 |
|
423 PRNetAddr inet_addr = (websocket_server.inet.port ? websocket_server : |
|
424 remote_addr); |
|
425 |
|
426 assert(buffer.margin()); |
|
427 |
|
428 // Cannot use strnchr so add a null char at the end. There is always some |
|
429 // space left because we preserve a margin. |
|
430 buffer.buffertail[1] = '\0'; |
|
431 |
|
432 // Verify this is a WebSocket header. |
|
433 char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE); |
|
434 if (!h1) |
|
435 return false; |
|
436 h1 += strlen(HEADER_UPGRADE); |
|
437 h1 += strspn(h1, " \t"); |
|
438 char* h2 = strstr(h1, "WebSocket\r\n"); |
|
439 if (!h2) h2 = strstr(h1, "websocket\r\n"); |
|
440 if (!h2) h2 = strstr(h1, "Websocket\r\n"); |
|
441 if (!h2) |
|
442 return false; |
|
443 |
|
444 char* host = strstr(buffer.bufferhead, HEADER_HOST); |
|
445 if (!host) |
|
446 return false; |
|
447 // advance pointer to beginning of hostname |
|
448 host += strlen(HEADER_HOST); |
|
449 host += strspn(host, " \t"); |
|
450 |
|
451 char* endhost = strstr(host, "\r\n"); |
|
452 if (!endhost) |
|
453 return false; |
|
454 |
|
455 // Save the original host, so we can use it later on responses from the |
|
456 // server. |
|
457 ci->original_host.assign(host, endhost-host); |
|
458 |
|
459 char newhost[40]; |
|
460 PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost)); |
|
461 assert(strlen(newhost) < sizeof(newhost) - 7); |
|
462 sprintf(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port)); |
|
463 |
|
464 int diff = strlen(newhost) - (endhost-host); |
|
465 if (diff > 0) |
|
466 assert(size_t(diff) <= buffer.margin()); |
|
467 memmove(endhost + diff, endhost, buffer.buffertail - host - diff); |
|
468 buffer.buffertail += diff; |
|
469 |
|
470 memcpy(host, newhost, strlen(newhost)); |
|
471 return true; |
|
472 } |
|
473 |
|
474 /** |
|
475 * This function prefixes Request-URI path with a full scheme-host-port |
|
476 * string. |
|
477 */ |
|
478 bool AdjustRequestURI(relayBuffer& buffer, string *host) |
|
479 { |
|
480 assert(buffer.margin()); |
|
481 |
|
482 // Cannot use strnchr so add a null char at the end. There is always some space left |
|
483 // because we preserve a margin. |
|
484 buffer.buffertail[1] = '\0'; |
|
485 LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead)); |
|
486 |
|
487 char *token, *path; |
|
488 path = strchr(buffer.bufferhead, ' ') + 1; |
|
489 if (!path) |
|
490 return false; |
|
491 |
|
492 // If the path doesn't start with a slash don't change it, it is probably '*' or a full |
|
493 // path already. Return true, we are done with this request adjustment. |
|
494 if (*path != '/') |
|
495 return true; |
|
496 |
|
497 token = strchr(path, ' ') + 1; |
|
498 if (!token) |
|
499 return false; |
|
500 |
|
501 if (strncmp(token, "HTTP/", 5)) |
|
502 return false; |
|
503 |
|
504 size_t hostlength = host->length(); |
|
505 assert(hostlength <= buffer.margin()); |
|
506 |
|
507 memmove(path + hostlength, path, buffer.buffertail - path); |
|
508 memcpy(path, host->c_str(), hostlength); |
|
509 buffer.buffertail += hostlength; |
|
510 |
|
511 return true; |
|
512 } |
|
513 |
|
514 bool ConnectSocket(PRFileDesc *fd, const PRNetAddr *addr, PRIntervalTime timeout) |
|
515 { |
|
516 PRStatus stat = PR_Connect(fd, addr, timeout); |
|
517 if (stat != PR_SUCCESS) |
|
518 return false; |
|
519 |
|
520 PRSocketOptionData option; |
|
521 option.option = PR_SockOpt_Nonblocking; |
|
522 option.value.non_blocking = true; |
|
523 PR_SetSocketOption(fd, &option); |
|
524 |
|
525 return true; |
|
526 } |
|
527 |
|
528 /* |
|
529 * Handle an incoming client connection. The server thread has already |
|
530 * accepted the connection, so we just need to connect to the remote |
|
531 * port and then proxy data back and forth. |
|
532 * The data parameter is a connection_info_t*, and must be deleted |
|
533 * by this function. |
|
534 */ |
|
535 void HandleConnection(void* data) |
|
536 { |
|
537 connection_info_t* ci = static_cast<connection_info_t*>(data); |
|
538 PRIntervalTime connect_timeout = PR_SecondsToInterval(30); |
|
539 |
|
540 ScopedPRFileDesc other_sock(PR_NewTCPSocket()); |
|
541 bool client_done = false; |
|
542 bool client_error = false; |
|
543 bool connect_accepted = !do_http_proxy; |
|
544 bool ssl_updated = !do_http_proxy; |
|
545 bool expect_request_start = do_http_proxy; |
|
546 string certificateToUse; |
|
547 string locationHeader; |
|
548 client_auth_option clientAuth; |
|
549 string fullHost; |
|
550 |
|
551 LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n", |
|
552 static_cast<void*>(data), |
|
553 static_cast<void*>(ci->client_sock), |
|
554 static_cast<void*>(other_sock))); |
|
555 if (other_sock) |
|
556 { |
|
557 int32_t numberOfSockets = 1; |
|
558 |
|
559 relayBuffer buffers[2]; |
|
560 |
|
561 if (!do_http_proxy) |
|
562 { |
|
563 if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, certificateToUse, caNone)) |
|
564 client_error = true; |
|
565 else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout)) |
|
566 client_error = true; |
|
567 else |
|
568 numberOfSockets = 2; |
|
569 } |
|
570 |
|
571 PRPollDesc sockets[2] = |
|
572 { |
|
573 {ci->client_sock, PR_POLL_READ, 0}, |
|
574 {other_sock, PR_POLL_READ, 0} |
|
575 }; |
|
576 bool socketErrorState[2] = {false, false}; |
|
577 |
|
578 while (!((client_error||client_done) && buffers[0].empty() && buffers[1].empty())) |
|
579 { |
|
580 sockets[0].in_flags |= PR_POLL_EXCEPT; |
|
581 sockets[1].in_flags |= PR_POLL_EXCEPT; |
|
582 LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n", |
|
583 static_cast<void*>(data), |
|
584 sockets[0].in_flags & PR_POLL_READ ? 'R' : '-', |
|
585 sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-', |
|
586 sockets[1].in_flags & PR_POLL_READ ? 'R' : '-', |
|
587 sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-')); |
|
588 int32_t pollStatus = PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000)); |
|
589 if (pollStatus < 0) |
|
590 { |
|
591 LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n", |
|
592 static_cast<void*>(data), pollStatus)); |
|
593 client_error = true; |
|
594 break; |
|
595 } |
|
596 |
|
597 if (pollStatus == 0) |
|
598 { |
|
599 // timeout |
|
600 LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n", |
|
601 static_cast<void*>(data))); |
|
602 continue; |
|
603 } |
|
604 |
|
605 for (int32_t s = 0; s < numberOfSockets; ++s) |
|
606 { |
|
607 int32_t s2 = s == 1 ? 0 : 1; |
|
608 int16_t out_flags = sockets[s].out_flags; |
|
609 int16_t &in_flags = sockets[s].in_flags; |
|
610 int16_t &in_flags2 = sockets[s2].in_flags; |
|
611 sockets[s].out_flags = 0; |
|
612 |
|
613 LOG_BEGIN_BLOCK(); |
|
614 LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d", |
|
615 static_cast<void*>(data), |
|
616 s == 0 ? 'c' : 's', |
|
617 s, |
|
618 static_cast<void*>(sockets[s].fd), |
|
619 out_flags)); |
|
620 if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP)) |
|
621 { |
|
622 LOG_DEBUG((" :exception\n")); |
|
623 client_error = true; |
|
624 socketErrorState[s] = true; |
|
625 // We got a fatal error state on the socket. Clear the output buffer |
|
626 // for this socket to break the main loop, we will never more be able |
|
627 // to send those data anyway. |
|
628 buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; |
|
629 continue; |
|
630 } // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling |
|
631 |
|
632 if (out_flags & PR_POLL_READ && !buffers[s].areafree()) |
|
633 { |
|
634 LOG_DEBUG((" no place in read buffer but got read flag, dropping it now!")); |
|
635 in_flags &= ~PR_POLL_READ; |
|
636 } |
|
637 |
|
638 if (out_flags & PR_POLL_READ && buffers[s].areafree()) |
|
639 { |
|
640 LOG_DEBUG((" :reading")); |
|
641 int32_t bytesRead = PR_Recv(sockets[s].fd, buffers[s].buffertail, |
|
642 buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT); |
|
643 |
|
644 if (bytesRead == 0) |
|
645 { |
|
646 LOG_DEBUG((" socket gracefully closed")); |
|
647 client_done = true; |
|
648 in_flags &= ~PR_POLL_READ; |
|
649 } |
|
650 else if (bytesRead < 0) |
|
651 { |
|
652 if (PR_GetError() != PR_WOULD_BLOCK_ERROR) |
|
653 { |
|
654 LOG_DEBUG((" error=%d", PR_GetError())); |
|
655 // We are in error state, indicate that the connection was |
|
656 // not closed gracefully |
|
657 client_error = true; |
|
658 socketErrorState[s] = true; |
|
659 // Wipe out our send buffer, we cannot send it anyway. |
|
660 buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; |
|
661 } |
|
662 else |
|
663 LOG_DEBUG((" would block")); |
|
664 } |
|
665 else |
|
666 { |
|
667 // If the other socket is in error state (unable to send/receive) |
|
668 // throw this data away and continue loop |
|
669 if (socketErrorState[s2]) |
|
670 { |
|
671 LOG_DEBUG((" have read but other socket is in error state\n")); |
|
672 continue; |
|
673 } |
|
674 |
|
675 buffers[s].buffertail += bytesRead; |
|
676 LOG_DEBUG((", read %d bytes", bytesRead)); |
|
677 |
|
678 // We have to accept and handle the initial CONNECT request here |
|
679 int32_t response; |
|
680 if (!connect_accepted && ReadConnectRequest(ci->server_info, buffers[s], |
|
681 &response, certificateToUse, &clientAuth, fullHost, locationHeader)) |
|
682 { |
|
683 // Mark this as a proxy-only connection (no SSL) if the CONNECT |
|
684 // request didn't come for port 443 or from any of the server's |
|
685 // cert or clientauth hostnames. |
|
686 if (fullHost.find(":443") == string::npos) |
|
687 { |
|
688 server_match_t match; |
|
689 match.fullHost = fullHost; |
|
690 match.matched = false; |
|
691 PL_HashTableEnumerateEntries(ci->server_info->host_cert_table, |
|
692 match_hostname, |
|
693 &match); |
|
694 PL_HashTableEnumerateEntries(ci->server_info->host_clientauth_table, |
|
695 match_hostname, |
|
696 &match); |
|
697 ci->http_proxy_only = !match.matched; |
|
698 } |
|
699 else |
|
700 { |
|
701 ci->http_proxy_only = false; |
|
702 } |
|
703 |
|
704 // Clean the request as it would be read |
|
705 buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer; |
|
706 in_flags |= PR_POLL_WRITE; |
|
707 connect_accepted = true; |
|
708 |
|
709 // Store response to the oposite buffer |
|
710 if (response == 200) |
|
711 { |
|
712 LOG_DEBUG((" accepted CONNECT request, connected to the server, sending OK to the client\n")); |
|
713 strcpy(buffers[s2].buffer, "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n"); |
|
714 } |
|
715 else if (response == 302) |
|
716 { |
|
717 LOG_DEBUG((" accepted CONNECT request with redirection, " |
|
718 "sending location and 302 to the client\n")); |
|
719 client_done = true; |
|
720 sprintf(buffers[s2].buffer, |
|
721 "HTTP/1.1 302 Moved\r\n" |
|
722 "Location: https://%s/\r\n" |
|
723 "Connection: close\r\n\r\n", |
|
724 locationHeader.c_str()); |
|
725 } |
|
726 else |
|
727 { |
|
728 LOG_ERRORD((" could not read the connect request, closing connection with %d", response)); |
|
729 client_done = true; |
|
730 sprintf(buffers[s2].buffer, "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", response); |
|
731 |
|
732 break; |
|
733 } |
|
734 |
|
735 buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer); |
|
736 |
|
737 // Send the response to the client socket |
|
738 break; |
|
739 } // end of CONNECT handling |
|
740 |
|
741 if (!buffers[s].areafree()) |
|
742 { |
|
743 // Do not poll for read when the buffer is full |
|
744 LOG_DEBUG((" no place in our read buffer, stop reading")); |
|
745 in_flags &= ~PR_POLL_READ; |
|
746 } |
|
747 |
|
748 if (ssl_updated) |
|
749 { |
|
750 if (s == 0 && expect_request_start) |
|
751 { |
|
752 if (!strstr(buffers[s].bufferhead, "\r\n\r\n")) |
|
753 { |
|
754 // We haven't received the complete header yet, so wait. |
|
755 continue; |
|
756 } |
|
757 else |
|
758 { |
|
759 ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci); |
|
760 expect_request_start = !(ci->iswebsocket || |
|
761 AdjustRequestURI(buffers[s], &fullHost)); |
|
762 PRNetAddr* addr = &remote_addr; |
|
763 if (ci->iswebsocket && websocket_server.inet.port) |
|
764 addr = &websocket_server; |
|
765 if (!ConnectSocket(other_sock, addr, connect_timeout)) |
|
766 { |
|
767 LOG_ERRORD((" could not open connection to the real server\n")); |
|
768 client_error = true; |
|
769 break; |
|
770 } |
|
771 LOG_DEBUG(("\n connected to remote server\n")); |
|
772 numberOfSockets = 2; |
|
773 } |
|
774 } |
|
775 else if (s == 1 && ci->iswebsocket) |
|
776 { |
|
777 if (!AdjustWebSocketLocation(buffers[s], ci)) |
|
778 continue; |
|
779 } |
|
780 |
|
781 in_flags2 |= PR_POLL_WRITE; |
|
782 LOG_DEBUG((" telling the other socket to write")); |
|
783 } |
|
784 else |
|
785 LOG_DEBUG((" we have something for the other socket to write, but ssl has not been administered on it")); |
|
786 } |
|
787 } // PR_POLL_READ handling |
|
788 |
|
789 if (out_flags & PR_POLL_WRITE) |
|
790 { |
|
791 LOG_DEBUG((" :writing")); |
|
792 int32_t bytesWrite = PR_Send(sockets[s].fd, buffers[s2].bufferhead, |
|
793 buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT); |
|
794 |
|
795 if (bytesWrite < 0) |
|
796 { |
|
797 if (PR_GetError() != PR_WOULD_BLOCK_ERROR) { |
|
798 LOG_DEBUG((" error=%d", PR_GetError())); |
|
799 client_error = true; |
|
800 socketErrorState[s] = true; |
|
801 // We got a fatal error while writting the buffer. Clear it to break |
|
802 // the main loop, we will never more be able to send it. |
|
803 buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer; |
|
804 } |
|
805 else |
|
806 LOG_DEBUG((" would block")); |
|
807 } |
|
808 else |
|
809 { |
|
810 LOG_DEBUG((", written %d bytes", bytesWrite)); |
|
811 buffers[s2].buffertail[1] = '\0'; |
|
812 LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead)); |
|
813 |
|
814 buffers[s2].bufferhead += bytesWrite; |
|
815 if (buffers[s2].present()) |
|
816 { |
|
817 LOG_DEBUG((" still have to write %d bytes", (int)buffers[s2].present())); |
|
818 in_flags |= PR_POLL_WRITE; |
|
819 } |
|
820 else |
|
821 { |
|
822 if (!ssl_updated) |
|
823 { |
|
824 LOG_DEBUG((" proxy response sent to the client")); |
|
825 // Proxy response has just been writen, update to ssl |
|
826 ssl_updated = true; |
|
827 if (ci->http_proxy_only) |
|
828 { |
|
829 LOG_DEBUG((" not updating to SSL based on http_proxy_only for this socket")); |
|
830 } |
|
831 else if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info, |
|
832 certificateToUse, clientAuth)) |
|
833 { |
|
834 LOG_ERRORD((" failed to config server socket\n")); |
|
835 client_error = true; |
|
836 break; |
|
837 } |
|
838 else |
|
839 { |
|
840 LOG_DEBUG((" client socket updated to SSL")); |
|
841 } |
|
842 } // sslUpdate |
|
843 |
|
844 LOG_DEBUG((" dropping our write flag and setting other socket read flag")); |
|
845 in_flags &= ~PR_POLL_WRITE; |
|
846 in_flags2 |= PR_POLL_READ; |
|
847 buffers[s2].compact(); |
|
848 } |
|
849 } |
|
850 } // PR_POLL_WRITE handling |
|
851 LOG_END_BLOCK(); // end the log |
|
852 } // for... |
|
853 } // while, poll |
|
854 } |
|
855 else |
|
856 client_error = true; |
|
857 |
|
858 LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n", |
|
859 static_cast<void*>(data), |
|
860 static_cast<void*>(ci->client_sock), |
|
861 static_cast<void*>(other_sock))); |
|
862 if (!client_error) |
|
863 PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND); |
|
864 PR_Close(ci->client_sock); |
|
865 |
|
866 delete ci; |
|
867 } |
|
868 |
|
869 /* |
|
870 * Start listening for SSL connections on a specified port, handing |
|
871 * them off to client threads after accepting the connection. |
|
872 * The data parameter is a server_info_t*, owned by the calling |
|
873 * function. |
|
874 */ |
|
875 void StartServer(void* data) |
|
876 { |
|
877 server_info_t* si = static_cast<server_info_t*>(data); |
|
878 |
|
879 //TODO: select ciphers? |
|
880 ScopedPRFileDesc listen_socket(PR_NewTCPSocket()); |
|
881 if (!listen_socket) { |
|
882 LOG_ERROR(("failed to create socket\n")); |
|
883 SignalShutdown(); |
|
884 return; |
|
885 } |
|
886 |
|
887 // In case the socket is still open in the TIME_WAIT state from a previous |
|
888 // instance of ssltunnel we ask to reuse the port. |
|
889 PRSocketOptionData socket_option; |
|
890 socket_option.option = PR_SockOpt_Reuseaddr; |
|
891 socket_option.value.reuse_addr = true; |
|
892 PR_SetSocketOption(listen_socket, &socket_option); |
|
893 |
|
894 PRNetAddr server_addr; |
|
895 PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr); |
|
896 if (PR_Bind(listen_socket, &server_addr) != PR_SUCCESS) { |
|
897 LOG_ERROR(("failed to bind socket\n")); |
|
898 SignalShutdown(); |
|
899 return; |
|
900 } |
|
901 |
|
902 if (PR_Listen(listen_socket, 1) != PR_SUCCESS) { |
|
903 LOG_ERROR(("failed to listen on socket\n")); |
|
904 SignalShutdown(); |
|
905 return; |
|
906 } |
|
907 |
|
908 LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port, |
|
909 si->cert_nickname.c_str())); |
|
910 |
|
911 while (!shutdown_server) { |
|
912 connection_info_t* ci = new connection_info_t(); |
|
913 ci->server_info = si; |
|
914 ci->http_proxy_only = do_http_proxy; |
|
915 // block waiting for connections |
|
916 ci->client_sock = PR_Accept(listen_socket, &ci->client_addr, |
|
917 PR_INTERVAL_NO_TIMEOUT); |
|
918 |
|
919 PRSocketOptionData option; |
|
920 option.option = PR_SockOpt_Nonblocking; |
|
921 option.value.non_blocking = true; |
|
922 PR_SetSocketOption(ci->client_sock, &option); |
|
923 |
|
924 if (ci->client_sock) |
|
925 // Not actually using this PRJob*... |
|
926 //PRJob* job = |
|
927 PR_QueueJob(threads, HandleConnection, ci, true); |
|
928 else |
|
929 delete ci; |
|
930 } |
|
931 } |
|
932 |
|
933 // bogus password func, just don't use passwords. :-P |
|
934 char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg) |
|
935 { |
|
936 if (retry) |
|
937 return nullptr; |
|
938 |
|
939 return PL_strdup(""); |
|
940 } |
|
941 |
|
942 server_info_t* findServerInfo(int portnumber) |
|
943 { |
|
944 for (vector<server_info_t>::iterator it = servers.begin(); |
|
945 it != servers.end(); it++) |
|
946 { |
|
947 if (it->listen_port == portnumber) |
|
948 return &(*it); |
|
949 } |
|
950 |
|
951 return nullptr; |
|
952 } |
|
953 |
|
954 int processConfigLine(char* configLine) |
|
955 { |
|
956 if (*configLine == 0 || *configLine == '#') |
|
957 return 0; |
|
958 |
|
959 char* _caret; |
|
960 char* keyword = strtok2(configLine, ":", &_caret); |
|
961 |
|
962 // Configure usage of http/ssl tunneling proxy behavior |
|
963 if (!strcmp(keyword, "httpproxy")) |
|
964 { |
|
965 char* value = strtok2(_caret, ":", &_caret); |
|
966 if (!strcmp(value, "1")) |
|
967 do_http_proxy = true; |
|
968 |
|
969 return 0; |
|
970 } |
|
971 |
|
972 if (!strcmp(keyword, "websocketserver")) |
|
973 { |
|
974 char* ipstring = strtok2(_caret, ":", &_caret); |
|
975 if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) { |
|
976 LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring)); |
|
977 return 1; |
|
978 } |
|
979 char* remoteport = strtok2(_caret, ":", &_caret); |
|
980 int port = atoi(remoteport); |
|
981 if (port <= 0) { |
|
982 LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport)); |
|
983 return 1; |
|
984 } |
|
985 websocket_server.inet.port = PR_htons(port); |
|
986 return 0; |
|
987 } |
|
988 |
|
989 // Configure the forward address of the target server |
|
990 if (!strcmp(keyword, "forward")) |
|
991 { |
|
992 char* ipstring = strtok2(_caret, ":", &_caret); |
|
993 if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) { |
|
994 LOG_ERROR(("Invalid remote IP address: %s\n", ipstring)); |
|
995 return 1; |
|
996 } |
|
997 char* serverportstring = strtok2(_caret, ":", &_caret); |
|
998 int port = atoi(serverportstring); |
|
999 if (port <= 0) { |
|
1000 LOG_ERROR(("Invalid remote port: %s\n", serverportstring)); |
|
1001 return 1; |
|
1002 } |
|
1003 remote_addr.inet.port = PR_htons(port); |
|
1004 |
|
1005 return 0; |
|
1006 } |
|
1007 |
|
1008 // Configure all listen sockets and port+certificate bindings |
|
1009 if (!strcmp(keyword, "listen")) |
|
1010 { |
|
1011 char* hostname = strtok2(_caret, ":", &_caret); |
|
1012 char* hostportstring = nullptr; |
|
1013 if (strcmp(hostname, "*")) |
|
1014 { |
|
1015 any_host_spec_config = true; |
|
1016 hostportstring = strtok2(_caret, ":", &_caret); |
|
1017 } |
|
1018 |
|
1019 char* serverportstring = strtok2(_caret, ":", &_caret); |
|
1020 char* certnick = strtok2(_caret, ":", &_caret); |
|
1021 |
|
1022 int port = atoi(serverportstring); |
|
1023 if (port <= 0) { |
|
1024 LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); |
|
1025 return 1; |
|
1026 } |
|
1027 |
|
1028 if (server_info_t* existingServer = findServerInfo(port)) |
|
1029 { |
|
1030 char *certnick_copy = new char[strlen(certnick)+1]; |
|
1031 char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; |
|
1032 |
|
1033 strcpy(hostname_copy, hostname); |
|
1034 strcat(hostname_copy, ":"); |
|
1035 strcat(hostname_copy, hostportstring); |
|
1036 strcpy(certnick_copy, certnick); |
|
1037 |
|
1038 PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, hostname_copy, certnick_copy); |
|
1039 if (!entry) { |
|
1040 LOG_ERROR(("Out of memory")); |
|
1041 return 1; |
|
1042 } |
|
1043 } |
|
1044 else |
|
1045 { |
|
1046 server_info_t server; |
|
1047 server.cert_nickname = certnick; |
|
1048 server.listen_port = port; |
|
1049 server.host_cert_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, |
|
1050 PL_CompareStrings, nullptr, nullptr); |
|
1051 if (!server.host_cert_table) |
|
1052 { |
|
1053 LOG_ERROR(("Internal, could not create hash table\n")); |
|
1054 return 1; |
|
1055 } |
|
1056 server.host_clientauth_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, |
|
1057 ClientAuthValueComparator, nullptr, nullptr); |
|
1058 if (!server.host_clientauth_table) |
|
1059 { |
|
1060 LOG_ERROR(("Internal, could not create hash table\n")); |
|
1061 return 1; |
|
1062 } |
|
1063 server.host_redir_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings, |
|
1064 PL_CompareStrings, nullptr, nullptr); |
|
1065 if (!server.host_redir_table) |
|
1066 { |
|
1067 LOG_ERROR(("Internal, could not create hash table\n")); |
|
1068 return 1; |
|
1069 } |
|
1070 servers.push_back(server); |
|
1071 } |
|
1072 |
|
1073 return 0; |
|
1074 } |
|
1075 |
|
1076 if (!strcmp(keyword, "clientauth")) |
|
1077 { |
|
1078 char* hostname = strtok2(_caret, ":", &_caret); |
|
1079 char* hostportstring = strtok2(_caret, ":", &_caret); |
|
1080 char* serverportstring = strtok2(_caret, ":", &_caret); |
|
1081 |
|
1082 int port = atoi(serverportstring); |
|
1083 if (port <= 0) { |
|
1084 LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); |
|
1085 return 1; |
|
1086 } |
|
1087 |
|
1088 if (server_info_t* existingServer = findServerInfo(port)) |
|
1089 { |
|
1090 char* authoptionstring = strtok2(_caret, ":", &_caret); |
|
1091 client_auth_option* authoption = new client_auth_option; |
|
1092 if (!authoption) { |
|
1093 LOG_ERROR(("Out of memory")); |
|
1094 return 1; |
|
1095 } |
|
1096 |
|
1097 if (!strcmp(authoptionstring, "require")) |
|
1098 *authoption = caRequire; |
|
1099 else if (!strcmp(authoptionstring, "request")) |
|
1100 *authoption = caRequest; |
|
1101 else if (!strcmp(authoptionstring, "none")) |
|
1102 *authoption = caNone; |
|
1103 else |
|
1104 { |
|
1105 LOG_ERROR(("Incorrect client auth option modifier for host '%s'", hostname)); |
|
1106 return 1; |
|
1107 } |
|
1108 |
|
1109 any_host_spec_config = true; |
|
1110 |
|
1111 char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; |
|
1112 if (!hostname_copy) { |
|
1113 LOG_ERROR(("Out of memory")); |
|
1114 return 1; |
|
1115 } |
|
1116 |
|
1117 strcpy(hostname_copy, hostname); |
|
1118 strcat(hostname_copy, ":"); |
|
1119 strcat(hostname_copy, hostportstring); |
|
1120 |
|
1121 PLHashEntry* entry = PL_HashTableAdd(existingServer->host_clientauth_table, hostname_copy, authoption); |
|
1122 if (!entry) { |
|
1123 LOG_ERROR(("Out of memory")); |
|
1124 return 1; |
|
1125 } |
|
1126 } |
|
1127 else |
|
1128 { |
|
1129 LOG_ERROR(("Server on port %d for client authentication option is not defined, use 'listen' option first", port)); |
|
1130 return 1; |
|
1131 } |
|
1132 |
|
1133 return 0; |
|
1134 } |
|
1135 |
|
1136 if (!strcmp(keyword, "redirhost")) |
|
1137 { |
|
1138 char* hostname = strtok2(_caret, ":", &_caret); |
|
1139 char* hostportstring = strtok2(_caret, ":", &_caret); |
|
1140 char* serverportstring = strtok2(_caret, ":", &_caret); |
|
1141 |
|
1142 int port = atoi(serverportstring); |
|
1143 if (port <= 0) { |
|
1144 LOG_ERROR(("Invalid port specified: %s\n", serverportstring)); |
|
1145 return 1; |
|
1146 } |
|
1147 |
|
1148 if (server_info_t* existingServer = findServerInfo(port)) |
|
1149 { |
|
1150 char* redirhoststring = strtok2(_caret, ":", &_caret); |
|
1151 |
|
1152 any_host_spec_config = true; |
|
1153 |
|
1154 char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2]; |
|
1155 if (!hostname_copy) { |
|
1156 LOG_ERROR(("Out of memory")); |
|
1157 return 1; |
|
1158 } |
|
1159 |
|
1160 strcpy(hostname_copy, hostname); |
|
1161 strcat(hostname_copy, ":"); |
|
1162 strcat(hostname_copy, hostportstring); |
|
1163 |
|
1164 char *redir_copy = new char[strlen(redirhoststring)+1]; |
|
1165 strcpy(redir_copy, redirhoststring); |
|
1166 PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table, hostname_copy, redir_copy); |
|
1167 if (!entry) { |
|
1168 LOG_ERROR(("Out of memory")); |
|
1169 return 1; |
|
1170 } |
|
1171 } |
|
1172 else |
|
1173 { |
|
1174 LOG_ERROR(("Server on port %d for redirhost option is not defined, use 'listen' option first", port)); |
|
1175 return 1; |
|
1176 } |
|
1177 |
|
1178 return 0; |
|
1179 } |
|
1180 |
|
1181 // Configure the NSS certificate database directory |
|
1182 if (!strcmp(keyword, "certdbdir")) |
|
1183 { |
|
1184 nssconfigdir = strtok2(_caret, "\n", &_caret); |
|
1185 return 0; |
|
1186 } |
|
1187 |
|
1188 LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword)); |
|
1189 return 1; |
|
1190 } |
|
1191 |
|
1192 int parseConfigFile(const char* filePath) |
|
1193 { |
|
1194 FILE* f = fopen(filePath, "r"); |
|
1195 if (!f) |
|
1196 return 1; |
|
1197 |
|
1198 char buffer[1024], *b = buffer; |
|
1199 while (!feof(f)) |
|
1200 { |
|
1201 char c; |
|
1202 fscanf(f, "%c", &c); |
|
1203 switch (c) |
|
1204 { |
|
1205 case '\n': |
|
1206 *b++ = 0; |
|
1207 if (processConfigLine(buffer)) |
|
1208 return 1; |
|
1209 b = buffer; |
|
1210 case '\r': |
|
1211 continue; |
|
1212 default: |
|
1213 *b++ = c; |
|
1214 } |
|
1215 } |
|
1216 |
|
1217 fclose(f); |
|
1218 |
|
1219 // Check mandatory items |
|
1220 if (nssconfigdir.empty()) |
|
1221 { |
|
1222 LOG_ERROR(("Error: missing path to NSS certification database\n,use certdbdir:<path> in the config file\n")); |
|
1223 return 1; |
|
1224 } |
|
1225 |
|
1226 if (any_host_spec_config && !do_http_proxy) |
|
1227 { |
|
1228 LOG_ERROR(("Warning: any host-specific configurations are ignored, add httpproxy:1 to allow them\n")); |
|
1229 } |
|
1230 |
|
1231 return 0; |
|
1232 } |
|
1233 |
|
1234 int freeHostCertHashItems(PLHashEntry *he, int i, void *arg) |
|
1235 { |
|
1236 delete [] (char*)he->key; |
|
1237 delete [] (char*)he->value; |
|
1238 return HT_ENUMERATE_REMOVE; |
|
1239 } |
|
1240 |
|
1241 int freeHostRedirHashItems(PLHashEntry *he, int i, void *arg) |
|
1242 { |
|
1243 delete [] (char*)he->key; |
|
1244 delete [] (char*)he->value; |
|
1245 return HT_ENUMERATE_REMOVE; |
|
1246 } |
|
1247 |
|
1248 int freeClientAuthHashItems(PLHashEntry *he, int i, void *arg) |
|
1249 { |
|
1250 delete [] (char*)he->key; |
|
1251 delete (client_auth_option*)he->value; |
|
1252 return HT_ENUMERATE_REMOVE; |
|
1253 } |
|
1254 |
|
1255 int main(int argc, char** argv) |
|
1256 { |
|
1257 const char* configFilePath; |
|
1258 |
|
1259 const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL"); |
|
1260 gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO; |
|
1261 |
|
1262 if (argc == 1) |
|
1263 configFilePath = "ssltunnel.cfg"; |
|
1264 else |
|
1265 configFilePath = argv[1]; |
|
1266 |
|
1267 memset(&websocket_server, 0, sizeof(PRNetAddr)); |
|
1268 |
|
1269 if (parseConfigFile(configFilePath)) { |
|
1270 LOG_ERROR(("Error: config file \"%s\" missing or formating incorrect\n" |
|
1271 "Specify path to the config file as parameter to ssltunnel or \n" |
|
1272 "create ssltunnel.cfg in the working directory.\n\n" |
|
1273 "Example format of the config file:\n\n" |
|
1274 " # Enable http/ssl tunneling proxy-like behavior.\n" |
|
1275 " # If not specified ssltunnel simply does direct forward.\n" |
|
1276 " httpproxy:1\n\n" |
|
1277 " # Specify path to the certification database used.\n" |
|
1278 " certdbdir:/path/to/certdb\n\n" |
|
1279 " # Forward/proxy all requests in raw to 127.0.0.1:8888.\n" |
|
1280 " forward:127.0.0.1:8888\n\n" |
|
1281 " # Accept connections on port 4443 or 5678 resp. and authenticate\n" |
|
1282 " # to any host ('*') using the 'server cert' or 'server cert 2' resp.\n" |
|
1283 " listen:*:4443:server cert\n" |
|
1284 " listen:*:5678:server cert 2\n\n" |
|
1285 " # Accept connections on port 4443 and authenticate using\n" |
|
1286 " # 'a different cert' when target host is 'my.host.name:443'.\n" |
|
1287 " # This only works in httpproxy mode and has higher priority\n" |
|
1288 " # than the previous option.\n" |
|
1289 " listen:my.host.name:443:4443:a different cert\n\n" |
|
1290 " # To make a specific host require or just request a client certificate\n" |
|
1291 " # to authenticate use the following options. This can only be used\n" |
|
1292 " # in httpproxy mode and only after the 'listen' option has been\n" |
|
1293 " # specified. You also have to specify the tunnel listen port.\n" |
|
1294 " clientauth:requesting-client-cert.host.com:443:4443:request\n" |
|
1295 " clientauth:requiring-client-cert.host.com:443:4443:require\n" |
|
1296 " # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n" |
|
1297 " # instead of the server specified in the 'forward' option.\n" |
|
1298 " websocketserver:127.0.0.1:9999\n", |
|
1299 configFilePath)); |
|
1300 return 1; |
|
1301 } |
|
1302 |
|
1303 // create a thread pool to handle connections |
|
1304 threads = PR_CreateThreadPool(INITIAL_THREADS * servers.size(), |
|
1305 MAX_THREADS * servers.size(), |
|
1306 DEFAULT_STACKSIZE); |
|
1307 if (!threads) { |
|
1308 LOG_ERROR(("Failed to create thread pool\n")); |
|
1309 return 1; |
|
1310 } |
|
1311 |
|
1312 shutdown_lock = PR_NewLock(); |
|
1313 if (!shutdown_lock) { |
|
1314 LOG_ERROR(("Failed to create lock\n")); |
|
1315 PR_ShutdownThreadPool(threads); |
|
1316 return 1; |
|
1317 } |
|
1318 shutdown_condvar = PR_NewCondVar(shutdown_lock); |
|
1319 if (!shutdown_condvar) { |
|
1320 LOG_ERROR(("Failed to create condvar\n")); |
|
1321 PR_ShutdownThreadPool(threads); |
|
1322 PR_DestroyLock(shutdown_lock); |
|
1323 return 1; |
|
1324 } |
|
1325 |
|
1326 PK11_SetPasswordFunc(password_func); |
|
1327 |
|
1328 // Initialize NSS |
|
1329 if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) { |
|
1330 int32_t errorlen = PR_GetErrorTextLength(); |
|
1331 char* err = new char[errorlen+1]; |
|
1332 PR_GetErrorText(err); |
|
1333 LOG_ERROR(("Failed to init NSS: %s", err)); |
|
1334 delete[] err; |
|
1335 PR_ShutdownThreadPool(threads); |
|
1336 PR_DestroyCondVar(shutdown_condvar); |
|
1337 PR_DestroyLock(shutdown_lock); |
|
1338 return 1; |
|
1339 } |
|
1340 |
|
1341 if (NSS_SetDomesticPolicy() != SECSuccess) { |
|
1342 LOG_ERROR(("NSS_SetDomesticPolicy failed\n")); |
|
1343 PR_ShutdownThreadPool(threads); |
|
1344 PR_DestroyCondVar(shutdown_condvar); |
|
1345 PR_DestroyLock(shutdown_lock); |
|
1346 NSS_Shutdown(); |
|
1347 return 1; |
|
1348 } |
|
1349 |
|
1350 // these values should make NSS use the defaults |
|
1351 if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { |
|
1352 LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n")); |
|
1353 PR_ShutdownThreadPool(threads); |
|
1354 PR_DestroyCondVar(shutdown_condvar); |
|
1355 PR_DestroyLock(shutdown_lock); |
|
1356 NSS_Shutdown(); |
|
1357 return 1; |
|
1358 } |
|
1359 |
|
1360 for (vector<server_info_t>::iterator it = servers.begin(); |
|
1361 it != servers.end(); it++) { |
|
1362 // Not actually using this PRJob*... |
|
1363 // PRJob* server_job = |
|
1364 PR_QueueJob(threads, StartServer, &(*it), true); |
|
1365 } |
|
1366 // now wait for someone to tell us to quit |
|
1367 PR_Lock(shutdown_lock); |
|
1368 PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT); |
|
1369 PR_Unlock(shutdown_lock); |
|
1370 shutdown_server = true; |
|
1371 LOG_INFO(("Shutting down...\n")); |
|
1372 // cleanup |
|
1373 PR_ShutdownThreadPool(threads); |
|
1374 PR_JoinThreadPool(threads); |
|
1375 PR_DestroyCondVar(shutdown_condvar); |
|
1376 PR_DestroyLock(shutdown_lock); |
|
1377 if (NSS_Shutdown() == SECFailure) { |
|
1378 LOG_DEBUG(("Leaked NSS objects!\n")); |
|
1379 } |
|
1380 |
|
1381 for (vector<server_info_t>::iterator it = servers.begin(); |
|
1382 it != servers.end(); it++) |
|
1383 { |
|
1384 PL_HashTableEnumerateEntries(it->host_cert_table, freeHostCertHashItems, nullptr); |
|
1385 PL_HashTableEnumerateEntries(it->host_clientauth_table, freeClientAuthHashItems, nullptr); |
|
1386 PL_HashTableEnumerateEntries(it->host_redir_table, freeHostRedirHashItems, nullptr); |
|
1387 PL_HashTableDestroy(it->host_cert_table); |
|
1388 PL_HashTableDestroy(it->host_clientauth_table); |
|
1389 PL_HashTableDestroy(it->host_redir_table); |
|
1390 } |
|
1391 |
|
1392 PR_Cleanup(); |
|
1393 return 0; |
|
1394 } |