cconvey commented on code in PR #11400:
URL: https://github.com/apache/tvm/pull/11400#discussion_r881863827


##########
tests/python/contrib/test_hexagon/benchmarks_table.py:
##########
@@ -0,0 +1,231 @@
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import csv
+
+class benchmarks_table:
+    """
+    Stores/reports the result of benchmark runs.
+
+    Each line item has a status: success, fail, or skip.
+
+    Each 'success' line item must include benchmark data,
+    in the form provided by TVM's `time_evaluator` mechanism.
+
+    Each line item may also specify values for any subset of
+    the columns provided to the table's construstor.
+    """
+    BUILTIN_COLUMN_NAMES_TO_DESCS = {
+            "row_status":"status",
+            "timings_median_usecs":"median(µsec)",
+            "timings_min_usecs":"min(µsec)",
+            "timings_max_usecs":"max(µsec)",
+            }
+
+    class column_metadata_:
+        def __init__(self, name, is_reserved, header_text):
+            self.name = name
+            self.is_reserved = is_reserved
+            self.header_text = header_text
+
+    class column_collection_metadata_:
+        def __init__(self):
+            self.by_name = {}
+            self.by_header_text = {}
+
+        def add(self, cm):
+            if cm.name in self.by_name:
+                raise Exception(f"A column already exists with name 
'{cm.name}'")
+
+            if cm.header_text in self.by_header_text:
+                raise Exception(f"A column already exists with header_text 
'{cm.header_text}'")
+
+            self.by_name[ cm.name ] = cm
+            self.by_header_text[ cm.header_text ] = cm
+
+        def get_column_names(self):
+            return set(self.by_name.keys())
+
+        def get_unreserved_column_names(self):
+            return set([ k for k,v in self.by_name.items() if not 
v.is_reserved])
+
+        def get_reserved_column_names(self):
+            return set([ k for k,v in self.by_name.items() if v.is_reserved])
+
+        def get_ordered_by_name_sequence(self, name_sequence):
+            """
+            Returns a list of column_metadata objects, ordered according to
+            `name_sequence`.
+            """
+            return_list = []
+            for column_name in name_sequence:
+                assert column_name in self.by_name
+                return_list.append(self.by_name[column_name])
+            return return_list
+
+        def convert_dict_key_from_column_name_to_header_text(self, d_in):
+            """
+            `d_in` : A dictionary whose keys are a subset of those in 
`self.by_name`
+
+            Returns a new dictionary whose keys have been replaced with the
+            corresponding `header_text`.
+
+            Useful for things like csv.DictWriter.
+            """
+            d_out = {}
+
+            for k_in,v in d_in.items():
+                k_out = self.by_name[k_in].header_text
+                d_out[ k_out ] = v
+
+            return d_out
+
+    def __init__(self, user_column_defns):
+        """
+        `user_column_defns` : A dictionary of the form
+            (column_name : column_description).
+
+            The combination of this dictionary and the
+            BUILTIN_COLUMN_NAMES_TO_DESCS dictionary defines the set
+            of columns in that the benchmark table supports.
+
+            In the combined dictionary, no two columns can have
+            the same name or the same description.
+        """
+        self.all_cols_metadata_ = self.column_collection_metadata_()
+
+        for col_name, col_header_text in 
self.BUILTIN_COLUMN_NAMES_TO_DESCS.items():
+            self.all_cols_metadata_.add(self.column_metadata_(col_name, True, 
col_header_text))
+
+        for col_name, col_header_text in user_column_defns.items():
+            self.all_cols_metadata_.add(self.column_metadata_(col_name, False, 
col_header_text))
+
+        self.line_items_ = []
+
+    def validate_user_supplied_kwargs(self, kwarg_dict):
+        provided_column_names = set(kwarg_dict.keys())
+        defined_column_names = self.all_cols_metadata_.get_column_names()
+        reserved_column_names = 
self.all_cols_metadata_.get_reserved_column_names()
+
+        reserved_names_used = 
provided_column_names.intersection(reserved_column_names)
+        undefined_names_used = provided_column_names - defined_column_names
+
+        if len(reserved_names_used) > 0:
+            name_list = ', '.join(reserved_names_used)
+            raise Exception(f'Cannot supply a value for reserved column names: 
{reserved_names_used}')
+
+        if len(undefined_names_used) > 0:
+            name_list = ', '.join(undefined_names_used)
+            raise Exception(f'Cannot supply a value for undefined column 
names: {undefined_names_used}')
+
+    def record_success(self, timings, **kwargs):
+        """
+        `timings` : Assumed to have the structure and syntax of
+          the timing results provided by TVM's `time_evaluator`
+          mechanism.
+
+        `kwargs` : Optional values for any of the other columns
+          defined for this benchmark table.
+        """
+        self.validate_user_supplied_kwargs(kwargs)
+
+        line_item = dict(kwargs)
+
+        line_item['row_status'] = 'SUCCESS'
+        line_item['timings_min_usecs'] = timings.min * 1000000
+        line_item['timings_max_usecs'] = timings.max * 1000000
+        line_item['timings_median_usecs'] = timings.median * 1000000
+
+        self.line_items_.append(line_item)
+
+    def record_skip(self, **kwargs):
+        self.validate_user_supplied_kwargs(kwargs)
+
+        line_item = dict(kwargs)
+        line_item['row_status'] = 'SKIP'
+        self.line_items_.append(line_item)
+
+    def record_fail(self, **kwargs):
+        self.validate_user_supplied_kwargs(kwargs)
+
+        line_item = dict(kwargs)
+        line_item['row_status'] = 'FAIL'
+        self.line_items_.append(line_item)
+
+    def has_fail(self):
+        """
+        Returns True if the table contains at least one 'file' line item,
+        otherwise returns False.
+        """
+        for li in self.line_items_:
+            if li['row_status'] == 'FAIL':
+                return True
+
+        return False
+
+    def print_csv(self, f, column_name_order, timing_decimal_places=3):
+        """
+        Print the benchmark results as a csv.
+
+        `f` : The output stream.
+
+        `column_name_order`: an iterable sequence of column names, indicating 
the
+           order of column in the CSV output.
+           Each string must be one of the column names provided by
+           BUILTIN_COLUMN_NAMES_TO_DESCS or provided to the class constructor.
+
+           The CSV output will contain only those columns that are mentioned in
+           this list.
+
+        `timing_decimal_places`: for the numeric timing values, this is the
+           number of decimal places to provide in the printed output.
+           For example, a value of 3 is equivalent to the Python formatting 
string
+           `'{:.3f}'`
+        """
+        csv.register_dialect(
+            "benchmarks",
+            delimiter="\t",
+            quotechar='"',
+            quoting=csv.QUOTE_MINIMAL,
+        )
+
+        output_order_cm_list = 
self.all_cols_metadata_.get_ordered_by_name_sequence(column_name_order)
+
+        output_order_header_texts = [ cm.header_text for cm in 
output_order_cm_list ]
+
+        writer = csv.DictWriter(f, output_order_header_texts, 
dialect="benchmarks", restval="")
+
+        writer.writeheader()
+        for line_item_dict in self.line_items_:
+            for k in [
+                    "timings_median_usecs",
+                    "timings_min_usecs",
+                    "timings_max_usecs",
+                    ]:
+                if k in line_item_dict:
+                    old_value = line_item_dict[k]
+                    if isinstance(old_value, float):
+                        str_value = f"{old_value:>0.{timing_decimal_places}f}"
+                        line_item_dict[k] = str_value

Review Comment:
   > I'd recommend constructing a temporary dictionary containing the updated 
object.
   
   Good idea. Done.
   
   > We can also avoid the nested format string by using the [builtin 
round](https://docs.python.org/3/library/functions.html#round) function
   
   In this particular case I prefer keeping  rounding / truncation in the 
string-rendering instead of using `round(...)`, for several reasons:
   - Some programmers (at least me) are accustomed to being vigilant about 
floating-point imprecision when representing base-10 fractions.  So even if the 
recipe you gave works in practice, it can be distracting because it fails a 
smell test.
   - Using only `round(...)` removes the users' ability to specify the exact 
number of decimal places in the output.  E.g., the number `print(round(4.2, 
3))` will print "4.2", not "4.200".  This removes some context w.r.t. the 
precision of the measurements.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to