Thu, 22 Jan 2015 13:21:57 +0100
Incorporate requested changes from Mozilla in review:
https://bugzilla.mozilla.org/show_bug.cgi?id=1123480#c6
1 /* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
5 #include "TLSServer.h"
7 #include <stdio.h>
8 #include "ScopedNSSTypes.h"
9 #include "nspr.h"
10 #include "nss.h"
11 #include "plarenas.h"
12 #include "prenv.h"
13 #include "prerror.h"
14 #include "prnetdb.h"
15 #include "prtime.h"
16 #include "ssl.h"
18 namespace mozilla { namespace test {
20 static const uint16_t LISTEN_PORT = 8443;
22 DebugLevel gDebugLevel = DEBUG_ERRORS;
23 uint16_t gCallbackPort = 0;
25 const char DEFAULT_CERT_NICKNAME[] = "localhostAndExampleCom";
27 struct Connection
28 {
29 PRFileDesc *mSocket;
30 char mByte;
32 Connection(PRFileDesc *aSocket);
33 ~Connection();
34 };
36 Connection::Connection(PRFileDesc *aSocket)
37 : mSocket(aSocket)
38 , mByte(0)
39 {}
41 Connection::~Connection()
42 {
43 if (mSocket) {
44 PR_Close(mSocket);
45 }
46 }
48 void
49 PrintPRError(const char *aPrefix)
50 {
51 const char *err = PR_ErrorToName(PR_GetError());
52 if (err) {
53 if (gDebugLevel >= DEBUG_ERRORS) {
54 fprintf(stderr, "%s: %s\n", aPrefix, err);
55 }
56 } else {
57 if (gDebugLevel >= DEBUG_ERRORS) {
58 fprintf(stderr, "%s\n", aPrefix);
59 }
60 }
61 }
63 nsresult
64 SendAll(PRFileDesc *aSocket, const char *aData, size_t aDataLen)
65 {
66 if (gDebugLevel >= DEBUG_VERBOSE) {
67 fprintf(stderr, "sending '%s'\n", aData);
68 }
70 while (aDataLen > 0) {
71 int32_t bytesSent = PR_Send(aSocket, aData, aDataLen, 0,
72 PR_INTERVAL_NO_TIMEOUT);
73 if (bytesSent == -1) {
74 PrintPRError("PR_Send failed");
75 return NS_ERROR_FAILURE;
76 }
78 aDataLen -= bytesSent;
79 aData += bytesSent;
80 }
82 return NS_OK;
83 }
85 nsresult
86 ReplyToRequest(Connection *aConn)
87 {
88 // For debugging purposes, SendAll can print out what it's sending.
89 // So, any strings we give to it to send need to be null-terminated.
90 char buf[2] = { aConn->mByte, 0 };
91 return SendAll(aConn->mSocket, buf, 1);
92 }
94 nsresult
95 SetupTLS(Connection *aConn, PRFileDesc *aModelSocket)
96 {
97 PRFileDesc *sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket);
98 if (!sslSocket) {
99 PrintPRError("SSL_ImportFD failed");
100 return NS_ERROR_FAILURE;
101 }
102 aConn->mSocket = sslSocket;
104 SSL_OptionSet(sslSocket, SSL_SECURITY, true);
105 SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false);
106 SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true);
108 SSL_ResetHandshake(sslSocket, /* asServer */ 1);
110 return NS_OK;
111 }
113 nsresult
114 ReadRequest(Connection *aConn)
115 {
116 int32_t bytesRead = PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0,
117 PR_INTERVAL_NO_TIMEOUT);
118 if (bytesRead < 0) {
119 PrintPRError("PR_Recv failed");
120 return NS_ERROR_FAILURE;
121 } else if (bytesRead == 0) {
122 PR_SetError(PR_IO_ERROR, 0);
123 PrintPRError("PR_Recv EOF in ReadRequest");
124 return NS_ERROR_FAILURE;
125 } else {
126 if (gDebugLevel >= DEBUG_VERBOSE) {
127 fprintf(stderr, "read '0x%hhx'\n", aConn->mByte);
128 }
129 }
130 return NS_OK;
131 }
133 void
134 HandleConnection(PRFileDesc *aSocket, PRFileDesc *aModelSocket)
135 {
136 Connection conn(aSocket);
137 nsresult rv = SetupTLS(&conn, aModelSocket);
138 if (NS_FAILED(rv)) {
139 PR_SetError(PR_INVALID_STATE_ERROR, 0);
140 PrintPRError("PR_Recv failed");
141 exit(1);
142 }
144 // TODO: On tests that are expected to fail (e.g. due to a revoked
145 // certificate), the client will close the connection wtihout sending us the
146 // request byte. In those cases, we should keep going. But, in the cases
147 // where the connection is supposed to suceed, we should verify that we
148 // successfully receive the request and send the response.
149 rv = ReadRequest(&conn);
150 if (NS_SUCCEEDED(rv)) {
151 rv = ReplyToRequest(&conn);
152 }
153 }
155 // returns 0 on success, non-zero on error
156 int
157 DoCallback()
158 {
159 ScopedPRFileDesc socket(PR_NewTCPSocket());
160 if (!socket) {
161 PrintPRError("PR_NewTCPSocket failed");
162 return 1;
163 }
165 PRNetAddr addr;
166 PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr);
167 if (PR_Connect(socket, &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) {
168 PrintPRError("PR_Connect failed");
169 return 1;
170 }
172 const char *request = "GET / HTTP/1.0\r\n\r\n";
173 SendAll(socket, request, strlen(request));
174 char buf[4096];
175 memset(buf, 0, sizeof(buf));
176 int32_t bytesRead = PR_Recv(socket, buf, sizeof(buf) - 1, 0,
177 PR_INTERVAL_NO_TIMEOUT);
178 if (bytesRead < 0) {
179 PrintPRError("PR_Recv failed 1");
180 return 1;
181 }
182 if (bytesRead == 0) {
183 fprintf(stderr, "PR_Recv eof 1\n");
184 return 1;
185 }
186 fprintf(stderr, "%s\n", buf);
187 return 0;
188 }
190 SECStatus
191 ConfigSecureServerWithNamedCert(PRFileDesc *fd, const char *certName,
192 /*optional*/ ScopedCERTCertificate *certOut,
193 /*optional*/ SSLKEAType *keaOut)
194 {
195 ScopedCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr));
196 if (!cert) {
197 PrintPRError("PK11_FindCertFromNickname failed");
198 return SECFailure;
199 }
201 ScopedSECKEYPrivateKey key(PK11_FindKeyByAnyCert(cert, nullptr));
202 if (!key) {
203 PrintPRError("PK11_FindKeyByAnyCert failed");
204 return SECFailure;
205 }
207 SSLKEAType certKEA = NSS_FindCertKEAType(cert);
209 if (SSL_ConfigSecureServer(fd, cert, key, certKEA) != SECSuccess) {
210 PrintPRError("SSL_ConfigSecureServer failed");
211 return SECFailure;
212 }
214 if (certOut) {
215 *certOut = cert.forget();
216 }
218 if (keaOut) {
219 *keaOut = certKEA;
220 }
222 return SECSuccess;
223 }
225 int
226 StartServer(const char *nssCertDBDir, SSLSNISocketConfig sniSocketConfig,
227 void *sniSocketConfigArg)
228 {
229 const char *debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL");
230 if (debugLevel) {
231 int level = atoi(debugLevel);
232 switch (level) {
233 case DEBUG_ERRORS: gDebugLevel = DEBUG_ERRORS; break;
234 case DEBUG_WARNINGS: gDebugLevel = DEBUG_WARNINGS; break;
235 case DEBUG_VERBOSE: gDebugLevel = DEBUG_VERBOSE; break;
236 default:
237 PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL");
238 return 1;
239 }
240 }
242 const char *callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT");
243 if (callbackPort) {
244 gCallbackPort = atoi(callbackPort);
245 }
247 if (NSS_Init(nssCertDBDir) != SECSuccess) {
248 PrintPRError("NSS_Init failed");
249 return 1;
250 }
252 if (NSS_SetDomesticPolicy() != SECSuccess) {
253 PrintPRError("NSS_SetDomesticPolicy failed");
254 return 1;
255 }
257 if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
258 PrintPRError("SSL_ConfigServerSessionIDCache failed");
259 return 1;
260 }
262 ScopedPRFileDesc serverSocket(PR_NewTCPSocket());
263 if (!serverSocket) {
264 PrintPRError("PR_NewTCPSocket failed");
265 return 1;
266 }
268 PRSocketOptionData socketOption;
269 socketOption.option = PR_SockOpt_Reuseaddr;
270 socketOption.value.reuse_addr = true;
271 PR_SetSocketOption(serverSocket, &socketOption);
273 PRNetAddr serverAddr;
274 PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr);
275 if (PR_Bind(serverSocket, &serverAddr) != PR_SUCCESS) {
276 PrintPRError("PR_Bind failed");
277 return 1;
278 }
280 if (PR_Listen(serverSocket, 1) != PR_SUCCESS) {
281 PrintPRError("PR_Listen failed");
282 return 1;
283 }
285 ScopedPRFileDesc rawModelSocket(PR_NewTCPSocket());
286 if (!rawModelSocket) {
287 PrintPRError("PR_NewTCPSocket failed for rawModelSocket");
288 return 1;
289 }
291 ScopedPRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.forget()));
292 if (!modelSocket) {
293 PrintPRError("SSL_ImportFD of rawModelSocket failed");
294 return 1;
295 }
297 if (SECSuccess != SSL_SNISocketConfigHook(modelSocket, sniSocketConfig,
298 sniSocketConfigArg)) {
299 PrintPRError("SSL_SNISocketConfigHook failed");
300 return 1;
301 }
303 // We have to configure the server with a certificate, but it's not one
304 // we're actually going to end up using. In the SNI callback, we pick
305 // the right certificate for the connection.
306 if (SECSuccess != ConfigSecureServerWithNamedCert(modelSocket,
307 DEFAULT_CERT_NICKNAME,
308 nullptr, nullptr)) {
309 return 1;
310 }
312 if (gCallbackPort != 0) {
313 if (DoCallback()) {
314 return 1;
315 }
316 }
318 while (true) {
319 PRNetAddr clientAddr;
320 PRFileDesc *clientSocket = PR_Accept(serverSocket, &clientAddr,
321 PR_INTERVAL_NO_TIMEOUT);
322 HandleConnection(clientSocket, modelSocket);
323 }
325 return 0;
326 }
328 } } // namespace mozilla::test