diff -r 000000000000 -r 6474c204b198 security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp Wed Dec 31 06:09:35 2014 +0100 @@ -0,0 +1,328 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "TLSServer.h" + +#include +#include "ScopedNSSTypes.h" +#include "nspr.h" +#include "nss.h" +#include "plarenas.h" +#include "prenv.h" +#include "prerror.h" +#include "prnetdb.h" +#include "prtime.h" +#include "ssl.h" + +namespace mozilla { namespace test { + +static const uint16_t LISTEN_PORT = 8443; + +DebugLevel gDebugLevel = DEBUG_ERRORS; +uint16_t gCallbackPort = 0; + +const char DEFAULT_CERT_NICKNAME[] = "localhostAndExampleCom"; + +struct Connection +{ + PRFileDesc *mSocket; + char mByte; + + Connection(PRFileDesc *aSocket); + ~Connection(); +}; + +Connection::Connection(PRFileDesc *aSocket) +: mSocket(aSocket) +, mByte(0) +{} + +Connection::~Connection() +{ + if (mSocket) { + PR_Close(mSocket); + } +} + +void +PrintPRError(const char *aPrefix) +{ + const char *err = PR_ErrorToName(PR_GetError()); + if (err) { + if (gDebugLevel >= DEBUG_ERRORS) { + fprintf(stderr, "%s: %s\n", aPrefix, err); + } + } else { + if (gDebugLevel >= DEBUG_ERRORS) { + fprintf(stderr, "%s\n", aPrefix); + } + } +} + +nsresult +SendAll(PRFileDesc *aSocket, const char *aData, size_t aDataLen) +{ + if (gDebugLevel >= DEBUG_VERBOSE) { + fprintf(stderr, "sending '%s'\n", aData); + } + + while (aDataLen > 0) { + int32_t bytesSent = PR_Send(aSocket, aData, aDataLen, 0, + PR_INTERVAL_NO_TIMEOUT); + if (bytesSent == -1) { + PrintPRError("PR_Send failed"); + return NS_ERROR_FAILURE; + } + + aDataLen -= bytesSent; + aData += bytesSent; + } + + return NS_OK; +} + +nsresult +ReplyToRequest(Connection *aConn) +{ + // For debugging purposes, SendAll can print out what it's sending. + // So, any strings we give to it to send need to be null-terminated. + char buf[2] = { aConn->mByte, 0 }; + return SendAll(aConn->mSocket, buf, 1); +} + +nsresult +SetupTLS(Connection *aConn, PRFileDesc *aModelSocket) +{ + PRFileDesc *sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket); + if (!sslSocket) { + PrintPRError("SSL_ImportFD failed"); + return NS_ERROR_FAILURE; + } + aConn->mSocket = sslSocket; + + SSL_OptionSet(sslSocket, SSL_SECURITY, true); + SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false); + SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true); + + SSL_ResetHandshake(sslSocket, /* asServer */ 1); + + return NS_OK; +} + +nsresult +ReadRequest(Connection *aConn) +{ + int32_t bytesRead = PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0, + PR_INTERVAL_NO_TIMEOUT); + if (bytesRead < 0) { + PrintPRError("PR_Recv failed"); + return NS_ERROR_FAILURE; + } else if (bytesRead == 0) { + PR_SetError(PR_IO_ERROR, 0); + PrintPRError("PR_Recv EOF in ReadRequest"); + return NS_ERROR_FAILURE; + } else { + if (gDebugLevel >= DEBUG_VERBOSE) { + fprintf(stderr, "read '0x%hhx'\n", aConn->mByte); + } + } + return NS_OK; +} + +void +HandleConnection(PRFileDesc *aSocket, PRFileDesc *aModelSocket) +{ + Connection conn(aSocket); + nsresult rv = SetupTLS(&conn, aModelSocket); + if (NS_FAILED(rv)) { + PR_SetError(PR_INVALID_STATE_ERROR, 0); + PrintPRError("PR_Recv failed"); + exit(1); + } + + // TODO: On tests that are expected to fail (e.g. due to a revoked + // certificate), the client will close the connection wtihout sending us the + // request byte. In those cases, we should keep going. But, in the cases + // where the connection is supposed to suceed, we should verify that we + // successfully receive the request and send the response. + rv = ReadRequest(&conn); + if (NS_SUCCEEDED(rv)) { + rv = ReplyToRequest(&conn); + } +} + +// returns 0 on success, non-zero on error +int +DoCallback() +{ + ScopedPRFileDesc socket(PR_NewTCPSocket()); + if (!socket) { + PrintPRError("PR_NewTCPSocket failed"); + return 1; + } + + PRNetAddr addr; + PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr); + if (PR_Connect(socket, &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) { + PrintPRError("PR_Connect failed"); + return 1; + } + + const char *request = "GET / HTTP/1.0\r\n\r\n"; + SendAll(socket, request, strlen(request)); + char buf[4096]; + memset(buf, 0, sizeof(buf)); + int32_t bytesRead = PR_Recv(socket, buf, sizeof(buf) - 1, 0, + PR_INTERVAL_NO_TIMEOUT); + if (bytesRead < 0) { + PrintPRError("PR_Recv failed 1"); + return 1; + } + if (bytesRead == 0) { + fprintf(stderr, "PR_Recv eof 1\n"); + return 1; + } + fprintf(stderr, "%s\n", buf); + return 0; +} + +SECStatus +ConfigSecureServerWithNamedCert(PRFileDesc *fd, const char *certName, + /*optional*/ ScopedCERTCertificate *certOut, + /*optional*/ SSLKEAType *keaOut) +{ + ScopedCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr)); + if (!cert) { + PrintPRError("PK11_FindCertFromNickname failed"); + return SECFailure; + } + + ScopedSECKEYPrivateKey key(PK11_FindKeyByAnyCert(cert, nullptr)); + if (!key) { + PrintPRError("PK11_FindKeyByAnyCert failed"); + return SECFailure; + } + + SSLKEAType certKEA = NSS_FindCertKEAType(cert); + + if (SSL_ConfigSecureServer(fd, cert, key, certKEA) != SECSuccess) { + PrintPRError("SSL_ConfigSecureServer failed"); + return SECFailure; + } + + if (certOut) { + *certOut = cert.forget(); + } + + if (keaOut) { + *keaOut = certKEA; + } + + return SECSuccess; +} + +int +StartServer(const char *nssCertDBDir, SSLSNISocketConfig sniSocketConfig, + void *sniSocketConfigArg) +{ + const char *debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL"); + if (debugLevel) { + int level = atoi(debugLevel); + switch (level) { + case DEBUG_ERRORS: gDebugLevel = DEBUG_ERRORS; break; + case DEBUG_WARNINGS: gDebugLevel = DEBUG_WARNINGS; break; + case DEBUG_VERBOSE: gDebugLevel = DEBUG_VERBOSE; break; + default: + PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL"); + return 1; + } + } + + const char *callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT"); + if (callbackPort) { + gCallbackPort = atoi(callbackPort); + } + + if (NSS_Init(nssCertDBDir) != SECSuccess) { + PrintPRError("NSS_Init failed"); + return 1; + } + + if (NSS_SetDomesticPolicy() != SECSuccess) { + PrintPRError("NSS_SetDomesticPolicy failed"); + return 1; + } + + if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { + PrintPRError("SSL_ConfigServerSessionIDCache failed"); + return 1; + } + + ScopedPRFileDesc serverSocket(PR_NewTCPSocket()); + if (!serverSocket) { + PrintPRError("PR_NewTCPSocket failed"); + return 1; + } + + PRSocketOptionData socketOption; + socketOption.option = PR_SockOpt_Reuseaddr; + socketOption.value.reuse_addr = true; + PR_SetSocketOption(serverSocket, &socketOption); + + PRNetAddr serverAddr; + PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr); + if (PR_Bind(serverSocket, &serverAddr) != PR_SUCCESS) { + PrintPRError("PR_Bind failed"); + return 1; + } + + if (PR_Listen(serverSocket, 1) != PR_SUCCESS) { + PrintPRError("PR_Listen failed"); + return 1; + } + + ScopedPRFileDesc rawModelSocket(PR_NewTCPSocket()); + if (!rawModelSocket) { + PrintPRError("PR_NewTCPSocket failed for rawModelSocket"); + return 1; + } + + ScopedPRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.forget())); + if (!modelSocket) { + PrintPRError("SSL_ImportFD of rawModelSocket failed"); + return 1; + } + + if (SECSuccess != SSL_SNISocketConfigHook(modelSocket, sniSocketConfig, + sniSocketConfigArg)) { + PrintPRError("SSL_SNISocketConfigHook failed"); + return 1; + } + + // We have to configure the server with a certificate, but it's not one + // we're actually going to end up using. In the SNI callback, we pick + // the right certificate for the connection. + if (SECSuccess != ConfigSecureServerWithNamedCert(modelSocket, + DEFAULT_CERT_NICKNAME, + nullptr, nullptr)) { + return 1; + } + + if (gCallbackPort != 0) { + if (DoCallback()) { + return 1; + } + } + + while (true) { + PRNetAddr clientAddr; + PRFileDesc *clientSocket = PR_Accept(serverSocket, &clientAddr, + PR_INTERVAL_NO_TIMEOUT); + HandleConnection(clientSocket, modelSocket); + } + + return 0; +} + +} } // namespace mozilla::test