#! /usr/bin/python3
#
# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
import optparse
import os
import pwd
import sys
import signal

try:
    from ovs.metrics import MetricsDB, MetricsReadError
except ModuleNotFoundError:
    print(u"""\
ERROR: Missing dependencies.
Please install the Open vSwitch python libraries: python3-openvswitch (version 2.17.8).
Alternatively, install them from source: ( cd ovs/python ; python3 setup.py install ).
Alternatively, check that your PYTHONPATH is pointing to the correct location.""")
    sys.exit(1)


options = None
outfile = None


def sigint_handler(sig, frame):
    if outfile:
        outfile.close()
    exit(0)
signal.signal(signal.SIGINT, sigint_handler)


def human_convert(value: float, power: int = 1000) -> (float, str):
    n = 0
    labels = ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y', 'R', 'Q']
    if power == 1024:
        # Change mega to mebi etc
        labels = [x + 'i' for x in labels]
        labels[0] = '' # Except when no label
    if power != 1000 and power != 1024:
        raise Exception("Unsupported power value for human conversion.")
    while abs(value) > power and n in range(len(labels) - 1):
        value /= power
        n += 1
    return value, labels[n]


class Entry():
    def __init__(self,
                 name: str,
                 unit_name: str,
                 metrics_names: list,
                 derive: str = '',
                 precision: int = 0,
                 power: int = 1000):
        self.name = name
        self.derive = derive

        self.unit_name = unit_name
        self.power = power

        self.metrics_names = metrics_names
        self.value = 0
        self.value_per_sec = 0
        self.precision = precision
        self.update_count = 0
        self.disabled = False

    def update(self, db):
        value = self.read_value(db)
        ms = db.ts_delta()
        if ms > 0:
            per_sec = 1000 / ms
            self.value_per_sec = (value - self.value) * per_sec
            self.update_count += 1
        self.value = value

    def read_value(self, db):
        if self.disabled:
            return 0
        try:
            if len(self.metrics_names) == 1:
                return db[self.metrics_names[0]].last()
            elif len(self.metrics_names) == 2:
                return db[self.metrics_names[0]].last() - db[self.metrics_names[1]].last()
        except KeyError:
            # Some metrics can be disabled, such as hw-offload ones if
            # hw-offload=false in OVS config.
            self.disabled = True
        return 0

    def legend(self, sep: str = ' ') -> str:
        name = '{:>10}'.format(self.name)
        if self.derive == '':
            return name
        return name + sep + '{:>10}'.format(self.derive)

    def unit(self, sep: str = ' ') -> str:
        name = '{:>10}'.format(self.unit_name)
        if self.derive == '':
            return name
        return name + sep + '{:>8}'.format(self.unit_name) + '/s'

    def format_human(self):
        def pretty_float(real: float, precision: int = 2) -> str:
            return '{0:.{1}f}'.format(real, precision)

        value, order = human_convert(self.value, self.power)
        precision = self.precision if value == self.value else 2
        value_str = pretty_float(value, precision) + order
        value_str = value_str.rjust(10)

        if self.derive == '':
            return value_str

        per_sec = self.value_per_sec if self.update_count > 1 else 0
        per_sec, order = human_convert(per_sec, self.power)
        per_sec_str = pretty_float(per_sec) + order
        per_sec_str = per_sec_str.rjust(10)

        return value_str + ' ' + per_sec_str

    def format_csv(self):
        def pretty_float(real: float, precision: int = 2) -> str:
            return '{0:10.{1}f}'.format(real, precision)

        value_str = pretty_float(self.value, self.precision)
        if self.derive == '':
            return value_str

        if self.update_count > 1:
            per_sec_str = pretty_float(self.value_per_sec)
        else:
            # Per-sec delta is incorrect after only one read.
            per_sec_str = pretty_float(0)

        return value_str + ' ' + per_sec_str


def metrics_watch():
    global outfile
    global options
    period = options.period
    clearLineStr = '\033[K'
    waitforpkt = options.waitforpkt
    started = False

    entries = [
        Entry('sw-pkts', 'packet',
              ['ovs_vswitchd_datapath_tx_packets_total',
               'ovs_vswitchd_datapath_tx_offloaded_packets_total'],
              derive='sw-pps'),
        Entry('sw-conns', 'conn',
              ['ovs_vswitchd_conntrack_n_connections'],
              derive='sw-cps'),
        Entry('hw-pkts', 'packet',
              ['ovs_vswitchd_datapath_tx_offloaded_packets_total'],
              derive='hw-pps'),
        Entry('hw-conns', 'conn',
              ['ovs_vswitchd_datapath_hw_offload_n_ct_bidir'],
              derive='hw-cps'),
        Entry('enqueued', 'request',
              ['ovs_vswitchd_datapath_hw_offload_n_enqueued']),
        Entry('hw-rules', 'rule',
              ['ovs_vswitchd_datapath_hw_offload_n_inserted'],
              derive='hw-rps'),
        Entry('RSS', 'byte',
              ['ovs_vswitchd_memory_rss'],
              power=1024),
        Entry('UsedMem', 'byte',
              ['ovs_vswitchd_memory_in_use'],
              power=1024),
        Entry('frag', '',
              ['ovs_vswitchd_memory_frag_factor'],
              precision=2),
    ]

    db = MetricsDB(extended=False, debug=False)

    if outfile is not None:
        legend = 'time\t'
        legend += '\t'.join('{}'.format(x.legend()) for x in entries)
        outfile.write(legend + '\n')
        units = 'second\t'
        units += '\t'.join('{}'.format(x.unit()) for x in entries)
        outfile.write(units + '\n')

    legend = '    time '
    legend += ' '.join('{}'.format(x.legend()) for x in entries)
    print(legend)

    units = '  second '
    units += ' '.join('{}'.format(x.unit()) for x in entries)
    print(units, flush=True)

    count = 0
    while True:
        try:
            db.update()
        except MetricsReadError:
            print('%sFailed to read OVS metrics, retrying...' % clearLineStr, end='\r')
            db.reset()
            db.wait(period)
            count = 0
            started = False
            continue

        for e in entries:
            e.update(db)

        if waitforpkt and not started:
            if entries[0].value_per_sec == 0:
                print('%sWaiting for traffic.' % clearLineStr, end='\r')
                db.wait(period)
                continue;
            else:
                started = True

        if count != 0 and count % 10 == 0:
            print(legend)
        count += 1

        sec = '{:08.3f}'.format(db.last_ts() / 1000)
        values = ' '.join('{}'.format(x.format_human()) for x in entries)
        print(sec + ' ' + values, flush=True)
        if outfile:
            values_csv = '\t'.join('{}'.format(x.format_csv()) for x in entries)
            outfile.write(sec + '\t' + values_csv + '\n')

        db.wait(period)


def metrics_one_shot():
    db = MetricsDB(extended=options.extended, debug=options.debug)
    for _, metric in db.items():
        for point in metric:
            print("{0:04.3f}:{1}".format(db.last_ts() / 1000, point))


def metrics_follow():
    period = options.period

    db = MetricsDB(extended=options.extended, debug=options.debug)
    if period < db.last_query_duration:
        sys.stderr.write("Requested period is too short to perform a query: {0} < {1} ms\n"
                         .format(period, db.last_query_duration))

    while True:
        db.update()
        for point in db.delta():
            print("{0:04.3f}:{1}".format(db.last_ts() / 1000, point), flush=True)
        period = db.wait(period)


def validate_options() -> bool:
    global options
    diag = True

    if options.one and options.follow:
        sys.stderr.write("'One-shot' and 'follow' mode are not compatible.\n")
        diag = False
    if options.one and options.outfile:
        sys.stderr.write("Output option does not work with 'one-shot' mode.\n")
        diag = False
    if options.follow and options.outfile:
        sys.stderr.write("Output option does not work with 'follow' mode.\n")
        diag = False

    return diag



def main():
    global options
    global outfile

    description = u'Open vSwitch metrics access script.'
    parser = optparse.OptionParser(version='2.17.8',
                                   usage='usage: %prog [options]',
                                   description=description)

    parser.add_option('-1', '--one-shot', dest='one', default=False,
                      help='Display metrics once then quit',
                      action='store_true')

    parser.add_option('-f', '--follow', dest='follow', default=False,
                      help='display metrics changes',
                      action='store_true')

    parser.add_option('-x', '--extended', dest='extended', default=False,
                      help='read the extended metrics page as well',
                      action='store_true')

    parser.add_option('-d', '--debug', dest='debug', default=False,
                      help='read the debug metrics page as well',
                      action='store_true')

    parser.add_option('-p', '--period', dest='period', type='int', default=1000,
                      help='metrics query periodicity', metavar='<ms>')

    parser.add_option('-o', '--output', dest='outfile', default='',
                      help='write measures to file', metavar='<file.csv>')

    parser.add_option('-w', '--wait', dest='waitforpkt', default=False,
                      help='Wait for packet activity before starting.',
                      action='store_true')

    (options, args) = parser.parse_args()

    if not validate_options():
        exit(1)

    if options.one:
        metrics_one_shot()
    elif options.follow:
        metrics_follow()
    else:
        if options.outfile:
            try:
                outfile = open(options.outfile, 'w')
            except Exception as e:
                sys.stderr.write("Failed to open '%s': %s.\n" % (e.filename, e.strerror))
                exit(1)

        metrics_watch()

        if outfile:
            outfile.close()


if __name__ == '__main__':
    main()

# vi: filetype=python
