DemesneGH commented on code in PR #173:
URL: 
https://github.com/apache/incubator-teaclave-trustzone-sdk/pull/173#discussion_r1993412576


##########
examples/mnist-rs/host/src/commands/train.rs:
##########
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::io::{Cursor, Read};
+use std::path::PathBuf;
+
+use crate::tee::Trainer;
+use optee_teec::Context;
+use proto::Image;
+use rand::seq::SliceRandom;
+
+#[derive(clap::Parser, Debug)]
+pub struct Args {
+    #[arg(short, long, default_value_t = 6)]
+    num_epochs: usize,
+    #[arg(short, long, default_value_t = 64)]
+    batch_size: usize,
+    #[arg(short, long, default_value_t = 0.0001)]
+    learning_rate: f64,
+    #[arg(short, long)]
+    output: Option<String>,
+}
+
+fn convert_datasets(images: &Vec<Image>, labels: &[u8]) -> Vec<(Image, u8)> {
+    let mut datasets: Vec<(Image, u8)> = images
+        .iter()
+        .map(|v| v.to_owned())
+        .zip(labels.iter().copied())
+        .collect();
+    datasets.shuffle(&mut rand::rng());
+    datasets
+}
+
+pub fn execute(args: &Args) -> anyhow::Result<()> {
+    // Initialize trainer
+    let mut ctx = Context::new()?;
+    let mut trainer = Trainer::new(&mut ctx, args.learning_rate)?;
+    // Download mnist data
+    let data = check_download_mnist_data()?;
+    // Prepare datasets
+    let train_datasets = convert_datasets(&data.train_data, 
&data.train_labels);
+    let valid_datasets = convert_datasets(&data.test_data, &data.test_labels);
+    // Training loop, Originally inspired by burn/crates/custom-training-loop
+    for epoch in 1..args.num_epochs + 1 {
+        for (iteration, data) in 
train_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.train(&images, &labels)?;
+            println!(
+                "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+
+        for (iteration, data) in 
valid_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.valid(&images, &labels)?;
+            println!(
+                "[Valid - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+    }
+    // Export the model to the given path
+    match args.output.as_ref() {
+        None => {}

Review Comment:
   use `if let Some(output_path) = args.output.as_ref() {..}` instead?



##########
examples/mnist-rs/host/src/tee.rs:
##########
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use optee_teec::{Context, ErrorKind, Operation, ParamNone, ParamTmpRef, 
Session, Uuid};
+use proto::{inference, train, Image};
+
+pub struct Trainer<'a> {
+    sess: Session<'a>,
+}
+
+impl<'a> Trainer<'a> {
+    pub fn new(ctx: &'a mut Context, learning_rate: f64) -> 
optee_teec::Result<Self> {
+        let bytes = learning_rate.to_le_bytes();
+        let uuid = Uuid::parse_str(train::UUID).map_err(|err| {
+            println!("parse uuid \"{}\" failed due to: {:?}", train::UUID, 
err);
+            ErrorKind::BadParameters
+        })?;
+        let mut op = Operation::new(
+            0,
+            ParamTmpRef::new_input(bytes.as_slice()),
+            ParamNone,
+            ParamNone,
+            ParamNone,
+        );
+
+        Ok(Self {
+            sess: ctx.open_session_with_operation(uuid, &mut op)?,
+        })
+    }
+    pub fn train(&mut self, images: &[Image], labels: &[u8]) -> 
optee_teec::Result<train::Output> {
+        let mut buffer = vec![0_u8; 1024];

Review Comment:
   Can we set the max size in const variables? same in line66, line88



##########
examples/mnist-rs/host/src/tee.rs:
##########
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use optee_teec::{Context, ErrorKind, Operation, ParamNone, ParamTmpRef, 
Session, Uuid};
+use proto::{inference, train, Image};
+
+pub struct Trainer<'a> {
+    sess: Session<'a>,
+}
+
+impl<'a> Trainer<'a> {
+    pub fn new(ctx: &'a mut Context, learning_rate: f64) -> 
optee_teec::Result<Self> {
+        let bytes = learning_rate.to_le_bytes();
+        let uuid = Uuid::parse_str(train::UUID).map_err(|err| {
+            println!("parse uuid \"{}\" failed due to: {:?}", train::UUID, 
err);
+            ErrorKind::BadParameters
+        })?;
+        let mut op = Operation::new(
+            0,
+            ParamTmpRef::new_input(bytes.as_slice()),
+            ParamNone,
+            ParamNone,
+            ParamNone,
+        );
+
+        Ok(Self {
+            sess: ctx.open_session_with_operation(uuid, &mut op)?,
+        })
+    }
+    pub fn train(&mut self, images: &[Image], labels: &[u8]) -> 
optee_teec::Result<train::Output> {
+        let mut buffer = vec![0_u8; 1024];
+        let images = bytemuck::cast_slice(images);
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_input(images),
+                ParamTmpRef::new_input(labels),
+                ParamTmpRef::new_output(&mut buffer),
+                ParamNone,
+            );
+            self.sess
+                .invoke_command(train::Command::Train as u32, &mut op)?;
+            op.parameters().2.updated_size()
+        };
+        let result = serde_json::from_slice(&buffer[0..size]).map_err(|err| {
+            println!("proto error: {:?}", err);
+            ErrorKind::BadFormat
+        })?;
+        Ok(result)
+    }
+    pub fn valid(&mut self, images: &[Image], labels: &[u8]) -> 
optee_teec::Result<train::Output> {
+        let mut buffer = vec![0_u8; 1024];
+        let images = bytemuck::cast_slice(images);
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_input(images),
+                ParamTmpRef::new_input(labels),
+                ParamTmpRef::new_output(&mut buffer),
+                ParamNone,
+            );
+            self.sess
+                .invoke_command(train::Command::Valid as u32, &mut op)?;
+            op.parameters().2.updated_size()
+        };
+        let result = serde_json::from_slice(&buffer[0..size]).map_err(|err| {
+            println!("proto error: {:?}", err);
+            ErrorKind::BadFormat
+        })?;
+        Ok(result)
+    }
+
+    pub fn export(&mut self) -> optee_teec::Result<Vec<u8>> {
+        let mut buffer = vec![0_u8; 10 * 1024 * 1024];
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_output(&mut buffer),
+                ParamNone,
+                ParamNone,
+                ParamNone,
+            );
+            self.sess
+                .invoke_command(train::Command::Export as u32, &mut op)?;
+            op.parameters().0.updated_size()
+        };
+        buffer.resize(size, 0);
+        Ok(buffer)
+    }
+}
+
+pub struct Model<'a> {
+    sess: Session<'a>,
+}
+
+unsafe impl Send for Model<'_> {}
+
+impl<'a> Model<'a> {
+    pub fn new(ctx: &'a mut Context, record: &[u8]) -> 
optee_teec::Result<Self> {
+        let uuid = Uuid::parse_str(inference::UUID).map_err(|err| {
+            println!(
+                "parse uuid \"{}\" failed due to: {:?}",
+                inference::UUID,
+                err
+            );
+            ErrorKind::BadParameters
+        })?;
+        let mut op = Operation::new(
+            0,
+            ParamTmpRef::new_input(record),
+            ParamNone,
+            ParamNone,
+            ParamNone,
+        );
+
+        Ok(Self {
+            sess: ctx.open_session_with_operation(uuid, &mut op)?,
+        })
+    }
+    pub fn infer_batch(&mut self, images: &[Image]) -> 
optee_teec::Result<Vec<u8>> {
+        let mut output = vec![0_u8; images.len()];
+        let size = {
+            let mut op = Operation::new(
+                0,
+                ParamTmpRef::new_input(bytemuck::cast_slice(images)),
+                ParamTmpRef::new_output(&mut output),
+                ParamNone,
+                ParamNone,
+            );
+            self.sess.invoke_command(0, &mut op)?;
+            op.parameters().1.updated_size()
+        };
+
+        assert_eq!(output.len(), size);

Review Comment:
   Prefer using `anyhow::ensure`



##########
examples/mnist-rs/ta/inference/src/main.rs:
##########
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![no_std]
+#![no_main]
+extern crate alloc;
+
+use burn::{
+    backend::{ndarray::NdArrayDevice, NdArray},
+    tensor::cast::ToElement,
+};
+
+use model::Model;
+use optee_utee::{
+    ta_close_session, ta_create, ta_destroy, ta_invoke_command, 
ta_open_session, trace_println,
+};
+use optee_utee::{ErrorKind, Parameter, Parameters, Result};
+use proto::Image;
+use spin::Mutex;
+
+type NoStdModel = Model<NdArray>;
+const DEVICE: NdArrayDevice = NdArrayDevice::Cpu;
+static MODEL: Mutex<Option<NoStdModel>> = Mutex::new(Option::None);
+
+#[ta_create]
+fn create() -> Result<()> {
+    trace_println!("[+] TA create");
+    Ok(())
+}
+
+#[ta_open_session]
+fn open_session(params: &mut Parameters) -> Result<()> {
+    let mut p0 = unsafe { params.0.as_memref()? };
+
+    let mut model = MODEL.lock();
+    model.replace(Model::import(&DEVICE, p0.buffer().to_vec()).map_err(|err| {
+        trace_println!("import failed: {:?}", err);
+        ErrorKind::BadParameters
+    })?);
+
+    Ok(())
+}
+
+#[ta_close_session]
+fn close_session() {
+    trace_println!("[+] TA close session");
+}
+
+#[ta_destroy]
+fn destroy() {
+    trace_println!("[+] TA destroy");
+}
+
+fn copy_to_output(param: &mut Parameter, data: &[u8]) -> Result<()> {
+    let mut output = unsafe { param.as_memref()? };
+
+    let buffer = output.buffer();
+    if buffer.len() < data.len() {
+        trace_println!(
+            "expect output buffer size {}, got size {} instead",
+            data.len(),
+            buffer.len()
+        );
+        return Err(ErrorKind::ShortBuffer.into());
+    }
+    buffer[..data.len()].copy_from_slice(data);
+    output.set_updated_size(data.len());
+    Ok(())
+}
+
+#[ta_invoke_command]
+fn invoke_command(_cmd_id: u32, params: &mut Parameters) -> Result<()> {
+    trace_println!("[+] TA invoke command");
+    let mut p0 = unsafe { params.0.as_memref()? };
+    let images: &[Image] = bytemuck::cast_slice(p0.buffer());
+    let input = NoStdModel::images_to_tensors(&DEVICE, images);
+
+    let output = {
+        let model = MODEL.lock();
+        model.as_ref().unwrap().forward(input)

Review Comment:
   Eliminate `unwrap()` to avoid panic.



##########
examples/mnist-rs/host/src/commands/train.rs:
##########
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::io::{Cursor, Read};
+use std::path::PathBuf;
+
+use crate::tee::Trainer;
+use optee_teec::Context;
+use proto::Image;
+use rand::seq::SliceRandom;
+
+#[derive(clap::Parser, Debug)]
+pub struct Args {
+    #[arg(short, long, default_value_t = 6)]
+    num_epochs: usize,
+    #[arg(short, long, default_value_t = 64)]
+    batch_size: usize,
+    #[arg(short, long, default_value_t = 0.0001)]
+    learning_rate: f64,
+    #[arg(short, long)]
+    output: Option<String>,
+}
+
+fn convert_datasets(images: &Vec<Image>, labels: &[u8]) -> Vec<(Image, u8)> {
+    let mut datasets: Vec<(Image, u8)> = images
+        .iter()
+        .map(|v| v.to_owned())
+        .zip(labels.iter().copied())
+        .collect();
+    datasets.shuffle(&mut rand::rng());
+    datasets
+}
+
+pub fn execute(args: &Args) -> anyhow::Result<()> {
+    // Initialize trainer
+    let mut ctx = Context::new()?;
+    let mut trainer = Trainer::new(&mut ctx, args.learning_rate)?;
+    // Download mnist data
+    let data = check_download_mnist_data()?;
+    // Prepare datasets
+    let train_datasets = convert_datasets(&data.train_data, 
&data.train_labels);
+    let valid_datasets = convert_datasets(&data.test_data, &data.test_labels);
+    // Training loop, Originally inspired by burn/crates/custom-training-loop
+    for epoch in 1..args.num_epochs + 1 {
+        for (iteration, data) in 
train_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.train(&images, &labels)?;
+            println!(
+                "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+
+        for (iteration, data) in 
valid_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.valid(&images, &labels)?;
+            println!(
+                "[Valid - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+    }
+    // Export the model to the given path
+    match args.output.as_ref() {
+        None => {}
+        Some(output_path) => {
+            let record = trainer.export()?;
+            println!("Export record to \"{}\"", output_path);
+            std::fs::write(output_path, &record)?;
+        }
+    }
+    println!("Train Success");
+    Ok(())
+}
+
+fn check_download_mnist_data() -> anyhow::Result<rust_mnist::Mnist> {
+    const DATA_PATH: &str = "./data/";
+
+    let folder = PathBuf::from(DATA_PATH);
+    if !folder.exists() {
+        std::fs::create_dir_all(&folder)?;
+    }
+    for (filename, gz_size, flat_size) in [
+        ("train-images-idx3-ubyte", 9912422, 47040016),
+        ("train-labels-idx1-ubyte", 28881, 60008),
+        ("t10k-images-idx3-ubyte", 1648877, 7840016),
+        ("t10k-labels-idx1-ubyte", 4542, 10008),
+    ]
+    .iter()

Review Comment:
   Can we get the size of file dynamically instead of hardcode here?



##########
examples/mnist-rs/host/src/commands/train.rs:
##########
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::io::{Cursor, Read};
+use std::path::PathBuf;
+
+use crate::tee::Trainer;
+use optee_teec::Context;
+use proto::Image;
+use rand::seq::SliceRandom;
+
+#[derive(clap::Parser, Debug)]
+pub struct Args {
+    #[arg(short, long, default_value_t = 6)]
+    num_epochs: usize,
+    #[arg(short, long, default_value_t = 64)]
+    batch_size: usize,
+    #[arg(short, long, default_value_t = 0.0001)]
+    learning_rate: f64,
+    #[arg(short, long)]
+    output: Option<String>,
+}
+
+fn convert_datasets(images: &Vec<Image>, labels: &[u8]) -> Vec<(Image, u8)> {
+    let mut datasets: Vec<(Image, u8)> = images
+        .iter()
+        .map(|v| v.to_owned())
+        .zip(labels.iter().copied())
+        .collect();
+    datasets.shuffle(&mut rand::rng());
+    datasets
+}
+
+pub fn execute(args: &Args) -> anyhow::Result<()> {
+    // Initialize trainer
+    let mut ctx = Context::new()?;
+    let mut trainer = Trainer::new(&mut ctx, args.learning_rate)?;
+    // Download mnist data
+    let data = check_download_mnist_data()?;
+    // Prepare datasets
+    let train_datasets = convert_datasets(&data.train_data, 
&data.train_labels);
+    let valid_datasets = convert_datasets(&data.test_data, &data.test_labels);
+    // Training loop, Originally inspired by burn/crates/custom-training-loop
+    for epoch in 1..args.num_epochs + 1 {
+        for (iteration, data) in 
train_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.train(&images, &labels)?;
+            println!(
+                "[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+
+        for (iteration, data) in 
valid_datasets.chunks(args.batch_size).enumerate() {
+            let images: Vec<Image> = data.iter().map(|v| v.0).collect();
+            let labels: Vec<u8> = data.iter().map(|v| v.1).collect();
+            let output = trainer.valid(&images, &labels)?;
+            println!(
+                "[Valid - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} 
%",
+                epoch, iteration, output.loss, output.accuracy,
+            );
+        }
+    }
+    // Export the model to the given path
+    match args.output.as_ref() {
+        None => {}
+        Some(output_path) => {
+            let record = trainer.export()?;
+            println!("Export record to \"{}\"", output_path);
+            std::fs::write(output_path, &record)?;
+        }
+    }
+    println!("Train Success");
+    Ok(())
+}
+
+fn check_download_mnist_data() -> anyhow::Result<rust_mnist::Mnist> {
+    const DATA_PATH: &str = "./data/";
+
+    let folder = PathBuf::from(DATA_PATH);
+    if !folder.exists() {
+        std::fs::create_dir_all(&folder)?;
+    }
+    for (filename, gz_size, flat_size) in [
+        ("train-images-idx3-ubyte", 9912422, 47040016),
+        ("train-labels-idx1-ubyte", 28881, 60008),
+        ("t10k-images-idx3-ubyte", 1648877, 7840016),
+        ("t10k-labels-idx1-ubyte", 4542, 10008),
+    ]
+    .iter()
+    {
+        let file = folder.join(filename);
+        if file.exists() && file.is_file() && std::fs::metadata(&file)?.len() 
== *flat_size {
+            println!("File {} exist, skip.", file.display());
+            continue;
+        }
+
+        let url = format!(
+            "https://storage.googleapis.com/cvdf-datasets/mnist/{}.gz";,
+            filename
+        );
+        println!("Download {} from {}", filename, url);
+        let body = ureq::get(&url).call()?.body_mut().read_to_vec()?;
+
+        assert_eq!(body.len(), *gz_size as usize);

Review Comment:
   Prefer using `anyhow::ensure` to raise error and avoid panic



##########
examples/mnist-rs/ta/train/src/main.rs:
##########
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![no_std]
+#![no_main]
+extern crate alloc;
+
+use burn::backend::{ndarray::NdArrayDevice, Autodiff, NdArray};
+use optee_utee::{
+    ta_close_session, ta_create, ta_destroy, ta_invoke_command, 
ta_open_session, trace_println,
+};
+use optee_utee::{ErrorKind, Parameters, Result, Parameter};
+use proto::train::Command;
+use spin::Mutex;
+
+mod trainer;
+
+type NoStdTrainer = trainer::Trainer<Autodiff<NdArray>>;
+
+const DEVICE: NdArrayDevice = NdArrayDevice::Cpu;
+static TRAINER: Mutex<Option<NoStdTrainer>> = Mutex::new(Option::None);
+
+#[ta_create]
+fn create() -> Result<()> {
+    trace_println!("[+] TA create");
+    Ok(())
+}
+
+#[ta_open_session]
+fn open_session(params: &mut Parameters) -> Result<()> {
+    let mut p0 = unsafe { params.0.as_memref()? };
+
+    let learning_rate = 
f64::from_le_bytes(p0.buffer().try_into().map_err(|err| {
+        trace_println!("bad parameter {:?}", err);
+        ErrorKind::BadParameters
+    })?);
+    trace_println!("Initialize with learning_rate: {}", learning_rate);
+
+    let mut trainer = TRAINER.lock();
+    trainer.replace(NoStdTrainer::new(DEVICE, learning_rate));
+
+    Ok(())
+}
+
+#[ta_close_session]
+fn close_session() {
+    trace_println!("[+] TA close session");
+}
+
+#[ta_destroy]
+fn destroy() {
+    trace_println!("[+] TA destroy");
+}
+
+fn copy_to_output(param: &mut Parameter, data: &[u8]) -> Result<()>{
+    let mut output = unsafe { param.as_memref()? };
+
+    let buffer = output.buffer();
+    if buffer.len() < data.len() {
+        return Err(ErrorKind::ShortBuffer.into());
+    }
+    buffer[..data.len()].copy_from_slice(data);
+    output.set_updated_size(data.len());
+    Ok(())
+}
+
+#[ta_invoke_command]
+fn invoke_command(cmd_id: u32, params: &mut Parameters) -> Result<()> {
+    match Command::try_from(cmd_id) {
+        Ok(Command::Train) => {
+            let mut p0 = unsafe { params.0.as_memref()? };
+            let mut p1 = unsafe { params.1.as_memref()? };
+
+            let images = p0.buffer();
+            let labels = p1.buffer();
+
+            let mut trainer = TRAINER.lock();
+            let result = 
trainer.as_mut().unwrap().train(bytemuck::cast_slice(images), labels);

Review Comment:
   Eliminate `unwrap()`, same in line108, line118.



##########
examples/mnist-rs/ta/inference/src/main.rs:
##########
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#![no_std]
+#![no_main]
+extern crate alloc;
+
+use burn::{
+    backend::{ndarray::NdArrayDevice, NdArray},
+    tensor::cast::ToElement,
+};
+
+use model::Model;
+use optee_utee::{
+    ta_close_session, ta_create, ta_destroy, ta_invoke_command, 
ta_open_session, trace_println,
+};
+use optee_utee::{ErrorKind, Parameter, Parameters, Result};
+use proto::Image;
+use spin::Mutex;
+
+type NoStdModel = Model<NdArray>;
+const DEVICE: NdArrayDevice = NdArrayDevice::Cpu;
+static MODEL: Mutex<Option<NoStdModel>> = Mutex::new(Option::None);
+
+#[ta_create]
+fn create() -> Result<()> {
+    trace_println!("[+] TA create");
+    Ok(())
+}
+
+#[ta_open_session]
+fn open_session(params: &mut Parameters) -> Result<()> {
+    let mut p0 = unsafe { params.0.as_memref()? };
+
+    let mut model = MODEL.lock();
+    model.replace(Model::import(&DEVICE, p0.buffer().to_vec()).map_err(|err| {
+        trace_println!("import failed: {:?}", err);
+        ErrorKind::BadParameters
+    })?);
+
+    Ok(())
+}
+
+#[ta_close_session]
+fn close_session() {
+    trace_println!("[+] TA close session");
+}
+
+#[ta_destroy]
+fn destroy() {
+    trace_println!("[+] TA destroy");
+}
+
+fn copy_to_output(param: &mut Parameter, data: &[u8]) -> Result<()> {
+    let mut output = unsafe { param.as_memref()? };
+
+    let buffer = output.buffer();
+    if buffer.len() < data.len() {
+        trace_println!(
+            "expect output buffer size {}, got size {} instead",
+            data.len(),
+            buffer.len()
+        );
+        return Err(ErrorKind::ShortBuffer.into());
+    }
+    buffer[..data.len()].copy_from_slice(data);
+    output.set_updated_size(data.len());
+    Ok(())
+}

Review Comment:
   Similar implementation in train TA but have a slight difference when 
printing logs. Suggest just keep one implementation and reorganize for the 
common code:
   - `train/`
   - `inference/`
   - `common/`
       - model.rs: struct `Model`
       - utils.rs: `copy_to_output`



##########
examples/mnist-rs/host/src/tee.rs:
##########
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use optee_teec::{Context, ErrorKind, Operation, ParamNone, ParamTmpRef, 
Session, Uuid};
+use proto::{inference, train, Image};
+
+pub struct Trainer<'a> {
+    sess: Session<'a>,
+}

Review Comment:
   This seems to be a connector to invoke training TA and inference TA (below 
struct `Model`). A more concrete name would help clarify its role. Prefer names 
such as `TrainingTaConnector`, `InferenceTaConnector`, or other better 
candidates as you like



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscr...@teaclave.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@teaclave.apache.org
For additional commands, e-mail: dev-h...@teaclave.apache.org

Reply via email to