yeandy commented on code in PR #21738: URL: https://github.com/apache/beam/pull/21738#discussion_r951480517
########## .test-infra/jenkins/job_InferenceBenchmarkTests_Python.groovy: ########## @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import CommonJobProperties as commonJobProperties +import LoadTestsBuilder as loadTestsBuilder +import PhraseTriggeringPostCommitBuilder +import CronJobBuilder + +def now = new Date().format("MMddHHmmss", TimeZone.getTimeZone('UTC')) + +def loadTestConfigurations = { + -> + [ + // Benchmark test config. Add multiple configs for multiple models. + // (TODO): Add model name to experiments once decided on which models to use. + [ + title : 'Pytorch Vision Classification with Resnet 101', + test : 'apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-imagenet-python' + now, + project : 'apache-beam-testing', + region : 'us-central1', + staging_location : 'gs://temp-storage-for-perf-tests/loadtests', + temp_location : 'gs://temp-storage-for-perf-tests/loadtests', + requirements_file : 'apache_beam/ml/inference/torch_tests_requirements.txt', + publish_to_big_query : true, + metrics_dataset : 'beam_run_inference', + metrics_table : 'torch_inference_imagenet_results_resnet101', + input_options : '{}', // this option is not required for RunInference tests. + influx_measurement : 'torch_inference_imagenet_resnet101', + influx_db_name : InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influx_hostname : InfluxDBCredentialsHelper.InfluxDBHostUrl, + // args defined in the performance test + pretrained_model_name : 'resnet101', + // args defined in the example. + input : 'gs://apache-beam-ml/testing/inputs/openimage_50k_benchmark.txt', + // TODO: make sure the model_state_dict_path weights are accurate. + model_state_dict_path : 'gs://apache-beam-ml/models/torchvision.models.resnet101.pth', + output : 'gs://temp-storage-for-end-to-end-tests/torch/result_' + now + '.txt' + ] + ], + [ + title : 'Pytorch Imagenet Classification with Resnet 152', + test : 'apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-imagenet-python' + now, + project : 'apache-beam-testing', + region : 'us-central1', + staging_location : 'gs://temp-storage-for-perf-tests/loadtests', + temp_location : 'gs://temp-storage-for-perf-tests/loadtests', + requirements_file : 'apache_beam/ml/inference/torch_tests_requirements.txt', + publish_to_big_query : true, + metrics_dataset : 'beam_run_inference', + metrics_table : 'torch_inference_imagenet_results_resnet152', + input_options : '{}', // this option is not required for RunInference tests. + influx_measurement : 'torch_inference_imagenet_resnet152', + influx_db_name : InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influx_hostname : InfluxDBCredentialsHelper.InfluxDBHostUrl, + // args defined in the performance test + pretrained_model_name : 'resnet152', + // args defined in the example. + input : 'gs://apache-beam-ml/testing/inputs/openimage_50k_benchmark.txt', + // TODO: make sure the model_state_dict_path weights are accurate. + model_state_dict_path : 'gs://apache-beam-ml/models/torchvision.models.resnet152.pth', + output : 'gs://temp-storage-for-end-to-end-tests/torch/result_' + now + '.txt' + ] + ], + // pytorch language modeling test using HuggingFace bert models Review Comment: ```suggestion // Pytorch language modeling test using HuggingFace BERT models ``` ########## sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py: ########## @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples.inference import pytorch_image_classification +from apache_beam.testing.load_tests.load_test import LoadTest +from torchvision import models + +_PERF_TEST_MODELS = ['resnet50', 'resnet101', 'resnet152'] +_PRETRAINED_MODEL_MODULE = 'torchvision.models' + + +class PytorchVisionBenchmarkTest(LoadTest): + def __init__(self): + # TODO (anandinguva): make get_namespace() method in RunInference static + self.metrics_namespace = 'RunInferencePytorch' + super().__init__(metrics_namespace=self.metrics_namespace) + + def test(self): + pretrained_model_name = self.pipeline.get_option('pretrained_model_name') + if not pretrained_model_name: + raise RuntimeError( + 'Please provide a pretrained torch model name.' + ' Model name must be from the module torchvision.models') + if pretrained_model_name == _PERF_TEST_MODELS[0]: + model_class = models.resnet50 + elif pretrained_model_name == _PERF_TEST_MODELS[1]: + model_class = models.resnet101 + elif pretrained_model_name == _PERF_TEST_MODELS[2]: + model_class = models.resnet152 Review Comment: Maybe define `_PERF_TEST_MODELS` to be a dict mapping the string "resnet50" to the class? i.e. `_PERF_TEST_MODELS = {'resnet50': models.resnet50}` and do a retrieval from the lookup? ########## .test-infra/jenkins/job_InferenceBenchmarkTests_Python.groovy: ########## @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import CommonJobProperties as commonJobProperties +import LoadTestsBuilder as loadTestsBuilder +import PhraseTriggeringPostCommitBuilder +import CronJobBuilder + +def now = new Date().format("MMddHHmmss", TimeZone.getTimeZone('UTC')) + +def loadTestConfigurations = { + -> + [ + // Benchmark test config. Add multiple configs for multiple models. + // (TODO): Add model name to experiments once decided on which models to use. + [ + title : 'Pytorch Vision Classification with Resnet 101', + test : 'apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-imagenet-python' + now, + project : 'apache-beam-testing', + region : 'us-central1', + staging_location : 'gs://temp-storage-for-perf-tests/loadtests', + temp_location : 'gs://temp-storage-for-perf-tests/loadtests', + requirements_file : 'apache_beam/ml/inference/torch_tests_requirements.txt', + publish_to_big_query : true, + metrics_dataset : 'beam_run_inference', + metrics_table : 'torch_inference_imagenet_results_resnet101', + input_options : '{}', // this option is not required for RunInference tests. + influx_measurement : 'torch_inference_imagenet_resnet101', + influx_db_name : InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influx_hostname : InfluxDBCredentialsHelper.InfluxDBHostUrl, + // args defined in the performance test + pretrained_model_name : 'resnet101', + // args defined in the example. + input : 'gs://apache-beam-ml/testing/inputs/openimage_50k_benchmark.txt', + // TODO: make sure the model_state_dict_path weights are accurate. + model_state_dict_path : 'gs://apache-beam-ml/models/torchvision.models.resnet101.pth', + output : 'gs://temp-storage-for-end-to-end-tests/torch/result_' + now + '.txt' + ] + ], + [ + title : 'Pytorch Imagenet Classification with Resnet 152', + test : 'apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-imagenet-python' + now, + project : 'apache-beam-testing', + region : 'us-central1', + staging_location : 'gs://temp-storage-for-perf-tests/loadtests', + temp_location : 'gs://temp-storage-for-perf-tests/loadtests', + requirements_file : 'apache_beam/ml/inference/torch_tests_requirements.txt', + publish_to_big_query : true, + metrics_dataset : 'beam_run_inference', + metrics_table : 'torch_inference_imagenet_results_resnet152', + input_options : '{}', // this option is not required for RunInference tests. + influx_measurement : 'torch_inference_imagenet_resnet152', + influx_db_name : InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influx_hostname : InfluxDBCredentialsHelper.InfluxDBHostUrl, + // args defined in the performance test + pretrained_model_name : 'resnet152', + // args defined in the example. + input : 'gs://apache-beam-ml/testing/inputs/openimage_50k_benchmark.txt', + // TODO: make sure the model_state_dict_path weights are accurate. + model_state_dict_path : 'gs://apache-beam-ml/models/torchvision.models.resnet152.pth', + output : 'gs://temp-storage-for-end-to-end-tests/torch/result_' + now + '.txt' + ] + ], + // pytorch language modeling test using HuggingFace bert models + [ + title : 'Pytorch Lanugaue Modeling using Hugging face bert-base-uncased model', Review Comment: ```suggestion title : 'Pytorch Language Modeling using Hugging Face bert-base-uncased model', ``` ########## .test-infra/jenkins/job_InferenceBenchmarkTests_Python.groovy: ########## @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import CommonJobProperties as commonJobProperties +import LoadTestsBuilder as loadTestsBuilder Review Comment: Is it just me, or does anyone see these red-ish color blocks here?  ########## sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py: ########## @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples.inference import pytorch_image_classification +from apache_beam.testing.load_tests.load_test import LoadTest +from torchvision import models + +_PERF_TEST_MODELS = ['resnet50', 'resnet101', 'resnet152'] +_PRETRAINED_MODEL_MODULE = 'torchvision.models' + + +class PytorchVisionBenchmarkTest(LoadTest): + def __init__(self): + # TODO (anandinguva): make get_namespace() method in RunInference static Review Comment: Do you mean `get_metrics_namespace()`? Let's add a GH issue to this. ########## .test-infra/jenkins/job_InferenceBenchmarkTests_Python.groovy: ########## @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import CommonJobProperties as commonJobProperties +import LoadTestsBuilder as loadTestsBuilder +import PhraseTriggeringPostCommitBuilder +import CronJobBuilder + +def now = new Date().format("MMddHHmmss", TimeZone.getTimeZone('UTC')) + +def loadTestConfigurations = { + -> + [ + // Benchmark test config. Add multiple configs for multiple models. + // (TODO): Add model name to experiments once decided on which models to use. + [ + title : 'Pytorch Vision Classification with Resnet 101', + test : 'apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-imagenet-python' + now, + project : 'apache-beam-testing', + region : 'us-central1', + staging_location : 'gs://temp-storage-for-perf-tests/loadtests', + temp_location : 'gs://temp-storage-for-perf-tests/loadtests', + requirements_file : 'apache_beam/ml/inference/torch_tests_requirements.txt', + publish_to_big_query : true, + metrics_dataset : 'beam_run_inference', + metrics_table : 'torch_inference_imagenet_results_resnet101', + input_options : '{}', // this option is not required for RunInference tests. + influx_measurement : 'torch_inference_imagenet_resnet101', + influx_db_name : InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influx_hostname : InfluxDBCredentialsHelper.InfluxDBHostUrl, + // args defined in the performance test + pretrained_model_name : 'resnet101', + // args defined in the example. + input : 'gs://apache-beam-ml/testing/inputs/openimage_50k_benchmark.txt', + // TODO: make sure the model_state_dict_path weights are accurate. + model_state_dict_path : 'gs://apache-beam-ml/models/torchvision.models.resnet101.pth', + output : 'gs://temp-storage-for-end-to-end-tests/torch/result_' + now + '.txt' + ] + ], + [ + title : 'Pytorch Imagenet Classification with Resnet 152', + test : 'apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-imagenet-python' + now, + project : 'apache-beam-testing', + region : 'us-central1', + staging_location : 'gs://temp-storage-for-perf-tests/loadtests', + temp_location : 'gs://temp-storage-for-perf-tests/loadtests', + requirements_file : 'apache_beam/ml/inference/torch_tests_requirements.txt', + publish_to_big_query : true, + metrics_dataset : 'beam_run_inference', + metrics_table : 'torch_inference_imagenet_results_resnet152', + input_options : '{}', // this option is not required for RunInference tests. + influx_measurement : 'torch_inference_imagenet_resnet152', + influx_db_name : InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influx_hostname : InfluxDBCredentialsHelper.InfluxDBHostUrl, + // args defined in the performance test + pretrained_model_name : 'resnet152', + // args defined in the example. + input : 'gs://apache-beam-ml/testing/inputs/openimage_50k_benchmark.txt', + // TODO: make sure the model_state_dict_path weights are accurate. + model_state_dict_path : 'gs://apache-beam-ml/models/torchvision.models.resnet152.pth', + output : 'gs://temp-storage-for-end-to-end-tests/torch/result_' + now + '.txt' + ] + ], + // pytorch language modeling test using HuggingFace bert models + [ + title : 'Pytorch Lanugaue Modeling using Hugging face bert-base-uncased model', + test : 'apache_beam.testing.benchmarks.inference.pytorch_language_modeling_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-language-modeling-bert-base-uncased' + now, + project : 'apache-beam-testing', + region : 'us-central1', + staging_location : 'gs://temp-storage-for-perf-tests/loadtests', + temp_location : 'gs://temp-storage-for-perf-tests/loadtests', + requirements_file : 'apache_beam/ml/inference/torch_tests_requirements.txt', + pickle_library : 'cloudpickle', + publish_to_big_query : true, + metrics_dataset : 'beam_run_inference', + metrics_table : 'torch_language_modeling_bert_base_uncased', + input_options : '{}', // this option is not required for RunInference tests. + influx_measurement : 'torch_language_modeling_bert_base_uncased', + influx_db_name : InfluxDBCredentialsHelper.InfluxDBDatabaseName, + influx_hostname : InfluxDBCredentialsHelper.InfluxDBHostUrl, + // args defined in the example. + input : 'gs://apache-beam-ml/testing/inputs/sentences_50k.txt', + // TODO: make sure the model_state_dict_path weights are accurate. + bert_tokenizer : 'bert-base-uncased', + model_state_dict_path : 'gs://apache-beam-ml/models/huggingface.BertForMaskedLM.bert-base-uncased.pth', + output : 'gs://temp-storage-for-end-to-end-tests/torch/result_' + now + '.txt', + ] + ], + [ + title : 'Pytorch Lanugaue Modeling using Hugging face bert-large-uncased model', Review Comment: ```suggestion title : 'Pytorch Langauge Modeling using Hugging Face bert-large-uncased model', ``` ########## sdks/python/apache_beam/options/pipeline_options.py: ########## @@ -205,7 +205,9 @@ def __init__(self, flags=None, **kwargs): # Build parser that will parse options recognized by the [sub]class of # PipelineOptions whose object is being instantiated. - parser = _BeamArgumentParser() + # set allow_abbrev=False to avoid prefix matching while parsing. + # https://docs.python.org/3/library/argparse.html#partial-parsing + parser = _BeamArgumentParser(allow_abbrev=False) Review Comment: This is to address ambiguous options `input` and `input_options`? Will this break users who have previously written pipelines with other abbreviations, esp custom ones? ########## .test-infra/jenkins/job_InferenceBenchmarkTests_Python.groovy: ########## @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import CommonJobProperties as commonJobProperties +import LoadTestsBuilder as loadTestsBuilder +import PhraseTriggeringPostCommitBuilder +import CronJobBuilder + +def now = new Date().format("MMddHHmmss", TimeZone.getTimeZone('UTC')) + +def loadTestConfigurations = { + -> + [ + // Benchmark test config. Add multiple configs for multiple models. + // (TODO): Add model name to experiments once decided on which models to use. + [ + title : 'Pytorch Vision Classification with Resnet 101', + test : 'apache_beam.testing.benchmarks.inference.pytorch_image_classification_benchmarks', + runner : CommonTestProperties.Runner.DATAFLOW, + pipelineOptions: [ + job_name : 'benchmark-tests-pytorch-imagenet-python' + now, Review Comment: Should we be consistent with the naming of the `test` and `job_name`? ########## sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py: ########## @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples.inference import pytorch_image_classification +from apache_beam.testing.load_tests.load_test import LoadTest +from torchvision import models + +_PERF_TEST_MODELS = ['resnet50', 'resnet101', 'resnet152'] +_PRETRAINED_MODEL_MODULE = 'torchvision.models' + + +class PytorchVisionBenchmarkTest(LoadTest): Review Comment: Right now, this only runs `pytorch_image_classification`. Are we going to add `pytorch_image_segmentation` to this class as well? ########## sdks/python/apache_beam/testing/benchmarks/inference/pytorch_image_classification_benchmarks.py: ########## @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import logging + +from apache_beam.examples.inference import pytorch_image_classification +from apache_beam.testing.load_tests.load_test import LoadTest +from torchvision import models + +_PERF_TEST_MODELS = ['resnet50', 'resnet101', 'resnet152'] +_PRETRAINED_MODEL_MODULE = 'torchvision.models' + + +class PytorchVisionBenchmarkTest(LoadTest): + def __init__(self): + # TODO (anandinguva): make get_namespace() method in RunInference static + self.metrics_namespace = 'RunInferencePytorch' + super().__init__(metrics_namespace=self.metrics_namespace) + + def test(self): + pretrained_model_name = self.pipeline.get_option('pretrained_model_name') + if not pretrained_model_name: + raise RuntimeError( + 'Please provide a pretrained torch model name.' + ' Model name must be from the module torchvision.models') + if pretrained_model_name == _PERF_TEST_MODELS[0]: + model_class = models.resnet50 + elif pretrained_model_name == _PERF_TEST_MODELS[1]: + model_class = models.resnet101 + elif pretrained_model_name == _PERF_TEST_MODELS[2]: + model_class = models.resnet152 + else: + raise NotImplementedError + + # model_params are same for all the models. But this may change if we add + # different models. + model_params = {'num_classes': 1000, 'pretrained': False} Review Comment: Should we parametrize this now? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
