nudles commented on a change in pull request #638: SINGA-500 add onnx example models URL: https://github.com/apache/singa/pull/638#discussion_r397844262
########## File path: examples/onnx/bert/bert-squad.py ########## @@ -0,0 +1,151 @@ +import os +import zipfile +import numpy as np +import json + +from singa import device +from singa import tensor +from singa import sonnx +import onnx +import tokenization +from run_onnx_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions + +import sys +sys.path.append(os.path.dirname(__file__) + '/..') +from utils import download_model, update_batch_size, check_exist_or_download + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(message)s') + +max_answer_length = 30 +max_seq_length = 256 +doc_stride = 128 +max_query_length = 64 +n_best_size = 20 +batch_size = 1 + + +def load_vocab(): + url = 'https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip' + download_dir = '/tmp/' + filename = os.path.join(download_dir, 'uncased_L-12_H-768_A-12', '.', + 'vocab.txt') + with zipfile.ZipFile(check_exist_or_download(url), 'r') as z: + z.extractall(path=download_dir) + return filename + + +class Infer: + + def __init__(self, sg_ir): + self.sg_ir = sg_ir + + def forward(self, x): + return sg_ir.run(x) + + +def preprocessing(): + vocab_file = load_vocab() + tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, + do_lower_case=True) + predict_file = os.path.join(os.path.dirname(__file__), 'inputs.json') + # print content + with open(predict_file) as json_file: + test_data = json.load(json_file) + print("The input is:", json.dumps(test_data, indent=2)) + + eval_examples = read_squad_examples(input_file=predict_file) + + # Use convert_examples_to_features method from run_onnx_squad to get parameters from the input + input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features( + eval_examples, tokenizer, max_seq_length, doc_stride, max_query_length) + return input_ids, input_mask, segment_ids, extra_data, eval_examples + + +def postprocessing(eval_examples, extra_data, all_results): Review comment: postprocess? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
