Source code for aiohttp.http_websocket

"""WebSocket protocol versions 13 and 8."""

import collections
import json
import random
import re
import sys
import zlib
from enum import IntEnum
from struct import Struct

from .helpers import NO_EXTENSIONS
from .log import ws_logger


__all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
           'WebSocketReader', 'WebSocketWriter', 'WSMessage',
           'WebSocketError', 'WSMsgType', 'WSCloseCode')


[docs]class WSCloseCode(IntEnum): OK = 1000 GOING_AWAY = 1001 PROTOCOL_ERROR = 1002 UNSUPPORTED_DATA = 1003 INVALID_TEXT = 1007 POLICY_VIOLATION = 1008 MESSAGE_TOO_BIG = 1009 MANDATORY_EXTENSION = 1010 INTERNAL_ERROR = 1011 SERVICE_RESTART = 1012 TRY_AGAIN_LATER = 1013
ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
[docs]class WSMsgType(IntEnum): # websocket spec types CONTINUATION = 0x0 TEXT = 0x1 BINARY = 0x2 PING = 0x9 PONG = 0xa CLOSE = 0x8 # aiohttp specific types CLOSING = 0x100 CLOSED = 0x101 ERROR = 0x102 text = TEXT binary = BINARY ping = PING pong = PONG close = CLOSE closing = CLOSING closed = CLOSED error = ERROR
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' UNPACK_LEN2 = Struct('!H').unpack_from UNPACK_LEN3 = Struct('!Q').unpack_from UNPACK_CLOSE_CODE = Struct('!H').unpack PACK_LEN1 = Struct('!BB').pack PACK_LEN2 = Struct('!BBH').pack PACK_LEN3 = Struct('!BBQ').pack PACK_CLOSE_CODE = Struct('!H').pack MSG_SIZE = 2 ** 14 DEFAULT_LIMIT = 2 ** 16 _WSMessageBase = collections.namedtuple('_WSMessageBase', ['type', 'data', 'extra'])
[docs]class WSMessage(_WSMessageBase):
[docs] def json(self, *, loads=json.loads): """Return parsed JSON data. .. versionadded:: 0.22 """ return loads(self.data)
WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None) WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None) class WebSocketError(Exception): """WebSocket protocol parser error.""" def __init__(self, code, message): self.code = code super().__init__(message) class WSHandshakeError(Exception): """WebSocket protocol handshake error.""" native_byteorder = sys.byteorder # Used by _websocket_mask_python _XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)] def _websocket_mask_python(mask, data): """Websocket masking function. `mask` is a `bytes` object of length 4; `data` is a `bytearray` object of any length. The contents of `data` are masked with `mask`, as specified in section 5.3 of RFC 6455. Note that this function mutates the `data` argument. This pure-python implementation may be replaced by an optimized version when available. """ assert isinstance(data, bytearray), data assert len(mask) == 4, mask if data: a, b, c, d = (_XOR_TABLE[n] for n in mask) data[::4] = data[::4].translate(a) data[1::4] = data[1::4].translate(b) data[2::4] = data[2::4].translate(c) data[3::4] = data[3::4].translate(d) if NO_EXTENSIONS: _websocket_mask = _websocket_mask_python else: try: from ._websocket import _websocket_mask_cython # type: ignore _websocket_mask = _websocket_mask_cython except ImportError: # pragma: no cover _websocket_mask = _websocket_mask_python _WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff]) _WS_EXT_RE = re.compile(r'^(?:;\s*(?:' r'(server_no_context_takeover)|' r'(client_no_context_takeover)|' r'(server_max_window_bits(?:=(\d+))?)|' r'(client_max_window_bits(?:=(\d+))?)))*$') _WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?') def ws_ext_parse(extstr, isserver=False): if not extstr: return 0, False compress = 0 notakeover = False for ext in _WS_EXT_RE_SPLIT.finditer(extstr): defext = ext.group(1) # Return compress = 15 when get `permessage-deflate` if not defext: compress = 15 break match = _WS_EXT_RE.match(defext) if match: compress = 15 if isserver: # Server never fail to detect compress handshake. # Server does not need to send max wbit to client if match.group(4): compress = int(match.group(4)) # Group3 must match if group4 matches # Compress wbit 8 does not support in zlib # If compress level not support, # CONTINUE to next extension if compress > 15 or compress < 9: compress = 0 continue if match.group(1): notakeover = True # Ignore regex group 5 & 6 for client_max_window_bits break else: if match.group(6): compress = int(match.group(6)) # Group5 must match if group6 matches # Compress wbit 8 does not support in zlib # If compress level not support, # FAIL the parse progress if compress > 15 or compress < 9: raise WSHandshakeError('Invalid window size') if match.group(2): notakeover = True # Ignore regex group 5 & 6 for client_max_window_bits break # Return Fail if client side and not match elif not isserver: raise WSHandshakeError('Extension for deflate not supported' + ext.group(1)) return compress, notakeover def ws_ext_gen(compress=15, isserver=False, server_notakeover=False): # client_notakeover=False not used for server # compress wbit 8 does not support in zlib if compress < 9 or compress > 15: raise ValueError('Compress wbits must between 9 and 15, ' 'zlib does not support wbits=8') enabledext = ['permessage-deflate'] if not isserver: enabledext.append('client_max_window_bits') if compress < 15: enabledext.append('server_max_window_bits=' + str(compress)) if server_notakeover: enabledext.append('server_no_context_takeover') # if client_notakeover: # enabledext.append('client_no_context_takeover') return '; '.join(enabledext) class WSParserState(IntEnum): READ_HEADER = 1 READ_PAYLOAD_LENGTH = 2 READ_PAYLOAD_MASK = 3 READ_PAYLOAD = 4 class WebSocketReader: def __init__(self, queue, max_msg_size, compress=True): self.queue = queue self._max_msg_size = max_msg_size self._exc = None self._partial = bytearray() self._state = WSParserState.READ_HEADER self._opcode = None self._frame_fin = False self._frame_opcode = None self._frame_payload = bytearray() self._tail = b'' self._has_mask = False self._frame_mask = None self._payload_length = 0 self._payload_length_flag = 0 self._compressed = None self._decompressobj = None self._compress = compress def feed_eof(self): self.queue.feed_eof() def feed_data(self, data): if self._exc: return True, data try: return self._feed_data(data) except Exception as exc: self._exc = exc self.queue.set_exception(exc) return True, b'' def _feed_data(self, data): for fin, opcode, payload, compressed in self.parse_frame(data): if compressed and not self._decompressobj: self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS) if opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] if (close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES): raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'Invalid close code: {}'.format(close_code)) try: close_message = payload[2:].decode('utf-8') except UnicodeDecodeError as exc: raise WebSocketError( WSCloseCode.INVALID_TEXT, 'Invalid UTF-8 text message') from exc msg = WSMessage(WSMsgType.CLOSE, close_code, close_message) elif payload: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'Invalid close frame: {} {} {!r}'.format( fin, opcode, payload)) else: msg = WSMessage(WSMsgType.CLOSE, 0, '') self.queue.feed_data(msg, 0) elif opcode == WSMsgType.PING: self.queue.feed_data( WSMessage(WSMsgType.PING, payload, ''), len(payload)) elif opcode == WSMsgType.PONG: self.queue.feed_data( WSMessage(WSMsgType.PONG, payload, ''), len(payload)) elif opcode not in ( WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "Unexpected opcode={!r}".format(opcode)) else: # load text/binary if not fin: # got partial frame payload if opcode != WSMsgType.CONTINUATION: self._opcode = opcode self._partial.extend(payload) if (self._max_msg_size and len(self._partial) >= self._max_msg_size): raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, "Message size {} exceeds limit {}".format( len(self._partial), self._max_msg_size)) else: # previous frame was non finished # we should get continuation opcode if self._partial: if opcode != WSMsgType.CONTINUATION: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'The opcode in non-fin frame is expected ' 'to be zero, got {!r}'.format(opcode)) if opcode == WSMsgType.CONTINUATION: opcode = self._opcode self._opcode = None self._partial.extend(payload) if (self._max_msg_size and len(self._partial) >= self._max_msg_size): raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, "Message size {} exceeds limit {}".format( len(self._partial), self._max_msg_size)) # Decompress process must to be done after all packets # received. if compressed: self._partial.extend(_WS_DEFLATE_TRAILING) payload_merged = self._decompressobj.decompress( self._partial, self._max_msg_size) if self._decompressobj.unconsumed_tail: left = len(self._decompressobj.unconsumed_tail) raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, "Decompressed message size exceeds limit {}". format(self._max_msg_size + left, self._max_msg_size)) else: payload_merged = bytes(self._partial) self._partial.clear() if opcode == WSMsgType.TEXT: try: text = payload_merged.decode('utf-8') self.queue.feed_data( WSMessage(WSMsgType.TEXT, text, ''), len(text)) except UnicodeDecodeError as exc: raise WebSocketError( WSCloseCode.INVALID_TEXT, 'Invalid UTF-8 text message') from exc else: self.queue.feed_data( WSMessage(WSMsgType.BINARY, payload_merged, ''), len(payload_merged)) return False, b'' def parse_frame(self, buf): """Return the next frame from the socket.""" frames = [] if self._tail: buf, self._tail = self._tail + buf, b'' start_pos = 0 buf_length = len(buf) while True: # read header if self._state == WSParserState.READ_HEADER: if buf_length - start_pos >= 2: data = buf[start_pos:start_pos+2] start_pos += 2 first_byte, second_byte = data fin = (first_byte >> 7) & 1 rsv1 = (first_byte >> 6) & 1 rsv2 = (first_byte >> 5) & 1 rsv3 = (first_byte >> 4) & 1 opcode = first_byte & 0xf # frame-fin = %x0 ; more frames of this message follow # / %x1 ; final frame of this message # frame-rsv1 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise # frame-rsv2 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise # frame-rsv3 = %x0 ; # 1 bit, MUST be 0 unless negotiated otherwise # # Remove rsv1 from this test for deflate development if rsv2 or rsv3 or (rsv1 and not self._compress): raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'Received frame with non-zero reserved bits') if opcode > 0x7 and fin == 0: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'Received fragmented control frame') has_mask = (second_byte >> 7) & 1 length = second_byte & 0x7f # Control frames MUST have a payload # length of 125 bytes or less if opcode > 0x7 and length > 125: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'Control frame payload cannot be ' 'larger than 125 bytes') # Set compress status if last package is FIN # OR set compress status if this is first fragment # Raise error if not first fragment with rsv1 = 0x1 if self._frame_fin or self._compressed is None: self._compressed = True if rsv1 else False elif rsv1: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, 'Received frame with non-zero reserved bits') self._frame_fin = fin self._frame_opcode = opcode self._has_mask = has_mask self._payload_length_flag = length self._state = WSParserState.READ_PAYLOAD_LENGTH else: break # read payload length if self._state == WSParserState.READ_PAYLOAD_LENGTH: length = self._payload_length_flag if length == 126: if buf_length - start_pos >= 2: data = buf[start_pos:start_pos+2] start_pos += 2 length = UNPACK_LEN2(data)[0] self._payload_length = length self._state = ( WSParserState.READ_PAYLOAD_MASK if self._has_mask else WSParserState.READ_PAYLOAD) else: break elif length > 126: if buf_length - start_pos >= 8: data = buf[start_pos:start_pos+8] start_pos += 8 length = UNPACK_LEN3(data)[0] self._payload_length = length self._state = ( WSParserState.READ_PAYLOAD_MASK if self._has_mask else WSParserState.READ_PAYLOAD) else: break else: self._payload_length = length self._state = ( WSParserState.READ_PAYLOAD_MASK if self._has_mask else WSParserState.READ_PAYLOAD) # read payload mask if self._state == WSParserState.READ_PAYLOAD_MASK: if buf_length - start_pos >= 4: self._frame_mask = buf[start_pos:start_pos+4] start_pos += 4 self._state = WSParserState.READ_PAYLOAD else: break if self._state == WSParserState.READ_PAYLOAD: length = self._payload_length payload = self._frame_payload chunk_len = buf_length - start_pos if length >= chunk_len: self._payload_length = length - chunk_len payload.extend(buf[start_pos:]) start_pos = buf_length else: self._payload_length = 0 payload.extend(buf[start_pos:start_pos+length]) start_pos = start_pos + length if self._payload_length == 0: if self._has_mask: _websocket_mask(self._frame_mask, payload) frames.append(( self._frame_fin, self._frame_opcode, payload, self._compressed)) self._frame_payload = bytearray() self._state = WSParserState.READ_HEADER else: break self._tail = buf[start_pos:] return frames class WebSocketWriter: def __init__(self, protocol, transport, *, use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(), compress=0, notakeover=False): self.protocol = protocol self.transport = transport self.use_mask = use_mask self.randrange = random.randrange self.compress = compress self.notakeover = notakeover self._closing = False self._limit = limit self._output_size = 0 self._compressobj = None async def _send_frame(self, message, opcode, compress=None): """Send a frame over the websocket with message as its payload.""" if self._closing: ws_logger.warning('websocket connection is closing.') rsv = 0 # Only compress larger packets (disabled) # Does small packet needs to be compressed? # if self.compress and opcode < 8 and len(message) > 124: if (compress or self.compress) and opcode < 8: if compress: # Do not set self._compress if compressing is for this frame compressobj = zlib.compressobj(wbits=-compress) else: # self.compress if not self._compressobj: self._compressobj = zlib.compressobj(wbits=-self.compress) compressobj = self._compressobj message = compressobj.compress(message) message = message + compressobj.flush( zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH) if message.endswith(_WS_DEFLATE_TRAILING): message = message[:-4] rsv = rsv | 0x40 msg_length = len(message) use_mask = self.use_mask if use_mask: mask_bit = 0x80 else: mask_bit = 0 if msg_length < 126: header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit) elif msg_length < (1 << 16): header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length) else: header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length) if use_mask: mask = self.randrange(0, 0xffffffff) mask = mask.to_bytes(4, 'big') message = bytearray(message) _websocket_mask(mask, message) self.transport.write(header + mask + message) self._output_size += len(header) + len(mask) + len(message) else: if len(message) > MSG_SIZE: self.transport.write(header) self.transport.write(message) else: self.transport.write(header + message) self._output_size += len(header) + len(message) if self._output_size > self._limit: self._output_size = 0 await self.protocol._drain_helper() async def pong(self, message=b''): """Send pong message.""" if isinstance(message, str): message = message.encode('utf-8') return await self._send_frame(message, WSMsgType.PONG) async def ping(self, message=b''): """Send ping message.""" if isinstance(message, str): message = message.encode('utf-8') return await self._send_frame(message, WSMsgType.PING) async def send(self, message, binary=False, compress=None): """Send a frame over the websocket with message as its payload.""" if isinstance(message, str): message = message.encode('utf-8') if binary: return await self._send_frame(message, WSMsgType.BINARY, compress) else: return await self._send_frame(message, WSMsgType.TEXT, compress) async def close(self, code=1000, message=b''): """Close the websocket, sending the specified code and message.""" if isinstance(message, str): message = message.encode('utf-8') try: return await self._send_frame( PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE) finally: self._closing = True