"""General utility functions. """ import copy import os import re import socket import threading import uuid from stomp.constants import * # List of all host names (unqualified, fully-qualified, and IP # addresses) that refer to the local host (both loopback interface # and external interfaces). This is used for determining # preferred targets. LOCALHOST_NAMES = ["localhost", "127.0.0.1"] NULL = b"\x00" if not os.environ.get("STOMP_SKIP_HOSTNAME_SCAN"): try: LOCALHOST_NAMES.append(socket.gethostbyname(socket.gethostname())) except Exception: pass try: LOCALHOST_NAMES.append(socket.gethostname()) except Exception: pass try: LOCALHOST_NAMES.append(socket.getfqdn(socket.gethostname())) except Exception: pass def decode(byte_data, encoding="utf-8"): """ Decode the byte data to a string if not None. :param bytes byte_data: the data to decode :param string encoding: character encoding :rtype: str """ if byte_data is None: return None return byte_data.decode(encoding, errors="replace") def encode(char_data, encoding="utf-8"): """ Encode the parameter as a byte string. :param char_data: the data to encode :param string encoding: character encoding :rtype: bytes """ if type(char_data) is str: return char_data.encode(encoding, errors="replace") elif type(char_data) is bytes: return char_data else: raise TypeError("message should be a string or bytes, found %s" % type(char_data)) def pack(pieces=()): """ Join a sequence of strings together. :param list pieces: list of strings :rtype: bytes """ return b"".join(pieces) def join(chars=()): """ Join a sequence of characters into a string. :param bytes chars: list of chars :rtype: str """ return b"".join(chars).decode() def is_eol_default(c): return c == b"\x0a" ## # Used to parse STOMP header lines in the format "key:value", # HEADER_LINE_RE = re.compile("(?P[^:]+)[:](?P.*)") ## # As of STOMP 1.2, lines can end with either line feed, or carriage return plus line feed. # PREAMBLE_END_RE = re.compile(b"\n\n|\r\n\r\n") ## # As of STOMP 1.2, lines can end with either line feed, or carriage return plus line feed. # LINE_END_RE = re.compile("\n|\r\n") ## # Used to replace the "passcode" to be dumped in the transport log (at debug level) # PASSCODE_RE = re.compile(r"passcode:\s+[^\' ]+") ENC_NEWLINE = encode("\n") ENC_NULL = encode(NULL) def default_create_thread(callback): """ Default thread creation - used to create threads when the client doesn't want to provide their own thread creation. :param function callback: the callback function provided to threading.Thread """ thread = threading.Thread(None, callback) thread.daemon = True # Don't let thread prevent termination thread.start() return thread def is_localhost(host_and_port): """ Return 1 if the specified host+port is a member of the 'localhost' list of hosts, 2 if not (predominately used as a sort key. :param (str,int) host_and_port: tuple containing host and port :rtype: int """ (host, _) = host_and_port if host in LOCALHOST_NAMES: return 1 else: return 2 _HEADER_ESCAPES = { "\r": "\\r", "\n": "\\n", ":": "\\c", "\\": "\\\\", } _HEADER_UNESCAPES = dict((value, key) for (key, value) in _HEADER_ESCAPES.items()) def _unescape_header(matchobj): escaped = matchobj.group(0) unescaped = _HEADER_UNESCAPES.get(escaped) if not unescaped: # TODO: unknown escapes MUST be treated as fatal protocol error per spec unescaped = escaped return unescaped def parse_headers(lines, offset=0): """ Parse the headers in a STOMP response :param list(str) lines: the lines received in the message response :param int offset: the starting line number :rtype: dict(str,str) """ headers = {} for header_line in lines[offset:]: header_match = HEADER_LINE_RE.match(header_line) if header_match: key = header_match.group("key") key = re.sub(r"\\.", _unescape_header, key) if key not in headers: value = header_match.group("value") value = re.sub(r"\\.", _unescape_header, value) headers[key] = value return headers def parse_frame(frame): """ Parse a STOMP frame into a Frame object. :param bytes frame: the frame received from the server (as a byte string) :rtype: Frame """ mat = PREAMBLE_END_RE.search(frame) if mat: preamble_end = mat.start() body_start = mat.end() else: preamble_end = len(frame) body_start = preamble_end preamble = decode(frame[0:preamble_end]) preamble_lines = LINE_END_RE.split(preamble) preamble_len = len(preamble_lines) body = frame[body_start:] # Skip any leading newlines first_line = 0 while first_line < preamble_len and len(preamble_lines[first_line]) == 0: first_line += 1 if first_line >= preamble_len: return None # Extract frame type/command cmd = preamble_lines[first_line] # Put headers into a key/value map headers = parse_headers(preamble_lines, first_line + 1) return Frame(cmd, headers, body) def merge_headers(header_map_list): """ Helper function for combining multiple header maps into one. :param list(dict) header_map_list: list of maps :rtype: dict """ headers = {} for header_map in header_map_list: if header_map: headers.update(header_map) return headers def clean_headers(headers): rtn = headers if "passcode" in headers: rtn = copy.copy(headers) rtn["passcode"] = "********" return rtn # lines: lines returned from a call to convert_frames def clean_lines(lines): return re.sub(PASSCODE_RE, "passcode:********", str(lines)) def calculate_heartbeats(shb, chb): """ Given a heartbeat string from the server, and a heartbeat tuple from the client, calculate what the actual heartbeat settings should be. :param (str,str) shb: server heartbeat numbers :param (int,int) chb: client heartbeat numbers :rtype: (int,int) """ (sx, sy) = shb (cx, cy) = chb x = 0 y = 0 if cx != 0 and sy != "0": x = max(cx, int(sy)) if cy != 0 and sx != "0": y = max(cy, int(sx)) return x, y def convert_frame(frame): """ Convert a frame to a list of lines separated by newlines. :param Frame frame: the Frame object to convert :rtype: list(str) """ lines = [] body = None if frame.body: body = encode(frame.body) if HDR_CONTENT_LENGTH in frame.headers: frame.headers[HDR_CONTENT_LENGTH] = len(body) if frame.cmd: lines.append(encode(frame.cmd)) lines.append(ENC_NEWLINE) for key, vals in sorted(frame.headers.items()): if vals is None: continue if type(vals) != tuple: vals = (vals,) for val in vals: lines.append(encode("%s:%s\n" % (key, val))) lines.append(ENC_NEWLINE) if body: lines.append(body) if frame.cmd: lines.append(ENC_NULL) return lines def length(s): """ Null (none) safe length function. :param str s: the string to return length of (None allowed) :rtype: int """ if s is not None: return len(s) return 0 class Frame(object): """ A STOMP frame (or message). :param str cmd: the protocol command :param dict headers: a map of headers for the frame :param body: the content of the frame. """ def __init__(self, cmd, headers=None, body=None): self.cmd = cmd self.headers = headers if headers is not None else {} self.original_headers = copy.copy(self.headers) self.body = body def __str__(self): return "{cmd=%s,headers=[%s],body=%s}" % (self.cmd, self.headers, self.body) def get_uuid(): return str(uuid.uuid4()) def get_errno(e): """ Return the errno of an exception, or the first argument if errno is not available. :param Exception e: the exception object """ try: return e.errno except AttributeError: return e.args[0] HEARTBEAT_FRAME = Frame("heartbeat")