#!/usr/bin/env python3

""" This file is part of msrtool.
   
  Copyright (C) 2011 Anton Kochkov <anton.kochkov@gmail.com>

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License version 2 as
  published by the Free Software Foundation.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
"""

import sys
import os
import xml.etree.ElementTree
import re

# define BITFIELD class

class BITFIELD:
	
	# Object initialisation

	def __init__(self, start = 63, size = 63, name = "unknown", description = "none", type = "bin"):
		self.start = start
		self.size = size
		self.name = name
		self.description = description
		self.type = type
		self.values = dict({})
		self.value_quantity = 0

	# Define object equation operation

	def __eq__(self, other):
		return (self.start == other.start)&(self.size == other.size)

	def addValue(self, value, description, id):
		valtmp = dict({})
		valtmp["value"] = value
		valtmp["description"] = description
		self.values[id] = valtmp
		self.value_quantity += 1

	def getValue(self, id):
		return self.values.get(id)

	# define properties

	@property
	def start(self):
		return self.__start
	
	@start.setter
	def start(self, start):
		self.__start = start

	@property
	def size(self):
		return self.__size
	
	@size.setter
	def size(self, size):
		self.__size = size

	@property
	def name(self):
		return self.__name
	
	@name.setter
	def name(self, name):
		self.__name = name

	@property
	def description(self):
		return self.__description
	
	@description.setter
	def description(self, description):
		self.__description = description

	@property
	def type(self):
		return self.__type
	
	@type.setter
	def type(self, type):
		self.__type = type

	@property
	def value_quantity(self):
		return self.__value_quantity

	@value_quantity.setter
	def value_quantity(self, value_quantity):
		self.__value_quantity = value_quantity

# define MSR class

class MSR:
	
	# Object initialisation

	def __init__(self, address = 0, type = "ro", name = "", description = ""):
		self.address = address
		self.type = type
		self.name = name
		self.description = description
		self.bitfields = dict({})
		self.bitfield_quantity = 0
	
	# Define object equation operation

	def __eq__(self, other):
		return self.address == other.address

	# Included structures
	
	def addBitfield(self, bitfield, id):
		self.bitfields[id] = bitfield
		self.bitfield_quantity += 1

	def getBitfield(self, id):
		return self.bitfields.get(id)

	# define properties

	@property
	def address(self):
		return self.__address

	@address.setter
	def address(self, address):
		self.__address = address

	@property
	def name(self):
		return self.__name
	
	@name.setter
	def name(self, name):
		self.__name = name

	@property
	def description(self):
		return self.__description

	@description.setter
	def description(self, description):
		self.__description = description

	@property
	def type(self):
		return self.__type

	@type.setter
	def type(self, type):
		self.__type = type
	
	@property
	def bitfield_quantity(self):
		return self.__bitfield_quantity

	@bitfield_quantity.setter
	def bitfield_quantity(self, bitfield_quantity):
		self.__bitfield_quantity = bitfield_quantity

	

# define MSR_ARRAY class

class MSR_ARRAY:
	
	# Object initialisation

	def __init__(self, name = "unknown", description = "none"):
		self.name = name
		self.description = description
		self.cpu_condition = ""
		self.msrs = dict({})
		self.msr_quantity = 0

	# Included structures
	
	def addMsr(self, msr, id):
		self.msrs[id] = msr
		self.msr_quantity += 1

	def getMsr(self, id):
		return self.msrs.get(id)

	# define properties

	@property
	def name(self):
		return self.__name
	
	@name.setter
	def name(self, name):
		self.__name = name

	@property
	def description(self):
		return self.__description
	
	@description.setter
	def description(self, description):
		self.__description = description

	@property
	def cpu_condition(self):
		return self.__cpu_condition
	
	@cpu_condition.setter
	def cpu_condition(self, cpu_condition):
		self.__cpu_condition = cpu_condition

	@property
	def msr_quantity(self):
		return self.__msr_quantity

	@msr_quantity.setter
	def msr_quantity(self, msr_quantity):
		self.__msr_quantity = msr_quantity

msrarray = MSR_ARRAY()

# Parse XML file and fill MSR object properties

def parse_definitions(filename):
	tree = None
	try:
		tree = xml.etree.ElementTree.parse(filename)
	except (EnviromentError, xml.parsers.expat.ExpatError) as err:
		print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
		return False
	
	# Read name of MSR module

	xml_msrarray = tree.getroot()
	msrarray.name = xml_msrarray.get("name")
	msrarray.description = xml_msrarray.get("description")
	msrarray.msr_quantity = 0;

	# Read supported CPUs

	for element in tree.findall("cpu"):
		try:
			cpu = {}
			for attribute in ("family", "model", "stepping"):
				cpu[attribute] = element.get(attribute)
			
			# produce code for check cpuid

			if (msrarray.cpu_condition != ""):
				msrarray.cpu_condition = msrarray.cpu_condition + " | "
			msrarray.cpu_condition = msrarray.cpu_condition + "(({0} == id->family) & ({1} == id->model))".format(cpu["family"], cpu["model"])

		except (ValueError, LookupError) as err:
			print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
			return False

	# Read all MSRs
	
	for element in tree.findall("msr"):
		try:

			# read all msr attributes
			msr = {}
			for attribute in ("address", "type", "name", "description"):
				msr[attribute] = element.get(attribute)

			if (msr["type"] == "wo"):
				msr["type"] = "MSRTYPE_WRONLY"
			elif (msr["type"] == "rw"):
				msr["type"] = "MSRTYPE_RDRW"
			else:
				msr["type"] = "MSRTYPE_RDONLY"

			msr_tmp = MSR(msr["address"], msr["type"], msr["name"], msr["description"])

			# search all bitfields defined in msr
			for bitf in element.findall("bitfield"):
				try:

					# read all bitfield attributes
					bitfield = {}
					for bitattr in ("start", "size", "name", "description", "type"):
						bitfield[bitattr] = bitf.get(bitattr)

					if (bitfield["type"] == "bin"):
						bitfield["type"] = "PRESENT_BIN"
					elif (bitfield["type"] == "dec"):
						bitfield["type"] = "PRESENT_DEC"
					elif (bitfield["type"] == "hex"):
						bitfield["type"] = "PRESENT_HEX"
					elif (bitfield["type"] == "oct"):
						bitfield["type"] = "PRESENT_OCT"
					else:
						bitfield["type"] = "PRESENT_RSVD"

					btf_tmp = BITFIELD(bitfield["start"], bitfield["size"], bitfield["name"], bitfield["description"], bitfield["type"])
					
					# search all possible values with description
					for bitval in bitf.findall("value"):
						try:

							# read each value description
							value = {}
							for valattr in ("number", "description"):
								value[valattr] = bitval.get(valattr)
							
							btf_tmp.addValue(value["number"], value["description"], btf_tmp.value_quantity)
						except (ValueError, LookupError) as err:
							print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
							return False

					msr_tmp.addBitfield(btf_tmp, msr_tmp.bitfield_quantity)

				except (ValueError, LookupError) as err:
					print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
					return False

			msrarray.addMsr(msr_tmp, msrarray.msr_quantity)
	
		except (ValueError, LookupError) as err:
			print("{0}: import error: {1}".format(os.path.basename(sys.argv[0]), err))
			return False
	
	return True

# Simple output implementation

def output_print():
	print("#include \"msrtool.h\"\n")
	print("int {0}_probe(const struct targetdef *target) ".format(msrarray.name) + "{")
	print("\tstruct cpuid_t *id = cpuid();")
	print("\tint cond = ({0});".format(msrarray.cpu_condition))
	print("\treturn cond;\n}\n")
	print(("const struct msrdef {0}_msrs[] = ".format(msrarray.name) + "{"))
	i = 0
	while i < msrarray.msr_quantity:
		print(("\t{ " + "{0}, {1}, MSR2(0, 0), \"{2}\", \"{3}\"".format(msrarray.getMsr(i).address, msrarray.getMsr(i).type, msrarray.getMsr(i).name, msrarray.getMsr(i).description) + ", {"))
		j = 0
		while j < msrarray.getMsr(i).bitfield_quantity:
			print("\t\t{ " + "{0}, {1}, \"{2}\", \"{3}\", {4}".format(msrarray.getMsr(i).getBitfield(j).start, msrarray.getMsr(i).getBitfield(j).size, msrarray.getMsr(i).getBitfield(j).name, msrarray.getMsr(i).getBitfield(j).description, msrarray.getMsr(i).getBitfield(j).type) + " }")
			k = 0
			while k < msrarray.getMsr(i).getBitfield(j).value_quantity:
				print("\t\t\t{ " + "MSR1({0}), \"{1}\"".format(msrarray.getMsr(i).getBitfield(j).getValue(k)["value"], msrarray.getMsr(i).getBitfield(j).getValue(k)["description"]) + " },")
				k += 1
			j += 1
			print("\t\t\t{ BITVAL_EOT }")
		print("\t\t{ BITS_EOT }")
		print("\t}},")
		i += 1
	print("\t{ MSR_EOT }")
	print("};")

# Simple output implementation

def output_file():
	filename = msrarray.name + ".c"
	fh = None
	try:
		fh = open(filename, "w", encoding="utf8")
		fh.write("#include \"msrtool.h\"\n\n")
		fh.write("int {0}_probe(const struct targetdef *target) ".format(msrarray.name) + "{\n")
		fh.write("\tstruct cpuid_t *id = cpuid();\n")
		fh.write("\tint cond = ({0});\n".format(msrarray.cpu_condition))
		fh.write("\treturn cond;\n}\n\n")
		fh.write(("const struct msrdef {0}_msrs[] = ".format(msrarray.name) + "{\n"))
		i = 0
		while i < msrarray.msr_quantity:
			fh.write(("\t{ " + "{0}, {1}, MSR2(0, 0), \"{2}\", \"{3}\"".format(msrarray.getMsr(i).address, msrarray.getMsr(i).type, msrarray.getMsr(i).name, msrarray.getMsr(i).description) + ", {\n"))
			j = 0
			while j < msrarray.getMsr(i).bitfield_quantity:
				fh.write("\t\t{ " + "{0}, {1}, \"{2}\", \"{3}\", {4}".format(msrarray.getMsr(i).getBitfield(j).start, msrarray.getMsr(i).getBitfield(j).size, msrarray.getMsr(i).getBitfield(j).name, msrarray.getMsr(i).getBitfield(j).description, msrarray.getMsr(i).getBitfield(j).type) + " }\n")
				k = 0
				while k < msrarray.getMsr(i).getBitfield(j).value_quantity:
					fh.write("\t\t\t{ " + "MSR1({0}), \"{1}\"".format(msrarray.getMsr(i).getBitfield(j).getValue(k)["value"], msrarray.getMsr(i).getBitfield(j).getValue(k)["description"]) + " },\n")
					k += 1
				j += 1
				fh.write("\t\t\t{ BITVAL_EOT }\n")
			fh.write("\t\t{ BITS_EOT }\n")
			fh.write("\t}},\n")
			i += 1
		fh.write("\t{ MSR_EOT }\n")
		fh.write("};\n")
		return True
	except (EnviromentError) as err:
		print("{0}: export error: {1}".format(os.path.basename(sys.argv[0]), err))
		return False
	finally:
		if fh is not None:
			fh.close()


# Produce patch file
def produce_patch():
	msrtool_header_file = "msrtool.h"
	msrtool_c_file = "msrtool.c"
	patch_name = msrarray.name + ".patch"

	# paste this before "#endif /* MSRTOOL_H */"
	header_add_line = "+/* {0}.c */\n+extern int {1}_probe(const struct targetdef *t);\n+extern const struct msrdef {2}_msrs[];\n".format(msrarray.name, msrarray.name, msrarray.name)
	header_line_number = 0
	
	# paste this before "\t{ TARGET_EOT }" and after "static struct targetdef alltargets[] = {"
	c_add_line = "+\t{" + "\"{0}\", \"{1}\", {2}_probe, {3}_msrs".format(msrarray.name, msrarray.description, msrarray.name, msrarray.name) + " },"
	c_line_number = 0

	# Search line numbers

	fh = None
	try:
		
		# Read header file

		fh = open(msrtool_header_file, "r", encoding="utf8")
		pattern = '#endif /* MSRTOOL_H */'
		src = fh.read()
		m = re.match(pattern, src)
		if m is not None:
			start = m.start()
			header_line_number = src.count('\n', 0, start) - 1
			print("Found chunk in {0} at line #{1}".format(msrtool_header_file, header_line_number))
		else:
			print("Not found anything useful in {0}".format(msrtool_header_file))

	except (EnviromentError) as err:
		print("{0}: header file reading error: {1}".format(os.path.basename(sys.argv[0]), err))
		return False
	finally:
		if fh is not None:
			fh.close()

	fh = None
	try:

		# Read C file

		fh = open(msrtool_c_file, "r", encoding="utf8")
		pattern = '\t{ TARGET_EOT }'
		src = fh.read()
		m = re.match(pattern, src)
		if m is not None:
			start = m.start()
			c_line_number = src.count('\n', 0, start)
			print("Found chunk in {0} at line #{1}".format(msrtool_c_file, c_line_number))
		else:
			print("Not found anything useful in {0}".format(msrtool_c_file))

	except (EnviromentError) as err:
		print("{0}: c file reading error: {1}".format(os.path.basename(sys.argv[0]), err))
		return False
	finally:
		if fh is not None:
			fh.close()

	fh = None
	try:
		
		# Patch header file

		fh = open(patch_name, "w", encoding="utf8")
		fh.write("Index: {0}\n".format(msrtool_header_file))
		fh.write("===================================================================\n")
		fh.write("--- {0} (revision 0)\n".format(msrtool_header_file))
		fh.write("+++ {0} (revision 0)\n".format(msrtool_header_file))
		fh.write("@@ -23,10 +23,10 @@\n")

		# fh.write("-CFLAGS  = @CFLAGS@\n") # We dont need this, really
		fh.write("{0}\n".format(header_add_line))
		
		# Patch *.c file

		fh.write("Index: {0}\n".format(msrtool_c_file))
		fh.write("===================================================================\n")
		fh.write("--- {0} (revision 0)\n".format(msrtool_c_file))
		fh.write("+++ {0} (revision 0)\n".format(msrtool_c_file))
		fh.write("@@ -23,10 +23,10 @@\n")
		
		# fh.write("-CFLAGS  = @CFLAGS@\n") # We dont need this, really
		fh.write("{0}\n".format(c_add_line))

		print("Patch \"{0}\" written successfully".format(patch_name))

	except (EnviromentError) as err:
		print("{0}: patch creating error: {1}".format(os.path.basename(sys.argv[0]), err))
		return False
	finally:
		if fh is not None:
			fh.close()
	

# Program body
# We trying to find all files *.msr.xml
# read them, produce files, produce patches

debug = 0
search_cmd = 'find . -name "*.msr.xml" -print'

for file in os.popen(search_cmd).readlines():     # run find command
	num  = 1
	name = file[:-1]
	parse_definitions(name)
	if debug:
		output_print()
	output_file()
	produce_patch()

