#!/usr/bin/env python
#
# Copyright (c) 2005 Cisco Systems.  All rights reserved.
# 
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; either version 2 of the
# License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

import sys, getopt, struct, umad

class rec:
    pass

def usage():
    print "Usage:"
    print "  ", sys.argv[0], " <lid>"
    print
    print "Options:"
    print "  -d, --device=<device>      use device to send queries (default /dev/infiniband/umad0)"
    print "  -t, --trap=<local device>  path for trap dest in ClassPortInfo (Topspin/Cisco compat)"
    print "                             e.g. /sys/class/infiniband/mthca0/ports/1"

def set_class_port_info(f, agt, dlid, dev):
    af = open(dev + "/lid")
    slid = struct.pack('!H', int(af.read(), 16))
    af.close()
    af = open(dev + "/gids/0")
    sgids = af.read()
    af.close()
    sgid = ''.join([struct.pack('!H', int(sgids[i * 5:i * 5 + 4], 16)) for i in range(8)])

    payload = '\x00' * 40 + sgid + '\x00' * 4 + slid
    payload += (192 - len(payload)) * '\x00'

    mad = umad.Mad(umad.Mad.header(umad.MgmtClass.DEV_MGT, 1,
                                   umad.Method.SET, 0, 0, 0,
                                   umad.devmgt.Attr.CLASS_PORT_INFO, 0) +
                   '\x00' * 40 + payload, qpn = 1, qkey = umad.QP1_QKEY,
                   lid = dlid, timeout_ms = 2000)

    agt.send(mad)
    recv = f.recv()

    if recv._mad._status:
        raise IOError

def get_iou_info(f, agt, dlid):
    mad = umad.Mad(umad.Mad.header(umad.MgmtClass.DEV_MGT, 1,
                                   umad.Method.GET, 0, 0, 0,
                                   umad.devmgt.Attr.IO_UNIT_INFO, 0) +
                   '\x00' * 232, qpn = 1, qkey = umad.QP1_QKEY,
                   lid = dlid, timeout_ms = 2000)
    agt.send(mad)
    recv = f.recv()

    if recv._mad._status:
        raise IOError

    ret = rec()
    ret._change_id, ret._max_controllers, ret._diagid_optionrom = \
                    struct.unpack('!HBB', recv._mad._data[64:68])
    ret._controller = [ -1 ]
    for i in range(64):
        ret._controller += [ ord(recv._mad._data[68 + i]) >> 4,
                             ord(recv._mad._data[68 + i]) & 0xf ]

    return ret

def get_ioc_prof(f, agt, dlid, index):
    mad = umad.Mad(umad.Mad.header(umad.MgmtClass.DEV_MGT, 1,
                                   umad.Method.GET, 0, 0, 0,
                                   umad.devmgt.Attr.IO_CONTROLLER_PROFILE,
                                   index) +
                   '\x00' * 232, qpn = 1, qkey = umad.QP1_QKEY,
                   lid = dlid, timeout_ms = 2000)
    agt.send(mad)
    recv = f.recv()

    if recv._mad._status:
        raise IOError

    ret = rec()
    ret._guid, ret._vendor_id, ret._device_id, ret._device_version, \
               ret._subsys_vendor_id, ret._subsys_device_id, \
               ret._io_class, ret._iosubclass, \
               ret._protocol, ret._protocol_version, ret._send_queue_depth, \
               ret._rdma_read_depth, ret._send_size, ret._rdma_size, \
               ret._cap_mask, ret._service_entries, ret._id = \
               struct.unpack('!QIIH2xIIHHHH4xHxBIIBxB9x64s',
                             recv._mad._data[64:192])

    return ret

def get_svc_entries(f, agt, dlid, index, begin, end):
    mad = umad.Mad(umad.Mad.header(umad.MgmtClass.DEV_MGT, 1,
                                   umad.Method.GET, 0, 0, 0,
                                   umad.devmgt.Attr.SERVICE_ENTRIES,
                                   (index << 16) | (end << 8) | begin ) +
                   '\x00' * 232, qpn = 1, qkey = umad.QP1_QKEY,
                   lid = dlid, timeout_ms = 2000)
    agt.send(mad)
    recv = f.recv()

    if recv._mad._status:
        raise IOError

    ret = rec()
    ret._service = [ ]
    for i in range(end - begin + 1):
        r = rec()
        r._name, r._id = struct.unpack('!40sQ',
                                       recv._mad._data[64 + i * 48:64 + (i + 1) * 48])
        ret._service.append(r)

    return ret

def main():
    try:
        opts, args = getopt.getopt(sys.argv[1:], "d:t:", [ "device=", "trap=" ])
    except:
        usage()
        sys.exit(1)

    if len(args) != 1:
        usage()
        sys.exit(1)

    dlid = int(args[0])
    dev  = "/dev/infiniband/umad0"
    trap = None

    for o, a in opts:
        if o in ("-d", "--device"):
            dev = a
        if o in ("-t", "--trap"):
            trap = a

    f = umad.UmadFile(dev)
    agt = f.reg_agent(1)

    if trap:
        set_class_port_info(f, agt, dlid, trap)

    iou_info = get_iou_info(f, agt, dlid)

    print "IO Unit Info:"
    print "    max controllers: ", iou_info._max_controllers
    print

    for i in range(1, iou_info._max_controllers + 1):
        if iou_info._controller[i] != umad.devmgt.IoUnitInfo.IOC_PRESENT:
            continue

        print "    controller[%3d]" % i

        ioc_prof = get_ioc_prof(f, agt, dlid, i)

        print "    GUID:      %016x" % ioc_prof._guid;
        print "    vendor ID: %06x" % ioc_prof._vendor_id;
        print "    device ID: %06x" % ioc_prof._device_id;
        print "    ID:        ", ioc_prof._id;
        print "    service entries: ", ioc_prof._service_entries;

        for j in range(0, ioc_prof._service_entries, 4):
            n = min(j + 3, ioc_prof._service_entries - 1)
            svc_entries = get_svc_entries(f, agt, dlid, i, j, n)

            for k in range(n - j + 1):
                print "        service[%3d]: %016x / %s" % (j + k,
                                                            svc_entries._service[k]._id,
                                                            svc_entries._service[k]._name)
    

if __name__ == "__main__":
    main()
