security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp

changeset 0
6474c204b198
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/security/manager/ssl/tests/unit/tlsserver/lib/TLSServer.cpp	Wed Dec 31 06:09:35 2014 +0100
     1.3 @@ -0,0 +1,328 @@
     1.4 +/* This Source Code Form is subject to the terms of the Mozilla Public
     1.5 + * License, v. 2.0. If a copy of the MPL was not distributed with this
     1.6 + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
     1.7 +
     1.8 +#include "TLSServer.h"
     1.9 +
    1.10 +#include <stdio.h>
    1.11 +#include "ScopedNSSTypes.h"
    1.12 +#include "nspr.h"
    1.13 +#include "nss.h"
    1.14 +#include "plarenas.h"
    1.15 +#include "prenv.h"
    1.16 +#include "prerror.h"
    1.17 +#include "prnetdb.h"
    1.18 +#include "prtime.h"
    1.19 +#include "ssl.h"
    1.20 +
    1.21 +namespace mozilla { namespace test {
    1.22 +
    1.23 +static const uint16_t LISTEN_PORT = 8443;
    1.24 +
    1.25 +DebugLevel gDebugLevel = DEBUG_ERRORS;
    1.26 +uint16_t gCallbackPort = 0;
    1.27 +
    1.28 +const char DEFAULT_CERT_NICKNAME[] = "localhostAndExampleCom";
    1.29 +
    1.30 +struct Connection
    1.31 +{
    1.32 +  PRFileDesc *mSocket;
    1.33 +  char mByte;
    1.34 +
    1.35 +  Connection(PRFileDesc *aSocket);
    1.36 +  ~Connection();
    1.37 +};
    1.38 +
    1.39 +Connection::Connection(PRFileDesc *aSocket)
    1.40 +: mSocket(aSocket)
    1.41 +, mByte(0)
    1.42 +{}
    1.43 +
    1.44 +Connection::~Connection()
    1.45 +{
    1.46 +  if (mSocket) {
    1.47 +    PR_Close(mSocket);
    1.48 +  }
    1.49 +}
    1.50 +
    1.51 +void
    1.52 +PrintPRError(const char *aPrefix)
    1.53 +{
    1.54 +  const char *err = PR_ErrorToName(PR_GetError());
    1.55 +  if (err) {
    1.56 +    if (gDebugLevel >= DEBUG_ERRORS) {
    1.57 +      fprintf(stderr, "%s: %s\n", aPrefix, err);
    1.58 +    }
    1.59 +  } else {
    1.60 +    if (gDebugLevel >= DEBUG_ERRORS) {
    1.61 +      fprintf(stderr, "%s\n", aPrefix);
    1.62 +    }
    1.63 +  }
    1.64 +}
    1.65 +
    1.66 +nsresult
    1.67 +SendAll(PRFileDesc *aSocket, const char *aData, size_t aDataLen)
    1.68 +{
    1.69 +  if (gDebugLevel >= DEBUG_VERBOSE) {
    1.70 +    fprintf(stderr, "sending '%s'\n", aData);
    1.71 +  }
    1.72 +
    1.73 +  while (aDataLen > 0) {
    1.74 +    int32_t bytesSent = PR_Send(aSocket, aData, aDataLen, 0,
    1.75 +                                PR_INTERVAL_NO_TIMEOUT);
    1.76 +    if (bytesSent == -1) {
    1.77 +      PrintPRError("PR_Send failed");
    1.78 +      return NS_ERROR_FAILURE;
    1.79 +    }
    1.80 +
    1.81 +    aDataLen -= bytesSent;
    1.82 +    aData += bytesSent;
    1.83 +  }
    1.84 +
    1.85 +  return NS_OK;
    1.86 +}
    1.87 +
    1.88 +nsresult
    1.89 +ReplyToRequest(Connection *aConn)
    1.90 +{
    1.91 +  // For debugging purposes, SendAll can print out what it's sending.
    1.92 +  // So, any strings we give to it to send need to be null-terminated.
    1.93 +  char buf[2] = { aConn->mByte, 0 };
    1.94 +  return SendAll(aConn->mSocket, buf, 1);
    1.95 +}
    1.96 +
    1.97 +nsresult
    1.98 +SetupTLS(Connection *aConn, PRFileDesc *aModelSocket)
    1.99 +{
   1.100 +  PRFileDesc *sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket);
   1.101 +  if (!sslSocket) {
   1.102 +    PrintPRError("SSL_ImportFD failed");
   1.103 +    return NS_ERROR_FAILURE;
   1.104 +  }
   1.105 +  aConn->mSocket = sslSocket;
   1.106 +
   1.107 +  SSL_OptionSet(sslSocket, SSL_SECURITY, true);
   1.108 +  SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false);
   1.109 +  SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true);
   1.110 +
   1.111 +  SSL_ResetHandshake(sslSocket, /* asServer */ 1);
   1.112 +
   1.113 +  return NS_OK;
   1.114 +}
   1.115 +
   1.116 +nsresult
   1.117 +ReadRequest(Connection *aConn)
   1.118 +{
   1.119 +  int32_t bytesRead = PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0,
   1.120 +                              PR_INTERVAL_NO_TIMEOUT);
   1.121 +  if (bytesRead < 0) {
   1.122 +    PrintPRError("PR_Recv failed");
   1.123 +    return NS_ERROR_FAILURE;
   1.124 +  } else if (bytesRead == 0) {
   1.125 +    PR_SetError(PR_IO_ERROR, 0);
   1.126 +    PrintPRError("PR_Recv EOF in ReadRequest");
   1.127 +    return NS_ERROR_FAILURE;
   1.128 +  } else {
   1.129 +    if (gDebugLevel >= DEBUG_VERBOSE) {
   1.130 +      fprintf(stderr, "read '0x%hhx'\n", aConn->mByte);
   1.131 +    }
   1.132 +  }
   1.133 +  return NS_OK;
   1.134 +}
   1.135 +
   1.136 +void
   1.137 +HandleConnection(PRFileDesc *aSocket, PRFileDesc *aModelSocket)
   1.138 +{
   1.139 +  Connection conn(aSocket);
   1.140 +  nsresult rv = SetupTLS(&conn, aModelSocket);
   1.141 +  if (NS_FAILED(rv)) {
   1.142 +    PR_SetError(PR_INVALID_STATE_ERROR, 0);
   1.143 +    PrintPRError("PR_Recv failed");
   1.144 +    exit(1);
   1.145 +  }
   1.146 +
   1.147 +  // TODO: On tests that are expected to fail (e.g. due to a revoked
   1.148 +  // certificate), the client will close the connection wtihout sending us the
   1.149 +  // request byte. In those cases, we should keep going. But, in the cases
   1.150 +  // where the connection is supposed to suceed, we should verify that we
   1.151 +  // successfully receive the request and send the response.
   1.152 +  rv = ReadRequest(&conn);
   1.153 +  if (NS_SUCCEEDED(rv)) {
   1.154 +    rv = ReplyToRequest(&conn);
   1.155 +  }
   1.156 +}
   1.157 +
   1.158 +// returns 0 on success, non-zero on error
   1.159 +int
   1.160 +DoCallback()
   1.161 +{
   1.162 +  ScopedPRFileDesc socket(PR_NewTCPSocket());
   1.163 +  if (!socket) {
   1.164 +    PrintPRError("PR_NewTCPSocket failed");
   1.165 +    return 1;
   1.166 +  }
   1.167 +
   1.168 +  PRNetAddr addr;
   1.169 +  PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr);
   1.170 +  if (PR_Connect(socket, &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) {
   1.171 +    PrintPRError("PR_Connect failed");
   1.172 +    return 1;
   1.173 +  }
   1.174 +
   1.175 +  const char *request = "GET / HTTP/1.0\r\n\r\n";
   1.176 +  SendAll(socket, request, strlen(request));
   1.177 +  char buf[4096];
   1.178 +  memset(buf, 0, sizeof(buf));
   1.179 +  int32_t bytesRead = PR_Recv(socket, buf, sizeof(buf) - 1, 0,
   1.180 +                              PR_INTERVAL_NO_TIMEOUT);
   1.181 +  if (bytesRead < 0) {
   1.182 +    PrintPRError("PR_Recv failed 1");
   1.183 +    return 1;
   1.184 +  }
   1.185 +  if (bytesRead == 0) {
   1.186 +    fprintf(stderr, "PR_Recv eof 1\n");
   1.187 +    return 1;
   1.188 +  }
   1.189 +  fprintf(stderr, "%s\n", buf);
   1.190 +  return 0;
   1.191 +}
   1.192 +
   1.193 +SECStatus
   1.194 +ConfigSecureServerWithNamedCert(PRFileDesc *fd, const char *certName,
   1.195 +                                /*optional*/ ScopedCERTCertificate *certOut,
   1.196 +                                /*optional*/ SSLKEAType *keaOut)
   1.197 +{
   1.198 +  ScopedCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr));
   1.199 +  if (!cert) {
   1.200 +    PrintPRError("PK11_FindCertFromNickname failed");
   1.201 +    return SECFailure;
   1.202 +  }
   1.203 +
   1.204 +  ScopedSECKEYPrivateKey key(PK11_FindKeyByAnyCert(cert, nullptr));
   1.205 +  if (!key) {
   1.206 +    PrintPRError("PK11_FindKeyByAnyCert failed");
   1.207 +    return SECFailure;
   1.208 +  }
   1.209 +
   1.210 +  SSLKEAType certKEA = NSS_FindCertKEAType(cert);
   1.211 +
   1.212 +  if (SSL_ConfigSecureServer(fd, cert, key, certKEA) != SECSuccess) {
   1.213 +    PrintPRError("SSL_ConfigSecureServer failed");
   1.214 +    return SECFailure;
   1.215 +  }
   1.216 +
   1.217 +  if (certOut) {
   1.218 +    *certOut = cert.forget();
   1.219 +  }
   1.220 +
   1.221 +  if (keaOut) {
   1.222 +    *keaOut = certKEA;
   1.223 +  }
   1.224 +
   1.225 +  return SECSuccess;
   1.226 +}
   1.227 +
   1.228 +int
   1.229 +StartServer(const char *nssCertDBDir, SSLSNISocketConfig sniSocketConfig,
   1.230 +            void *sniSocketConfigArg)
   1.231 +{
   1.232 +  const char *debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL");
   1.233 +  if (debugLevel) {
   1.234 +    int level = atoi(debugLevel);
   1.235 +    switch (level) {
   1.236 +      case DEBUG_ERRORS: gDebugLevel = DEBUG_ERRORS; break;
   1.237 +      case DEBUG_WARNINGS: gDebugLevel = DEBUG_WARNINGS; break;
   1.238 +      case DEBUG_VERBOSE: gDebugLevel = DEBUG_VERBOSE; break;
   1.239 +      default:
   1.240 +        PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL");
   1.241 +        return 1;
   1.242 +    }
   1.243 +  }
   1.244 +
   1.245 +  const char *callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT");
   1.246 +  if (callbackPort) {
   1.247 +    gCallbackPort = atoi(callbackPort);
   1.248 +  }
   1.249 +
   1.250 +  if (NSS_Init(nssCertDBDir) != SECSuccess) {
   1.251 +    PrintPRError("NSS_Init failed");
   1.252 +    return 1;
   1.253 +  }
   1.254 +
   1.255 +  if (NSS_SetDomesticPolicy() != SECSuccess) {
   1.256 +    PrintPRError("NSS_SetDomesticPolicy failed");
   1.257 +    return 1;
   1.258 +  }
   1.259 +
   1.260 +  if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
   1.261 +    PrintPRError("SSL_ConfigServerSessionIDCache failed");
   1.262 +    return 1;
   1.263 +  }
   1.264 +
   1.265 +  ScopedPRFileDesc serverSocket(PR_NewTCPSocket());
   1.266 +  if (!serverSocket) {
   1.267 +    PrintPRError("PR_NewTCPSocket failed");
   1.268 +    return 1;
   1.269 +  }
   1.270 +
   1.271 +  PRSocketOptionData socketOption;
   1.272 +  socketOption.option = PR_SockOpt_Reuseaddr;
   1.273 +  socketOption.value.reuse_addr = true;
   1.274 +  PR_SetSocketOption(serverSocket, &socketOption);
   1.275 +
   1.276 +  PRNetAddr serverAddr;
   1.277 +  PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr);
   1.278 +  if (PR_Bind(serverSocket, &serverAddr) != PR_SUCCESS) {
   1.279 +    PrintPRError("PR_Bind failed");
   1.280 +    return 1;
   1.281 +  }
   1.282 +
   1.283 +  if (PR_Listen(serverSocket, 1) != PR_SUCCESS) {
   1.284 +    PrintPRError("PR_Listen failed");
   1.285 +    return 1;
   1.286 +  }
   1.287 +
   1.288 +  ScopedPRFileDesc rawModelSocket(PR_NewTCPSocket());
   1.289 +  if (!rawModelSocket) {
   1.290 +    PrintPRError("PR_NewTCPSocket failed for rawModelSocket");
   1.291 +    return 1;
   1.292 +  }
   1.293 +
   1.294 +  ScopedPRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.forget()));
   1.295 +  if (!modelSocket) {
   1.296 +    PrintPRError("SSL_ImportFD of rawModelSocket failed");
   1.297 +    return 1;
   1.298 +  }
   1.299 +
   1.300 +  if (SECSuccess != SSL_SNISocketConfigHook(modelSocket, sniSocketConfig,
   1.301 +                                            sniSocketConfigArg)) {
   1.302 +    PrintPRError("SSL_SNISocketConfigHook failed");
   1.303 +    return 1;
   1.304 +  }
   1.305 +
   1.306 +  // We have to configure the server with a certificate, but it's not one
   1.307 +  // we're actually going to end up using. In the SNI callback, we pick
   1.308 +  // the right certificate for the connection.
   1.309 +  if (SECSuccess != ConfigSecureServerWithNamedCert(modelSocket,
   1.310 +                                                    DEFAULT_CERT_NICKNAME,
   1.311 +                                                    nullptr, nullptr)) {
   1.312 +    return 1;
   1.313 +  }
   1.314 +
   1.315 +  if (gCallbackPort != 0) {
   1.316 +    if (DoCallback()) {
   1.317 +      return 1;
   1.318 +    }
   1.319 +  }
   1.320 +
   1.321 +  while (true) {
   1.322 +    PRNetAddr clientAddr;
   1.323 +    PRFileDesc *clientSocket = PR_Accept(serverSocket, &clientAddr,
   1.324 +                                         PR_INTERVAL_NO_TIMEOUT);
   1.325 +    HandleConnection(clientSocket, modelSocket);
   1.326 +  }
   1.327 +
   1.328 +  return 0;
   1.329 +}
   1.330 +
   1.331 +} } // namespace mozilla::test

mercurial