michael@0: /* This Source Code Form is subject to the terms of the Mozilla Public michael@0: * License, v. 2.0. If a copy of the MPL was not distributed with this michael@0: * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ michael@0: michael@0: #include "TLSServer.h" michael@0: michael@0: #include michael@0: #include "ScopedNSSTypes.h" michael@0: #include "nspr.h" michael@0: #include "nss.h" michael@0: #include "plarenas.h" michael@0: #include "prenv.h" michael@0: #include "prerror.h" michael@0: #include "prnetdb.h" michael@0: #include "prtime.h" michael@0: #include "ssl.h" michael@0: michael@0: namespace mozilla { namespace test { michael@0: michael@0: static const uint16_t LISTEN_PORT = 8443; michael@0: michael@0: DebugLevel gDebugLevel = DEBUG_ERRORS; michael@0: uint16_t gCallbackPort = 0; michael@0: michael@0: const char DEFAULT_CERT_NICKNAME[] = "localhostAndExampleCom"; michael@0: michael@0: struct Connection michael@0: { michael@0: PRFileDesc *mSocket; michael@0: char mByte; michael@0: michael@0: Connection(PRFileDesc *aSocket); michael@0: ~Connection(); michael@0: }; michael@0: michael@0: Connection::Connection(PRFileDesc *aSocket) michael@0: : mSocket(aSocket) michael@0: , mByte(0) michael@0: {} michael@0: michael@0: Connection::~Connection() michael@0: { michael@0: if (mSocket) { michael@0: PR_Close(mSocket); michael@0: } michael@0: } michael@0: michael@0: void michael@0: PrintPRError(const char *aPrefix) michael@0: { michael@0: const char *err = PR_ErrorToName(PR_GetError()); michael@0: if (err) { michael@0: if (gDebugLevel >= DEBUG_ERRORS) { michael@0: fprintf(stderr, "%s: %s\n", aPrefix, err); michael@0: } michael@0: } else { michael@0: if (gDebugLevel >= DEBUG_ERRORS) { michael@0: fprintf(stderr, "%s\n", aPrefix); michael@0: } michael@0: } michael@0: } michael@0: michael@0: nsresult michael@0: SendAll(PRFileDesc *aSocket, const char *aData, size_t aDataLen) michael@0: { michael@0: if (gDebugLevel >= DEBUG_VERBOSE) { michael@0: fprintf(stderr, "sending '%s'\n", aData); michael@0: } michael@0: michael@0: while (aDataLen > 0) { michael@0: int32_t bytesSent = PR_Send(aSocket, aData, aDataLen, 0, michael@0: PR_INTERVAL_NO_TIMEOUT); michael@0: if (bytesSent == -1) { michael@0: PrintPRError("PR_Send failed"); michael@0: return NS_ERROR_FAILURE; michael@0: } michael@0: michael@0: aDataLen -= bytesSent; michael@0: aData += bytesSent; michael@0: } michael@0: michael@0: return NS_OK; michael@0: } michael@0: michael@0: nsresult michael@0: ReplyToRequest(Connection *aConn) michael@0: { michael@0: // For debugging purposes, SendAll can print out what it's sending. michael@0: // So, any strings we give to it to send need to be null-terminated. michael@0: char buf[2] = { aConn->mByte, 0 }; michael@0: return SendAll(aConn->mSocket, buf, 1); michael@0: } michael@0: michael@0: nsresult michael@0: SetupTLS(Connection *aConn, PRFileDesc *aModelSocket) michael@0: { michael@0: PRFileDesc *sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket); michael@0: if (!sslSocket) { michael@0: PrintPRError("SSL_ImportFD failed"); michael@0: return NS_ERROR_FAILURE; michael@0: } michael@0: aConn->mSocket = sslSocket; michael@0: michael@0: SSL_OptionSet(sslSocket, SSL_SECURITY, true); michael@0: SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false); michael@0: SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true); michael@0: michael@0: SSL_ResetHandshake(sslSocket, /* asServer */ 1); michael@0: michael@0: return NS_OK; michael@0: } michael@0: michael@0: nsresult michael@0: ReadRequest(Connection *aConn) michael@0: { michael@0: int32_t bytesRead = PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0, michael@0: PR_INTERVAL_NO_TIMEOUT); michael@0: if (bytesRead < 0) { michael@0: PrintPRError("PR_Recv failed"); michael@0: return NS_ERROR_FAILURE; michael@0: } else if (bytesRead == 0) { michael@0: PR_SetError(PR_IO_ERROR, 0); michael@0: PrintPRError("PR_Recv EOF in ReadRequest"); michael@0: return NS_ERROR_FAILURE; michael@0: } else { michael@0: if (gDebugLevel >= DEBUG_VERBOSE) { michael@0: fprintf(stderr, "read '0x%hhx'\n", aConn->mByte); michael@0: } michael@0: } michael@0: return NS_OK; michael@0: } michael@0: michael@0: void michael@0: HandleConnection(PRFileDesc *aSocket, PRFileDesc *aModelSocket) michael@0: { michael@0: Connection conn(aSocket); michael@0: nsresult rv = SetupTLS(&conn, aModelSocket); michael@0: if (NS_FAILED(rv)) { michael@0: PR_SetError(PR_INVALID_STATE_ERROR, 0); michael@0: PrintPRError("PR_Recv failed"); michael@0: exit(1); michael@0: } michael@0: michael@0: // TODO: On tests that are expected to fail (e.g. due to a revoked michael@0: // certificate), the client will close the connection wtihout sending us the michael@0: // request byte. In those cases, we should keep going. But, in the cases michael@0: // where the connection is supposed to suceed, we should verify that we michael@0: // successfully receive the request and send the response. michael@0: rv = ReadRequest(&conn); michael@0: if (NS_SUCCEEDED(rv)) { michael@0: rv = ReplyToRequest(&conn); michael@0: } michael@0: } michael@0: michael@0: // returns 0 on success, non-zero on error michael@0: int michael@0: DoCallback() michael@0: { michael@0: ScopedPRFileDesc socket(PR_NewTCPSocket()); michael@0: if (!socket) { michael@0: PrintPRError("PR_NewTCPSocket failed"); michael@0: return 1; michael@0: } michael@0: michael@0: PRNetAddr addr; michael@0: PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr); michael@0: if (PR_Connect(socket, &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) { michael@0: PrintPRError("PR_Connect failed"); michael@0: return 1; michael@0: } michael@0: michael@0: const char *request = "GET / HTTP/1.0\r\n\r\n"; michael@0: SendAll(socket, request, strlen(request)); michael@0: char buf[4096]; michael@0: memset(buf, 0, sizeof(buf)); michael@0: int32_t bytesRead = PR_Recv(socket, buf, sizeof(buf) - 1, 0, michael@0: PR_INTERVAL_NO_TIMEOUT); michael@0: if (bytesRead < 0) { michael@0: PrintPRError("PR_Recv failed 1"); michael@0: return 1; michael@0: } michael@0: if (bytesRead == 0) { michael@0: fprintf(stderr, "PR_Recv eof 1\n"); michael@0: return 1; michael@0: } michael@0: fprintf(stderr, "%s\n", buf); michael@0: return 0; michael@0: } michael@0: michael@0: SECStatus michael@0: ConfigSecureServerWithNamedCert(PRFileDesc *fd, const char *certName, michael@0: /*optional*/ ScopedCERTCertificate *certOut, michael@0: /*optional*/ SSLKEAType *keaOut) michael@0: { michael@0: ScopedCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr)); michael@0: if (!cert) { michael@0: PrintPRError("PK11_FindCertFromNickname failed"); michael@0: return SECFailure; michael@0: } michael@0: michael@0: ScopedSECKEYPrivateKey key(PK11_FindKeyByAnyCert(cert, nullptr)); michael@0: if (!key) { michael@0: PrintPRError("PK11_FindKeyByAnyCert failed"); michael@0: return SECFailure; michael@0: } michael@0: michael@0: SSLKEAType certKEA = NSS_FindCertKEAType(cert); michael@0: michael@0: if (SSL_ConfigSecureServer(fd, cert, key, certKEA) != SECSuccess) { michael@0: PrintPRError("SSL_ConfigSecureServer failed"); michael@0: return SECFailure; michael@0: } michael@0: michael@0: if (certOut) { michael@0: *certOut = cert.forget(); michael@0: } michael@0: michael@0: if (keaOut) { michael@0: *keaOut = certKEA; michael@0: } michael@0: michael@0: return SECSuccess; michael@0: } michael@0: michael@0: int michael@0: StartServer(const char *nssCertDBDir, SSLSNISocketConfig sniSocketConfig, michael@0: void *sniSocketConfigArg) michael@0: { michael@0: const char *debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL"); michael@0: if (debugLevel) { michael@0: int level = atoi(debugLevel); michael@0: switch (level) { michael@0: case DEBUG_ERRORS: gDebugLevel = DEBUG_ERRORS; break; michael@0: case DEBUG_WARNINGS: gDebugLevel = DEBUG_WARNINGS; break; michael@0: case DEBUG_VERBOSE: gDebugLevel = DEBUG_VERBOSE; break; michael@0: default: michael@0: PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL"); michael@0: return 1; michael@0: } michael@0: } michael@0: michael@0: const char *callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT"); michael@0: if (callbackPort) { michael@0: gCallbackPort = atoi(callbackPort); michael@0: } michael@0: michael@0: if (NSS_Init(nssCertDBDir) != SECSuccess) { michael@0: PrintPRError("NSS_Init failed"); michael@0: return 1; michael@0: } michael@0: michael@0: if (NSS_SetDomesticPolicy() != SECSuccess) { michael@0: PrintPRError("NSS_SetDomesticPolicy failed"); michael@0: return 1; michael@0: } michael@0: michael@0: if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { michael@0: PrintPRError("SSL_ConfigServerSessionIDCache failed"); michael@0: return 1; michael@0: } michael@0: michael@0: ScopedPRFileDesc serverSocket(PR_NewTCPSocket()); michael@0: if (!serverSocket) { michael@0: PrintPRError("PR_NewTCPSocket failed"); michael@0: return 1; michael@0: } michael@0: michael@0: PRSocketOptionData socketOption; michael@0: socketOption.option = PR_SockOpt_Reuseaddr; michael@0: socketOption.value.reuse_addr = true; michael@0: PR_SetSocketOption(serverSocket, &socketOption); michael@0: michael@0: PRNetAddr serverAddr; michael@0: PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr); michael@0: if (PR_Bind(serverSocket, &serverAddr) != PR_SUCCESS) { michael@0: PrintPRError("PR_Bind failed"); michael@0: return 1; michael@0: } michael@0: michael@0: if (PR_Listen(serverSocket, 1) != PR_SUCCESS) { michael@0: PrintPRError("PR_Listen failed"); michael@0: return 1; michael@0: } michael@0: michael@0: ScopedPRFileDesc rawModelSocket(PR_NewTCPSocket()); michael@0: if (!rawModelSocket) { michael@0: PrintPRError("PR_NewTCPSocket failed for rawModelSocket"); michael@0: return 1; michael@0: } michael@0: michael@0: ScopedPRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.forget())); michael@0: if (!modelSocket) { michael@0: PrintPRError("SSL_ImportFD of rawModelSocket failed"); michael@0: return 1; michael@0: } michael@0: michael@0: if (SECSuccess != SSL_SNISocketConfigHook(modelSocket, sniSocketConfig, michael@0: sniSocketConfigArg)) { michael@0: PrintPRError("SSL_SNISocketConfigHook failed"); michael@0: return 1; michael@0: } michael@0: michael@0: // We have to configure the server with a certificate, but it's not one michael@0: // we're actually going to end up using. In the SNI callback, we pick michael@0: // the right certificate for the connection. michael@0: if (SECSuccess != ConfigSecureServerWithNamedCert(modelSocket, michael@0: DEFAULT_CERT_NICKNAME, michael@0: nullptr, nullptr)) { michael@0: return 1; michael@0: } michael@0: michael@0: if (gCallbackPort != 0) { michael@0: if (DoCallback()) { michael@0: return 1; michael@0: } michael@0: } michael@0: michael@0: while (true) { michael@0: PRNetAddr clientAddr; michael@0: PRFileDesc *clientSocket = PR_Accept(serverSocket, &clientAddr, michael@0: PR_INTERVAL_NO_TIMEOUT); michael@0: HandleConnection(clientSocket, modelSocket); michael@0: } michael@0: michael@0: return 0; michael@0: } michael@0: michael@0: } } // namespace mozilla::test