michael@0: # Copyright 2011, Google Inc. michael@0: # All rights reserved. michael@0: # michael@0: # Redistribution and use in source and binary forms, with or without michael@0: # modification, are permitted provided that the following conditions are michael@0: # met: michael@0: # michael@0: # * Redistributions of source code must retain the above copyright michael@0: # notice, this list of conditions and the following disclaimer. michael@0: # * Redistributions in binary form must reproduce the above michael@0: # copyright notice, this list of conditions and the following disclaimer michael@0: # in the documentation and/or other materials provided with the michael@0: # distribution. michael@0: # * Neither the name of Google Inc. nor the names of its michael@0: # contributors may be used to endorse or promote products derived from michael@0: # this software without specific prior written permission. michael@0: # michael@0: # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS michael@0: # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT michael@0: # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR michael@0: # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT michael@0: # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, michael@0: # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT michael@0: # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, michael@0: # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY michael@0: # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT michael@0: # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE michael@0: # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. michael@0: michael@0: michael@0: """Common functions and exceptions used by WebSocket opening handshake michael@0: processors. michael@0: """ michael@0: michael@0: michael@0: from mod_pywebsocket import common michael@0: from mod_pywebsocket import http_header_util michael@0: michael@0: michael@0: class AbortedByUserException(Exception): michael@0: """Exception for aborting a connection intentionally. michael@0: michael@0: If this exception is raised in do_extra_handshake handler, the connection michael@0: will be abandoned. No other WebSocket or HTTP(S) handler will be invoked. michael@0: michael@0: If this exception is raised in transfer_data_handler, the connection will michael@0: be closed without closing handshake. No other WebSocket or HTTP(S) handler michael@0: will be invoked. michael@0: """ michael@0: michael@0: pass michael@0: michael@0: michael@0: class HandshakeException(Exception): michael@0: """This exception will be raised when an error occurred while processing michael@0: WebSocket initial handshake. michael@0: """ michael@0: michael@0: def __init__(self, name, status=None): michael@0: super(HandshakeException, self).__init__(name) michael@0: self.status = status michael@0: michael@0: michael@0: class VersionException(Exception): michael@0: """This exception will be raised when a version of client request does not michael@0: match with version the server supports. michael@0: """ michael@0: michael@0: def __init__(self, name, supported_versions=''): michael@0: """Construct an instance. michael@0: michael@0: Args: michael@0: supported_version: a str object to show supported hybi versions. michael@0: (e.g. '8, 13') michael@0: """ michael@0: super(VersionException, self).__init__(name) michael@0: self.supported_versions = supported_versions michael@0: michael@0: michael@0: def get_default_port(is_secure): michael@0: if is_secure: michael@0: return common.DEFAULT_WEB_SOCKET_SECURE_PORT michael@0: else: michael@0: return common.DEFAULT_WEB_SOCKET_PORT michael@0: michael@0: michael@0: def validate_subprotocol(subprotocol, hixie): michael@0: """Validate a value in subprotocol fields such as WebSocket-Protocol, michael@0: Sec-WebSocket-Protocol. michael@0: michael@0: See michael@0: - RFC 6455: Section 4.1., 4.2.2., and 4.3. michael@0: - HyBi 00: Section 4.1. Opening handshake michael@0: - Hixie 75: Section 4.1. Handshake michael@0: """ michael@0: michael@0: if not subprotocol: michael@0: raise HandshakeException('Invalid subprotocol name: empty') michael@0: if hixie: michael@0: # Parameter should be in the range U+0020 to U+007E. michael@0: for c in subprotocol: michael@0: if not 0x20 <= ord(c) <= 0x7e: michael@0: raise HandshakeException( michael@0: 'Illegal character in subprotocol name: %r' % c) michael@0: else: michael@0: # Parameter should be encoded HTTP token. michael@0: state = http_header_util.ParsingState(subprotocol) michael@0: token = http_header_util.consume_token(state) michael@0: rest = http_header_util.peek(state) michael@0: # If |rest| is not None, |subprotocol| is not one token or invalid. If michael@0: # |rest| is None, |token| must not be None because |subprotocol| is michael@0: # concatenation of |token| and |rest| and is not None. michael@0: if rest is not None: michael@0: raise HandshakeException('Invalid non-token string in subprotocol ' michael@0: 'name: %r' % rest) michael@0: michael@0: michael@0: def parse_host_header(request): michael@0: fields = request.headers_in['Host'].split(':', 1) michael@0: if len(fields) == 1: michael@0: return fields[0], get_default_port(request.is_https()) michael@0: try: michael@0: return fields[0], int(fields[1]) michael@0: except ValueError, e: michael@0: raise HandshakeException('Invalid port number format: %r' % e) michael@0: michael@0: michael@0: def format_header(name, value): michael@0: return '%s: %s\r\n' % (name, value) michael@0: michael@0: michael@0: def build_location(request): michael@0: """Build WebSocket location for request.""" michael@0: location_parts = [] michael@0: if request.is_https(): michael@0: location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) michael@0: else: michael@0: location_parts.append(common.WEB_SOCKET_SCHEME) michael@0: location_parts.append('://') michael@0: host, port = parse_host_header(request) michael@0: connection_port = request.connection.local_addr[1] michael@0: if port != connection_port: michael@0: raise HandshakeException('Header/connection port mismatch: %d/%d' % michael@0: (port, connection_port)) michael@0: location_parts.append(host) michael@0: if (port != get_default_port(request.is_https())): michael@0: location_parts.append(':') michael@0: location_parts.append(str(port)) michael@0: location_parts.append(request.uri) michael@0: return ''.join(location_parts) michael@0: michael@0: michael@0: def get_mandatory_header(request, key): michael@0: value = request.headers_in.get(key) michael@0: if value is None: michael@0: raise HandshakeException('Header %s is not defined' % key) michael@0: return value michael@0: michael@0: michael@0: def validate_mandatory_header(request, key, expected_value, fail_status=None): michael@0: value = get_mandatory_header(request, key) michael@0: michael@0: if value.lower() != expected_value.lower(): michael@0: raise HandshakeException( michael@0: 'Expected %r for header %s but found %r (case-insensitive)' % michael@0: (expected_value, key, value), status=fail_status) michael@0: michael@0: michael@0: def check_request_line(request): michael@0: # 5.1 1. The three character UTF-8 string "GET". michael@0: # 5.1 2. A UTF-8-encoded U+0020 SPACE character (0x20 byte). michael@0: if request.method != 'GET': michael@0: raise HandshakeException('Method is not GET') michael@0: michael@0: michael@0: def check_header_lines(request, mandatory_headers): michael@0: check_request_line(request) michael@0: michael@0: # The expected field names, and the meaning of their corresponding michael@0: # values, are as follows. michael@0: # |Upgrade| and |Connection| michael@0: for key, expected_value in mandatory_headers: michael@0: validate_mandatory_header(request, key, expected_value) michael@0: michael@0: michael@0: def parse_token_list(data): michael@0: """Parses a header value which follows 1#token and returns parsed elements michael@0: as a list of strings. michael@0: michael@0: Leading LWSes must be trimmed. michael@0: """ michael@0: michael@0: state = http_header_util.ParsingState(data) michael@0: michael@0: token_list = [] michael@0: michael@0: while True: michael@0: token = http_header_util.consume_token(state) michael@0: if token is not None: michael@0: token_list.append(token) michael@0: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: if http_header_util.peek(state) is None: michael@0: break michael@0: michael@0: if not http_header_util.consume_string(state, ','): michael@0: raise HandshakeException( michael@0: 'Expected a comma but found %r' % http_header_util.peek(state)) michael@0: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: if len(token_list) == 0: michael@0: raise HandshakeException('No valid token found') michael@0: michael@0: return token_list michael@0: michael@0: michael@0: def _parse_extension_param(state, definition, allow_quoted_string): michael@0: param_name = http_header_util.consume_token(state) michael@0: michael@0: if param_name is None: michael@0: raise HandshakeException('No valid parameter name found') michael@0: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: if not http_header_util.consume_string(state, '='): michael@0: definition.add_parameter(param_name, None) michael@0: return michael@0: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: if allow_quoted_string: michael@0: # TODO(toyoshim): Add code to validate that parsed param_value is token michael@0: param_value = http_header_util.consume_token_or_quoted_string(state) michael@0: else: michael@0: param_value = http_header_util.consume_token(state) michael@0: if param_value is None: michael@0: raise HandshakeException( michael@0: 'No valid parameter value found on the right-hand side of ' michael@0: 'parameter %r' % param_name) michael@0: michael@0: definition.add_parameter(param_name, param_value) michael@0: michael@0: michael@0: def _parse_extension(state, allow_quoted_string): michael@0: extension_token = http_header_util.consume_token(state) michael@0: if extension_token is None: michael@0: return None michael@0: michael@0: extension = common.ExtensionParameter(extension_token) michael@0: michael@0: while True: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: if not http_header_util.consume_string(state, ';'): michael@0: break michael@0: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: try: michael@0: _parse_extension_param(state, extension, allow_quoted_string) michael@0: except HandshakeException, e: michael@0: raise HandshakeException( michael@0: 'Failed to parse Sec-WebSocket-Extensions header: ' michael@0: 'Failed to parse parameter for %r (%r)' % michael@0: (extension_token, e)) michael@0: michael@0: return extension michael@0: michael@0: michael@0: def parse_extensions(data, allow_quoted_string=False): michael@0: """Parses Sec-WebSocket-Extensions header value returns a list of michael@0: common.ExtensionParameter objects. michael@0: michael@0: Leading LWSes must be trimmed. michael@0: """ michael@0: michael@0: state = http_header_util.ParsingState(data) michael@0: michael@0: extension_list = [] michael@0: while True: michael@0: extension = _parse_extension(state, allow_quoted_string) michael@0: if extension is not None: michael@0: extension_list.append(extension) michael@0: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: if http_header_util.peek(state) is None: michael@0: break michael@0: michael@0: if not http_header_util.consume_string(state, ','): michael@0: raise HandshakeException( michael@0: 'Failed to parse Sec-WebSocket-Extensions header: ' michael@0: 'Expected a comma but found %r' % michael@0: http_header_util.peek(state)) michael@0: michael@0: http_header_util.consume_lwses(state) michael@0: michael@0: if len(extension_list) == 0: michael@0: raise HandshakeException( michael@0: 'Sec-WebSocket-Extensions header contains no valid extension') michael@0: michael@0: return extension_list michael@0: michael@0: michael@0: def format_extensions(extension_list): michael@0: formatted_extension_list = [] michael@0: for extension in extension_list: michael@0: formatted_params = [extension.name()] michael@0: for param_name, param_value in extension.get_parameters(): michael@0: if param_value is None: michael@0: formatted_params.append(param_name) michael@0: else: michael@0: quoted_value = http_header_util.quote_if_necessary(param_value) michael@0: formatted_params.append('%s=%s' % (param_name, quoted_value)) michael@0: michael@0: formatted_extension_list.append('; '.join(formatted_params)) michael@0: michael@0: return ', '.join(formatted_extension_list) michael@0: michael@0: michael@0: # vi:sts=4 sw=4 et