#!/usr/bin/env python3

import sys
import os
import xml.etree.ElementTree

# define BITFIELD class

class BITFIELD:
	
	# Object initialisation

	def __init__(self, start = 63, size = 63, name = "unknown", description = "none", type = "bin"):
		self.start = start
		self.size = size
		self.name = name
		self.description = description
		self.type = type
		self.value = dict({})

	# Define object equation operation

	def __eq__(self, other):
		return (self.start == other.start)&(self.size == other.size)

	def addValue(self, value, description):
		self.value["value"] = value
		self.value["description"] = description

	def getValue(self, value):
		return self.description

	# define properties

	@property
	def start(self):
		return self.__start
	
	@start.setter
	def start(self, start):
		self.__start = start

	@property
	def size(self):
		return self.__size
	
	@size.setter
	def size(self, size):
		self.__size = size

	@property
	def name(self):
		return self.__name
	
	@name.setter
	def name(self, name):
		self.__name = name

	@property
	def description(self):
		return self.__description
	
	@description.setter
	def description(self, description):
		self.__description = description

	@property
	def type(self):
		return self.__type
	
	@type.setter
	def type(self, type):
		self.__type = type

# define MSR class

class MSR:
	
	# Object initialisation

	def __init__(self, address = 0, type = "ro", name = "", description = ""):
		self.address = address
		self.type = type
		self.name = name
		self.description = description
		self.bitfields = dict({})
		self.bitfield_quantity = 0
	
	# Define object equation operation

	def __eq__(self, other):
		return self.address == other.address

	# Included structures
	
	def addBitfield(self, bitfield, id):
		self.bitfields[id] = bitfield

	def getBitfield(self, id):
		return self.bitfields.get(id)

	# define properties

	@property
	def address(self):
		return self.__address

	@address.setter
	def address(self, address):
		self.__address = address

	@property
	def name(self):
		return self.__name
	
	@name.setter
	def name(self, name):
		self.__name = name

	@property
	def description(self):
		return self.__description

	@description.setter
	def description(self, description):
		self.__description = description

	@property
	def type(self):
		return self.__type

	@type.setter
	def type(self, type):
		self.__type = type
	
	@property
	def bitfield_quantity(self):
		return self.__bitfield_quantity

	@bitfield_quantity.setter
	def bitfield_quantity(self, bitfield_quantity):
		self.__bitfield_quantity = bitfield_quantity

	

# define MSR_ARRAY class

class MSR_ARRAY:
	
	# Object initialisation

	def __init__(self, name = "unknown"):
		self.name = name
		self.cpu_condition = ""
		self.msrs = dict({})
		self.msr_quantity = 0

	# Included structures
	
	def addMsr(self, msr, id):
		self.msrs[id] = msr
		self.msr_quantity += 1

	def getMsr(self, id):
		return self.msrs.get(id)

	# define properties

	@property
	def name(self):
		return self.__name
	
	@name.setter
	def name(self, name):
		self.__name = name

	@property
	def cpu_condition(self):
		return self.__cpu_condition
	
	@cpu_condition.setter
	def cpu_condition(self, cpu_condition):
		self.__cpu_condition = cpu_condition

	@property
	def msr_quantity(self):
		return self.__msr_quantity

	@msr_quantity.setter
	def msr_quantity(self, msr_quantity):
		self.__msr_quantity = msr_quantity

msrarray = MSR_ARRAY()

# Parse XML file and fill MSR object properties

def parse_definitions(filename):
	tree = None
	try:
		tree = xml.etree.ElementTree.parse(filename)
	except (EnviromentError, xml.parsers.expat.ExpatError) as err:
		print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
		return False
	
	# Read name of MSR module

	xml_msrarray = tree.getroot()
	msrarray.name = xml_msrarray.get("name")

	# Read supported CPUs

	for element in tree.findall("cpu"):
		try:
			cpu = {}
			for attribute in ("family", "model", "stepping"):
				cpu[attribute] = element.get(attribute)
			
			# produce code for check cpuid

			if (msrarray.cpu_condition != ""):
				msrarray.cpu_condition = msrarray.cpu_condition + " | "
			msrarray.cpu_condition = msrarray.cpu_condition + "(({0} == id->family) & ({1} == id->model))".format(cpu["family"], cpu["model"])

		except (ValueError, LookupError) as err:
			print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
			return False

	# Read all MSRs
	
	for element in tree.findall("msr"):
		try:

			# read all msr attributes
			msr = {}
			for attribute in ("address", "type", "name", "description"):
				msr[attribute] = element.get(attribute)

			if (msr["type"] == "wo"):
				msr["type"] = "MSRTYPE_RDWO"
			elif (msr["type"] == "rw"):
				msr["type"] = "MSRTYPE_RDRW"
			else:
				msr["type"] = "MSRTYPE_RDRO"

			msr_tmp = MSR(msr["address"], msr["type"], msr["name"], msr["description"])

			# search all bitfields defined in msr
			for bitf in element.findall("bitfield"):
				try:

					# read all bitfield attributes
					bitfield = {}
					for bitattr in ("start", "size", "name", "description", "type"):
						bitfield[bitattr] = bitf.get(bitattr)

					btf_tmp = BITFIELD(bitfield["start"], bitfield["size"], bitfield["name"], bitfield["description"], bitfield["type"])
					
					# search all possible values with description
					for bitval in bitf.findall("value"):
						try:

							# read each value description
							value = {}
							for valattr in ("number", "description"):
								value[valattr] = bitval.get(valattr)
							
							btf_tmp.addValue(value["number"], value["description"])
						except (ValueError, LookupError) as err:
							print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
							return False

					msr_tmp.addBitfield(btf_tmp, msr_tmp.bitfield_quantity)
					msr_tmp.bitfield_quantity += 1

				except (ValueError, LookupError) as err:
					print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
					return False

			msrarray.addMsr(msr_tmp, msrarray.msr_quantity)
	
		except (ValueError, LookupError) as err:
			print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
			return False
	
	return True

# Simple output implementation

def output_print():
	print("#include \"msrtool.h\"\n")
	print("int {0}_probe(const struct targetdef *target) ".format(msrarray.name) + "{")
	print("\tstruct cpuid_t *id = cpuid();")
	print("\tint cond = ({0});".format(msrarray.cpu_condition))
	print("\treturn cond;\n}\n")
	print(("const struct msrdef {0}_msrs[] = ".format(msrarray.name) + "{"))
	i = 0
	while i < msrarray.msr_quantity:
		print(("\t{ " + "{0}, {1}, MSR2(0, 0), \"{2}\", \"{3}\"".format(msrarray.getMsr(i).address, msrarray.getMsr(i).type, msrarray.getMsr(i).name, msrarray.getMsr(i).description) + ", {"))
		j = 0
		while j < msrarray.getMsr(i).bitfield_quantity:
			print("\t\t{ " + "{0}, {1}, \"{2}\", \"{3}\", {4}".format(msrarray.getMsr(i).getBitfield(j).start, msrarray.getMsr(i).getBitfield(j).size, msrarray.getMsr(i).getBitfield(j).name, msrarray.getMsr(i).getBitfield(j).description, msrarray.getMsr(i).getBitfield(j).type) + " }")
			j += 1
		print("\t\t{ BITS_EOT }")
		print("\t}},")
		i += 1
	print("\t{ MSR_EOT }")
	print("}")

# Program body

parse_definitions("msr.xml")
output_print()


