#!/usr/bin/python3

import argparse
import collections
from enum import Enum
import networkx as nx
import operator
import re
import sys


class IBGUID:
    def __init__(self, guid: str):
        guid = guid.strip().replace('0x', '').replace('0X', '')
        if not re.fullmatch('[0-9a-fA-F]{1,16}', guid):
            raise ValueError('guid should be 1-16 hexadecimal numbers')
        # Some guids have leading '0' stripped
        guid = '0' * (16 - len(guid)) + guid
        self.guid = bytearray.fromhex(guid)

    def __str__(self):
        return self.guid.hex()

    def __hash__(self):
        return hash(self.guid.hex())

    def __eq__(self, other):
        return self.guid == other.guid


class IBNode:
    class Type(Enum):
        SWITCH = 1
        CA = 2

    @classmethod
    def from_nodeid(cls, nodeid):
        m = re.fullmatch('(?P<type>[SH])-(?P<guid>[0-9a-f]{16})',
                         nodeid
                        )
        if not m:
            raise RuntimeError("Can't parse nodeid " + nodeid)
        node = IBNode(m.group('guid'))
        if m.group('type') == 'S':
            node.type = IBNode.Type.SWITCH
        elif m.group('type') == 'H':
            node.type = IBNode.Type.CA
        node.nodeid = nodeid
        return node

    def __init__(self, guid: str = None):
       if guid:
            self.guid = IBGUID(guid)
       else:
            self.guid = None

    def __hash__(self):
        if not self.guid:
            raise ValueError('Can\'t hash IBNode without guid')
        return hash(self.guid)

    def __eq__(self, other):
        return self.guid == other.guid

    def __bool__(self):
        return self.guid is not None

    def __str__(self):
        if not self.guid:
            return '<Node without guid>'
        if self.type == IBNode.Type.CA:
            return "CA " + str(self.guid)
        elif self.type == IBNode.Type.SWITCH:
            return "Switch " + str(self.guid)
        else:
            return "Unknow node " + str(self.guid)

    def __getitem__(self, key):
        return self.__dict__[key]

    def __setitem__(self, key, value):
        if key not in self.__dict__:
            self.__dict__[key] = value
            self.try_update_nodeid()
        elif self.__dict__[key] != value:
            raise ValueError('Trying to update with conflicting info: ' + str(value) + '-> ' + str(self) + '[' + key + '] = ' + str(value))

    def set_guid(self, guid):
        if self.guid:
            raise ValueError('Attempt to redefine IBNode guid')
        if isinstance(guid, IBGUID):
            self.guid = guid
        else:
            self.guid = IBGUID(guid)
        self.try_update_nodeid()

    def try_update_nodeid(self):
        if hasattr(self, 'nodeid'):
            return
        if hasattr(self, 'type') and self.guid:
            if self.type == IBNode.Type.SWITCH:
                self.nodeid = "S-" + str(self.guid)
            elif self.type == IBNode.Type.CA:
                self.nodeid = "H-" + str(self.guid)
            print("Set nodeid to", self.nodeid)

    def update(self, other_node):
        """
        Update self with new data from other_node
        If node and other_node have keys with different values,
        raise an exception
        """
        for p in other_node.__dict__:
            if p in self.__dict__:
                if self.__dict__[p] != other_node.__dict__[p]:
                    raise ValueError('Trying to update with conflicting info: ' + str(other_node.__dict__[p]) + '-> ' + str(self) + '[' + p + '] = ' + str(self.__dict__[p]))
            self.__dict__[p] = other_node.__dict__[p]

LinkDesc = collections.namedtuple('LinkDesc', ['num_links', 'linktype'])

class IBNetwork:
    def __init__(self):
        self.nodes = {}
        """
        connections[node1][node2] - LinkDesc (num_links, linktype)
        linktype - string like '4xHDR' - type of link
        num_links - doubled (*2) number of parallel links between same nodes
        num_links is doubled because every connection in ibnetdiscover file is
        processed twice, once for every side of the connection
        """
        self.connections = {}

    def add_node(self, node):
        """
        Add a node to network
        If node with the same guid exists, add new information.
        If old node and new node have the same member (e.g. vendid)
        with different values, raise exception
        """
        if not node:
            raise ValueError("Can't add node without guid")
        old_node = self.find_node_byguid(node.guid)
        if not old_node:
            self.nodes[node.guid] = node
        else:
            old_node.update(node)

    def remove_node(self, node):
        node_known = self.find_node_byguid(node.guid)
        if not node_known:
            raise RuntimeError("Can't remove node " + str(node))
        for node2 in self.connections_by_node(node):
            del self.connections[node2][node_known]
        del self.connections[node_known]

    def add_connection(self, node1, node2, linktype):
        """
        Add connection between nodes of type linktype (string)
        If previous connection with different linktype exists, raise exception
        If previous connection with the same linktype exists, increment connection count
        """
        node1_known = self.find_node_byguid(node1.guid)
        node2_known = self.find_node_byguid(node2.guid)
        if not node1_known or not node2_known:
            raise RuntimeError("Can't find node " + str(node1) + " or " + str(node2))
        if node1_known not in self.connections:
            self.connections[node1_known] = {}
        if node2_known not in self.connections:
            self.connections[node2_known] = {}
        if node2_known not in self.connections[node1_known]:
            self.connections[node1_known][node2_known] = LinkDesc(0, linktype)
        elif self.connections[node1_known][node2_known].linktype != linktype:
            raise ValueError(
                "Trying to update linktype with conflictiing info: link between {} and {} was {}, updating with {}".format(
                    str(node1), str(node2), self.connections[node1_known][node2_known], linktype
                )
            )
        if node1_known not in self.connections[node2_known]:
            self.connections[node2_known][node1_known] = LinkDesc(0, linktype)
        elif self.connections[node2_known][node1_known].linktype != linktype:
            raise ValueError(
                "Trying to update linktype with conflicting info: link between {} and {} was {}, updating with {}".format(
                    str(node1), str(node2), self.connections[node1_known][node2_known], linktype
                )
            )
        
        desc = self.connections[node1_known][node2_known]
        self.connections[node1_known][node2_known] = LinkDesc(desc.num_links + 1, desc.linktype)
        desc = self.connections[node2_known][node1_known]
        self.connections[node2_known][node1_known] = LinkDesc(desc.num_links + 1, desc.linktype)

    def check_connection(self, node1, node2):
        return node1 in self.connections and node2 in self.connections[node1]

    def connection_num_links(self, node1, node2):
        if node1 in self.connections and node2 in self.connections[node1]:
            return self.connections[node1][node2].num_links
        raise KeyError("no connection between {} and {}".format(str(node1), str(node2)))

    def connection_linktype(self, node1, node2):
        if node1 in self.connections and node2 in self.connections[node1]:
            return self.connections[node1][node2].linktype
        raise KeyError("no connection between {} and {}".format(str(node1), str(node2)))

    def connections_by_node(self, node):
        if node in self.connections:
            return list(self.connections[node])
        else:
            return []

    def find_node_byguid(self, guid: IBGUID):
        if guid in self.nodes:
            return self.nodes[guid]
        else:
            return None

    def find_node_bynodeid(self, nodeid):
        new_node = IBNode.from_nodeid(nodeid)
        old_node = self.find_node_byguid(new_node.guid)
        return old_node or None

def parse_ibnetdiscover_output(ibnetdiscover_output):
    network = IBNetwork()

    cur_node = IBNode()
    for line in ibnetdiscover_output:
        line = line.strip()
        if line == '':
            if cur_node:
                network.add_node(cur_node)
                #print("Added node " + str(cur_node))
            cur_node = IBNode()
        elif line[0] == '[':
            #  [1](b8599f03004af86d) 	"S-0c42a1030060ac30"[2]		# lid 230 lmc 0 "MF0;rvn-everest-t1:MQM8700/U1" lid 2 4xHDR
            #  [19]	"S-043f720300ea0662"[11]		# "MF0;rvn-vega-l1:MQM8700/U1" lid 6 4xHDR
            #  [41]	"H-1c34da0300728cce"[1](1c34da0300728cce) 		# "Mellanox Technologies Aggregation Node" lid 75 4xHDR
            m = re.fullmatch('\[(?P<srcport>\d+)\](\((?P<srcportguid>[0-9a-f]+)\))?\s*"(?P<dstnodeid>[^"]+)"\[(?P<dstport>\d+)\](\((?P<dstportguid>[0-9a-f]+)\))?\s*#.*"(?P<dstnodedesc>[^"]+)".*\s+(?P<linktype>\S+)',
                                line,
                                re.IGNORECASE | re.ASCII
                            )
            if not m:
                raise RuntimeError("Can't parse line " + line)
            dstnodeid = m.group('dstnodeid')
            dst_node = network.find_node_bynodeid(dstnodeid)
            if not dst_node:
                dst_node = IBNode.from_nodeid(dstnodeid)
                network.add_node(dst_node)
            linktype = m.group('linktype')
            network.add_connection(cur_node, dst_node, linktype)

        elif line.startswith("Switch"):
            #  Switch	41 "S-1c34da0300728cc6"		# "MF0;rvn-vega-t3:MQM8700/U1" enhanced port 0 lid 74 lmc 0
            m = re.fullmatch('Switch\s+(?P<numports>\d+)\s+"(?P<nodeid>[^"]+)"\s*#\s*"(?P<nodedesc>[^"]+)"\s*(?P<comment>.*)',
                                line,
                                re.IGNORECASE | re.ASCII
                            )
            if not m:
                raise RuntimeError("Can't parse line " + line)
            cur_node.num_ports = int(m.group('numports'))
            cur_node.nodeid = m.group('nodeid')
            cur_node.nodedesc = m.group('nodedesc')
            cur_node.comment = m.group('comment')
            cur_node.num_special_nodes = 0
            network.add_node(cur_node)
        elif line.startswith("Ca"):
            #  Ca	1 "H-98039b0300860c37"		# "clx-everest-003 HCA-3"
            m = re.fullmatch('Ca\s+(?P<numports>\d+)\s+"(?P<nodeid>[^"]+)"\s*#\s*"(?P<nodedesc>[^"]+)"',
                                line,
                                re.IGNORECASE | re.ASCII
                            )
            if not m:
                raise RuntimeError("Can't parse line " + line)
            cur_node.num_ports = int(m.group('numports'))
            cur_node.nodeid = m.group('nodeid')
            cur_node.nodedesc = m.group('nodedesc')
            network.add_node(cur_node)
        elif re.fullmatch('(?P<propname>[a-z]+)=(?P<propval>0x[0-9a-f]+)(\((?P<portguid>[0-9a-f]+)\))?',
                            line,
                            re.IGNORECASE | re.ASCII
                            ):
            #  vendid=0x2c9
            #  devid=0x101b
            #  sysimgguid=0xb8599f03004af86c
            #  caguid=0xb8599f03004af86d
            #  switchguid=0xb8599f0300557370(b8599f0300557370)
            m = re.fullmatch('(?P<propname>[a-z]+)=(?P<propval>0x[0-9a-f]+)(\((?P<portguid>[0-9a-f]+)\))?',
                                line,
                                re.IGNORECASE | re.ASCII
                            )
            propname = m.group('propname')
            propval = m.group('propval')
            if propname == 'vendid' or propname == 'devid':
                cur_node[propname] = propval
            elif propname == 'sysimgguid':
                cur_node.sysimgguid = IBGUID(propval)
            elif propname == 'caguid':
                cur_node.set_guid(propval)
                cur_node.type = IBNode.Type.CA
            elif propname == 'switchguid':
                cur_node.set_guid(propval)
                cur_node.type = IBNode.Type.SWITCH
                # Ignore portguid for now
                # If we have connections to the same node to different ports
                # a conflict can occur
                #cur_node.portguid = IBGUID(m.group('portguid'))
            else:
                raise RuntimeError('Unknow property ' + propname + '=' + propval)

    print("Parsing ibnetdiscover output file finished.", file=sys.stderr)
    return network

is_node = lambda n: n.type == IBNode.Type.CA
is_switch = lambda n: n.type == IBNode.Type.SWITCH

# Node names that should be cleared from the final topology
prohibited_nodenames_re = [
        '^localhost',
        '^MT.*ConnectX.*Mellanox Technologies$'
    ]

# Special nodes that should not be counted towards
# port count for oversubscribe ratio
special_nodenames_re = [
        '^Mellanox Technologies Aggregation Node$',
        '^MF\d+;.*/AN\d+$'
    ]

def fix_node_names(network):
    # Fix node names
    for n in filter(is_node, network.nodes.values()):
        nodedesc = n.nodedesc
        for r in prohibited_nodenames_re:
            if re.match(r, nodedesc):
                network.remove_node(n)
                nodedesc = ''
                break
        for r in special_nodenames_re:
            if re.match(r, nodedesc):
                for s in network.connections_by_node(n):
                    s.num_special_nodes += 1
                network.remove_node(n)
                nodedesc = ''
                break
        if not nodedesc:
            continue
        # Leave only a part up to first space
        nodedesc = nodedesc.split()[0]
        n.nodedesc = nodedesc

def extract_tors(network):
    MIN_NODES_IN_TOR = 4
    tors = []
    tor_nodes = {}

    for n in filter(is_switch, network.nodes.values()):
        connected_nodes = list(filter(is_node, network.connections_by_node(n)))
        n.num_connected_nodes = len(connected_nodes)
        if n.num_connected_nodes >= MIN_NODES_IN_TOR:
            tors.append(n)
            tor_nodes[n] = connected_nodes
    return tors, tor_nodes

def write_scope_file(tors, tor_nodes, full_scope_file):
    for tor_num, tor in enumerate(tors):
        for node in tor_nodes[tor]:
            full_scope_file.write("{},{},{}\n".format(node.nodedesc, tor.nodeid, tor_num))
    if full_scope_file != sys.stdout:
        full_scope_file.close()

def select_switches(network):
    # Graph of switches only
    g = nx.Graph()

    for s in filter(is_switch, network.nodes.values()):
        g.add_node(s)
        for d in filter(is_switch, network.connections_by_node(s)):
            g.add_edge(s, d)

    return g

def calculate_distances(network_switches_only, tors):
    lengths = {}
    for s in tors:
        lengths[s] = nx.single_source_shortest_path_length(network_switches_only, s)
    return lengths

def calculate_oversubscribe(tors, network):
    oversubscribe_info = {}
    for t in tors:
        num_uplinks = sum(network.connection_num_links(t, n) for n in filter(is_switch, network.connections_by_node(t))) / 2
        num_downlinks = t.num_ports - t.num_special_nodes - num_uplinks
        if num_uplinks != 0:
            oversubscribe_ratio = num_downlinks / num_uplinks
        else:
            oversubscribe_ratio = 1.0
        if oversubscribe_ratio < 1:
            oversubscribe_ratio = 1.0
        oversubscribe_info[t] = oversubscribe_ratio

    return oversubscribe_info

def calculate_oversubscribe_downlink_nodes_only(tors, network):
    oversubscribe_info = {}
    for t in tors:
        num_uplinks = sum(network.connection_num_links(t, n) for n in filter(is_switch, network.connections_by_node(t))) / 2
        num_downlinks = t.num_connected_nodes
        if num_uplinks != 0:
            oversubscribe_ratio = num_downlinks / num_uplinks
        else:
            oversubscribe_ratio = 1.0
        if oversubscribe_ratio < 1:
            oversubscribe_ratio = 1.0
        oversubscribe_info[t] = oversubscribe_ratio

    return oversubscribe_info

def write_distances(tors, lengths, distances_file):
    for s, tor_s in enumerate(tors):
        for tor_d in tors[(s+1):]:
            print("{} -- {} == {}".format(tor_s.nodeid, tor_d.nodeid, lengths[tor_s][tor_d]), file=distances_file)
    if distances_file != sys.stdout:
        distances_file.close()

def write_oversubscribe(tors, oversubscribe_info, topo_file):
    for n, tor in enumerate(tors):
        print("{},{},{}".format(tor.nodeid, n, oversubscribe_info[tor]), file=topo_file)

    if topo_file != sys.stdout:
        topo_file.close()

def write_order_file(tors, lengths, all_pairs, order_file):
    pairs_by_dist = collections.defaultdict(list)
    all_scopes = set()

    scope2num = {}
    num2scope = []
    cur_scope_num = 0

    for s in range(len(tors)):
        for d in range(s + 1, len(tors)):
            scope1 = tors[s].nodeid
            scope2 = tors[d].nodeid
            dist = lengths[tors[s]][tors[d]]
            if scope1 not in scope2num:
                scope2num[scope1] = cur_scope_num
                num2scope.append(scope1)
                cur_scope_num += 1
            if scope2 not in scope2num:
                scope2num[scope2] = cur_scope_num
                num2scope.append(scope2)
                cur_scope_num += 1
            pairs_by_dist[dist].append((scope2num[scope1], scope2num[scope2]))

    num_run = 0
    num_scopes = len(scope2num)
    for d in pairs_by_dist:
        print("Distance {} starting from run {}:".format(d, num_run + 1), file=sys.stderr)
        scope_pairs = pairs_by_dist[d].copy()
        while (scope_pairs):
            seen_scopes = set()
            scope_pairs_indices_to_remove = []
            num_run += 1
            for i in range(len(scope_pairs)):
                scope1, scope2 = scope_pairs[i]
                if scope1 in seen_scopes or scope2 in seen_scopes:
                    continue

                seen_scopes.add(scope1)
                seen_scopes.add(scope2)
                print('{},{},{}'.format(num_run,num2scope[scope1],num2scope[scope2]), file=order_file)

                scope_pairs_indices_to_remove.append(i)

            if not all_pairs:
                # All possible pairs are not requested
                break

            scope_pairs_indices_to_remove.reverse()
            for i in scope_pairs_indices_to_remove:
                del scope_pairs[i]

def run(args):
    network = parse_ibnetdiscover_output(args.ibnetdiscover_output)
    fix_node_names(network)
    tors, tor_nodes = extract_tors(network)
    write_scope_file(tors, tor_nodes, args.full_scope)

    print("TORs: {} found".format(str(len(tors))), file=sys.stderr)

    if not args.distances_file and not args.order and not args.topo_file:
        print("Finished, no topology information calculated.", file=sys.stderr)
        sys.exit(0)

    # Make a warning for large networks
    if len(tors) > 500:
        print("Calculating topology information, it may take significant time (up to tens of minutes for large networks).", file=sys.stderr)

    network_switches_only = select_switches(network)
    lengths = calculate_distances(network_switches_only, tors)
    if not args.downlink_nodes_only:
        oversubscribe_info = calculate_oversubscribe(tors, network)
    else:
        oversubscribe_info = calculate_oversubscribe_downlink_nodes_only(tors, network)


    if not args.no_sort:
        tors.sort(reverse=True, key=operator.attrgetter('num_connected_nodes'))

    if args.distances_file:
        write_distances(tors, lengths, args.distances_file)

    if args.topo_file:
        write_oversubscribe(tors, oversubscribe_info, args.topo_file)

    print("Calculating distances between TORs finished.", file=sys.stderr)

    if not args.order:
        print("Finished.", file=sys.stderr)
        sys.exit(0)

    print("Writing order file, it may take significant time.", file=sys.stderr)

    write_order_file(tors, lengths, args.all_pairs, args.order)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('ibnetdiscover_output',
                        type=argparse.FileType('r'),
                        metavar='<ibnetdiscover output>',
                    )
    parser.add_argument('-f', '--full-scope',
                        type=argparse.FileType('w'),
                        required=False,
                        help='Output scopefile. If not given, use stdout.',
                        default=sys.stdout,
                    )
    parser.add_argument('-d', '--distances-file',
                        type=argparse.FileType('w'),
                        required=False,
                        help='File to write distances between all TOR pairs.',
                    )
    parser.add_argument('-t', '--topo-file',
                        type=argparse.FileType('w'),
                        required=False,
                        help='File to write topology information.',
                    )
    parser.add_argument(      '--downlink-nodes-only',
                        help='Only ports connected to nodes are counted as downlinks ' +
                             'when calculating oversubscribe ratio ' +
                             '(default is to count all ports except connected to switches and special nodes).',
                        action='store_true',
                    )
    parser.add_argument(      '--no-sort',
                        help='Do not sort TORs in distances and scope_order files by number of connected nodes',
                        action='store_true',
                    )
    parser.add_argument('-o', '--order',
                        type=argparse.FileType('w'),
                        required=False,
                        help='Output scope_order file.',
                    )
    parser.add_argument('-a', '--all-pairs',
                        help='Use all possible pairs of TORs. Without it, only one set of pairs will be generated.',
                        action='store_true',
                    )
    args = parser.parse_args()

    print("Started.", file=sys.stderr)
    run(args)
    print("Finished.", file=sys.stderr)
