diff --git a/onert-micro/onert-micro/include/onert-micro-train.h b/onert-micro/onert-micro/include/onert-micro-train.h index bae99b74a0d..9322f48b8b4 100644 --- a/onert-micro/onert-micro/include/onert-micro-train.h +++ b/onert-micro/onert-micro/include/onert-micro-train.h @@ -44,7 +44,7 @@ om_create_context(&ctx); // 1. load model om_load_model_from_file(ctx, MODEL_PATH); -// 1-1. (optional) configure training options +// 1-1. (optional, TBD) configure training options om_train_compile(ctx); @@ -59,8 +59,8 @@ for(int epoch=0; epoch < NUM_EPOCHS; epoch++) { memcpy(training_output, train_output_data + THIS_BATCH_OFFSET, BATCH_SIZE*OUTPUT_SIZE); // 2. feed training input / output - om_train_set_input(ctx, 0 , training_input, input_tensorinfo); - om_train_set_expected(ctx, 0 , training_input, output_tensorinfo); + om_train_set_input(ctx, 0 , training_input, BATCH_SIZE*INPUT_SIZE); + om_train_set_expected(ctx, 0 , training_input,BATCH_SIZE*INPUT_SIZE); // 3. train a step om_train_single_step(ctx); } @@ -68,10 +68,12 @@ for(int epoch=0; epoch < NUM_EPOCHS; epoch++) { float loss; om_train_get_loss(ctx, 0, &loss); - if(loss < TARGET_LOSS) { + if(loss > TARGET_LOSS) { om_train_save_as_checkpoint(ctx, PATH); } - + else { + om_train_save_as_inferencemodel(ctx, PATH); + } } */ @@ -92,76 +94,6 @@ typedef enum OM_STATUS_ERROR = 1, } OM_STATUS; -/** - * @brief Tensor types - * - * The type of tensor represented in {@link om_tensorinfo} - */ -typedef enum -{ - /** A tensor of 32 bit floating point */ - OM_TYPE_TENSOR_FLOAT32 = 0, - /** A tensor of 32 bit signed integer */ - OM_TYPE_TENSOR_INT32 = 1, - /** - * A tensor of 8 bit unsigned integers that represent real numbers. - * - * real_value = (integer_value - zeroPoint) * scale. - */ - OM_TYPE_TENSOR_QUANT8_ASYMM = 2, - /** A tensor of boolean */ - OM_TYPE_TENSOR_BOOL = 3, - - /** A tensor of 8 bit unsigned integer */ - OM_TYPE_TENSOR_UINT8 = 4, - - /** A tensor of 64 bit signed integer */ - OM_TYPE_TENSOR_INT64 = 5, - - /** - * A tensor of 8 bit signed integers that represent real numbers. - * - * real_value = (integer_value - zeroPoint) * scale. - */ - OM_TYPE_TENSOR_QUANT8_ASYMM_SIGNED = 6, - - /** - * A tensor of 16 bit signed integers that represent real numbers. - * - * real_value = (integer_value - zeroPoint) * scale. - * - * Forced to have zeroPoint equal to 0. - */ - OM_TYPE_TENSOR_QUANT16_SYMM_SIGNED = 7, - -} OM_TYPE; - -/** - * @brief tensor info describes the type and shape of tensors - * - *
This structure is used to describe input and output tensors. - * Application can get input and output tensor type and shape described in model by using - * {@link om_input_tensorinfo} and {@link om_output_tensorinfo} - * - *
Maximum rank is 6 (OM_MAX_RANK). And tensor's dimension value is filled in 'dims' field from - * index 0. - * For example, if tensor's rank is 4, - * application can get dimension value from dims[0], dims[1], dims[2], and dims[3] - */ -typedef struct om_tensorinfo -{ - /** The data type */ - OM_TYPE dtype; - /** The number of dimensions (rank) */ - int32_t rank; - /** - * The dimension of tensor. - * Maximum rank is 6 (OM_MAX_RANK). - */ - int32_t dims[OM_MAX_RANK]; -} om_tensorinfo; - - struct om_context; /** @@ -211,12 +143,11 @@ OM_STATUS om_train_compile(om_context *ctx); * @param[in] context The context to be set training inputs and expected model outputs * @param[in] index The index of training input * @param[in] input The input buffers for training - * @param[in] input_info The shape and type of input buffer + * @param[in] size The byte size of input buffer * If it is nullptr, it will not change shape and batch size * @return @c OM_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_set_input(om_context *ctx, uint32_t index, const void *input, - const om_tensorinfo *input_info); +OM_STATUS om_train_set_input(om_context *ctx, uint32_t index, const void *input, int size); /** * @brief Set training expected output @@ -225,12 +156,11 @@ OM_STATUS om_train_set_input(om_context *ctx, uint32_t index, const void *input, * @param context The context to be set training inputs and expected model outputs * @param index The index of training expected output * @param expected The expected buffers for training - * @param expected_info The shape and type of expected buffer - * If it is nullptr, it will not change shape and batch size + * @param[in] size The byte size of input buffer + * * @return @c OM_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_set_expected(om_context *ctx, uint32_t index, const void *expected, - const om_tensorinfo *expected_info); +OM_STATUS om_train_set_expected(om_context *ctx, uint32_t index, const void *expected, int size); /** * @brief Perform a inference with current state