#!/usr/local/bin/python3 import os import socket import struct import sys from ctypes import c_byte from ctypes import c_char from ctypes import c_int from ctypes import c_long from ctypes import c_uint32 from ctypes import c_ulong from ctypes import c_ushort from ctypes import sizeof from ctypes import Structure from typing import Dict from typing import List from typing import Optional from typing import Union def roundup2(val: int, num: int) -> int: if val % num: return (val | (num - 1)) + 1 else: return val class RtSockException(OSError): pass class RtConst: RTM_VERSION = 5 ALIGN = sizeof(c_long) AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 AF_LINK = socket.AF_LINK RTA_DST = 0x1 RTA_GATEWAY = 0x2 RTA_NETMASK = 0x4 RTA_GENMASK = 0x8 RTA_IFP = 0x10 RTA_IFA = 0x20 RTA_AUTHOR = 0x40 RTA_BRD = 0x80 RTM_ADD = 1 RTM_DELETE = 2 RTM_CHANGE = 3 RTM_GET = 4 RTF_UP = 0x1 RTF_GATEWAY = 0x2 RTF_HOST = 0x4 RTF_REJECT = 0x8 RTF_DYNAMIC = 0x10 RTF_MODIFIED = 0x20 RTF_DONE = 0x40 RTF_XRESOLVE = 0x200 RTF_LLINFO = 0x400 RTF_LLDATA = 0x400 RTF_STATIC = 0x800 RTF_BLACKHOLE = 0x1000 RTF_PROTO2 = 0x4000 RTF_PROTO1 = 0x8000 RTF_PROTO3 = 0x40000 RTF_FIXEDMTU = 0x80000 RTF_PINNED = 0x100000 RTF_LOCAL = 0x200000 RTF_BROADCAST = 0x400000 RTF_MULTICAST = 0x800000 RTF_STICKY = 0x10000000 RTF_RNH_LOCKED = 0x40000000 RTF_GWFLAG_COMPAT = 0x80000000 RTV_MTU = 0x1 RTV_HOPCOUNT = 0x2 RTV_EXPIRE = 0x4 RTV_RPIPE = 0x8 RTV_SPIPE = 0x10 RTV_SSTHRESH = 0x20 RTV_RTT = 0x40 RTV_RTTVAR = 0x80 RTV_WEIGHT = 0x100 @staticmethod def get_props(prefix: str) -> List[str]: return [n for n in dir(RtConst) if n.startswith(prefix)] @staticmethod def get_name(prefix: str, value: int) -> str: props = RtConst.get_props(prefix) for prop in props: if getattr(RtConst, prop) == value: return prop return "U:{}:{}".format(prefix, value) @staticmethod def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: props = RtConst.get_props(prefix) propmap = {getattr(RtConst, prop): prop for prop in props} v = 1 ret = {} while value: if v & value: if v in propmap: ret[v] = propmap[v] else: ret[v] = hex(v) value -= v v *= 2 return ret @staticmethod def get_bitmask_str(prefix: str, value: int) -> str: bmap = RtConst.get_bitmask_map(prefix, value) return ",".join([v for k, v in bmap.items()]) class RtMetrics(Structure): _fields_ = [ ("rmx_locks", c_ulong), ("rmx_mtu", c_ulong), ("rmx_hopcount", c_ulong), ("rmx_expire", c_ulong), ("rmx_recvpipe", c_ulong), ("rmx_sendpipe", c_ulong), ("rmx_ssthresh", c_ulong), ("rmx_rtt", c_ulong), ("rmx_rttvar", c_ulong), ("rmx_pksent", c_ulong), ("rmx_weight", c_ulong), ("rmx_nhidx", c_ulong), ("rmx_filler", c_ulong * 2), ] class RtMsgHdr(Structure): _fields_ = [ ("rtm_msglen", c_ushort), ("rtm_version", c_byte), ("rtm_type", c_byte), ("rtm_index", c_ushort), ("_rtm_spare1", c_ushort), ("rtm_flags", c_int), ("rtm_addrs", c_int), ("rtm_pid", c_int), ("rtm_seq", c_int), ("rtm_errno", c_int), ("rtm_fmask", c_int), ("rtm_inits", c_ulong), ("rtm_rmx", RtMetrics), ] class SockaddrIn(Structure): _fields_ = [ ("sin_len", c_byte), ("sin_family", c_byte), ("sin_port", c_ushort), ("sin_addr", c_uint32), ("sin_zero", c_char * 8), ] class SockaddrIn6(Structure): _fields_ = [ ("sin6_len", c_byte), ("sin6_family", c_byte), ("sin6_port", c_ushort), ("sin6_flowinfo", c_uint32), ("sin6_addr", c_byte * 16), ("sin6_scope_id", c_uint32), ] class SockaddrDl(Structure): _fields_ = [ ("sdl_len", c_byte), ("sdl_family", c_byte), ("sdl_index", c_ushort), ("sdl_type", c_byte), ("sdl_nlen", c_byte), ("sdl_alen", c_byte), ("sdl_slen", c_byte), ("sdl_data", c_byte * 8), ] class SaHelper(object): @staticmethod def is_ipv6(ip: str) -> bool: return ":" in ip @staticmethod def ip_sa(ip: str, scopeid: int = 0) -> bytes: if SaHelper.is_ipv6(ip): return SaHelper.ip6_sa(ip, scopeid) else: return SaHelper.ip4_sa(ip) @staticmethod def ip4_sa(ip: str) -> bytes: addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) return bytes(sin) @staticmethod def ip6_sa(ip6: str, scopeid: int) -> bytes: addr_bytes = (c_byte * 16)() for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)): addr_bytes[i] = b sin6 = SockaddrIn6( sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid ) return bytes(sin6) @staticmethod def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes: sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) return bytes(sa) @staticmethod def pxlen4_sa(pxlen: int) -> bytes: return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen)) @staticmethod def pxlen_to_ip4(pxlen: int) -> str: if pxlen == 32: return "255.255.255.255" else: addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1) addr_bytes = struct.pack("!I", addr) return socket.inet_ntop(socket.AF_INET, addr_bytes) @staticmethod def pxlen6_sa(pxlen: int) -> bytes: return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen)) @staticmethod def pxlen_to_ip6(pxlen: int) -> str: ip6_b = [0] * 16 start = 0 while pxlen > 8: ip6_b[start] = 0xFF pxlen -= 8 start += 1 ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1) return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b)) @staticmethod def print_sa_inet(sa: bytes): if len(sa) < 8: raise RtSockException("IPv4 sa size too small: {}".format(len(sa))) addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) return "{}".format(addr) @staticmethod def print_sa_inet6(sa: bytes): if len(sa) < sizeof(SockaddrIn6): raise RtSockException("IPv6 sa size too small: {}".format(len(sa))) addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) scopeid = struct.unpack(">I", sa[24:28])[0] return "{} scopeid {}".format(addr, scopeid) @staticmethod def print_sa_link(sa: bytes, hd: Optional[bool] = True): if len(sa) < sizeof(SockaddrDl): raise RtSockException("LINK sa size too small: {}".format(len(sa))) sdl = SockaddrDl.from_buffer_copy(sa) if sdl.sdl_index: ifindex = "link#{} ".format(sdl.sdl_index) else: ifindex = "" if sdl.sdl_nlen: iface_offset = 8 if sdl.sdl_nlen + iface_offset > len(sa): raise RtSockException( "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)) ) ifname = "ifname:{} ".format( bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen]) ) else: ifname = "" return "{}{}".format(ifindex, ifname) @staticmethod def print_sa_unknown(sa: bytes): return "unknown_type:{}".format(sa[1]) @classmethod def print_sa(cls, sa: bytes, hd: Optional[bool] = False): if sa[0] != len(sa): raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) if len(sa) < 2: raise Exception( "sa type {} too short: {}".format( RtConst.get_name("AF_", sa[1]), len(sa) ) ) if sa[1] == socket.AF_INET: text = cls.print_sa_inet(sa) elif sa[1] == socket.AF_INET6: text = cls.print_sa_inet6(sa) elif sa[1] == socket.AF_LINK: text = cls.print_sa_link(sa) else: text = cls.print_sa_unknown(sa) if hd: dump = " [{!r}]".format(sa) else: dump = "" return "{}{}".format(text, dump) class BaseRtsockMessage(object): def __init__(self, rtm_type): self.rtm_type = rtm_type self.sa = SaHelper() @staticmethod def print_rtm_type(rtm_type): return RtConst.get_name("RTM_", rtm_type) @property def rtm_type_str(self): return self.print_rtm_type(self.rtm_type) class RtsockRtMessage(BaseRtsockMessage): messages = [ RtConst.RTM_ADD, RtConst.RTM_DELETE, RtConst.RTM_CHANGE, RtConst.RTM_GET, ] def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): super().__init__(rtm_type) self.rtm_flags = 0 self.rtm_seq = rtm_seq self._attrs = {} self.rtm_errno = 0 self.rtm_pid = 0 self.rtm_inits = 0 self.rtm_rmx = RtMetrics() self._orig_data = None if dst_sa: self.add_sa_attr(RtConst.RTA_DST, dst_sa) if mask_sa: self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) def add_sa_attr(self, attr_type, attr_bytes: bytes): self._attrs[attr_type] = attr_bytes def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0): if ":" in ip_addr: self.add_ip6_attr(attr_type, ip_addr, scopeid) else: self.add_ip4_attr(attr_type, ip_addr) def add_ip4_attr(self, attr_type, ip: str): self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) def add_link_attr(self, attr_type, ifindex: Optional[int] = 0): self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) def get_sa(self, attr_type) -> bytes: return self._attrs.get(attr_type) def print_message(self): # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags: if self._orig_data: rtm_len = len(self._orig_data) else: rtm_len = len(bytes(self)) print( "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( self.rtm_type_str, rtm_len, self.rtm_pid, self.rtm_seq, self.rtm_errno, RtConst.get_bitmask_str("RTF_", self.rtm_flags), ) ) rtm_addrs = sum(list(self._attrs.keys())) print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) for attr in sorted(self._attrs.keys()): sa_data = SaHelper.print_sa(self._attrs[attr]) print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) def print_in_message(self): print("vvvvvvvv IN vvvvvvvv") self.print_message() print() def verify_sa_inet(self, sa_data): if len(sa_data) < 8: raise Exception("IPv4 sa size too small: {}".format(sa_data)) if sa_data[0] > len(sa_data): raise Exception( "IPv4 sin_len too big: {} vs sa size {}: {}".format( sa_data[0], len(sa_data), sa_data ) ) sin = SockaddrIn.from_buffer_copy(sa_data) assert sin.sin_port == 0 assert sin.sin_zero == [0] * 8 def compare_sa(self, sa_type, sa_data): if len(sa_data) < 4: sa_type_name = RtConst.get_name("RTA_", sa_type) raise Exception( "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data)) ) our_sa = self._attrs[sa_type] assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa) assert len(sa_data) == len(our_sa) assert sa_data == our_sa def verify(self, rtm_type: int, rtm_sa): assert self.rtm_type_str == self.print_rtm_type(rtm_type) assert self.rtm_errno == 0 hdr = RtMsgHdr.from_buffer_copy(self._orig_data) assert hdr._rtm_spare1 == 0 for sa_type, sa_data in rtm_sa.items(): if sa_type not in self._attrs: sa_type_name = RtConst.get_name("RTA_", sa_type) raise Exception("SA type {} not present".format(sa_type_name)) self.compare_sa(sa_type, sa_data) @classmethod def from_bytes(cls, data: bytes): if len(data) < sizeof(RtMsgHdr): raise Exception( "messages size {} is less than expected {}".format( len(data), sizeof(RtMsgHdr) ) ) hdr = RtMsgHdr.from_buffer_copy(data) self = cls(hdr.rtm_type) self.rtm_flags = hdr.rtm_flags self.rtm_seq = hdr.rtm_seq self.rtm_errno = hdr.rtm_errno self.rtm_pid = hdr.rtm_pid self.rtm_inits = hdr.rtm_inits self.rtm_rmx = hdr.rtm_rmx self._orig_data = data off = sizeof(RtMsgHdr) v = 1 addrs_mask = hdr.rtm_addrs while addrs_mask: if addrs_mask & v: addrs_mask -= v if off + data[off] > len(data): raise Exception( "SA sizeof for {} > total message length: {}+{} > {}".format( RtConst.get_name("RTA_", v), off, data[off], len(data) ) ) self._attrs[v] = data[off : off + data[off]] off += roundup2(data[off], RtConst.ALIGN) v *= 2 return self def __bytes__(self): sz = sizeof(RtMsgHdr) addrs_mask = 0 for k, v in self._attrs.items(): sz += roundup2(len(v), RtConst.ALIGN) addrs_mask += k hdr = RtMsgHdr( rtm_msglen=sz, rtm_version=RtConst.RTM_VERSION, rtm_type=self.rtm_type, rtm_flags=self.rtm_flags, rtm_seq=self.rtm_seq, rtm_addrs=addrs_mask, rtm_inits=self.rtm_inits, rtm_rmx=self.rtm_rmx, ) buf = bytearray(sz) buf[0 : sizeof(RtMsgHdr)] = hdr off = sizeof(RtMsgHdr) for attr in sorted(self._attrs.keys()): v = self._attrs[attr] sa_len = len(v) buf[off : off + sa_len] = v off += roundup2(len(v), RtConst.ALIGN) return bytes(buf) class Rtsock: def __init__(self): self.socket = self._setup_rtsock() self.rtm_seq = 1 self.msgmap = self.build_msgmap() def build_msgmap(self): classes = [RtsockRtMessage] xmap = {} for cls in classes: for message in cls.messages: xmap[message] = cls return xmap def get_seq(self): ret = self.rtm_seq self.rtm_seq += 1 return ret def get_weight(self, weight) -> int: if weight: return weight else: return 1 # RT_DEFAULT_WEIGHT def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]): px = prefix.split("/") addr_sa = SaHelper.ip_sa(px[0]) if len(px) > 1: pxlen = int(px[1]) if SaHelper.is_ipv6(px[0]): mask_sa = SaHelper.pxlen6_sa(pxlen) else: mask_sa = SaHelper.pxlen4_sa(pxlen) else: mask_sa = None msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa) if isinstance(gw, bytes): msg.add_sa_attr(RtConst.RTA_GATEWAY, gw) else: # String msg.add_ip_attr(RtConst.RTA_GATEWAY, gw) return msg def new_rtm_add(self, prefix: str, gw: Union[str, bytes]): return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw) def new_rtm_del(self, prefix: str, gw: Union[str, bytes]): return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw) def new_rtm_change(self, prefix: str, gw: Union[str, bytes]): return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw) def _setup_rtsock(self) -> socket.socket: s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC) s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1) return s def print_hd(self, data: bytes): width = 16 print("==========================================") for chunk in [data[i : i + width] for i in range(0, len(data), width)]: for b in chunk: print("0x{:02X} ".format(b), end="") print() print() def write_message(self, msg): print("vvvvvvvv OUT vvvvvvvv") msg.print_message() print() msg_bytes = bytes(msg) ret = os.write(self.socket.fileno(), msg_bytes) if ret != -1: assert ret == len(msg_bytes) def parse_message(self, data: bytes): if len(data) < 4: raise OSError("Short read from rtsock: {} bytes".format(len(data))) rtm_type = data[4] if rtm_type not in self.msgmap: return None def write_data(self, data: bytes): self.socket.send(data) def read_data(self, seq: Optional[int] = None) -> bytes: while True: data = self.socket.recv(4096) if seq is None: break if len(data) > sizeof(RtMsgHdr): hdr = RtMsgHdr.from_buffer_copy(data) if hdr.rtm_seq == seq: break return data def read_message(self) -> bytes: data = self.read_data() return self.parse_message(data)