import os
import sys
import simplejson
import json
import fnmatch
import argparse
import configparser
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import re
from fpdf import FPDF

#using this two lines for import BWAnalyzer
p = os.path.abspath(os.path.join(__file__ ,"../..")) 
sys.path.insert(1, p)
from bwResultAnalyzer import BWAnalyzer


PLINES = "--------------------------------------------"
BORDER_COLOR = [255, 255, 255]  # white



class Logger(object):
    ERROR = 10
    WARNING = 20
    INFO = 30
    DEBUG = 40

    LEVEL_HEADER = {
        ERROR: '-E-',
        WARNING: '-W-',
        INFO: '-I-',
        DEBUG: '-D-',
    }

    def __init__(self, level=INFO):
        self._level = level

    def _log(self, level, msg, *args):
        if level > self._level:
            return
        if args:
            msg = msg % args
        prefix = self.LEVEL_HEADER[level]
        print(prefix, msg)

    def set_level(self, level):
        if level in self.LEVEL_HEADER:
            self._level = level
        else:
            self.error("Invalid logging level: %s", level)

    def info(self, msg, *args):
        self._log(self.INFO, msg, *args)

    def debug(self, msg, *args):
        self._log(self.DEBUG, msg, *args)

    def warning(self, msg, *args):
        self._log(self.WARNING, msg, *args)

    def error(self, msg, *args):
        self._log(self.ERROR, msg, *args)


class CSVConfigFileParser(object):
    def __init__(self, file_name, logger):
        self._file_name = file_name
        self._logger = logger
        self._parsed = False
        self._latencies = dict()
        self._fibers = set()

    @staticmethod
    def _clean_line(line):
        for c in ' \n\t':
            line = line.replace(c, '')
        return line

    def _add_latency(self, tokens, line_n):
        try:
            dev_id, latency = tokens[0], tokens[1]
            latency = (float(latency) * 1e-9) / 2.0
        except Exception as e:
            self._logger.error("line %d: Cannot extract device id and latency, ignoring line", line_n)
            return

        exist_latency = self._latencies.get(dev_id, False)
        if exist_latency and exist_latency != latency:
            str_exist = int(exist_latency*2 / 1e-9)
            str_latency = int(latency*2 / 1e-9)
            self._logger.warning("Ambiguous latency for device id %s. %d and %d. taking %d", dev_id,
                                 str_exist, str_latency, str_latency)
        self._latencies[dev_id] = latency

    def _add_fiber(self, fiber):
        self._fibers.add(fiber)

    @staticmethod
    def _to_tokens(line):
        tokens = line.split(',')
        tokens = list(filter(''.__ne__, tokens))  # remove empty strings
        return tokens

    def _add_line(self, line, line_n):
        tokens = self._to_tokens(line)
        n = len(tokens)
        if n == 2:
            self._add_latency(tokens, line_n)
        elif n == 1:
            self._add_fiber(tokens[0])
        elif n > 2:
            self._logger.error("line %d: Cannot parse, ignoring line", line_n)

    def _parse(self):
        with open(self._file_name, 'r') as fp:
            line_n = 1
            for line in fp:
                parts = line.partition('#')
                line = parts[0]
                line = self._clean_line(line)
                self._add_line(line, line_n)
                line_n += 1
        self._parsed = True

    def get(self):
        if not self._parsed:
            self._parse()
        return self._latencies, self._fibers


class LatencyException(Exception):
    pass


class GUIDLookup(object):
    lookup = {}
    next = 0

    #
    # This reduces the SIZE of the GUID.  The GUID normally only has [0-9a-f] and perhaps the 'x' character.
    # i.e  0x1234567890abcdef....   We compress this to something that is MUCH shorter, since we use base64 instead of base16.
    #
    @staticmethod
    def getGUID(v):
        enable_guid_compression = True
        if enable_guid_compression is False:
            return v

        if v not in GUIDLookup.lookup:
            GUIDLookup.next = GUIDLookup.next + 1
            # val = base64.b64encode(struct.pack('<Q', GUIDLookup.next).rstrip('\x00') or '\x00').rstrip('=')
            # GUIDLookup.lookup[v] = val
            GUIDLookup.lookup[v] = str(GUIDLookup.next)
            # print ("Converting ", v, " to ", GUIDLookup.lookup[v])

        # print ("Returning ", GUIDLookup.lookup[v], "for guid: ", v)
        return GUIDLookup.lookup[v]


class Node(object):
    NO_LAT = 0.0
    BASE_LAT = 90.0 * 1e-9 / 2.0
    STANDARD_FEC_LAT = 221.2 * 1e-9 / 2.0
    LOW_LATENCY_LAT = 208.8 * 1e-9 / 2.0
    NO_FEC_LAT = 192.1 * 1e-9 / 2.0
    QUANTUM_LAT = 130 * 1e-9 / 2.0
    CONNECTX5_LAT = 600 * 1e-9 / 2.0
    CONNECTX6_LAT = 800 * 1e-9 / 2.0

    # TBD: Get the ID's and latency of the more recent switches.  The list here is several years old
    info = {}
    info["0xc738"] = ("SwitchX2",  BASE_LAT)
    info["0xcb20"] = ("SwitchIB",  BASE_LAT)
    info["0xcb84"] = ("Spectrum",  BASE_LAT)
    info["0xcf08"] = ("SwitchIB2", STANDARD_FEC_LAT)
    info["0xd2f0"] = ("Quantum",   QUANTUM_LAT)
    info["0x1017"] = ("ConnectX5", CONNECTX5_LAT)
    info["0x101b"] = ("ConnectX6", CONNECTX6_LAT)

    def __init__(self, node_desc, node_type, node_guid, port_guid, device_id, local_port, logger):
        self.node_desc = node_desc
        self.node_type = node_type
        self.node_guid = node_guid
        self.port_guid = port_guid
        self.local_port = local_port
        self.device_id = device_id
        self._logger = logger
        self.lid = 0
        self.latency = None

    def get_name(self):
        if self.device_id in Node.info:
            return Node.info[self.device_id][0]

    def get_delay(self):
        if self.latency is None:
            if self.device_id in Node.info:
                return Node.info[self.device_id][1]
            self._logger.info("No information for unknown device: '%s'  Assuming %.4f ns latency" %
                        self.device_id, Node.NO_LAT * 1e9)
            return Node.NO_LAT
        else:
            return self.latency


class Vertex(object):
    cache = {}

    @staticmethod
    def getVertex(port_guid, port_num):
        MULT_FACTOR = 1000
        # k = port_guid + "." + str(port_num)
        k = int(port_guid) * MULT_FACTOR + int(port_num)
        if k not in Vertex.cache:
            Vertex.cache[k] = Vertex(port_guid, port_num)
        return Vertex.cache[k]

    def __init__(self, port_guid, port_num):
        self.port_guid = port_guid
        self.port_num = port_num

    def __hash__(self):
        return hash((self.port_guid, self.port_num))

    def __eq__(self, other):
        if not isinstance(other, Vertex):
            return False
        return self.port_guid == other.port_guid and self.port_num == other.port_num

    def __repr__(self):
        MULT_FACTOR_FOR_PORT = 1000  # There are ~128 ports in reality, but this makes it easier to see guid_num&port_num
        return "%10d" % (int(self.port_guid) * MULT_FACTOR_FOR_PORT + int(self.port_num))
        # return "%s:%s" % (self.port_guid, self.port_num)


class Link(object):
    all_links = {}
    count = 0
    link_cache = {}

    @staticmethod
    def getLink(src, dst, cable):

        key = str(src) + "." + str(dst)
        if key not in Link.link_cache:
            """
            # try alternate direction
            k2 = str(dst) + "." + str(src)
            if k2 in Link.link_cache:
                return Link.link_cache[k2]

                # save based on the forward-route
                # an alternative is to save both normalized.json-b and b-normalized.json, i.e.
                # both forward and return route, but this doubles the # of entries we need to store in memory
                # After implementation and performance testing we can decide what is the optimal course of action
            """
            Link.link_cache[key] = Link(src, dst, cable)
        return Link.link_cache[key]

    def __init__(self, src, dst, cable):
        self.src = src
        self.dst = dst
        self.cable = cable
        self.link_id = Link.count
        Link.count += 1

    def __repr__(self):
        return "%s --> %s (%s m, PN: %s)" % (
            repr(self.src), repr(self.dst), self.cable.length, self.cable.pn)

    def get_guid_and_port(self):
        return "%s --> %s" % (repr(self.src), repr(self.dst))


class CableInfo(object):
    # PN of optic fibers with length <= 5m
    SHORT_FIBERS = {
        "MC220731V-003", "MC2206310-003", "MC2210310-003", "MFA1A00-C003", "MFA1A00-E003", "MFA1A00-E003-TG",
        "MFA2P10-A003", "MFA7A50-C003", "MFS1S00-H003-LL", "MFS1S00-H003E", "MFS1S00-V003E", "MFS1S50-H003E",
        "MFS1S50-H003E-LL", "MFS1S50-V003E", "MFS1S90-H003E", "MFA2P10-A005", "MC2206310-005", "MC2210310-005",
        "MFA1A00-E005", "MFA1A00-E005-TG", "MFA7A20-C005", "MFA7A50-C005", "MFS1S00-H005-LL", "MFS1S00-H005E",
        "MFS1S00-V005E", "MFS1S50-H005E", "MFS1S50-H005E-LL", "MFS1S50-V005E", "MFS1S90-H005E"}
    HERCULES_1_REGEX = re.compile(r'MFS1S00-H\S\S\SE')
    HERCULES_2_REGEX = re.compile(r'MFS1S00-H0\S\S-LL')

    def __init__(self, pn='', length=0):
        self.pn = pn
        self.length = length
        self.fiber = None

    def is_fiber(self):
        res = False
        if self.fiber:
            res = True
        else:
            # most optic fibers has length > 5
            if self.length > 5:
                res = True
            else:
                res = self.pn in self.SHORT_FIBERS
        return res

    def pn_regex_add_delay(self, regex, delay):
        res = 0
        match = regex.match(self.pn)
        if match:
            res = delay
        return res

    def calc_delay(self):
        C = 299792458.0  # m/s  Speed of light
        v = 230000000.0  # assume copper
        if self.is_fiber():
            v = C / 1.60  # v = C/n     C is speed of light, n is index of refraction of material

        t = float(self.length) / v
        t += self.pn_regex_add_delay(self.HERCULES_1_REGEX, 64 * 1e-9)
        t += self.pn_regex_add_delay(self.HERCULES_2_REGEX, 20 * 1e-9)
        return t

    def to_tuple(self):
        return self.length, self.pn, self.calc_delay()


class BaseParser(object):
    FILE_NAME = ""

    def __init__(self, dir_name, logger):
        self._file_name = os.path.join(dir_name, self.FILE_NAME)
        self._logger = logger

    def _parse_line(self, line):
        pass

    def parse(self):
        self._logger.info("Parsing file: %s" % self.FILE_NAME)
        try:
            with open(self._file_name, 'r') as fp:
                self._parse(fp)
        except IOError:
            self._logger.error("Failed to read file: %s" % self._file_name)

    def _parse(self, fp):
        for line in fp:
            line = line.strip()
            if not line:
                continue
            res = self._parse_line(line)
            if not res:
                break


class DbParser(BaseParser):
    FILE_NAME = "ibdiagnet2.db_csv"
    _STAGE_NONE = 0
    _STAGE_NODES_HDR = 10
    _STAGE_NODES = 11
    _STAGE_PORTS_HDR = 20
    _STAGE_PORTS = 21
    _STAGE_LINKS_HDR = 30
    _STAGE_LINKS = 31
    _STAGE_DONE = 100

    _node_desc_index = 0
    _node_guid_index = 0
    _node_port_guid_index = 0
    _node_local_port_index = 0
    _node_type_index = 0
    _node_device_id_index = 0

    _port_node_guid_index = 0
    _port_port_guid_index = 0
    _port_port_num_index = 0
    _port_lid_index = 0

    _link_node1_index = 0
    _link_port1_index = 0
    _link_node2_index = 0
    _link_port2_index = 0

    def __init__(self, dir_name, logger):
        super(DbParser, self).__init__(dir_name, logger)
        self._stage = self._STAGE_NONE
        self._topology = dict()
        self._links = dict()

        self._methods = {
            self._STAGE_NONE: self._parse_general_line,
            self._STAGE_NODES_HDR: self._parse_node_header,
            self._STAGE_NODES: self._parse_node,
            self._STAGE_PORTS_HDR: self._parse_port_header,
            self._STAGE_PORTS: self._parse_port,
            self._STAGE_LINKS_HDR: self._parse_link_header,
            self._STAGE_LINKS: self._parse_link,
        }

    @property
    def topology(self):
        return self._topology

    @property
    def links(self):
        return self._links

    def _parse_line(self, line):
        parse_method = self._methods[self._stage]
        parse_method(line)
        if self._stage == self._STAGE_DONE:
            return False
        return True

    def _parse_general_line(self, line):
        if line == "START_NODES":
            self._stage = self._STAGE_NODES_HDR
        elif line == "START_PORTS":
            self._stage = self._STAGE_PORTS_HDR
        elif line == "START_LINKS":
            self._stage = self._STAGE_LINKS_HDR

    def _parse_node_header(self, line):
        node_attrs = line.split(',')
        nodes_header = dict()
        for index, attr in enumerate(node_attrs):
            nodes_header[attr] = index
        self._node_desc_index = nodes_header['NodeDesc']
        self._node_guid_index = nodes_header['NodeGUID']
        self._node_device_id_index = nodes_header['DeviceID']
        self._node_port_guid_index = nodes_header['PortGUID']
        self._node_local_port_index = nodes_header['LocalPortNum']
        self._node_type_index = nodes_header['NodeType']
        self._stage += 1

    def _parse_port_header(self, line):
        ports_header = dict()
        port_attrs = line.split(',')
        for index, attr in enumerate(port_attrs):
            ports_header[attr] = index
        self._port_node_guid_index = ports_header['NodeGuid']
        self._port_port_guid_index = ports_header['PortGuid']
        self._port_port_num_index = ports_header['PortNum']
        self._port_lid_index = ports_header['LID']
        self._stage += 1

    def _parse_link_header(self, line):
        links_header = dict()
        link_attrs = line.split(',')
        for index, attr in enumerate(link_attrs):
            links_header[attr] = index
        self._link_node1_index = links_header['NodeGuid1']
        self._link_port1_index = links_header['PortNum1']
        self._link_node2_index = links_header['NodeGuid2']
        self._link_port2_index = links_header['PortNum2']
        self._stage += 1

    def _parse_node(self, line):
        if line == "END_NODES":
            self._stage = self._STAGE_NONE
            return

        node_attrs = line.split(',')
        node_desc = node_attrs[self._node_desc_index]
        node_guid = GUIDLookup.getGUID(node_attrs[self._node_guid_index])
        port_guid = GUIDLookup.getGUID(node_attrs[self._node_port_guid_index])
        local_port = int(node_attrs[self._node_local_port_index])
        node_type = int(node_attrs[self._node_type_index])
        device_id = int(node_attrs[self._node_device_id_index])
        device_id = "0x%04x" % device_id
        node = Node(node_desc, node_type, node_guid, port_guid, device_id,
                    local_port, self._logger)
        if port_guid in self._topology:
            self._logger.warning("duplicate port-guid: %s", port_guid)
        self._topology[port_guid] = node

    def _parse_port(self, line):
        if line == "END_PORTS":
            self._stage = self._STAGE_NONE
            return
        port_attrs = line.split(',')
        port_guid = GUIDLookup.getGUID(port_attrs[self._port_port_guid_index])
        port_lid = int(port_attrs[self._port_lid_index])
        node = self._topology[port_guid]
        node.lid = port_lid

    def _parse_link(self, line):
        if line == "END_LINKS":
            self._stage = self._STAGE_DONE
            return
        link_attrs = line.split(',')
        port_guid1 = GUIDLookup.getGUID(link_attrs[self._link_node1_index])
        port_num1 = int(link_attrs[self._link_port1_index])
        port_guid2 = GUIDLookup.getGUID(link_attrs[self._link_node2_index])
        port_num2 = int(link_attrs[self._link_port2_index])
        src = Vertex.getVertex(port_guid1, port_num1)
        dst = Vertex.getVertex(port_guid2, port_num2)
        self._links[src] = dst
        self._links[dst] = src


class FdbsParser(BaseParser):
    SWITCH_REGEX = re.compile(r'\w+\:\s+Switch\s+([0-9a-fxA-F]+)')
    HOP_REGEX = re.compile(r'(0x[0-9a-f]{4})\s+\:\s+([0-9]+)\s+.*')
    FILE_NAME = "ibdiagnet2.fdbs"

    def __init__(self, dir_name, logger):
        super(FdbsParser, self).__init__(dir_name, logger)
        self._hops = dict()
        self._curr_switch = None

    def _parse_line(self, line):
        if line.startswith('0x'):
            if self._curr_switch is None:
                return True
            match = self.HOP_REGEX.match(line)
            if match:
                lid = int(match.group(1), 16)
                port_num = int(match.group(2))
                self._curr_switch[lid] = port_num
        else:
            match = self.SWITCH_REGEX.match(line)
            if match:
                switch_name = GUIDLookup.getGUID(match.group(1))

                self._curr_switch = self._hops[switch_name] = dict()
        return True

    @property
    def hops(self):
        return self._hops


class CableParser(BaseParser):
    FILE_NAME = "ibdiagnet2.cables"
    SWITCH_REGEX = re.compile(r'Port=(\d+)\s+Lid=(\w+)\s+GUID=(\w+)\s+.*')
    LENGTH_REGEX = re.compile(r'Length\:\s+(\d+).*')
    PN_REGEX = re.compile(r'PN\:\s+(\S+).*')

    def __init__(self, dir_name, logger):
        super(CableParser, self).__init__(dir_name, logger)
        self._cables = dict()
        self._curr_cable = None

    def _parse_line(self, line):
        match = self.LENGTH_REGEX.match(line)
        if match:
            if self._curr_cable is None:
                return True
            self._curr_cable.length = int(match.group(1))
            return True
        match = self.PN_REGEX.match(line)
        if match:
            if self._curr_cable is None:
                return True
            self._curr_cable.pn = match.group(1)
            return True
        match = self.SWITCH_REGEX.match(line)
        if match:
            curr_port = int(match.group(1))
            switch_name = GUIDLookup.getGUID(match.group(3))
            curr_switch = self._cables.setdefault(
                switch_name, dict())
            self._curr_cable = curr_switch[curr_port] = CableInfo()
            self._logger.debug("Gathering Cable info for: %s-%s", switch_name,
                         curr_port)
            return True
        return True

    @property
    def cables(self):
        return self._cables


class NodeInfoParser(BaseParser):
    FILE_NAME = "ibdiagnet2.nodes_info"
    NODE_REGEX = re.compile(r'Node Name=(\S+)')
    GUID_REGEX = re.compile(r'GUID=(\w+)')

    def __init__(self, dir_name, logger):
        super(NodeInfoParser, self).__init__(dir_name, logger)
        self._node_mapping = dict()
        self._curr_node = None

    def _parse_line(self, line):
        match = self.NODE_REGEX.match(line)
        if match:
            self._curr_node = match.group(1)
            return True
        match = self.GUID_REGEX.match(line)
        if match:
            self._node_mapping[self._curr_node] = match.group(1)
        return True

    @property
    def node_mapping(self):
        return self._node_mapping


class LatencyCalculator(object):
    def __init__(self, dir_name, fibers_pn, latencies, logger, per_hca=False):
        self._dir_name = dir_name
        self._logger = logger
        self._topology = None
        self._links = None
        self._hops = None
        self._cables = None
        self._per_hca = per_hca
        self._node_name_mapping = dict()
        self._fibers_pn = fibers_pn
        self._dev_id_to_lat = latencies

    def load(self):
        self._load_db()
        self._load_fdbs()
        self._load_cables()
        self._load_nodes_info()

    def get_latency(self, src_node_name, dst_node_name):
        cables, switches, hcas = self.get_latency_info(src_node_name, dst_node_name)

        # calculate the overall latency
        lat = 0.0
        for c in cables:
            lat += c[-1]
        for s in switches:
            lat += s[-1]
        for h in hcas:
            lat += h[-1]

        return lat * 1000000.0

    def get_cables_tuples(self, node_path):
        res = []
        for link in node_path:
            tup = link.cable.to_tuple()
            res.append(tup)
        return res

    def get_switches_cables(self, node_path):
        res = []
        for link in node_path[:-1]:
            dst = self._topology[link.dst.port_guid]
            tup = (dst.port_guid, dst.device_id, dst.get_delay())
            res.append(tup)
        return res

    def get_hca_tuple(self, v):
        v = v.port_guid
        v = self._topology[v]
        tup = (v.port_guid, v.device_id, v.get_delay())
        return tup

    def get_hcas_tuples(self, node_path):
        first = self.get_hca_tuple(node_path[0].src)
        last = self.get_hca_tuple(node_path[-1].dst)
        res = [first, last]
        return res

    def get_latency_info(self, src_node_name, dst_node_name):
        node_path = self.get_path_tokens(src_node_name, dst_node_name)
        # cables = [link.cable.to_tuple() for link in node_path]
        # switches = [self._topology[link.dst.port_guid] for link in node_path[:-1]]
        # switches = [(switch.port_guid, switch.device_id, switch.get_delay()) for switch in switches]
        cables = self.get_cables_tuples(node_path)
        switches = self.get_switches_cables(node_path)
        hcas = self.get_hcas_tuples(node_path)

        return cables, switches, hcas

    def get_node_name_mapping(self, node_name):
        guid = self._node_name_mapping.get(node_name, None)
        if guid is None:
            # sometimes the node name is longer than the keys in self._node_name_mapping
            resolved = node_name.partition('.')[0]
            guid = self._node_name_mapping.get(resolved, None)
        if guid is None:
            # when running with --bycore, a "-XXX" postfix is added to node name (core number)
            if self._per_hca:
                idx = node_name.find(' HCA-')
                if idx >= 5:
                    resolved = node_name[:idx - 4] + node_name[idx:]
                    guid = self._node_name_mapping.get(resolved, None)
            else:
                if len(node_name) > 4:
                    resolved = node_name[:-4]
                    guid = self._node_name_mapping.get(resolved, None)
        return guid

    def get_path_tokens(self, src_node_name, dst_node_name):
        src_guid = self.get_node_name_mapping(src_node_name)
        dst_guid = self.get_node_name_mapping(dst_node_name)
        if not src_guid:
            raise LatencyException("Could not find %s in topology" % src_node_name)
        if not dst_guid:
            raise LatencyException("Could not find %s in topology" % dst_node_name)
        node_path = self._calc_path(src_guid, dst_guid, need_cable_info=True)
        return node_path

    def is_per_hca(self):
        return self._per_hca is True

    def _load_fdbs(self):
        parser = FdbsParser(self._dir_name, self._logger)
        parser.parse()
        self._hops = parser.hops

    def _load_db(self):
        parser = DbParser(self._dir_name, self._logger)
        parser.parse()
        self._topology = parser.topology
        self._links = parser.links
        for node_obj in self._topology.values():
            node_name = node_obj.node_desc
            node_name = node_name.replace('"', '')
            if not self._per_hca:
                # take node name only
                node_name = node_name.split()[0]
            self._node_name_mapping[node_name] = node_obj.port_guid
            if node_name not in self._node_name_mapping:
                self._node_name_mapping[node_name] = node_obj.port_guid
            device_id = node_obj.device_id
            if self._dev_id_to_lat and device_id in self._dev_id_to_lat:
                node_obj.latency = self._dev_id_to_lat[device_id]

    def _load_nodes_info(self):
        parser = NodeInfoParser(self._dir_name, self._logger)
        parser.parse()
        self._node_name_mapping.update(parser.node_mapping)

    def _load_cables(self):
        parser = CableParser(self._dir_name, self._logger)
        parser.parse()
        self._cables = parser.cables
        for _, d in self._cables.items():
            for i, cable_info in d.items():
                if self._fibers_pn is not None:
                    if cable_info.pn in self._fibers_pn:
                        cable_info.fiber = True
                    else:
                        cable_info.fiber = False

    def _get_cable_info(self, vertex):
        cable_info = self._cables.get(vertex.port_guid, dict()).get(
            vertex.port_num, )
        return cable_info

    def _append_path(self, node_path, switch_guid, dst_lid, need_cable_info=True):
        port_num = self._hops[switch_guid][dst_lid]
        src_vertex = Vertex.getVertex(switch_guid, port_num)
        dst_vertex = self._links[src_vertex]
        if need_cable_info:
            cable_info = self._get_cable_info(dst_vertex)
        else:
            cable_info = None
        dst_node = self._topology[dst_vertex.port_guid]
        node_path.append(Link.getLink(src_vertex, dst_vertex, cable_info))
        if dst_node.lid == dst_lid:
            return
        self._append_path(node_path, dst_node.port_guid, dst_lid, need_cable_info)

    def _calc_path(self, src_port_guid, dst_port_guid, need_cable_info=True):
        node_path = []
        src_node_obj = self._topology.get(src_port_guid)
        src_vertex = Vertex.getVertex(src_node_obj.port_guid, src_node_obj.local_port)
        dst_vertex = self._links[src_vertex]
        if need_cable_info:
            cable_info = self._get_cable_info(dst_vertex)
        else:
            cable_info = None

        link = Link.getLink(src_vertex, dst_vertex, cable_info)
        node_path.append(link)

        dst_node_obj = self._topology.get(dst_port_guid)
        dst_lid = dst_node_obj.lid
        self._append_path(node_path, dst_vertex.port_guid, dst_lid, need_cable_info)
        verbose = False
        if verbose:
            for index, link in enumerate(node_path):
                self._logger.info("hop #%d: %s" % str(index + 1), link)
        return node_path


class PrettyFloat(float):
    def __repr__(self):
        return '%.15g' % self


class NormalizePathLatency:
    def __init__(self, filename, logger, latency_factor=0.85):
        self.filename = filename
        self._logger = logger
        self.vals = {}
        self.json = None
        self.factor = latency_factor
        self.load()

    def load(self):
        try:
            with open(self.filename) as json_data:
                d = json.load(json_data)
                json_data.close()
                self.json = d
            return d
        except Exception as e:
            self._logger.error("Failed to parse json file: %s" % self.filename)

    def add_per_hca_data(self, nodes):
        # when clusterkit run by hca it adds '-xxx' postfix to each node name
        # where xxx is node rank in 3 digits
        for k, name in nodes.items():
            postfix = name[-3:]
            try:
                node_rank = int(postfix)
            except Exception as e:
                msg = "Cannot read hca info"
                self._logger.error(msg)
                raise LatencyException(msg)
            name = name.split('.')[0]
            name = name + ' HCA-' + str(node_rank + 1)
            nodes[k] = name
        return nodes

    def get_value(self, x, y, links):
        if x == y:
            return 0.0
        if x > y:
            x, y = y, x

        return links[x][y - x]

    def set_value(self, x, y, links, val):
        if x == y:
            self._logger.warning("should not get here...")
            return

        if x > y:
            x, y = y, x
        links[x][y - x] = val

    def get_data(self, n):
        if n not in self.vals:
            self.vals[n] = {}
        return self.vals[n]

    def save(self):
        dest_file = self.filename.replace(".json", "_normalized.json")
        normalized_file = open(dest_file, "w")
        normalized_file.write(simplejson.dumps(self.json, indent=4, sort_keys=True))
        normalized_file.close()
        return dest_file

    def process(self, latency_calculator):
        nodes = self.json["Nodes"]
        links = self.json["Links"]
        num_items = len(nodes)
        nodes_by_idx = {}
        for k in nodes.keys():
            nodes_by_idx[nodes[k]] = k

        nodes = nodes_by_idx
        if latency_calculator.is_per_hca():
            nodes = self.add_per_hca_data(nodes)

        for i in range(num_items):
            n1 = nodes[i]
            for j in range(i):
                n2 = nodes[j]
                val = self.get_value(i, j, links)

                delta = 0.0
                # try:
                # Account for the fact that there are 2 paths being traversed, n1 -> n2,   n2 -> n1,
                # and they are not symmetric
                lat = latency_calculator.get_latency(n1, n2)
                lat2 = latency_calculator.get_latency(n2, n1)
                delta = (self.factor * (lat + lat2)) / 2.0
                normalized_val = max(0, val - delta)
                self.set_value(i, j, links, normalized_val)

                self._logger.debug("%s ->  %s   %.4f  %.4f %.4f " % (n1, n2, delta, lat, lat2))
                # except Exception as e:
                #     logger.warning("Failed to calculate latency between %s and %s " % (n1, n2))


class PDFWriter(FPDF):
    def __init__(self):
        super().__init__()
        self.add_page()
        self.WIDTH = 210   # A4 SIZE in mm
        self.HEIGHT = 297  # A4 SIZE in mm
        self.MARGINS = 20  # in mm, left = right = 10
        self.FONT = 'Courier'
        self.FONT_SIZE = 10
        self.FONT_SIZE_TITLE = 30
        #self.MM_TO_POINT = 2.8346456693 # conversion to points
        self.POINT_TO_MM = 0.3527777778
         
    def footer(self):
        # This method is used to render the page footer.
        self.set_y(-15)  # Go to 1.5 cm from bottom  
        self.set_font(self.FONT, 'I', self.FONT_SIZE)
        self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C')  # Print centered page number

    def add_title(self, title):
        self.set_font(self.FONT, 'B', self.FONT_SIZE_TITLE)
        self.cell(self.WIDTH - 30, 7, title, 0, 1, 'C')

    def add_small_title(self, title):
        self.set_font(self.FONT, 'B', self.FONT_SIZE)
        self.cell(self.WIDTH - 30, 7, title, 0, 1, 'L')

    def add_content(self, text):
        if text == '':
            return
        self.set_font(self.FONT, size = self.FONT_SIZE)
        self.multi_cell(self.WIDTH, 5, text, 0, 0)
        self.ln(3)

    def image(self, image):
        super().image(image, x=10, y=25, w=self.WIDTH - 15, h=self.HEIGHT / 2, type='PNG')


class CKWriter:
    def __init__(self, writer, report_args, lat='latency', bw='bandwidth', norm='latency_normalized'):
        self.writer = writer
        self.report_args = report_args
        self._logger = get_logger()
        self.lat_json       = None
        self.lat_png        = None
        self.bw_json        = None
        self.bw_png         = None
        self.norm_json      = None
        self.lat            = None
        self.bw             = None
        self.norm           = None
        self.set_files_names(lat, bw, norm)

    def set_files_names(self, lat, bw, norm):
        JSON_SUFFIX = '.json'
        PNG_SUFFIX = '.png'
        self.lat_json = os.path.join(self.report_args.dir, lat + JSON_SUFFIX)
        self.lat_png = os.path.join(self.report_args.dir, lat + PNG_SUFFIX)
        self.bw_json = os.path.join(self.report_args.dir, bw + JSON_SUFFIX)
        self.bw_png = os.path.join(self.report_args.dir, bw + PNG_SUFFIX)
        self.norm_json = os.path.join(self.report_args.dir, norm + JSON_SUFFIX)
        self.norm_png = os.path.join(self.report_args.dir, norm + PNG_SUFFIX)

    def parse_json_content(self, file_name):
        try:
            with open(file_name) as fp:
                data = json.loads(fp.read())
        except Exception as e:
            self._logger.warning(f'Could not open {file_name}')
            return False

        nodes =  data['Nodes']
        del data['Testname']
        del data['Links']
        del data['Units']
        del data['JOBID']
        del data['Nodes']
        if data['HCA_Tag'] == 'Unknown':
            del data['HCA_Tag']
        new_string = json.dumps(data, indent = 0)
        s = re.sub('[\"{},]', '', new_string)
        s = s.replace('_',' ')
        node_str = self.nodes_string(nodes)
        
        return s, node_str 

    def nodes_string(self, dict):
        space = '\t' * 3 
        longest_name = max(dict, key = len)
        last_ind = len(dict) - 1
        longest_ind = str(last_ind)
        pattern = longest_name + ' ' + longest_ind + space

        longest_node = self.writer.get_string_width(pattern)
        longest_node = longest_node * self.writer.POINT_TO_MM
        width_page = self.writer.WIDTH - self.writer.MARGINS
        
        num_col = width_page // longest_node  # num of col to the 'matrix nodes'
  
        s = ''
        line = ''
        i = 0
        value_padding = len(longest_ind)
        ind_padding = len(longest_name)
        for key, value in dict.items():    
            line += str(value).ljust(value_padding) + ' ' + key.ljust(ind_padding) + space
            i += 1
            if i == num_col:
                s += line
                s += '\n'
                line = ''
                i = 0
        return s
        
    def analyzer(self, file_name):
        analyzer = BWAnalyzer(file_name, None, None, None, None, False)
        analyzer.run_analysis()
        
        bw_dic, noise_dic, latency_dic = analyzer.get_analysis() 

        self.writer.add_small_title("Examined nodes:")
        self.writer.add_content(bw_dic['$examined_nodes'])

        self.writer.add_small_title("Bandwidth:")
        self.writer.add_content(bw_dic['$bandwidth_statistics'])
        self.writer.add_content(bw_dic['$bad_node_msg'])
        self.writer.add_content(bw_dic['$extreme_bad_node_msg'])
        
        if noise_dic:
            self.writer.add_small_title("Noise:")
            self.writer.add_content(noise_dic['$statistics'])
        if latency_dic:
            self.writer.add_small_title("Latency:")
            self.writer.add_content(latency_dic['$statistics'])      
            
    def add_heatmap(self, png_file):
        name_with_suffix  = os.path.basename(png_file)
        name = name_with_suffix.replace('.png', '')
        self.writer.add_page()
        self.writer.add_title(name.capitalize() + ' Heatmap')
        self.writer.image(png_file)

    def run(self):
        self.writer.set_font('Arial', 'B', self.writer.FONT_SIZE)
        self.writer.add_title('ClusterKit Report')

        general, nodes = self.parse_json_content(self.lat_json)
        if general is False:
            general, nodes = self.parse_json_content(self.bw_json)
            if general is False:
                print ('failed to open Latency or Bandwidth file')
                return

        self.writer.add_content(general)
        self.analyzer(self.bw_json)
        self.add_heatmap(self.lat_png)
        self.add_heatmap(self.bw_png)
        self.writer.add_page()
        self.writer.add_small_title('Nodes:')
        self.writer.add_content(nodes)
        
        self.writer.output(os.path.join(self.report_args.dir, 'report.pdf'))


class ReportWrapper:
    REPORT_PARSER = None
    def __init__(self):
        self._logger = get_logger()    

    @classmethod
    def print_help(cls):
        parser = cls.get_parser()
        parser.print_help()

    @classmethod
    def get_parser(cls):
        if cls.REPORT_PARSER is None:
            cls.REPORT_PARSER = argparse.ArgumentParser(add_help=False, description="script that generate PDF output file")
            cls.REPORT_PARSER.add_argument('-p', '--report', required=False, action='store_true', help='creating PDF report')
            cls.REPORT_PARSER.add_argument('-d', '--dir', required=True, help='directory with bandwidth/ latency JSON files')

        return cls.REPORT_PARSER

    def run(self):
        parser = self.get_parser()
        report_args, __ = parser.parse_known_args()
        fpdf = PDFWriter()
        ck_writer = CKWriter(fpdf, report_args)
        ck_writer.run()


class NormalizeWrapper:
    """
    this class is a wrapper class for running latency normalization process.
    check config file and run normalization
    """
    IBD_FILES = ["ibdiagnet2.db_csv", "ibdiagnet2.fdbs", "ibdiagnet2.cables", "ibdiagnet2.nodes_info"]
    INI_DESCRIPTION = """
INI file content:
    The INI configuration file should consist a [normalize_latency] section with the following arguments:
        config_file   a CSV file with devices and latencies in ns or short optical fibers.
                      See latency_calc_config.csv for an example."""
    NORMALIZE_PARSER = None

    def __init__(self, file):
        self._args = dict_from_ini(file, 'normalize_latency')
        self._logger = get_logger(self._args)

    @classmethod
    def get_parser(cls):
        if cls.NORMALIZE_PARSER is None:
            cls.NORMALIZE_PARSER = argparse.ArgumentParser(add_help=False, description="normalize latency JSON file")
            cls.NORMALIZE_PARSER.add_argument('-f', '--config', required=True, help='configuration INI file')
            cls.NORMALIZE_PARSER.add_argument('-j', '--json', required=True, help='latency JSON file')
            cls.NORMALIZE_PARSER.add_argument('-d', '--ibdiag-dir', required=True,
                                              help=f'directory with the following files: {", ".join(cls.IBD_FILES)}')
            cls.NORMALIZE_PARSER.add_argument('-n', '--normalize', required=True, action='store_true',
                                              help='normalize latency JSON file')
            cls.NORMALIZE_PARSER.add_argument('-p', '--per-hca', required=False, action='store_true',
                                              help='use if JSON data is per HCA')
        return cls.NORMALIZE_PARSER

    @classmethod
    def print_help(cls):
        parser = cls.get_parser()
        parser.print_help()
        print(cls.INI_DESCRIPTION)

    def file_is_good(self, file):
        res = os.access(file, os.R_OK)
        if res is False:
            self._logger.error("Cannot open file: %s. Please check that the path is absolute" % file)
        return res

    def files_are_good(self, ibdiag_dir, files_to_check):
        res = True
        for f in files_to_check:
            look_for = os.path.join(ibdiag_dir, f)
            if not self.file_is_good(look_for):
                res = False
        return res

    def run(self):
        parser = self.get_parser()
        normalize_args, __ = parser.parse_known_args()
        fibers = None
        latencies = None
        config_file = ini_line_to_string(self._args, "config_file")
        if config_file:
            if self.file_is_good(config_file):
                config_parser = CSVConfigFileParser(config_file, self._logger)
                latencies, fibers = config_parser.get()
            else:
                sys.exit(-1)

        latency_calc = None
        ibdiag_dir = normalize_args.ibdiag_dir
        if os.path.exists(ibdiag_dir):

            self._logger.debug("ib diagnet files from: %s" % ibdiag_dir)

            files = self.IBD_FILES
            if not self.files_are_good(ibdiag_dir, files):
                sys.exit(-1)
            per_hca = False
            if normalize_args.per_hca:
                per_hca = True
            self._logger.info("per_hca = %s", per_hca)
            latency_calc = LatencyCalculator(ibdiag_dir, fibers, latencies, self._logger, per_hca)
            latency_calc.load()

        if not latency_calc:
            self._logger.error("could not find required files in ibdiagnet dir")
            sys.exit(-1)

        file_name = normalize_args.json
        if os.path.exists(file_name):
            self._logger.debug("read latency file: %s" % file_name)
            n = NormalizePathLatency(file_name, self._logger)
            n.process(latency_calc)
            normalized_file = n.save()
            self._logger.info("path-length normalized file written to: %s" % normalized_file)


# normalization classes end here
# rendering classes starts here

class Legend:
    def __init__(self, logger, avg, sigma, color_scale=None, profile_name="Gaussian"):
        self.font = ImageFont.load_default()
        red = [185, 52, 52]
        med_red = [230, 103, 103]
        pink = [255, 190, 190]
        gray = [172, 172, 172]
        lt_green = [190, 255, 190]
        green = [120, 228, 120]
        dk_green = [56, 193, 56]
        self._logger = logger
        self.higher_better_colors = [red, med_red, pink, gray, lt_green, green, dk_green]
        self.lower_better_colors = [dk_green, green, lt_green, gray, pink, med_red, red]

        self.vals = self.set_gaussian_color_scale(avg, sigma)
        self.profile_name = profile_name
        if color_scale is not None:
            try:
                self.vals = self.set_color_scale(avg, sigma, color_scale)
            except Exception as e:
                self._logger.warning("Could not parse thresholds, setting gaussian color scale")
                self.vals = self.set_gaussian_color_scale(avg, sigma)

        # print(self.bw_colors)
        # print(self.lat_colors)

        self.rev_range = range(len(self.vals) - 1, -1, -1)
        self.range = range(len(self.vals))

    def set_gaussian_color_scale(self, avg, sigma):
        self.vals = [avg - 2.0 * sigma,
                     avg - 1.0 * sigma,
                     avg - 0.5 * sigma,
                     avg + 0.5 * sigma,
                     avg + 1.0 * sigma,
                     avg + 2.0 * sigma]
        return self.vals

    def set_color_scale(self, avg, sigma, color_scale):
        # variables for eval
        Average = mean = average = avg
        STD = std = sigma
        res = color_scale
        if not RenderWrapper.valid_thresholds(color_scale):
            try:
                thresholds = eval(color_scale)
            except Exception as e:
                color_scale = color_scale.replace('\n', ',')
                thresholds = eval(color_scale)
            thresholds = RenderWrapper.to_list(thresholds)
            res = RenderWrapper.set_thresholds(thresholds)
        return res

    def get_color(self, v, higher_better):

        if v == 0.0:
            return BORDER_COLOR

        if higher_better:

            if v < self.vals[0]:
                return self.higher_better_colors[0]
            if v > self.vals[-1]:
                return self.higher_better_colors[-1]

            for i in self.rev_range:
                if v >= self.vals[i]:
                    return self.higher_better_colors[i + 1]

            return self.higher_better_colors[-1]
        else:
            for i in self.range:
                if v < self.vals[i]:
                    return self.lower_better_colors[i]
            return self.lower_better_colors[-1]

    def render_higher_better(self, d):

        x0 = 20
        x1 = 100

        txt_x = 120

        k = 0
        k_inc = 1
        units = "MB/sec"

        black = (0, 0, 0)
        y_loc = 50

        for i in range(len(self.higher_better_colors)):
            col = self.higher_better_colors[i]
            if 0 <= k < len(self.vals):
                legend_val = "%5.2f %s" % (self.vals[k], units)
                k = k + k_inc
            else:
                legend_val = None
            col_to_use = tuple(col)
            y0 = y_loc
            y1 = y0 + 30
            pt0 = (x0, y0)
            pt1 = (x1, y1)
            d.rectangle([pt0, pt1], fill=col_to_use, outline=black)

            if legend_val:
                txt_loc = (txt_x, y1)  # was y1
                d.text(txt_loc, legend_val, font=self.font, fill=(0, 0, 0))

            y_loc = y_loc + 40

    def render_lower_better(self, d):

        x0 = 20
        x1 = 100

        txt_x = 120

        k = len(self.vals) - 1
        k_inc = -1
        units = "usec"

        black = (0, 0, 0)
        y_loc = 50

        rev_col = list(self.lower_better_colors)
        rev_col.reverse()

        for i in range(len(rev_col)):
            col = rev_col[i]
            if 0 <= k < len(self.vals):
                legend_val = "%5.2f %s" % (self.vals[k], units)
                k = k + k_inc
            else:
                legend_val = None
            col_to_use = tuple(col)
            y0 = y_loc
            y1 = y0 + 30
            pt0 = (x0, y0)
            pt1 = (x1, y1)
            d.rectangle([pt0, pt1], fill=col_to_use, outline=black)

            if legend_val:
                txt_loc = (txt_x, y1)  # was y1
                d.text(txt_loc, legend_val, font=self.font, fill=(0, 0, 0))

            y_loc = y_loc + 40

    def get_as_image(self, higher_better):

        legend_width = 260
        legend_height = 400

        img = Image.new('RGB', (legend_width, legend_height), color=(255, 255, 255))
        d = ImageDraw.Draw(img)
        x_loc = (10, 10)
        d.text(x_loc, self.profile_name, font=self.font, fill=(0, 0, 0))

        if higher_better:
            self.render_higher_better(d)
        else:
            self.render_lower_better(d)

        return img


class ImageGenerator:
    def __init__(self, filename, logger):
        self.filename = filename
        self._logger = logger
        self.imgname = self.filename[:-4] + "png"
        self.higherisbetter = "bandwidth" in self.filename

        if self.higherisbetter:
            self.title = "Bandwidth"
        else:
            self.title = "Latency"

        self.init()

    def init(self):
        with open(self.filename) as data_file:
            try:
                self.data = json.load(data_file)
            except ValueError:
                self._logger.error("json load failed. File not .json or formatted incorrectly.")
                sys.exit()

    def value(self, x, y):
        if (type(x) is str):
            a = self.data["Nodes"][x]
        else:
            a = x

        if (type(y) is str):
            b = self.data["Nodes"][y]
        else:
            b = y

        if a == b:
            return 0.0

        if a > b:
            a, b = b, a
        return self.data["Links"][a][b - a]

    def add_all_ranks(self, draw, n_nodes, radius, font_small, heatmap_offset):
        pos = heatmap_offset + (radius // 4)
        for i in range(n_nodes):
            if i == 10:
                # adjust to 2 digits number size
                pos -= 3
            rank = str(i)
            # row and column
            draw.text((pos, 5), rank, font=font_small, fill=(0, 0, 0))
            draw.text((5, pos), rank, font=font_small, fill=(0, 0, 0))
            pos += radius

    def add_n_ranks(self, draw, n_nodes, radius, font, last_pos, heatmap_offset, n=20):
        jump = (n_nodes + (n // 2)) // n
        pixel_jump = jump * radius
        pos = heatmap_offset - 5
        rank = 0
        prev_len = 1

        # print 0 once diagonal to heatmap
        draw.text((heatmap_offset - 10, heatmap_offset - 10), '0', font=font, fill=(0, 0, 0))

        for i in range(1, n):
            pos += pixel_jump
            rank += jump
            s = str(rank)

            # rank is longer in string length - adjust position
            if len(s) != prev_len:
                diff = len(s) - prev_len
                pos -= 3 * diff
                prev_len += diff
            draw.text((pos, 5), s, font=font, fill=(0, 0, 0))
            draw.text((0, pos), s, font=font, fill=(0, 0, 0))

        # at least 20 pixels diff
        if last_pos - pos > 20:
            last = str(n_nodes)
            draw.text((last_pos, 5), last, font=font, fill=(0, 0, 0))
            draw.text((0, last_pos), last, font=font, fill=(0, 0, 0))

    def add_ranks(self, draw, n_nodes, radius, font, width, heatmap_offset):
        if n_nodes < 60:
            self.add_all_ranks(draw, n_nodes, radius, font, heatmap_offset)
        else:
            self.add_n_ranks(draw, n_nodes, radius, font, width, heatmap_offset, n=20)

    def get_non_zero_stats(self, arr):
        # return STD and average without zero elements in input array
        b = arr[arr != 0]
        std_dev = np.std(b, dtype=np.float64)
        average = np.average(b)
        return std_dev, average

    def process(self, color_scale, profile_name):
        if "Nodes" not in self.data:
            self._logger.warning("No 'Nodes' info for: ", self.filename)
            return

        size = len(self.data["Nodes"])
        radius = 1
        if size > 1024:
            self.pixels = np.zeros((size, size, 3), dtype=np.uint8)
        else:
            while (radius <= (1024 / (size + 1))):
                radius += 1
            radius -= 1
            if radius <= 0:
                self._logger.info("radius forced to 1")
                radius = 1

            self.pixels = np.zeros((radius * (size), radius * (size), 3), dtype=np.uint8)

        data_array2 = np.empty((size) * (size - 1))
        idx = 0
        for i in range(size):
            for j in range(size):
                # do not include values on the diagonal - they will skew the mean/stddev drastically
                if i != j:
                    val = self.value(i, j)
                    data_array2[idx] = val
                    idx = idx + 1

        std_dev, average = self.get_non_zero_stats(data_array2)
   
        # std_dev = np.std(data_array2, dtype=np.float64)
        # average = np.average(data_array2)
        self._logger.info(f"std_dev = {std_dev:.4f} average= {average:.4f}")

        legend = Legend(self._logger, average, std_dev, color_scale, profile_name)

        leg_img = legend.get_as_image(self.higherisbetter)
        # leg_img.show()
        # leg_img.save('legend.png')

        skip_image = False

        if skip_image:
            return

        for i in range(size):
            for j in range(size):
                val = self.value(i, j)
                color = legend.get_color(val, self.higherisbetter)

                # without grid lines - for dense matrices
                if radius < 5:
                    for k in range(radius):
                        for l in range(radius):
                            self.pixels[i * radius + k, j * radius + l] = color  # [red, green, blue]
                # with grid lines - border color in last column and last row of block[i, j]
                else:
                    for k in range(radius - 1):
                        for l in range(radius):
                            if l < radius - 1:
                                self.pixels[i * radius + k, j * radius + l] = color  # [red, green, blue]
                            # last column
                            else:
                                self.pixels[i * radius + k, j * radius + l] = BORDER_COLOR

                    # last row
                    for z in range(radius):
                        self.pixels[i * radius + radius - 1, j * radius + z] = BORDER_COLOR

        # what do we do with the image?
        img = Image.fromarray(self.pixels)
        # img.show()
        # smp.imsave(self.imgname, self.pixels)

        heatmap_offset = 20
        img_to_legend_margin = 40
        new_width = img._size[0] + img_to_legend_margin + leg_img._size[0]
        new_height = max(img._size[1], leg_img._size[1]) + 40

        new_im = Image.new('RGB', (new_width, new_height), color=(255, 255, 255))
        new_im.paste(img, (heatmap_offset, heatmap_offset))
        new_im.paste(leg_img, (img_to_legend_margin + img._size[0], 20))

        # render filename as part of the img.
        d = ImageDraw.Draw(new_im)
        font = ImageFont.load_default()
        title_loc = (img._size[0] + 60, new_height - 60)

        fname_loc = (20, new_im._size[1] - 18)
        d.text(fname_loc, self.imgname, font=font, fill=(0, 0, 0))
        d.text(title_loc, self.title, font=font, fill=(0, 0, 0))

        self.add_ranks(d, size, radius, font, img._size[0], heatmap_offset)

        show_image = False
        if show_image:
            new_im.show()
        new_im.save(self.imgname)

        self._logger.info(self.imgname)


class RenderWrapper:
    """
    this class is a wrapper class for running heatmap rendering process.
    check config file, run rendering and contain some general rendering methods
    """
    INI_DESCRIPTION = """
INI file content:
    The INI configuration file should consist a [render_heatmap] section with the following arguments:
        thresholds        comma separated list (without spaces) of color scale thresholds
        thresholds_file   path to file with color scale thresholds"""
    RENDER_PARSER = None

    def __init__(self, file):
        self._args = dict_from_ini(file, 'render_heatmap')
        self._logger = get_logger(self._args)
        self._profile_name = "Gaussian"

    @classmethod
    def get_parser(cls):
        if cls.RENDER_PARSER is None:
            cls.RENDER_PARSER = argparse.ArgumentParser(add_help=False, description="render a bandwidth/ latency heatmap")
            cls.RENDER_PARSER.add_argument('-f', '--config', required=True, help='configuration INI file')
            cls.RENDER_PARSER.add_argument('-j', '--json', required=False, help='bandwidth/ latency JSON file')
            cls.RENDER_PARSER.add_argument('-d', '--dir', required=False, help=f'directory with bandwidth/ latency JSON files')
            cls.RENDER_PARSER.add_argument('-g', '--gaussian', required=False, action='store_true', help=f'set a gaussian color scale bar')
            cls.RENDER_PARSER.add_argument('-r', '--render', required=True, action='store_true', help='do heatmap rendering')

        return cls.RENDER_PARSER

    @classmethod
    def print_help(cls):
        parser = cls.get_parser()
        parser.print_help()
        print(cls.INI_DESCRIPTION)

    @staticmethod
    def find_files(directory, pattern):
        for root, dirs, files in os.walk(directory):
            for basename in files:
                if fnmatch.fnmatch(basename, pattern):
                    filename = os.path.join(root, basename)
                    yield filename

    def is_numbers(self, lines):
        evaluated = []
        try:
            for v in lines:
                val = eval(v)
                evaluated.append(val)
        except Exception as e:
            return False
        res = self.valid_thresholds(evaluated)
        return res

    def thresholds_file_to_str(self, thresholds_file):
        lines = []
        try:
            with open(thresholds_file, 'r') as fp:
                for a_line in fp:
                    parts = a_line.partition('#')
                    a_line = parts[0]
                    a_line = a_line.replace('\n', '')
                    lines.append(a_line)
        except Exception as e:
            self._logger.warning("Cannot read thresholds file. Please check that the path is absolute")
            return None

        if self.is_numbers(lines):
            res = ','.join(lines)
        else:
            res = '\n'.join(lines)
        return res

    def read_tech_and_bidirectional(self, file_name):
        tech = None
        bidirectional = None
        try:
            with open(file_name) as json_data:
                d = json.load(json_data)
                tech = d.get("Technology", "Unknown")
                bidirectional = d.get("Bidirectional", "False")
                bidirectional = eval(bidirectional)
        except Exception as e:
            self._logger.warning("Could not read %s" % file_name)
            return None
        return tech, bidirectional

    def get_test_prefix(self, file_name):
        if 'latency' in file_name.lower():
            test = "latency"
        else:
            test = "bandwidth"
        return test

    def get_per_technology_thresholds(self, file_name):
        res = None
        rc = self.read_tech_and_bidirectional(file_name)
        if rc is not None:
            tech, bidirectional = rc
            if tech != "Unknown":
                test = self.get_test_prefix(file_name)
                res = ini_line_to_string(self._args, tech.lower() + '_' + test + '_thresholds')
                if bidirectional is False and test == 'bandwidth':
                    res = ini_line_to_string(self._args, tech.lower() + '_bandwidth_unidirectional_thresholds')
                    self._logger.info("taking bandwidth unidirectional thresholds for %s technology", tech)
                    self._profile_name = f"{tech} Unidirectional"
                else:
                    self._logger.info("taking %s thresholds for %s technology", test, tech)
                    self._profile_name = f"{tech}"
        return res

    def get_user_thresholds(self, file_name):
        res = None
        bidirectional = True
        test = self.get_test_prefix(file_name)
        thresholds_postfix = '_thresholds'
        thresholds_file_postfix = '_thresholds_file'

        rc = self.read_tech_and_bidirectional(file_name)
        if rc is not None:
            _, bidirectional = rc
        if bidirectional is False and test == 'bandwidth':
            thresholds_postfix = '_unidirectional' + thresholds_postfix
            thresholds_file_postfix = '_unidirectional' + thresholds_file_postfix

        thresholds = ini_line_to_string(self._args, test + thresholds_postfix)
        thresholds_file = ini_line_to_string(self._args, test + thresholds_file_postfix)
        if thresholds or thresholds_file:
            res = thresholds, thresholds_file
        return res

    def set_color_scale_args(self, file_name):
        """
        returns a list of thresholds or a string to evaluate with 'avg' and 'std' variables
        """
        rc = self.get_user_thresholds(file_name)
        if rc is None:
            as_str = self.get_per_technology_thresholds(file_name)
            if as_str is None:
                return None
        else:
            self._profile_name = "Custom Thresholds"
            thresholds, thresholds_file = rc
            if thresholds:
                as_str = thresholds
            else:
                as_str = self.thresholds_file_to_str(thresholds_file)

        try:
            values = eval(as_str)
            values = self.to_list(values)
            res = self.set_thresholds(values)
        except Exception as e:
            res = as_str
        return res

    def dir_iterate_render(self, directory, file_pattern):
        for f in self.find_files(directory, file_pattern):
            try:
                self._logger.info("Processing: %s" % f)
                scale_thresholds = self.set_color_scale_args(f)
                img = ImageGenerator(f, self._logger)
                img.process(scale_thresholds, self._profile_name)
            except Exception as e:
                self._logger.error(f"{e} when processing {f}")

    def run(self):
        parser = self.get_parser()
        render_args, __ = parser.parse_known_args()
        file_name = render_args.json
        directory = render_args.dir
        gaussian = render_args.gaussian
        self._logger.info("filename = %s" % file_name)
        self._logger.info("dir = %s" % directory)

        if file_name:
            scale_thresholds = None
            if not gaussian:
                scale_thresholds = self.set_color_scale_args(file_name)
            img = ImageGenerator(file_name, self._logger)
            img.process(scale_thresholds, self._profile_name)

        if directory:
            self.dir_iterate_render(directory, "*latency*.json")
            self.dir_iterate_render(directory, "*bandwidth*.json")

    @staticmethod
    def handle_thresholds_three(li):
        c = max(li)
        a = min(li)
        middle_set = set(li) - {a, c}
        if middle_set:
            b = middle_set.pop()  # get middle
        else:
            return [a, a, a, c, c, c]
        res = [a, a, b, b, c, c]
        return res

    @staticmethod
    def handle_thresholds_two(li):
        a = min(li)
        b = max(li)
        res = [a, a, a, b, b, b]
        return res

    @staticmethod
    def handle_thresholds_one(val):
        return [val] * 6

    @classmethod
    def set_thresholds(cls, arg):
        if isinstance(arg, list) and len(arg) in (1, 2, 3, 6):
            res = list(arg)
            if len(res) == 6:
                res.sort()
            elif len(res) == 3:
                res = cls.handle_thresholds_three(res)
            elif len(res) == 2:
                res = cls.handle_thresholds_two(res)
            elif len(res) == 1:
                res = cls.handle_thresholds_one(res[0])
        elif isinstance(arg, str):
            res = cls.handle_thresholds_one(float(arg))
        elif type(arg) in (float, int):
            res = cls.handle_thresholds_one(arg)
        else:
            raise Exception()
        return res

    @staticmethod
    def to_list(values):
        if isinstance(values, tuple):
            values = list(values)
        elif not isinstance(values, list):
            values = [values]
        return values

    @staticmethod
    def valid_thresholds(thresholds):
        if len(thresholds) == 0:
            return False
        res = True
        for x in thresholds:
            res = res and (isinstance(x, int) or isinstance(x, float))
        return res


def dict_from_ini(conf_file, section_name):
    d = {}
    cp = configparser.ConfigParser()
    cp.read(conf_file)
    section = cp[section_name]
    for key, val in section.items():
        if val == '':
            val = None
        d[key] = val
    return d


def ini_line_to_string(ini_dict, key):
    val = ini_dict.get(key, None)
    if val is None:
        return ''
    return val


def ini_line_to_bool(ini_dict, key):
    val = ini_dict.get(key, 0)
    try:
        val = int(val)
        val = val > 0
    except Exception as e:
        val = False
    return val


def get_logger(args_dict={}):
    verbose = ini_line_to_bool(args_dict, 'verbose')
    if verbose is True:
        level = Logger.INFO
    else:
        level = Logger.WARNING
    logger = Logger(level)
    return logger


def check_args():
    lst = [args.render, args.normalize, args.report]
    missing = any(lst) is False
    if missing:
        logger = Logger()
        logger.error("No process was specified")
        return False

    if args.report:
        return True

    # try to read INI file
    cp = configparser.ConfigParser()
    try:
        cp.read(args.config)
        return True
    except Exception as e:
        return False


def get_module():
    if args.normalize:
        res = NormalizeWrapper(args.config)
    elif args.report:
        res = ReportWrapper()
    else:
        res = RenderWrapper(args.config)
    return res


def get_parser():
    prog_name = os.path.basename(__file__)
    desc = 'Script for normalizing latency, rendering heatmaps and creating PDF report'
    parser = argparse.ArgumentParser(prog=prog_name, description=desc, add_help=False)
    parser.add_argument('-h', '--help', required=False, action='store_true', help='show this help message and exit')
    parser.add_argument('-f', '--config', required=False, help='configuration INI file')
    parser.add_argument('-r', '--render', required=False, action='store_true', help='do heatmap rendering')
    parser.add_argument('-n', '--normalize', required=False, action='store_true', help='normalize latency JSON file')
    parser.add_argument('-p', '--report', required=False, action='store_true', help='creating PDF report')

    return parser


def print_help(parser):
    print("General flags:")
    parser.print_help()
    print(PLINES)
    print("Normalize Latency:")
    NormalizeWrapper.print_help()
    print(PLINES)
    print("Render Heatmap:")
    RenderWrapper.print_help()
    print(PLINES)
    print("Report:")
    ReportWrapper.print_help()


if __name__ == "__main__":
    main_parser = get_parser()
    args, _ = main_parser.parse_known_args()
    if args.help is True:
        print_help(main_parser)
        sys.exit()

    verified = check_args()
    if verified:
        module = get_module()
        module.run()
    else:
        print_help(main_parser)
