Skip to content

Commit

Permalink
[onert-micro] Add SparseCrossEntropy Test
Browse files Browse the repository at this point in the history
This commit adds simple mnist test case for onert-micro that can test Sparse cross entropy train loss.

ONE-DCO-1.0-Signed-off-by: Jungwoo Lee <[email protected]>
  • Loading branch information
ljwoo94 authored and Jungwoo Lee committed Aug 9, 2024
1 parent 0fc4de4 commit c7f1938
Show file tree
Hide file tree
Showing 6 changed files with 1,226 additions and 2 deletions.
16 changes: 14 additions & 2 deletions onert-micro/onert-micro/include/train/tests/OMTestUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ OMStatus train(OMTrainingInterpreter &train_interpreter, OMConfig &config,
const uint32_t num_train_data_samples = test_base.getTrainNumSamples();
const uint32_t batch_size = config.training_context.batch_size;
const uint32_t input_size = train_interpreter.getInputSizeAt(0);
const uint32_t target_size = train_interpreter.getOutputSizeAt(0);
uint32_t target_size = train_interpreter.getOutputSizeAt(0);

// TODO: Need to revisit this to make getOuputSize can get proper output number
// when 'all' target number and output numbers are different
if (config.training_context.loss == SPARSE_CROSS_ENTROPY)
target_size = 1;

for (uint32_t e = 0; e < training_epochs; ++e)
{
config.training_context.num_epoch = e + 1;
Expand Down Expand Up @@ -92,7 +98,13 @@ OMStatus evaluate(OMTrainingInterpreter &train_interpreter, OMConfig &config,
const uint32_t num_test_data_samples = test_base.getTestNumSamples();
const uint32_t batch_size = 1;
const uint32_t input_size = train_interpreter.getInputSizeAt(0);
const uint32_t target_size = train_interpreter.getOutputSizeAt(0);
uint32_t target_size = train_interpreter.getOutputSizeAt(0);

// TODO: Need to revisit this to make getOuputSize can get proper output number
// when 'all' target number and output numbers are different
if (config.training_context.loss == SPARSE_CROSS_ENTROPY)
target_size = 1;

for (int i = 0; i < num_test_data_samples; ++i)
{
// Read current input and target data
Expand Down
Loading

0 comments on commit c7f1938

Please sign in to comment.