1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000 1.2 +++ b/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp Wed Dec 31 06:09:35 2014 +0100 1.3 @@ -0,0 +1,328 @@ 1.4 +/* This Source Code Form is subject to the terms of the Mozilla Public 1.5 + * License, v. 2.0. If a copy of the MPL was not distributed with this 1.6 + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ 1.7 + 1.8 +#include "TLSServer.h" 1.9 + 1.10 +#include <stdio.h> 1.11 +#include "ScopedNSSTypes.h" 1.12 +#include "nspr.h" 1.13 +#include "nss.h" 1.14 +#include "plarenas.h" 1.15 +#include "prenv.h" 1.16 +#include "prerror.h" 1.17 +#include "prnetdb.h" 1.18 +#include "prtime.h" 1.19 +#include "ssl.h" 1.20 + 1.21 +namespace mozilla { namespace test { 1.22 + 1.23 +static const uint16_t LISTEN_PORT = 8443; 1.24 + 1.25 +DebugLevel gDebugLevel = DEBUG_ERRORS; 1.26 +uint16_t gCallbackPort = 0; 1.27 + 1.28 +const char DEFAULT_CERT_NICKNAME[] = "localhostAndExampleCom"; 1.29 + 1.30 +struct Connection 1.31 +{ 1.32 + PRFileDesc *mSocket; 1.33 + char mByte; 1.34 + 1.35 + Connection(PRFileDesc *aSocket); 1.36 + ~Connection(); 1.37 +}; 1.38 + 1.39 +Connection::Connection(PRFileDesc *aSocket) 1.40 +: mSocket(aSocket) 1.41 +, mByte(0) 1.42 +{} 1.43 + 1.44 +Connection::~Connection() 1.45 +{ 1.46 + if (mSocket) { 1.47 + PR_Close(mSocket); 1.48 + } 1.49 +} 1.50 + 1.51 +void 1.52 +PrintPRError(const char *aPrefix) 1.53 +{ 1.54 + const char *err = PR_ErrorToName(PR_GetError()); 1.55 + if (err) { 1.56 + if (gDebugLevel >= DEBUG_ERRORS) { 1.57 + fprintf(stderr, "%s: %s\n", aPrefix, err); 1.58 + } 1.59 + } else { 1.60 + if (gDebugLevel >= DEBUG_ERRORS) { 1.61 + fprintf(stderr, "%s\n", aPrefix); 1.62 + } 1.63 + } 1.64 +} 1.65 + 1.66 +nsresult 1.67 +SendAll(PRFileDesc *aSocket, const char *aData, size_t aDataLen) 1.68 +{ 1.69 + if (gDebugLevel >= DEBUG_VERBOSE) { 1.70 + fprintf(stderr, "sending '%s'\n", aData); 1.71 + } 1.72 + 1.73 + while (aDataLen > 0) { 1.74 + int32_t bytesSent = PR_Send(aSocket, aData, aDataLen, 0, 1.75 + PR_INTERVAL_NO_TIMEOUT); 1.76 + if (bytesSent == -1) { 1.77 + PrintPRError("PR_Send failed"); 1.78 + return NS_ERROR_FAILURE; 1.79 + } 1.80 + 1.81 + aDataLen -= bytesSent; 1.82 + aData += bytesSent; 1.83 + } 1.84 + 1.85 + return NS_OK; 1.86 +} 1.87 + 1.88 +nsresult 1.89 +ReplyToRequest(Connection *aConn) 1.90 +{ 1.91 + // For debugging purposes, SendAll can print out what it's sending. 1.92 + // So, any strings we give to it to send need to be null-terminated. 1.93 + char buf[2] = { aConn->mByte, 0 }; 1.94 + return SendAll(aConn->mSocket, buf, 1); 1.95 +} 1.96 + 1.97 +nsresult 1.98 +SetupTLS(Connection *aConn, PRFileDesc *aModelSocket) 1.99 +{ 1.100 + PRFileDesc *sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket); 1.101 + if (!sslSocket) { 1.102 + PrintPRError("SSL_ImportFD failed"); 1.103 + return NS_ERROR_FAILURE; 1.104 + } 1.105 + aConn->mSocket = sslSocket; 1.106 + 1.107 + SSL_OptionSet(sslSocket, SSL_SECURITY, true); 1.108 + SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false); 1.109 + SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true); 1.110 + 1.111 + SSL_ResetHandshake(sslSocket, /* asServer */ 1); 1.112 + 1.113 + return NS_OK; 1.114 +} 1.115 + 1.116 +nsresult 1.117 +ReadRequest(Connection *aConn) 1.118 +{ 1.119 + int32_t bytesRead = PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0, 1.120 + PR_INTERVAL_NO_TIMEOUT); 1.121 + if (bytesRead < 0) { 1.122 + PrintPRError("PR_Recv failed"); 1.123 + return NS_ERROR_FAILURE; 1.124 + } else if (bytesRead == 0) { 1.125 + PR_SetError(PR_IO_ERROR, 0); 1.126 + PrintPRError("PR_Recv EOF in ReadRequest"); 1.127 + return NS_ERROR_FAILURE; 1.128 + } else { 1.129 + if (gDebugLevel >= DEBUG_VERBOSE) { 1.130 + fprintf(stderr, "read '0x%hhx'\n", aConn->mByte); 1.131 + } 1.132 + } 1.133 + return NS_OK; 1.134 +} 1.135 + 1.136 +void 1.137 +HandleConnection(PRFileDesc *aSocket, PRFileDesc *aModelSocket) 1.138 +{ 1.139 + Connection conn(aSocket); 1.140 + nsresult rv = SetupTLS(&conn, aModelSocket); 1.141 + if (NS_FAILED(rv)) { 1.142 + PR_SetError(PR_INVALID_STATE_ERROR, 0); 1.143 + PrintPRError("PR_Recv failed"); 1.144 + exit(1); 1.145 + } 1.146 + 1.147 + // TODO: On tests that are expected to fail (e.g. due to a revoked 1.148 + // certificate), the client will close the connection wtihout sending us the 1.149 + // request byte. In those cases, we should keep going. But, in the cases 1.150 + // where the connection is supposed to suceed, we should verify that we 1.151 + // successfully receive the request and send the response. 1.152 + rv = ReadRequest(&conn); 1.153 + if (NS_SUCCEEDED(rv)) { 1.154 + rv = ReplyToRequest(&conn); 1.155 + } 1.156 +} 1.157 + 1.158 +// returns 0 on success, non-zero on error 1.159 +int 1.160 +DoCallback() 1.161 +{ 1.162 + ScopedPRFileDesc socket(PR_NewTCPSocket()); 1.163 + if (!socket) { 1.164 + PrintPRError("PR_NewTCPSocket failed"); 1.165 + return 1; 1.166 + } 1.167 + 1.168 + PRNetAddr addr; 1.169 + PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr); 1.170 + if (PR_Connect(socket, &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) { 1.171 + PrintPRError("PR_Connect failed"); 1.172 + return 1; 1.173 + } 1.174 + 1.175 + const char *request = "GET / HTTP/1.0\r\n\r\n"; 1.176 + SendAll(socket, request, strlen(request)); 1.177 + char buf[4096]; 1.178 + memset(buf, 0, sizeof(buf)); 1.179 + int32_t bytesRead = PR_Recv(socket, buf, sizeof(buf) - 1, 0, 1.180 + PR_INTERVAL_NO_TIMEOUT); 1.181 + if (bytesRead < 0) { 1.182 + PrintPRError("PR_Recv failed 1"); 1.183 + return 1; 1.184 + } 1.185 + if (bytesRead == 0) { 1.186 + fprintf(stderr, "PR_Recv eof 1\n"); 1.187 + return 1; 1.188 + } 1.189 + fprintf(stderr, "%s\n", buf); 1.190 + return 0; 1.191 +} 1.192 + 1.193 +SECStatus 1.194 +ConfigSecureServerWithNamedCert(PRFileDesc *fd, const char *certName, 1.195 + /*optional*/ ScopedCERTCertificate *certOut, 1.196 + /*optional*/ SSLKEAType *keaOut) 1.197 +{ 1.198 + ScopedCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr)); 1.199 + if (!cert) { 1.200 + PrintPRError("PK11_FindCertFromNickname failed"); 1.201 + return SECFailure; 1.202 + } 1.203 + 1.204 + ScopedSECKEYPrivateKey key(PK11_FindKeyByAnyCert(cert, nullptr)); 1.205 + if (!key) { 1.206 + PrintPRError("PK11_FindKeyByAnyCert failed"); 1.207 + return SECFailure; 1.208 + } 1.209 + 1.210 + SSLKEAType certKEA = NSS_FindCertKEAType(cert); 1.211 + 1.212 + if (SSL_ConfigSecureServer(fd, cert, key, certKEA) != SECSuccess) { 1.213 + PrintPRError("SSL_ConfigSecureServer failed"); 1.214 + return SECFailure; 1.215 + } 1.216 + 1.217 + if (certOut) { 1.218 + *certOut = cert.forget(); 1.219 + } 1.220 + 1.221 + if (keaOut) { 1.222 + *keaOut = certKEA; 1.223 + } 1.224 + 1.225 + return SECSuccess; 1.226 +} 1.227 + 1.228 +int 1.229 +StartServer(const char *nssCertDBDir, SSLSNISocketConfig sniSocketConfig, 1.230 + void *sniSocketConfigArg) 1.231 +{ 1.232 + const char *debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL"); 1.233 + if (debugLevel) { 1.234 + int level = atoi(debugLevel); 1.235 + switch (level) { 1.236 + case DEBUG_ERRORS: gDebugLevel = DEBUG_ERRORS; break; 1.237 + case DEBUG_WARNINGS: gDebugLevel = DEBUG_WARNINGS; break; 1.238 + case DEBUG_VERBOSE: gDebugLevel = DEBUG_VERBOSE; break; 1.239 + default: 1.240 + PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL"); 1.241 + return 1; 1.242 + } 1.243 + } 1.244 + 1.245 + const char *callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT"); 1.246 + if (callbackPort) { 1.247 + gCallbackPort = atoi(callbackPort); 1.248 + } 1.249 + 1.250 + if (NSS_Init(nssCertDBDir) != SECSuccess) { 1.251 + PrintPRError("NSS_Init failed"); 1.252 + return 1; 1.253 + } 1.254 + 1.255 + if (NSS_SetDomesticPolicy() != SECSuccess) { 1.256 + PrintPRError("NSS_SetDomesticPolicy failed"); 1.257 + return 1; 1.258 + } 1.259 + 1.260 + if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { 1.261 + PrintPRError("SSL_ConfigServerSessionIDCache failed"); 1.262 + return 1; 1.263 + } 1.264 + 1.265 + ScopedPRFileDesc serverSocket(PR_NewTCPSocket()); 1.266 + if (!serverSocket) { 1.267 + PrintPRError("PR_NewTCPSocket failed"); 1.268 + return 1; 1.269 + } 1.270 + 1.271 + PRSocketOptionData socketOption; 1.272 + socketOption.option = PR_SockOpt_Reuseaddr; 1.273 + socketOption.value.reuse_addr = true; 1.274 + PR_SetSocketOption(serverSocket, &socketOption); 1.275 + 1.276 + PRNetAddr serverAddr; 1.277 + PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr); 1.278 + if (PR_Bind(serverSocket, &serverAddr) != PR_SUCCESS) { 1.279 + PrintPRError("PR_Bind failed"); 1.280 + return 1; 1.281 + } 1.282 + 1.283 + if (PR_Listen(serverSocket, 1) != PR_SUCCESS) { 1.284 + PrintPRError("PR_Listen failed"); 1.285 + return 1; 1.286 + } 1.287 + 1.288 + ScopedPRFileDesc rawModelSocket(PR_NewTCPSocket()); 1.289 + if (!rawModelSocket) { 1.290 + PrintPRError("PR_NewTCPSocket failed for rawModelSocket"); 1.291 + return 1; 1.292 + } 1.293 + 1.294 + ScopedPRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.forget())); 1.295 + if (!modelSocket) { 1.296 + PrintPRError("SSL_ImportFD of rawModelSocket failed"); 1.297 + return 1; 1.298 + } 1.299 + 1.300 + if (SECSuccess != SSL_SNISocketConfigHook(modelSocket, sniSocketConfig, 1.301 + sniSocketConfigArg)) { 1.302 + PrintPRError("SSL_SNISocketConfigHook failed"); 1.303 + return 1; 1.304 + } 1.305 + 1.306 + // We have to configure the server with a certificate, but it's not one 1.307 + // we're actually going to end up using. In the SNI callback, we pick 1.308 + // the right certificate for the connection. 1.309 + if (SECSuccess != ConfigSecureServerWithNamedCert(modelSocket, 1.310 + DEFAULT_CERT_NICKNAME, 1.311 + nullptr, nullptr)) { 1.312 + return 1; 1.313 + } 1.314 + 1.315 + if (gCallbackPort != 0) { 1.316 + if (DoCallback()) { 1.317 + return 1; 1.318 + } 1.319 + } 1.320 + 1.321 + while (true) { 1.322 + PRNetAddr clientAddr; 1.323 + PRFileDesc *clientSocket = PR_Accept(serverSocket, &clientAddr, 1.324 + PR_INTERVAL_NO_TIMEOUT); 1.325 + HandleConnection(clientSocket, modelSocket); 1.326 + } 1.327 + 1.328 + return 0; 1.329 +} 1.330 + 1.331 +} } // namespace mozilla::test