media/mtransport/transportlayerdtls.cpp

changeset 0
6474c204b198
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/media/mtransport/transportlayerdtls.cpp	Wed Dec 31 06:09:35 2014 +0100
     1.3 @@ -0,0 +1,954 @@
     1.4 +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
     1.5 +/* vim: set ts=2 et sw=2 tw=80: */
     1.6 +/* This Source Code Form is subject to the terms of the Mozilla Public
     1.7 + * License, v. 2.0. If a copy of the MPL was not distributed with this file,
     1.8 + * You can obtain one at http://mozilla.org/MPL/2.0/. */
     1.9 +
    1.10 +// Original author: ekr@rtfm.com
    1.11 +
    1.12 +#include <queue>
    1.13 +#include <algorithm>
    1.14 +
    1.15 +#include "logging.h"
    1.16 +#include "ssl.h"
    1.17 +#include "sslerr.h"
    1.18 +#include "sslproto.h"
    1.19 +#include "keyhi.h"
    1.20 +
    1.21 +#include "nsCOMPtr.h"
    1.22 +#include "nsComponentManagerUtils.h"
    1.23 +#include "nsIEventTarget.h"
    1.24 +#include "nsNetCID.h"
    1.25 +#include "nsComponentManagerUtils.h"
    1.26 +#include "nsServiceManagerUtils.h"
    1.27 +
    1.28 +#include "dtlsidentity.h"
    1.29 +#include "transportflow.h"
    1.30 +#include "transportlayerdtls.h"
    1.31 +
    1.32 +namespace mozilla {
    1.33 +
    1.34 +MOZ_MTLOG_MODULE("mtransport")
    1.35 +
    1.36 +static PRDescIdentity transport_layer_identity = PR_INVALID_IO_LAYER;
    1.37 +
    1.38 +// TODO: Implement a mode for this where
    1.39 +// the channel is not ready until confirmed externally
    1.40 +// (e.g., after cert check).
    1.41 +
    1.42 +#define UNIMPLEMENTED                                           \
    1.43 +  MOZ_MTLOG(ML_ERROR,                                           \
    1.44 +       "Call to unimplemented function "<< __FUNCTION__);       \
    1.45 +  MOZ_ASSERT(false);                                            \
    1.46 +  PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0)
    1.47 +
    1.48 +
    1.49 +// We need to adapt the NSPR/libssl model to the TransportFlow model.
    1.50 +// The former wants pull semantics and TransportFlow wants push.
    1.51 +//
    1.52 +// - A TransportLayerDtls assumes it is sitting on top of another
    1.53 +//   TransportLayer, which means that events come in asynchronously.
    1.54 +// - NSS (libssl) wants to sit on top of a PRFileDesc and poll.
    1.55 +// - The TransportLayerNSPRAdapter is a PRFileDesc containing a
    1.56 +//   FIFO.
    1.57 +// - When TransportLayerDtls.PacketReceived() is called, we insert
    1.58 +//   the packets in the FIFO and then do a PR_Recv() on the NSS
    1.59 +//   PRFileDesc, which eventually reads off the FIFO.
    1.60 +//
    1.61 +// All of this stuff is assumed to happen solely in a single thread
    1.62 +// (generally the SocketTransportService thread)
    1.63 +struct Packet {
    1.64 +  Packet() : data_(nullptr), len_(0), offset_(0) {}
    1.65 +
    1.66 +  void Assign(const void *data, int32_t len) {
    1.67 +    data_ = new uint8_t[len];
    1.68 +    memcpy(data_, data, len);
    1.69 +    len_ = len;
    1.70 +  }
    1.71 +
    1.72 +  ScopedDeleteArray<uint8_t> data_;
    1.73 +  int32_t len_;
    1.74 +  int32_t offset_;
    1.75 +};
    1.76 +
    1.77 +void TransportLayerNSPRAdapter::PacketReceived(const void *data, int32_t len) {
    1.78 +  input_.push(new Packet());
    1.79 +  input_.back()->Assign(data, len);
    1.80 +}
    1.81 +
    1.82 +int32_t TransportLayerNSPRAdapter::Read(void *data, int32_t len) {
    1.83 +  if (input_.empty()) {
    1.84 +    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    1.85 +    return TE_WOULDBLOCK;
    1.86 +  }
    1.87 +
    1.88 +  Packet* front = input_.front();
    1.89 +  int32_t to_read = std::min(len, front->len_ - front->offset_);
    1.90 +  memcpy(data, front->data_, to_read);
    1.91 +  front->offset_ += to_read;
    1.92 +
    1.93 +  if (front->offset_ == front->len_) {
    1.94 +    input_.pop();
    1.95 +    delete front;
    1.96 +  }
    1.97 +
    1.98 +  return to_read;
    1.99 +}
   1.100 +
   1.101 +int32_t TransportLayerNSPRAdapter::Write(const void *buf, int32_t length) {
   1.102 +  TransportResult r = output_->SendPacket(
   1.103 +      static_cast<const unsigned char *>(buf), length);
   1.104 +  if (r >= 0) {
   1.105 +    return r;
   1.106 +  }
   1.107 +
   1.108 +  if (r == TE_WOULDBLOCK) {
   1.109 +    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
   1.110 +  } else {
   1.111 +    PR_SetError(PR_IO_ERROR, 0);
   1.112 +  }
   1.113 +
   1.114 +  return -1;
   1.115 +}
   1.116 +
   1.117 +
   1.118 +// Implementation of NSPR methods
   1.119 +static PRStatus TransportLayerClose(PRFileDesc *f) {
   1.120 +  f->secret = nullptr;
   1.121 +  return PR_SUCCESS;
   1.122 +}
   1.123 +
   1.124 +static int32_t TransportLayerRead(PRFileDesc *f, void *buf, int32_t length) {
   1.125 +  TransportLayerNSPRAdapter *io = reinterpret_cast<TransportLayerNSPRAdapter *>(f->secret);
   1.126 +  return io->Read(buf, length);
   1.127 +}
   1.128 +
   1.129 +static int32_t TransportLayerWrite(PRFileDesc *f, const void *buf, int32_t length) {
   1.130 +  TransportLayerNSPRAdapter *io = reinterpret_cast<TransportLayerNSPRAdapter *>(f->secret);
   1.131 +  return io->Write(buf, length);
   1.132 +}
   1.133 +
   1.134 +static int32_t TransportLayerAvailable(PRFileDesc *f) {
   1.135 +  UNIMPLEMENTED;
   1.136 +  return -1;
   1.137 +}
   1.138 +
   1.139 +int64_t TransportLayerAvailable64(PRFileDesc *f) {
   1.140 +  UNIMPLEMENTED;
   1.141 +  return -1;
   1.142 +}
   1.143 +
   1.144 +static PRStatus TransportLayerSync(PRFileDesc *f) {
   1.145 +  UNIMPLEMENTED;
   1.146 +  return PR_FAILURE;
   1.147 +}
   1.148 +
   1.149 +static int32_t TransportLayerSeek(PRFileDesc *f, int32_t offset,
   1.150 +                                  PRSeekWhence how) {
   1.151 +  UNIMPLEMENTED;
   1.152 +  return -1;
   1.153 +}
   1.154 +
   1.155 +static int64_t TransportLayerSeek64(PRFileDesc *f, int64_t offset,
   1.156 +                                    PRSeekWhence how) {
   1.157 +  UNIMPLEMENTED;
   1.158 +  return -1;
   1.159 +}
   1.160 +
   1.161 +static PRStatus TransportLayerFileInfo(PRFileDesc *f, PRFileInfo *info) {
   1.162 +  UNIMPLEMENTED;
   1.163 +  return PR_FAILURE;
   1.164 +}
   1.165 +
   1.166 +static PRStatus TransportLayerFileInfo64(PRFileDesc *f, PRFileInfo64 *info) {
   1.167 +  UNIMPLEMENTED;
   1.168 +  return PR_FAILURE;
   1.169 +}
   1.170 +
   1.171 +static int32_t TransportLayerWritev(PRFileDesc *f, const PRIOVec *iov,
   1.172 +                                    int32_t iov_size, PRIntervalTime to) {
   1.173 +  UNIMPLEMENTED;
   1.174 +  return -1;
   1.175 +}
   1.176 +
   1.177 +static PRStatus TransportLayerConnect(PRFileDesc *f, const PRNetAddr *addr,
   1.178 +                                      PRIntervalTime to) {
   1.179 +  UNIMPLEMENTED;
   1.180 +  return PR_FAILURE;
   1.181 +}
   1.182 +
   1.183 +static PRFileDesc *TransportLayerAccept(PRFileDesc *sd, PRNetAddr *addr,
   1.184 +                                        PRIntervalTime to) {
   1.185 +  UNIMPLEMENTED;
   1.186 +  return nullptr;
   1.187 +}
   1.188 +
   1.189 +static PRStatus TransportLayerBind(PRFileDesc *f, const PRNetAddr *addr) {
   1.190 +  UNIMPLEMENTED;
   1.191 +  return PR_FAILURE;
   1.192 +}
   1.193 +
   1.194 +static PRStatus TransportLayerListen(PRFileDesc *f, int32_t depth) {
   1.195 +  UNIMPLEMENTED;
   1.196 +  return PR_FAILURE;
   1.197 +}
   1.198 +
   1.199 +static PRStatus TransportLayerShutdown(PRFileDesc *f, int32_t how) {
   1.200 +  UNIMPLEMENTED;
   1.201 +  return PR_FAILURE;
   1.202 +}
   1.203 +
   1.204 +// This function does not support peek.
   1.205 +static int32_t TransportLayerRecv(PRFileDesc *f, void *buf, int32_t amount,
   1.206 +                                  int32_t flags, PRIntervalTime to) {
   1.207 +  MOZ_ASSERT(flags == 0);
   1.208 +  if (flags != 0) {
   1.209 +    PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
   1.210 +    return -1;
   1.211 +  }
   1.212 +
   1.213 +  return TransportLayerRead(f, buf, amount);
   1.214 +}
   1.215 +
   1.216 +// Note: this is always nonblocking and assumes a zero timeout.
   1.217 +static int32_t TransportLayerSend(PRFileDesc *f, const void *buf, int32_t amount,
   1.218 +                                  int32_t flags, PRIntervalTime to) {
   1.219 +  int32_t written = TransportLayerWrite(f, buf, amount);
   1.220 +  return written;
   1.221 +}
   1.222 +
   1.223 +static int32_t TransportLayerRecvfrom(PRFileDesc *f, void *buf, int32_t amount,
   1.224 +                                      int32_t flags, PRNetAddr *addr, PRIntervalTime to) {
   1.225 +  UNIMPLEMENTED;
   1.226 +  return -1;
   1.227 +}
   1.228 +
   1.229 +static int32_t TransportLayerSendto(PRFileDesc *f, const void *buf, int32_t amount,
   1.230 +                                    int32_t flags, const PRNetAddr *addr, PRIntervalTime to) {
   1.231 +  UNIMPLEMENTED;
   1.232 +  return -1;
   1.233 +}
   1.234 +
   1.235 +static int16_t TransportLayerPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) {
   1.236 +  UNIMPLEMENTED;
   1.237 +  return -1;
   1.238 +}
   1.239 +
   1.240 +static int32_t TransportLayerAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
   1.241 +                                        PRNetAddr **raddr,
   1.242 +                                        void *buf, int32_t amount, PRIntervalTime t) {
   1.243 +  UNIMPLEMENTED;
   1.244 +  return -1;
   1.245 +}
   1.246 +
   1.247 +static int32_t TransportLayerTransmitFile(PRFileDesc *sd, PRFileDesc *f,
   1.248 +                                          const void *headers, int32_t hlen,
   1.249 +                                          PRTransmitFileFlags flags, PRIntervalTime t) {
   1.250 +  UNIMPLEMENTED;
   1.251 +  return -1;
   1.252 +}
   1.253 +
   1.254 +static PRStatus TransportLayerGetpeername(PRFileDesc *f, PRNetAddr *addr) {
   1.255 +  // TODO: Modify to return unique names for each channel
   1.256 +  // somehow, as opposed to always the same static address. The current
   1.257 +  // implementation messes up the session cache, which is why it's off
   1.258 +  // elsewhere
   1.259 +  addr->inet.family = PR_AF_INET;
   1.260 +  addr->inet.port = 0;
   1.261 +  addr->inet.ip = 0;
   1.262 +
   1.263 +  return PR_SUCCESS;
   1.264 +}
   1.265 +
   1.266 +static PRStatus TransportLayerGetsockname(PRFileDesc *f, PRNetAddr *addr) {
   1.267 +  UNIMPLEMENTED;
   1.268 +  return PR_FAILURE;
   1.269 +}
   1.270 +
   1.271 +static PRStatus TransportLayerGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) {
   1.272 +  switch (opt->option) {
   1.273 +    case PR_SockOpt_Nonblocking:
   1.274 +      opt->value.non_blocking = PR_TRUE;
   1.275 +      return PR_SUCCESS;
   1.276 +    default:
   1.277 +      UNIMPLEMENTED;
   1.278 +      break;
   1.279 +  }
   1.280 +
   1.281 +  return PR_FAILURE;
   1.282 +}
   1.283 +
   1.284 +// Imitate setting socket options. These are mostly noops.
   1.285 +static PRStatus TransportLayerSetsockoption(PRFileDesc *f,
   1.286 +                                            const PRSocketOptionData *opt) {
   1.287 +  switch (opt->option) {
   1.288 +    case PR_SockOpt_Nonblocking:
   1.289 +      return PR_SUCCESS;
   1.290 +    case PR_SockOpt_NoDelay:
   1.291 +      return PR_SUCCESS;
   1.292 +    default:
   1.293 +      UNIMPLEMENTED;
   1.294 +      break;
   1.295 +  }
   1.296 +
   1.297 +  return PR_FAILURE;
   1.298 +}
   1.299 +
   1.300 +static int32_t TransportLayerSendfile(PRFileDesc *out, PRSendFileData *in,
   1.301 +                                      PRTransmitFileFlags flags, PRIntervalTime to) {
   1.302 +  UNIMPLEMENTED;
   1.303 +  return -1;
   1.304 +}
   1.305 +
   1.306 +static PRStatus TransportLayerConnectContinue(PRFileDesc *f, int16_t flags) {
   1.307 +  UNIMPLEMENTED;
   1.308 +  return PR_FAILURE;
   1.309 +}
   1.310 +
   1.311 +static int32_t TransportLayerReserved(PRFileDesc *f) {
   1.312 +  UNIMPLEMENTED;
   1.313 +  return -1;
   1.314 +}
   1.315 +
   1.316 +static const struct PRIOMethods TransportLayerMethods = {
   1.317 +  PR_DESC_LAYERED,
   1.318 +  TransportLayerClose,
   1.319 +  TransportLayerRead,
   1.320 +  TransportLayerWrite,
   1.321 +  TransportLayerAvailable,
   1.322 +  TransportLayerAvailable64,
   1.323 +  TransportLayerSync,
   1.324 +  TransportLayerSeek,
   1.325 +  TransportLayerSeek64,
   1.326 +  TransportLayerFileInfo,
   1.327 +  TransportLayerFileInfo64,
   1.328 +  TransportLayerWritev,
   1.329 +  TransportLayerConnect,
   1.330 +  TransportLayerAccept,
   1.331 +  TransportLayerBind,
   1.332 +  TransportLayerListen,
   1.333 +  TransportLayerShutdown,
   1.334 +  TransportLayerRecv,
   1.335 +  TransportLayerSend,
   1.336 +  TransportLayerRecvfrom,
   1.337 +  TransportLayerSendto,
   1.338 +  TransportLayerPoll,
   1.339 +  TransportLayerAcceptRead,
   1.340 +  TransportLayerTransmitFile,
   1.341 +  TransportLayerGetsockname,
   1.342 +  TransportLayerGetpeername,
   1.343 +  TransportLayerReserved,
   1.344 +  TransportLayerReserved,
   1.345 +  TransportLayerGetsockoption,
   1.346 +  TransportLayerSetsockoption,
   1.347 +  TransportLayerSendfile,
   1.348 +  TransportLayerConnectContinue,
   1.349 +  TransportLayerReserved,
   1.350 +  TransportLayerReserved,
   1.351 +  TransportLayerReserved,
   1.352 +  TransportLayerReserved
   1.353 +};
   1.354 +
   1.355 +TransportLayerDtls::~TransportLayerDtls() {
   1.356 +  if (timer_) {
   1.357 +    timer_->Cancel();
   1.358 +  }
   1.359 +}
   1.360 +
   1.361 +nsresult TransportLayerDtls::InitInternal() {
   1.362 +  // Get the transport service as an event target
   1.363 +  nsresult rv;
   1.364 +  target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
   1.365 +
   1.366 +  if (NS_FAILED(rv)) {
   1.367 +    MOZ_MTLOG(ML_ERROR, "Couldn't get socket transport service");
   1.368 +    return rv;
   1.369 +  }
   1.370 +
   1.371 +  timer_ = do_CreateInstance(NS_TIMER_CONTRACTID, &rv);
   1.372 +  if (NS_FAILED(rv)) {
   1.373 +    MOZ_MTLOG(ML_ERROR, "Couldn't get timer");
   1.374 +    return rv;
   1.375 +  }
   1.376 +
   1.377 +  return NS_OK;
   1.378 +}
   1.379 +
   1.380 +
   1.381 +void TransportLayerDtls::WasInserted() {
   1.382 +  // Connect to the lower layers
   1.383 +  if (!Setup()) {
   1.384 +    TL_SET_STATE(TS_ERROR);
   1.385 +  }
   1.386 +}
   1.387 +
   1.388 +
   1.389 +nsresult TransportLayerDtls::SetVerificationAllowAll() {
   1.390 +  // Defensive programming
   1.391 +  if (verification_mode_ != VERIFY_UNSET)
   1.392 +    return NS_ERROR_ALREADY_INITIALIZED;
   1.393 +
   1.394 +  verification_mode_ = VERIFY_ALLOW_ALL;
   1.395 +
   1.396 +  return NS_OK;
   1.397 +}
   1.398 +
   1.399 +nsresult
   1.400 +TransportLayerDtls::SetVerificationDigest(const std::string digest_algorithm,
   1.401 +                                          const unsigned char *digest_value,
   1.402 +                                          size_t digest_len) {
   1.403 +  // Defensive programming
   1.404 +  if (verification_mode_ != VERIFY_UNSET &&
   1.405 +      verification_mode_ != VERIFY_DIGEST) {
   1.406 +    return NS_ERROR_ALREADY_INITIALIZED;
   1.407 +  }
   1.408 +
   1.409 +  // Note that we do not sanity check these values for length.
   1.410 +  // We merely ensure they will fit into the buffer.
   1.411 +  // TODO: is there a Data construct we could use?
   1.412 +  if (digest_len > kMaxDigestLength)
   1.413 +    return NS_ERROR_INVALID_ARG;
   1.414 +
   1.415 +  digests_.push_back(new VerificationDigest(
   1.416 +      digest_algorithm, digest_value, digest_len));
   1.417 +
   1.418 +  verification_mode_ = VERIFY_DIGEST;
   1.419 +
   1.420 +  return NS_OK;
   1.421 +}
   1.422 +
   1.423 +// TODO: make sure this is called from STS. Otherwise
   1.424 +// we have thread safety issues
   1.425 +bool TransportLayerDtls::Setup() {
   1.426 +  CheckThread();
   1.427 +  SECStatus rv;
   1.428 +
   1.429 +  if (!downward_) {
   1.430 +    MOZ_MTLOG(ML_ERROR, "DTLS layer with nothing below. This is useless");
   1.431 +    return false;
   1.432 +  }
   1.433 +  nspr_io_adapter_ = new TransportLayerNSPRAdapter(downward_);
   1.434 +
   1.435 +  if (!identity_) {
   1.436 +    MOZ_MTLOG(ML_ERROR, "Can't start DTLS without an identity");
   1.437 +    return false;
   1.438 +  }
   1.439 +
   1.440 +  if (verification_mode_ == VERIFY_UNSET) {
   1.441 +    MOZ_MTLOG(ML_ERROR,
   1.442 +              "Can't start DTLS without specifying a verification mode");
   1.443 +    return false;
   1.444 +  }
   1.445 +
   1.446 +  if (transport_layer_identity == PR_INVALID_IO_LAYER) {
   1.447 +    transport_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
   1.448 +  }
   1.449 +
   1.450 +  ScopedPRFileDesc pr_fd(PR_CreateIOLayerStub(transport_layer_identity,
   1.451 +                                              &TransportLayerMethods));
   1.452 +  MOZ_ASSERT(pr_fd != nullptr);
   1.453 +  if (!pr_fd)
   1.454 +    return false;
   1.455 +  pr_fd->secret = reinterpret_cast<PRFilePrivate *>(nspr_io_adapter_.get());
   1.456 +
   1.457 +  ScopedPRFileDesc ssl_fd;
   1.458 +  if (mode_ == DGRAM) {
   1.459 +    ssl_fd = DTLS_ImportFD(nullptr, pr_fd);
   1.460 +  } else {
   1.461 +    ssl_fd = SSL_ImportFD(nullptr, pr_fd);
   1.462 +  }
   1.463 +
   1.464 +  MOZ_ASSERT(ssl_fd != nullptr);  // This should never happen
   1.465 +  if (!ssl_fd) {
   1.466 +    return false;
   1.467 +  }
   1.468 +
   1.469 +  pr_fd.forget(); // ownership transfered to ssl_fd;
   1.470 +
   1.471 +  if (role_ == CLIENT) {
   1.472 +    MOZ_MTLOG(ML_DEBUG, "Setting up DTLS as client");
   1.473 +    rv = SSL_GetClientAuthDataHook(ssl_fd, GetClientAuthDataHook,
   1.474 +                                   this);
   1.475 +    if (rv != SECSuccess) {
   1.476 +      MOZ_MTLOG(ML_ERROR, "Couldn't set identity");
   1.477 +      return false;
   1.478 +    }
   1.479 +  } else {
   1.480 +    MOZ_MTLOG(ML_DEBUG, "Setting up DTLS as server");
   1.481 +    // Server side
   1.482 +    rv = SSL_ConfigSecureServer(ssl_fd, identity_->cert(),
   1.483 +                                identity_->privkey(),
   1.484 +                                kt_rsa);
   1.485 +    if (rv != SECSuccess) {
   1.486 +      MOZ_MTLOG(ML_ERROR, "Couldn't set identity");
   1.487 +      return false;
   1.488 +    }
   1.489 +
   1.490 +    // Insist on a certificate from the client
   1.491 +    rv = SSL_OptionSet(ssl_fd, SSL_REQUEST_CERTIFICATE, PR_TRUE);
   1.492 +    if (rv != SECSuccess) {
   1.493 +      MOZ_MTLOG(ML_ERROR, "Couldn't request certificate");
   1.494 +      return false;
   1.495 +    }
   1.496 +
   1.497 +    rv = SSL_OptionSet(ssl_fd, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
   1.498 +    if (rv != SECSuccess) {
   1.499 +      MOZ_MTLOG(ML_ERROR, "Couldn't require certificate");
   1.500 +      return false;
   1.501 +    }
   1.502 +  }
   1.503 +
   1.504 +  // Require TLS 1.1. Perhaps some day in the future we will allow
   1.505 +  // TLS 1.0 for stream modes.
   1.506 +  SSLVersionRange version_range = {
   1.507 +    SSL_LIBRARY_VERSION_TLS_1_1,
   1.508 +    SSL_LIBRARY_VERSION_TLS_1_1
   1.509 +  };
   1.510 +
   1.511 +  rv = SSL_VersionRangeSet(ssl_fd, &version_range);
   1.512 +  if (rv != SECSuccess) {
   1.513 +    MOZ_MTLOG(ML_ERROR, "Can't disable SSLv3");
   1.514 +    return false;
   1.515 +  }
   1.516 +
   1.517 +  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
   1.518 +  if (rv != SECSuccess) {
   1.519 +    MOZ_MTLOG(ML_ERROR, "Couldn't disable session tickets");
   1.520 +    return false;
   1.521 +  }
   1.522 +
   1.523 +  rv = SSL_OptionSet(ssl_fd, SSL_NO_CACHE, PR_TRUE);
   1.524 +  if (rv != SECSuccess) {
   1.525 +    MOZ_MTLOG(ML_ERROR, "Couldn't disable session caching");
   1.526 +    return false;
   1.527 +  }
   1.528 +
   1.529 +  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_DEFLATE, PR_FALSE);
   1.530 +  if (rv != SECSuccess) {
   1.531 +    MOZ_MTLOG(ML_ERROR, "Couldn't disable deflate");
   1.532 +    return false;
   1.533 +  }
   1.534 +
   1.535 +  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_NEVER);
   1.536 +  if (rv != SECSuccess) {
   1.537 +    MOZ_MTLOG(ML_ERROR, "Couldn't disable renegotiation");
   1.538 +    return false;
   1.539 +  }
   1.540 +
   1.541 +  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_FALSE_START, PR_FALSE);
   1.542 +  if (rv != SECSuccess) {
   1.543 +    MOZ_MTLOG(ML_ERROR, "Couldn't disable false start");
   1.544 +    return false;
   1.545 +  }
   1.546 +
   1.547 +  rv = SSL_OptionSet(ssl_fd, SSL_NO_LOCKS, PR_TRUE);
   1.548 +  if (rv != SECSuccess) {
   1.549 +    MOZ_MTLOG(ML_ERROR, "Couldn't disable locks");
   1.550 +    return false;
   1.551 +  }
   1.552 +
   1.553 +  // Set the SRTP ciphers
   1.554 +  if (srtp_ciphers_.size()) {
   1.555 +    // Note: std::vector is guaranteed to contiguous
   1.556 +    rv = SSL_SetSRTPCiphers(ssl_fd, &srtp_ciphers_[0],
   1.557 +                            srtp_ciphers_.size());
   1.558 +
   1.559 +    if (rv != SECSuccess) {
   1.560 +      MOZ_MTLOG(ML_ERROR, "Couldn't set SRTP cipher suite");
   1.561 +      return false;
   1.562 +    }
   1.563 +  }
   1.564 +
   1.565 +  // Certificate validation
   1.566 +  rv = SSL_AuthCertificateHook(ssl_fd, AuthCertificateHook,
   1.567 +                               reinterpret_cast<void *>(this));
   1.568 +  if (rv != SECSuccess) {
   1.569 +    MOZ_MTLOG(ML_ERROR, "Couldn't set certificate validation hook");
   1.570 +    return false;
   1.571 +  }
   1.572 +
   1.573 +  // Now start the handshake
   1.574 +  rv = SSL_ResetHandshake(ssl_fd, role_ == SERVER ? PR_TRUE : PR_FALSE);
   1.575 +  if (rv != SECSuccess) {
   1.576 +    MOZ_MTLOG(ML_ERROR, "Couldn't reset handshake");
   1.577 +    return false;
   1.578 +  }
   1.579 +  ssl_fd_ = ssl_fd.forget();
   1.580 +
   1.581 +  // Finally, get ready to receive data
   1.582 +  downward_->SignalStateChange.connect(this, &TransportLayerDtls::StateChange);
   1.583 +  downward_->SignalPacketReceived.connect(this, &TransportLayerDtls::PacketReceived);
   1.584 +
   1.585 +  if (downward_->state() == TS_OPEN) {
   1.586 +    Handshake();
   1.587 +  }
   1.588 +
   1.589 +  return true;
   1.590 +}
   1.591 +
   1.592 +
   1.593 +void TransportLayerDtls::StateChange(TransportLayer *layer, State state) {
   1.594 +  if (state <= state_) {
   1.595 +    MOZ_MTLOG(ML_ERROR, "Lower layer state is going backwards from ours");
   1.596 +    TL_SET_STATE(TS_ERROR);
   1.597 +    return;
   1.598 +  }
   1.599 +
   1.600 +  switch (state) {
   1.601 +    case TS_NONE:
   1.602 +      MOZ_ASSERT(false);  // Can't happen
   1.603 +      break;
   1.604 +
   1.605 +    case TS_INIT:
   1.606 +      MOZ_MTLOG(ML_ERROR,
   1.607 +                LAYER_INFO << "State change of lower layer to INIT forbidden");
   1.608 +      TL_SET_STATE(TS_ERROR);
   1.609 +      break;
   1.610 +
   1.611 +    case TS_CONNECTING:
   1.612 +      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Lower lower is connecting.");
   1.613 +      break;
   1.614 +
   1.615 +    case TS_OPEN:
   1.616 +      MOZ_MTLOG(ML_ERROR,
   1.617 +                LAYER_INFO << "Lower lower is now open; starting TLS");
   1.618 +      Handshake();
   1.619 +      break;
   1.620 +
   1.621 +    case TS_CLOSED:
   1.622 +      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Lower lower is now closed");
   1.623 +      TL_SET_STATE(TS_CLOSED);
   1.624 +      break;
   1.625 +
   1.626 +    case TS_ERROR:
   1.627 +      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Lower lower experienced an error");
   1.628 +      TL_SET_STATE(TS_ERROR);
   1.629 +      break;
   1.630 +  }
   1.631 +}
   1.632 +
   1.633 +void TransportLayerDtls::Handshake() {
   1.634 +  TL_SET_STATE(TS_CONNECTING);
   1.635 +
   1.636 +  // Clear the retransmit timer
   1.637 +  timer_->Cancel();
   1.638 +
   1.639 +  SECStatus rv = SSL_ForceHandshake(ssl_fd_);
   1.640 +
   1.641 +  if (rv == SECSuccess) {
   1.642 +    MOZ_MTLOG(ML_NOTICE,
   1.643 +              LAYER_INFO << "****** SSL handshake completed ******");
   1.644 +    if (!cert_ok_) {
   1.645 +      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Certificate check never occurred");
   1.646 +      TL_SET_STATE(TS_ERROR);
   1.647 +      return;
   1.648 +    }
   1.649 +    TL_SET_STATE(TS_OPEN);
   1.650 +  } else {
   1.651 +    int32_t err = PR_GetError();
   1.652 +    switch(err) {
   1.653 +      case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
   1.654 +        if (mode_ != DGRAM) {
   1.655 +          MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Malformed TLS message");
   1.656 +          TL_SET_STATE(TS_ERROR);
   1.657 +        } else {
   1.658 +          MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Malformed DTLS message; ignoring");
   1.659 +        }
   1.660 +        // Fall through
   1.661 +      case PR_WOULD_BLOCK_ERROR:
   1.662 +        MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "Would have blocked");
   1.663 +        if (mode_ == DGRAM) {
   1.664 +          PRIntervalTime timeout;
   1.665 +          rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
   1.666 +          if (rv == SECSuccess) {
   1.667 +            uint32_t timeout_ms = PR_IntervalToMilliseconds(timeout);
   1.668 +
   1.669 +            MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Setting DTLS timeout to " <<
   1.670 +                 timeout_ms);
   1.671 +            timer_->SetTarget(target_);
   1.672 +            timer_->InitWithFuncCallback(TimerCallback,
   1.673 +                                         this, timeout_ms,
   1.674 +                                         nsITimer::TYPE_ONE_SHOT);
   1.675 +          }
   1.676 +        }
   1.677 +        break;
   1.678 +      default:
   1.679 +        MOZ_MTLOG(ML_ERROR, LAYER_INFO << "SSL handshake error "<< err);
   1.680 +        TL_SET_STATE(TS_ERROR);
   1.681 +        break;
   1.682 +    }
   1.683 +  }
   1.684 +}
   1.685 +
   1.686 +void TransportLayerDtls::PacketReceived(TransportLayer* layer,
   1.687 +                                        const unsigned char *data,
   1.688 +                                        size_t len) {
   1.689 +  CheckThread();
   1.690 +  MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "PacketReceived(" << len << ")");
   1.691 +
   1.692 +  if (state_ != TS_CONNECTING && state_ != TS_OPEN) {
   1.693 +    MOZ_MTLOG(ML_DEBUG,
   1.694 +              LAYER_INFO << "Discarding packet in inappropriate state");
   1.695 +    return;
   1.696 +  }
   1.697 +
   1.698 +  nspr_io_adapter_->PacketReceived(data, len);
   1.699 +
   1.700 +  // If we're still connecting, try to handshake
   1.701 +  if (state_ == TS_CONNECTING) {
   1.702 +    Handshake();
   1.703 +  }
   1.704 +
   1.705 +  // Now try a recv if we're open, since there might be data left
   1.706 +  if (state_ == TS_OPEN) {
   1.707 +    unsigned char buf[2000];
   1.708 +
   1.709 +    int32_t rv = PR_Recv(ssl_fd_, buf, sizeof(buf), 0, PR_INTERVAL_NO_WAIT);
   1.710 +    if (rv > 0) {
   1.711 +      // We have data
   1.712 +      MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Read " << rv << " bytes from NSS");
   1.713 +      SignalPacketReceived(this, buf, rv);
   1.714 +    } else if (rv == 0) {
   1.715 +      TL_SET_STATE(TS_CLOSED);
   1.716 +    } else {
   1.717 +      int32_t err = PR_GetError();
   1.718 +
   1.719 +      if (err == PR_WOULD_BLOCK_ERROR) {
   1.720 +        // This gets ignored
   1.721 +        MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "Would have blocked");
   1.722 +      } else {
   1.723 +        MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "NSS Error " << err);
   1.724 +        TL_SET_STATE(TS_ERROR);
   1.725 +      }
   1.726 +    }
   1.727 +  }
   1.728 +}
   1.729 +
   1.730 +TransportResult TransportLayerDtls::SendPacket(const unsigned char *data,
   1.731 +                                               size_t len) {
   1.732 +  CheckThread();
   1.733 +  if (state_ != TS_OPEN) {
   1.734 +    MOZ_MTLOG(ML_ERROR, LAYER_INFO << "Can't call SendPacket() in state "
   1.735 +              << state_);
   1.736 +    return TE_ERROR;
   1.737 +  }
   1.738 +
   1.739 +  int32_t rv = PR_Send(ssl_fd_, data, len, 0, PR_INTERVAL_NO_WAIT);
   1.740 +
   1.741 +  if (rv > 0) {
   1.742 +    // We have data
   1.743 +    MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Wrote " << rv << " bytes to SSL Layer");
   1.744 +    return rv;
   1.745 +  }
   1.746 +
   1.747 +  if (rv == 0) {
   1.748 +    TL_SET_STATE(TS_CLOSED);
   1.749 +    return 0;
   1.750 +  }
   1.751 +
   1.752 +  int32_t err = PR_GetError();
   1.753 +
   1.754 +  if (err == PR_WOULD_BLOCK_ERROR) {
   1.755 +    // This gets ignored
   1.756 +    MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "Would have blocked");
   1.757 +    return TE_WOULDBLOCK;
   1.758 +  }
   1.759 +
   1.760 +  MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "NSS Error " << err);
   1.761 +  TL_SET_STATE(TS_ERROR);
   1.762 +  return TE_ERROR;
   1.763 +}
   1.764 +
   1.765 +SECStatus TransportLayerDtls::GetClientAuthDataHook(void *arg, PRFileDesc *fd,
   1.766 +                                                    CERTDistNames *caNames,
   1.767 +                                                    CERTCertificate **pRetCert,
   1.768 +                                                    SECKEYPrivateKey **pRetKey) {
   1.769 +  MOZ_MTLOG(ML_DEBUG, "Server requested client auth");
   1.770 +
   1.771 +  TransportLayerDtls *stream = reinterpret_cast<TransportLayerDtls *>(arg);
   1.772 +  stream->CheckThread();
   1.773 +
   1.774 +  if (!stream->identity_) {
   1.775 +    MOZ_MTLOG(ML_ERROR, "No identity available");
   1.776 +    PR_SetError(SSL_ERROR_NO_CERTIFICATE, 0);
   1.777 +    return SECFailure;
   1.778 +  }
   1.779 +
   1.780 +  *pRetCert = CERT_DupCertificate(stream->identity_->cert());
   1.781 +  if (!*pRetCert) {
   1.782 +    PR_SetError(PR_OUT_OF_MEMORY_ERROR, 0);
   1.783 +    return SECFailure;
   1.784 +  }
   1.785 +
   1.786 +  *pRetKey = SECKEY_CopyPrivateKey(stream->identity_->privkey());
   1.787 +  if (!*pRetKey) {
   1.788 +    CERT_DestroyCertificate(*pRetCert);
   1.789 +    *pRetCert = nullptr;
   1.790 +    PR_SetError(PR_OUT_OF_MEMORY_ERROR, 0);
   1.791 +    return SECFailure;
   1.792 +  }
   1.793 +
   1.794 +  return SECSuccess;
   1.795 +}
   1.796 +
   1.797 +nsresult TransportLayerDtls::SetSrtpCiphers(std::vector<uint16_t> ciphers) {
   1.798 +  // TODO: We should check these
   1.799 +  srtp_ciphers_ = ciphers;
   1.800 +
   1.801 +  return NS_OK;
   1.802 +}
   1.803 +
   1.804 +nsresult TransportLayerDtls::GetSrtpCipher(uint16_t *cipher) {
   1.805 +  CheckThread();
   1.806 +  SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, cipher);
   1.807 +  if (rv != SECSuccess) {
   1.808 +    MOZ_MTLOG(ML_DEBUG, "No SRTP cipher negotiated");
   1.809 +    return NS_ERROR_FAILURE;
   1.810 +  }
   1.811 +
   1.812 +  return NS_OK;
   1.813 +}
   1.814 +
   1.815 +nsresult TransportLayerDtls::ExportKeyingMaterial(const std::string& label,
   1.816 +                                                  bool use_context,
   1.817 +                                                  const std::string& context,
   1.818 +                                                  unsigned char *out,
   1.819 +                                                  unsigned int outlen) {
   1.820 +  CheckThread();
   1.821 +  SECStatus rv = SSL_ExportKeyingMaterial(ssl_fd_,
   1.822 +                                          label.c_str(),
   1.823 +                                          label.size(),
   1.824 +                                          use_context,
   1.825 +                                          reinterpret_cast<const unsigned char *>(
   1.826 +                                              context.c_str()),
   1.827 +                                          context.size(),
   1.828 +                                          out,
   1.829 +                                          outlen);
   1.830 +  if (rv != SECSuccess) {
   1.831 +    MOZ_MTLOG(ML_ERROR, "Couldn't export SSL keying material");
   1.832 +    return NS_ERROR_FAILURE;
   1.833 +  }
   1.834 +
   1.835 +  return NS_OK;
   1.836 +}
   1.837 +
   1.838 +SECStatus TransportLayerDtls::AuthCertificateHook(void *arg,
   1.839 +                                                  PRFileDesc *fd,
   1.840 +                                                  PRBool checksig,
   1.841 +                                                  PRBool isServer) {
   1.842 +  TransportLayerDtls *stream = reinterpret_cast<TransportLayerDtls *>(arg);
   1.843 +  stream->CheckThread();
   1.844 +  return stream->AuthCertificateHook(fd, checksig, isServer);
   1.845 +}
   1.846 +
   1.847 +SECStatus
   1.848 +TransportLayerDtls::CheckDigest(const RefPtr<VerificationDigest>&
   1.849 +                                digest,
   1.850 +                                CERTCertificate *peer_cert) {
   1.851 +  unsigned char computed_digest[kMaxDigestLength];
   1.852 +  size_t computed_digest_len;
   1.853 +
   1.854 +  MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Checking digest, algorithm="
   1.855 +            << digest->algorithm_);
   1.856 +  nsresult res =
   1.857 +      DtlsIdentity::ComputeFingerprint(peer_cert,
   1.858 +                                       digest->algorithm_,
   1.859 +                                       computed_digest,
   1.860 +                                       sizeof(computed_digest),
   1.861 +                                       &computed_digest_len);
   1.862 +  if (NS_FAILED(res)) {
   1.863 +    MOZ_MTLOG(ML_ERROR, "Could not compute peer fingerprint for digest " <<
   1.864 +              digest->algorithm_);
   1.865 +    // Go to end
   1.866 +    PR_SetError(SSL_ERROR_BAD_CERTIFICATE, 0);
   1.867 +    return SECFailure;
   1.868 +  }
   1.869 +
   1.870 +  if (computed_digest_len != digest->len_) {
   1.871 +    MOZ_MTLOG(ML_ERROR, "Digest is wrong length " << digest->len_ <<
   1.872 +              " should be " << computed_digest_len << " for algorithm " <<
   1.873 +              digest->algorithm_);
   1.874 +    PR_SetError(SSL_ERROR_BAD_CERTIFICATE, 0);
   1.875 +    return SECFailure;
   1.876 +  }
   1.877 +
   1.878 +  if (memcmp(digest->value_, computed_digest, computed_digest_len) != 0) {
   1.879 +    MOZ_MTLOG(ML_ERROR, "Digest does not match");
   1.880 +    PR_SetError(SSL_ERROR_BAD_CERTIFICATE, 0);
   1.881 +    return SECFailure;
   1.882 +  }
   1.883 +
   1.884 +  return SECSuccess;
   1.885 +}
   1.886 +
   1.887 +
   1.888 +SECStatus TransportLayerDtls::AuthCertificateHook(PRFileDesc *fd,
   1.889 +                                                  PRBool checksig,
   1.890 +                                                  PRBool isServer) {
   1.891 +  CheckThread();
   1.892 +  ScopedCERTCertificate peer_cert;
   1.893 +  peer_cert = SSL_PeerCertificate(fd);
   1.894 +
   1.895 +
   1.896 +  // We are not set up to take this being called multiple
   1.897 +  // times. Change this if we ever add renegotiation.
   1.898 +  MOZ_ASSERT(!auth_hook_called_);
   1.899 +  if (auth_hook_called_) {
   1.900 +    PR_SetError(PR_UNKNOWN_ERROR, 0);
   1.901 +    return SECFailure;
   1.902 +  }
   1.903 +  auth_hook_called_ = true;
   1.904 +
   1.905 +  MOZ_ASSERT(verification_mode_ != VERIFY_UNSET);
   1.906 +  MOZ_ASSERT(peer_cert_ == nullptr);
   1.907 +
   1.908 +  switch (verification_mode_) {
   1.909 +    case VERIFY_UNSET:
   1.910 +      // Break out to error exit
   1.911 +      PR_SetError(PR_UNKNOWN_ERROR, 0);
   1.912 +      break;
   1.913 +
   1.914 +    case VERIFY_ALLOW_ALL:
   1.915 +      peer_cert_ = peer_cert.forget();
   1.916 +      cert_ok_ = true;
   1.917 +      return SECSuccess;
   1.918 +
   1.919 +    case VERIFY_DIGEST:
   1.920 +      {
   1.921 +        MOZ_ASSERT(digests_.size() != 0);
   1.922 +        // Check all the provided digests
   1.923 +
   1.924 +        // Checking functions call PR_SetError()
   1.925 +        SECStatus rv = SECFailure;
   1.926 +        for (size_t i = 0; i < digests_.size(); i++) {
   1.927 +          RefPtr<VerificationDigest> digest = digests_[i];
   1.928 +          rv = CheckDigest(digest, peer_cert);
   1.929 +
   1.930 +          if (rv != SECSuccess)
   1.931 +            break;
   1.932 +        }
   1.933 +
   1.934 +        if (rv == SECSuccess) {
   1.935 +          // Matches all digests, we are good to go
   1.936 +          cert_ok_ = true;
   1.937 +          peer_cert = peer_cert.forget();
   1.938 +          return SECSuccess;
   1.939 +        }
   1.940 +      }
   1.941 +      break;
   1.942 +    default:
   1.943 +      MOZ_CRASH();  // Can't happen
   1.944 +  }
   1.945 +
   1.946 +  return SECFailure;
   1.947 +}
   1.948 +
   1.949 +void TransportLayerDtls::TimerCallback(nsITimer *timer, void *arg) {
   1.950 +  TransportLayerDtls *dtls = reinterpret_cast<TransportLayerDtls *>(arg);
   1.951 +
   1.952 +  MOZ_MTLOG(ML_DEBUG, "DTLS timer expired");
   1.953 +
   1.954 +  dtls->Handshake();
   1.955 +}
   1.956 +
   1.957 +}  // close namespace

mercurial