Github user njayaram2 commented on a diff in the pull request: https://github.com/apache/madlib/pull/229#discussion_r163048733 --- Diff: src/modules/convex/algo/igd.hpp --- @@ -56,6 +59,62 @@ IGD<State, ConstState, Task>::transition(state_type &state, state.task.stepsize * tuple.weight); } +/** + * @brief Update the transition state in mini-batches + * + * Note: We assume that + * 1. Task defines a model_eigen_type + * 2. A batch of tuple.indVar is a Matrix + * 3. A batch of tuple.depVar is a ColumnVector + * 4. Task defines a getLossAndUpdateModel method + * + */ + template <class State, class ConstState, class Task> + void + IGD<State, ConstState, Task>::transitionInMiniBatch( + state_type &state, + const tuple_type &tuple) { + + madlib_assert(tuple.indVar.rows() == tuple.depVar.rows(), + std::runtime_error("Invalid data. Independent and dependent " + "batches don't have same number of rows.")); + + int batch_size = state.algo.batchSize; + int n_epochs = state.algo.nEpochs; + + // n_rows/n_ind_cols are the rows/cols in a transition tuple. + int n_rows = tuple.indVar.rows(); + int n_ind_cols = tuple.indVar.cols(); + int n_batches = n_rows < batch_size ? 1 : + n_rows / batch_size + + int(n_rows%batch_size > 0); + + for (int curr_epoch=0; curr_epoch < n_epochs; curr_epoch++) { + double loss = 0.0; + for (int curr_batch=0, curr_batch_row_index=0; curr_batch < n_batches; + curr_batch++, curr_batch_row_index += batch_size) { + Matrix X_batch; + ColumnVector y_batch; + if (curr_batch == n_batches-1) { + // last batch + X_batch = tuple.indVar.bottomRows(n_rows-curr_batch_row_index); + y_batch = tuple.depVar.tail(n_rows-curr_batch_row_index); + } else { + X_batch = tuple.indVar.block(curr_batch_row_index, 0, batch_size, n_ind_cols); + y_batch = tuple.depVar.segment(curr_batch_row_index, batch_size); + } + loss += Task::getLossAndUpdateModel( + state.task.model, X_batch, y_batch, state.task.stepsize); + } + + // The first epoch will most likely have the most loss. + // So being pessimistic, we return average loss only for the first epoch. + if (curr_epoch==0) state.algo.loss += loss; --- End diff -- Should we average this over `n_batches`?
---