|
1 /* This Source Code Form is subject to the terms of the Mozilla Public |
|
2 * License, v. 2.0. If a copy of the MPL was not distributed with this |
|
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ |
|
4 |
|
5 #include "TLSServer.h" |
|
6 |
|
7 #include <stdio.h> |
|
8 #include "ScopedNSSTypes.h" |
|
9 #include "nspr.h" |
|
10 #include "nss.h" |
|
11 #include "plarenas.h" |
|
12 #include "prenv.h" |
|
13 #include "prerror.h" |
|
14 #include "prnetdb.h" |
|
15 #include "prtime.h" |
|
16 #include "ssl.h" |
|
17 |
|
18 namespace mozilla { namespace test { |
|
19 |
|
20 static const uint16_t LISTEN_PORT = 8443; |
|
21 |
|
22 DebugLevel gDebugLevel = DEBUG_ERRORS; |
|
23 uint16_t gCallbackPort = 0; |
|
24 |
|
25 const char DEFAULT_CERT_NICKNAME[] = "localhostAndExampleCom"; |
|
26 |
|
27 struct Connection |
|
28 { |
|
29 PRFileDesc *mSocket; |
|
30 char mByte; |
|
31 |
|
32 Connection(PRFileDesc *aSocket); |
|
33 ~Connection(); |
|
34 }; |
|
35 |
|
36 Connection::Connection(PRFileDesc *aSocket) |
|
37 : mSocket(aSocket) |
|
38 , mByte(0) |
|
39 {} |
|
40 |
|
41 Connection::~Connection() |
|
42 { |
|
43 if (mSocket) { |
|
44 PR_Close(mSocket); |
|
45 } |
|
46 } |
|
47 |
|
48 void |
|
49 PrintPRError(const char *aPrefix) |
|
50 { |
|
51 const char *err = PR_ErrorToName(PR_GetError()); |
|
52 if (err) { |
|
53 if (gDebugLevel >= DEBUG_ERRORS) { |
|
54 fprintf(stderr, "%s: %s\n", aPrefix, err); |
|
55 } |
|
56 } else { |
|
57 if (gDebugLevel >= DEBUG_ERRORS) { |
|
58 fprintf(stderr, "%s\n", aPrefix); |
|
59 } |
|
60 } |
|
61 } |
|
62 |
|
63 nsresult |
|
64 SendAll(PRFileDesc *aSocket, const char *aData, size_t aDataLen) |
|
65 { |
|
66 if (gDebugLevel >= DEBUG_VERBOSE) { |
|
67 fprintf(stderr, "sending '%s'\n", aData); |
|
68 } |
|
69 |
|
70 while (aDataLen > 0) { |
|
71 int32_t bytesSent = PR_Send(aSocket, aData, aDataLen, 0, |
|
72 PR_INTERVAL_NO_TIMEOUT); |
|
73 if (bytesSent == -1) { |
|
74 PrintPRError("PR_Send failed"); |
|
75 return NS_ERROR_FAILURE; |
|
76 } |
|
77 |
|
78 aDataLen -= bytesSent; |
|
79 aData += bytesSent; |
|
80 } |
|
81 |
|
82 return NS_OK; |
|
83 } |
|
84 |
|
85 nsresult |
|
86 ReplyToRequest(Connection *aConn) |
|
87 { |
|
88 // For debugging purposes, SendAll can print out what it's sending. |
|
89 // So, any strings we give to it to send need to be null-terminated. |
|
90 char buf[2] = { aConn->mByte, 0 }; |
|
91 return SendAll(aConn->mSocket, buf, 1); |
|
92 } |
|
93 |
|
94 nsresult |
|
95 SetupTLS(Connection *aConn, PRFileDesc *aModelSocket) |
|
96 { |
|
97 PRFileDesc *sslSocket = SSL_ImportFD(aModelSocket, aConn->mSocket); |
|
98 if (!sslSocket) { |
|
99 PrintPRError("SSL_ImportFD failed"); |
|
100 return NS_ERROR_FAILURE; |
|
101 } |
|
102 aConn->mSocket = sslSocket; |
|
103 |
|
104 SSL_OptionSet(sslSocket, SSL_SECURITY, true); |
|
105 SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, false); |
|
106 SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_SERVER, true); |
|
107 |
|
108 SSL_ResetHandshake(sslSocket, /* asServer */ 1); |
|
109 |
|
110 return NS_OK; |
|
111 } |
|
112 |
|
113 nsresult |
|
114 ReadRequest(Connection *aConn) |
|
115 { |
|
116 int32_t bytesRead = PR_Recv(aConn->mSocket, &aConn->mByte, 1, 0, |
|
117 PR_INTERVAL_NO_TIMEOUT); |
|
118 if (bytesRead < 0) { |
|
119 PrintPRError("PR_Recv failed"); |
|
120 return NS_ERROR_FAILURE; |
|
121 } else if (bytesRead == 0) { |
|
122 PR_SetError(PR_IO_ERROR, 0); |
|
123 PrintPRError("PR_Recv EOF in ReadRequest"); |
|
124 return NS_ERROR_FAILURE; |
|
125 } else { |
|
126 if (gDebugLevel >= DEBUG_VERBOSE) { |
|
127 fprintf(stderr, "read '0x%hhx'\n", aConn->mByte); |
|
128 } |
|
129 } |
|
130 return NS_OK; |
|
131 } |
|
132 |
|
133 void |
|
134 HandleConnection(PRFileDesc *aSocket, PRFileDesc *aModelSocket) |
|
135 { |
|
136 Connection conn(aSocket); |
|
137 nsresult rv = SetupTLS(&conn, aModelSocket); |
|
138 if (NS_FAILED(rv)) { |
|
139 PR_SetError(PR_INVALID_STATE_ERROR, 0); |
|
140 PrintPRError("PR_Recv failed"); |
|
141 exit(1); |
|
142 } |
|
143 |
|
144 // TODO: On tests that are expected to fail (e.g. due to a revoked |
|
145 // certificate), the client will close the connection wtihout sending us the |
|
146 // request byte. In those cases, we should keep going. But, in the cases |
|
147 // where the connection is supposed to suceed, we should verify that we |
|
148 // successfully receive the request and send the response. |
|
149 rv = ReadRequest(&conn); |
|
150 if (NS_SUCCEEDED(rv)) { |
|
151 rv = ReplyToRequest(&conn); |
|
152 } |
|
153 } |
|
154 |
|
155 // returns 0 on success, non-zero on error |
|
156 int |
|
157 DoCallback() |
|
158 { |
|
159 ScopedPRFileDesc socket(PR_NewTCPSocket()); |
|
160 if (!socket) { |
|
161 PrintPRError("PR_NewTCPSocket failed"); |
|
162 return 1; |
|
163 } |
|
164 |
|
165 PRNetAddr addr; |
|
166 PR_InitializeNetAddr(PR_IpAddrLoopback, gCallbackPort, &addr); |
|
167 if (PR_Connect(socket, &addr, PR_INTERVAL_NO_TIMEOUT) != PR_SUCCESS) { |
|
168 PrintPRError("PR_Connect failed"); |
|
169 return 1; |
|
170 } |
|
171 |
|
172 const char *request = "GET / HTTP/1.0\r\n\r\n"; |
|
173 SendAll(socket, request, strlen(request)); |
|
174 char buf[4096]; |
|
175 memset(buf, 0, sizeof(buf)); |
|
176 int32_t bytesRead = PR_Recv(socket, buf, sizeof(buf) - 1, 0, |
|
177 PR_INTERVAL_NO_TIMEOUT); |
|
178 if (bytesRead < 0) { |
|
179 PrintPRError("PR_Recv failed 1"); |
|
180 return 1; |
|
181 } |
|
182 if (bytesRead == 0) { |
|
183 fprintf(stderr, "PR_Recv eof 1\n"); |
|
184 return 1; |
|
185 } |
|
186 fprintf(stderr, "%s\n", buf); |
|
187 return 0; |
|
188 } |
|
189 |
|
190 SECStatus |
|
191 ConfigSecureServerWithNamedCert(PRFileDesc *fd, const char *certName, |
|
192 /*optional*/ ScopedCERTCertificate *certOut, |
|
193 /*optional*/ SSLKEAType *keaOut) |
|
194 { |
|
195 ScopedCERTCertificate cert(PK11_FindCertFromNickname(certName, nullptr)); |
|
196 if (!cert) { |
|
197 PrintPRError("PK11_FindCertFromNickname failed"); |
|
198 return SECFailure; |
|
199 } |
|
200 |
|
201 ScopedSECKEYPrivateKey key(PK11_FindKeyByAnyCert(cert, nullptr)); |
|
202 if (!key) { |
|
203 PrintPRError("PK11_FindKeyByAnyCert failed"); |
|
204 return SECFailure; |
|
205 } |
|
206 |
|
207 SSLKEAType certKEA = NSS_FindCertKEAType(cert); |
|
208 |
|
209 if (SSL_ConfigSecureServer(fd, cert, key, certKEA) != SECSuccess) { |
|
210 PrintPRError("SSL_ConfigSecureServer failed"); |
|
211 return SECFailure; |
|
212 } |
|
213 |
|
214 if (certOut) { |
|
215 *certOut = cert.forget(); |
|
216 } |
|
217 |
|
218 if (keaOut) { |
|
219 *keaOut = certKEA; |
|
220 } |
|
221 |
|
222 return SECSuccess; |
|
223 } |
|
224 |
|
225 int |
|
226 StartServer(const char *nssCertDBDir, SSLSNISocketConfig sniSocketConfig, |
|
227 void *sniSocketConfigArg) |
|
228 { |
|
229 const char *debugLevel = PR_GetEnv("MOZ_TLS_SERVER_DEBUG_LEVEL"); |
|
230 if (debugLevel) { |
|
231 int level = atoi(debugLevel); |
|
232 switch (level) { |
|
233 case DEBUG_ERRORS: gDebugLevel = DEBUG_ERRORS; break; |
|
234 case DEBUG_WARNINGS: gDebugLevel = DEBUG_WARNINGS; break; |
|
235 case DEBUG_VERBOSE: gDebugLevel = DEBUG_VERBOSE; break; |
|
236 default: |
|
237 PrintPRError("invalid MOZ_TLS_SERVER_DEBUG_LEVEL"); |
|
238 return 1; |
|
239 } |
|
240 } |
|
241 |
|
242 const char *callbackPort = PR_GetEnv("MOZ_TLS_SERVER_CALLBACK_PORT"); |
|
243 if (callbackPort) { |
|
244 gCallbackPort = atoi(callbackPort); |
|
245 } |
|
246 |
|
247 if (NSS_Init(nssCertDBDir) != SECSuccess) { |
|
248 PrintPRError("NSS_Init failed"); |
|
249 return 1; |
|
250 } |
|
251 |
|
252 if (NSS_SetDomesticPolicy() != SECSuccess) { |
|
253 PrintPRError("NSS_SetDomesticPolicy failed"); |
|
254 return 1; |
|
255 } |
|
256 |
|
257 if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) { |
|
258 PrintPRError("SSL_ConfigServerSessionIDCache failed"); |
|
259 return 1; |
|
260 } |
|
261 |
|
262 ScopedPRFileDesc serverSocket(PR_NewTCPSocket()); |
|
263 if (!serverSocket) { |
|
264 PrintPRError("PR_NewTCPSocket failed"); |
|
265 return 1; |
|
266 } |
|
267 |
|
268 PRSocketOptionData socketOption; |
|
269 socketOption.option = PR_SockOpt_Reuseaddr; |
|
270 socketOption.value.reuse_addr = true; |
|
271 PR_SetSocketOption(serverSocket, &socketOption); |
|
272 |
|
273 PRNetAddr serverAddr; |
|
274 PR_InitializeNetAddr(PR_IpAddrLoopback, LISTEN_PORT, &serverAddr); |
|
275 if (PR_Bind(serverSocket, &serverAddr) != PR_SUCCESS) { |
|
276 PrintPRError("PR_Bind failed"); |
|
277 return 1; |
|
278 } |
|
279 |
|
280 if (PR_Listen(serverSocket, 1) != PR_SUCCESS) { |
|
281 PrintPRError("PR_Listen failed"); |
|
282 return 1; |
|
283 } |
|
284 |
|
285 ScopedPRFileDesc rawModelSocket(PR_NewTCPSocket()); |
|
286 if (!rawModelSocket) { |
|
287 PrintPRError("PR_NewTCPSocket failed for rawModelSocket"); |
|
288 return 1; |
|
289 } |
|
290 |
|
291 ScopedPRFileDesc modelSocket(SSL_ImportFD(nullptr, rawModelSocket.forget())); |
|
292 if (!modelSocket) { |
|
293 PrintPRError("SSL_ImportFD of rawModelSocket failed"); |
|
294 return 1; |
|
295 } |
|
296 |
|
297 if (SECSuccess != SSL_SNISocketConfigHook(modelSocket, sniSocketConfig, |
|
298 sniSocketConfigArg)) { |
|
299 PrintPRError("SSL_SNISocketConfigHook failed"); |
|
300 return 1; |
|
301 } |
|
302 |
|
303 // We have to configure the server with a certificate, but it's not one |
|
304 // we're actually going to end up using. In the SNI callback, we pick |
|
305 // the right certificate for the connection. |
|
306 if (SECSuccess != ConfigSecureServerWithNamedCert(modelSocket, |
|
307 DEFAULT_CERT_NICKNAME, |
|
308 nullptr, nullptr)) { |
|
309 return 1; |
|
310 } |
|
311 |
|
312 if (gCallbackPort != 0) { |
|
313 if (DoCallback()) { |
|
314 return 1; |
|
315 } |
|
316 } |
|
317 |
|
318 while (true) { |
|
319 PRNetAddr clientAddr; |
|
320 PRFileDesc *clientSocket = PR_Accept(serverSocket, &clientAddr, |
|
321 PR_INTERVAL_NO_TIMEOUT); |
|
322 HandleConnection(clientSocket, modelSocket); |
|
323 } |
|
324 |
|
325 return 0; |
|
326 } |
|
327 |
|
328 } } // namespace mozilla::test |