|
1 # Copyright 2011, Google Inc. |
|
2 # All rights reserved. |
|
3 # |
|
4 # Redistribution and use in source and binary forms, with or without |
|
5 # modification, are permitted provided that the following conditions are |
|
6 # met: |
|
7 # |
|
8 # * Redistributions of source code must retain the above copyright |
|
9 # notice, this list of conditions and the following disclaimer. |
|
10 # * Redistributions in binary form must reproduce the above |
|
11 # copyright notice, this list of conditions and the following disclaimer |
|
12 # in the documentation and/or other materials provided with the |
|
13 # distribution. |
|
14 # * Neither the name of Google Inc. nor the names of its |
|
15 # contributors may be used to endorse or promote products derived from |
|
16 # this software without specific prior written permission. |
|
17 # |
|
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
|
19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
|
20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
|
21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
|
22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
|
23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
|
24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
|
25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
|
26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
|
27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
|
28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
29 |
|
30 |
|
31 """Common functions and exceptions used by WebSocket opening handshake |
|
32 processors. |
|
33 """ |
|
34 |
|
35 |
|
36 from mod_pywebsocket import common |
|
37 from mod_pywebsocket import http_header_util |
|
38 |
|
39 |
|
40 class AbortedByUserException(Exception): |
|
41 """Exception for aborting a connection intentionally. |
|
42 |
|
43 If this exception is raised in do_extra_handshake handler, the connection |
|
44 will be abandoned. No other WebSocket or HTTP(S) handler will be invoked. |
|
45 |
|
46 If this exception is raised in transfer_data_handler, the connection will |
|
47 be closed without closing handshake. No other WebSocket or HTTP(S) handler |
|
48 will be invoked. |
|
49 """ |
|
50 |
|
51 pass |
|
52 |
|
53 |
|
54 class HandshakeException(Exception): |
|
55 """This exception will be raised when an error occurred while processing |
|
56 WebSocket initial handshake. |
|
57 """ |
|
58 |
|
59 def __init__(self, name, status=None): |
|
60 super(HandshakeException, self).__init__(name) |
|
61 self.status = status |
|
62 |
|
63 |
|
64 class VersionException(Exception): |
|
65 """This exception will be raised when a version of client request does not |
|
66 match with version the server supports. |
|
67 """ |
|
68 |
|
69 def __init__(self, name, supported_versions=''): |
|
70 """Construct an instance. |
|
71 |
|
72 Args: |
|
73 supported_version: a str object to show supported hybi versions. |
|
74 (e.g. '8, 13') |
|
75 """ |
|
76 super(VersionException, self).__init__(name) |
|
77 self.supported_versions = supported_versions |
|
78 |
|
79 |
|
80 def get_default_port(is_secure): |
|
81 if is_secure: |
|
82 return common.DEFAULT_WEB_SOCKET_SECURE_PORT |
|
83 else: |
|
84 return common.DEFAULT_WEB_SOCKET_PORT |
|
85 |
|
86 |
|
87 def validate_subprotocol(subprotocol, hixie): |
|
88 """Validate a value in subprotocol fields such as WebSocket-Protocol, |
|
89 Sec-WebSocket-Protocol. |
|
90 |
|
91 See |
|
92 - RFC 6455: Section 4.1., 4.2.2., and 4.3. |
|
93 - HyBi 00: Section 4.1. Opening handshake |
|
94 - Hixie 75: Section 4.1. Handshake |
|
95 """ |
|
96 |
|
97 if not subprotocol: |
|
98 raise HandshakeException('Invalid subprotocol name: empty') |
|
99 if hixie: |
|
100 # Parameter should be in the range U+0020 to U+007E. |
|
101 for c in subprotocol: |
|
102 if not 0x20 <= ord(c) <= 0x7e: |
|
103 raise HandshakeException( |
|
104 'Illegal character in subprotocol name: %r' % c) |
|
105 else: |
|
106 # Parameter should be encoded HTTP token. |
|
107 state = http_header_util.ParsingState(subprotocol) |
|
108 token = http_header_util.consume_token(state) |
|
109 rest = http_header_util.peek(state) |
|
110 # If |rest| is not None, |subprotocol| is not one token or invalid. If |
|
111 # |rest| is None, |token| must not be None because |subprotocol| is |
|
112 # concatenation of |token| and |rest| and is not None. |
|
113 if rest is not None: |
|
114 raise HandshakeException('Invalid non-token string in subprotocol ' |
|
115 'name: %r' % rest) |
|
116 |
|
117 |
|
118 def parse_host_header(request): |
|
119 fields = request.headers_in['Host'].split(':', 1) |
|
120 if len(fields) == 1: |
|
121 return fields[0], get_default_port(request.is_https()) |
|
122 try: |
|
123 return fields[0], int(fields[1]) |
|
124 except ValueError, e: |
|
125 raise HandshakeException('Invalid port number format: %r' % e) |
|
126 |
|
127 |
|
128 def format_header(name, value): |
|
129 return '%s: %s\r\n' % (name, value) |
|
130 |
|
131 |
|
132 def build_location(request): |
|
133 """Build WebSocket location for request.""" |
|
134 location_parts = [] |
|
135 if request.is_https(): |
|
136 location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) |
|
137 else: |
|
138 location_parts.append(common.WEB_SOCKET_SCHEME) |
|
139 location_parts.append('://') |
|
140 host, port = parse_host_header(request) |
|
141 connection_port = request.connection.local_addr[1] |
|
142 if port != connection_port: |
|
143 raise HandshakeException('Header/connection port mismatch: %d/%d' % |
|
144 (port, connection_port)) |
|
145 location_parts.append(host) |
|
146 if (port != get_default_port(request.is_https())): |
|
147 location_parts.append(':') |
|
148 location_parts.append(str(port)) |
|
149 location_parts.append(request.uri) |
|
150 return ''.join(location_parts) |
|
151 |
|
152 |
|
153 def get_mandatory_header(request, key): |
|
154 value = request.headers_in.get(key) |
|
155 if value is None: |
|
156 raise HandshakeException('Header %s is not defined' % key) |
|
157 return value |
|
158 |
|
159 |
|
160 def validate_mandatory_header(request, key, expected_value, fail_status=None): |
|
161 value = get_mandatory_header(request, key) |
|
162 |
|
163 if value.lower() != expected_value.lower(): |
|
164 raise HandshakeException( |
|
165 'Expected %r for header %s but found %r (case-insensitive)' % |
|
166 (expected_value, key, value), status=fail_status) |
|
167 |
|
168 |
|
169 def check_request_line(request): |
|
170 # 5.1 1. The three character UTF-8 string "GET". |
|
171 # 5.1 2. A UTF-8-encoded U+0020 SPACE character (0x20 byte). |
|
172 if request.method != 'GET': |
|
173 raise HandshakeException('Method is not GET') |
|
174 |
|
175 |
|
176 def check_header_lines(request, mandatory_headers): |
|
177 check_request_line(request) |
|
178 |
|
179 # The expected field names, and the meaning of their corresponding |
|
180 # values, are as follows. |
|
181 # |Upgrade| and |Connection| |
|
182 for key, expected_value in mandatory_headers: |
|
183 validate_mandatory_header(request, key, expected_value) |
|
184 |
|
185 |
|
186 def parse_token_list(data): |
|
187 """Parses a header value which follows 1#token and returns parsed elements |
|
188 as a list of strings. |
|
189 |
|
190 Leading LWSes must be trimmed. |
|
191 """ |
|
192 |
|
193 state = http_header_util.ParsingState(data) |
|
194 |
|
195 token_list = [] |
|
196 |
|
197 while True: |
|
198 token = http_header_util.consume_token(state) |
|
199 if token is not None: |
|
200 token_list.append(token) |
|
201 |
|
202 http_header_util.consume_lwses(state) |
|
203 |
|
204 if http_header_util.peek(state) is None: |
|
205 break |
|
206 |
|
207 if not http_header_util.consume_string(state, ','): |
|
208 raise HandshakeException( |
|
209 'Expected a comma but found %r' % http_header_util.peek(state)) |
|
210 |
|
211 http_header_util.consume_lwses(state) |
|
212 |
|
213 if len(token_list) == 0: |
|
214 raise HandshakeException('No valid token found') |
|
215 |
|
216 return token_list |
|
217 |
|
218 |
|
219 def _parse_extension_param(state, definition, allow_quoted_string): |
|
220 param_name = http_header_util.consume_token(state) |
|
221 |
|
222 if param_name is None: |
|
223 raise HandshakeException('No valid parameter name found') |
|
224 |
|
225 http_header_util.consume_lwses(state) |
|
226 |
|
227 if not http_header_util.consume_string(state, '='): |
|
228 definition.add_parameter(param_name, None) |
|
229 return |
|
230 |
|
231 http_header_util.consume_lwses(state) |
|
232 |
|
233 if allow_quoted_string: |
|
234 # TODO(toyoshim): Add code to validate that parsed param_value is token |
|
235 param_value = http_header_util.consume_token_or_quoted_string(state) |
|
236 else: |
|
237 param_value = http_header_util.consume_token(state) |
|
238 if param_value is None: |
|
239 raise HandshakeException( |
|
240 'No valid parameter value found on the right-hand side of ' |
|
241 'parameter %r' % param_name) |
|
242 |
|
243 definition.add_parameter(param_name, param_value) |
|
244 |
|
245 |
|
246 def _parse_extension(state, allow_quoted_string): |
|
247 extension_token = http_header_util.consume_token(state) |
|
248 if extension_token is None: |
|
249 return None |
|
250 |
|
251 extension = common.ExtensionParameter(extension_token) |
|
252 |
|
253 while True: |
|
254 http_header_util.consume_lwses(state) |
|
255 |
|
256 if not http_header_util.consume_string(state, ';'): |
|
257 break |
|
258 |
|
259 http_header_util.consume_lwses(state) |
|
260 |
|
261 try: |
|
262 _parse_extension_param(state, extension, allow_quoted_string) |
|
263 except HandshakeException, e: |
|
264 raise HandshakeException( |
|
265 'Failed to parse Sec-WebSocket-Extensions header: ' |
|
266 'Failed to parse parameter for %r (%r)' % |
|
267 (extension_token, e)) |
|
268 |
|
269 return extension |
|
270 |
|
271 |
|
272 def parse_extensions(data, allow_quoted_string=False): |
|
273 """Parses Sec-WebSocket-Extensions header value returns a list of |
|
274 common.ExtensionParameter objects. |
|
275 |
|
276 Leading LWSes must be trimmed. |
|
277 """ |
|
278 |
|
279 state = http_header_util.ParsingState(data) |
|
280 |
|
281 extension_list = [] |
|
282 while True: |
|
283 extension = _parse_extension(state, allow_quoted_string) |
|
284 if extension is not None: |
|
285 extension_list.append(extension) |
|
286 |
|
287 http_header_util.consume_lwses(state) |
|
288 |
|
289 if http_header_util.peek(state) is None: |
|
290 break |
|
291 |
|
292 if not http_header_util.consume_string(state, ','): |
|
293 raise HandshakeException( |
|
294 'Failed to parse Sec-WebSocket-Extensions header: ' |
|
295 'Expected a comma but found %r' % |
|
296 http_header_util.peek(state)) |
|
297 |
|
298 http_header_util.consume_lwses(state) |
|
299 |
|
300 if len(extension_list) == 0: |
|
301 raise HandshakeException( |
|
302 'Sec-WebSocket-Extensions header contains no valid extension') |
|
303 |
|
304 return extension_list |
|
305 |
|
306 |
|
307 def format_extensions(extension_list): |
|
308 formatted_extension_list = [] |
|
309 for extension in extension_list: |
|
310 formatted_params = [extension.name()] |
|
311 for param_name, param_value in extension.get_parameters(): |
|
312 if param_value is None: |
|
313 formatted_params.append(param_name) |
|
314 else: |
|
315 quoted_value = http_header_util.quote_if_necessary(param_value) |
|
316 formatted_params.append('%s=%s' % (param_name, quoted_value)) |
|
317 |
|
318 formatted_extension_list.append('; '.join(formatted_params)) |
|
319 |
|
320 return ', '.join(formatted_extension_list) |
|
321 |
|
322 |
|
323 # vi:sts=4 sw=4 et |