diff --git a/onert-micro/onert-micro/include/onert-micro.h b/onert-micro/onert-micro/include/onert-micro.h index 9322f48b8b4..dedd6c61bc9 100644 --- a/onert-micro/onert-micro/include/onert-micro.h +++ b/onert-micro/onert-micro/include/onert-micro.h @@ -15,8 +15,8 @@ */ -#ifndef ONERT-MICRO-TRAIN_H_ -#define ONERT-MICRO-TRAIN_H_ +#ifndef _ONERT_MICRO_H_ +#define _ONERT_MICRO_H_ #ifdef __cplusplus extern "C" { @@ -36,49 +36,50 @@ extern "C" { * 4-2. more training -> go to 2 * 5. save current state to inference model * 6. inference with inference model - // sample example // 0. create context -om_context *ctx; -om_create_context(&ctx); - -// 1. load model -om_load_model_from_file(ctx, MODEL_PATH); +nnfw_session *session; +nnfw_create_session(&session); +// 1. load model (and checkpoint if continue training) +nnfw_load_model_from_file(session, MODEL_PATH); // 1-1. (optional, TBD) configure training options - -om_train_compile(ctx); - +nnfw_train_import_checkpoint(session, CKPT_PATH); +nnfw_train_prepare(session); float training_input[BATCH_SIZE*INPUT_SIZE]; float training_label[BATCH_SIZE*OUTPUT_SIZE]; - // main training loop for(int epoch=0; epoch < NUM_EPOCHS; epoch++) { for(int step=0; step < NUM_BATCHES ; step++) { // prepare this steps's intput/label memcpy(training_input, train_input_data + THIS_BATCH_OFFSET, BATCH_SIZE*INPUT_SIZE); memcpy(training_output, train_output_data + THIS_BATCH_OFFSET, BATCH_SIZE*OUTPUT_SIZE); - - // 2. feed training input / output - om_train_set_input(ctx, 0 , training_input, BATCH_SIZE*INPUT_SIZE); - om_train_set_expected(ctx, 0 , training_input,BATCH_SIZE*INPUT_SIZE); + // 2. feed training input / expected output + nnfw_train_set_input(session, 0 , training_input, NULL); + nnfw_train_set_expected(session, 0 , training_input, NULL); // 3. train a step - om_train_single_step(ctx); + nnfw_train(session); } // 4. check loss float loss; - om_train_get_loss(ctx, 0, &loss); - + nnfw_train_get_loss(ctx, 0, &loss); if(loss > TARGET_LOSS) { - om_train_save_as_checkpoint(ctx, PATH); + nnfw_train_export_checkpoints(ctx, CKPT_PATH); } else { - om_train_save_as_inferencemodel(ctx, PATH); + nnfw_train_export_circle(ctx, CIRCLE_PATH); } } */ +typedef struct nnfw_session nnfw_session; - +typedef enum +{ + /** A tensor of 32 bit floating point */ + NNFW_TYPE_TENSOR_FLOAT32 = 0, + /** A tensor of 32 bit signed integer */ + NNFW_TYPE_TENSOR_INT32 = 1, +} NNFW_TYPE; /** * @brief Result values returned from a call to an API function @@ -86,150 +87,248 @@ for(int epoch=0; epoch < NUM_EPOCHS; epoch++) { typedef enum { /** Successful */ - OM_STATUS_NO_ERROR = 0, + NNFW_STATUS_NO_ERROR = 0, /** * An error code for general use. * Mostly used when there is no specific value for that certain situation. */ - OM_STATUS_ERROR = 1, -} OM_STATUS; - -struct om_context; + NNFW_STATUS_ERROR = 1, + /** Unexpected null argument is given. */ + NNFW_STATUS_UNEXPECTED_NULL = 2, + /** When a function was called but it is not valid for the current session state. */ + NNFW_STATUS_INVALID_STATE = 3, + /** When it is out of memory */ + NNFW_STATUS_OUT_OF_MEMORY = 4, + /** When it was given an insufficient output buffer */ + NNFW_STATUS_INSUFFICIENT_OUTPUT_SIZE = 5, + /** When API is deprecated */ + NNFW_STATUS_DEPRECATED_API = 6, +} NNFW_STATUS; /** - * @brief Create a new context. - * - * @param[out] cxt The context to be created - * @return OM_STATUS_NO_ERROR if successful + * @brief Data format of a tensor */ -OM_STATUS om_create_context(om_context **cxt); +typedef enum +{ + /** Don't care layout */ + NNFW_LAYOUT_NONE = 0, + /** + * Channel last layout + * If rank is 4, layout is NHWC + */ + NNFW_LAYOUT_CHANNELS_LAST = 1, + /** + * Channel first layout + * If rank is 4, layout is NCHW + */ + NNFW_LAYOUT_CHANNELS_FIRST = 2, +} NNFW_LAYOUT; + /** - * @brief Load model from model dir - * - * @param[in] ctx context loading the given model dir - * @param[in] path Path to the model or to the dir which contains training artifacts - * - * @return @c OM_STATUS_NO_ERROR if successful + * @brief Maximum rank expressible with nnfw */ -OM_STATUS om_load_model_from_file(om_context *ctx, const char *path); +#define NNFW_MAX_RANK (6) /** - * @brief Load checkpoint from checkpoint dir + * @brief tensor info describes the type and shape of tensors * - * @param[in] ctx context loading the given checkpoint - * @param[in] path Path to the checkpoint dir which contains all the checkpoint files + *

This structure is used to describe input and output tensors. + * Application can get input and output tensor type and shape described in model by using + * {@link nnfw_input_tensorinfo} and {@link nnfw_output_tensorinfo} * - * @return @c OM_STATUS_NO_ERROR if successful + *

Maximum rank is 6 (NNFW_MAX_RANK). And tensor's dimension value is filled in 'dims' field from + * index 0. + * For example, if tensor's rank is 4, + * application can get dimension value from dims[0], dims[1], dims[2], and dims[3] */ -OM_STATUS om_load_checkpoint_from_file(om_context *ctx, const char *path); +typedef struct nnfw_tensorinfo +{ + /** The data type */ + NNFW_TYPE dtype; + /** The number of dimensions (rank) */ + int32_t rank; + /** + * The dimension of tensor. + * Maximum rank is 6 (NNFW_MAX_RANK). + */ + int32_t dims[NNFW_MAX_RANK]; +} nnfw_tensorinfo; +////////////////////////////////////////////// +// Essential APIs for training +////////////////////////////////////////////// +typedef enum +{ + NNFW_TRAIN_LOSS_UNDEFINED = 0, + NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR = 1, + NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY = 2, +} NNFW_TRAIN_LOSS; + +typedef enum +{ + /** Undefined */ + NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED = 0, + /** Scalar sum divided by number of elements in losses */ + NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE = 1, + /** Scalar sum of weighted losses */ + NNFW_TRAIN_LOSS_REDUCTION_SUM = 2, +} NNFW_TRAIN_LOSS_REDUCTION; + +typedef enum +{ + NNFW_TRAIN_OPTIMIZER_UNDEFINED = 0, + NNFW_TRAIN_OPTIMIZER_SGD = 1, + NNFW_TRAIN_OPTIMIZER_ADAM = 2, +} NNFW_TRAIN_OPTIMIZER; + +typedef struct nnfw_loss_info +{ + NNFW_TRAIN_LOSS loss; + NNFW_TRAIN_LOSS_REDUCTION reduction_type; +} nnfw_loss_info; + /** - * @brief Prepare context to be ready for training - * @note The context will be entered into training mode - * - * @param[in] ctx The context to be prepared for training - * - * @return @c OM_STATUS_NO_ERROR if successful + * @brief Training information to prepare training + * @todo Add more training information + * (e.g. optimizer, loss function, ...) */ -OM_STATUS om_train_compile(om_context *ctx); +typedef struct nnfw_train_info +{ + /** Learning rate */ + float learning_rate = 0.001f; + /** Batch size */ + uint32_t batch_size = 1; + /** loss info */ + nnfw_loss_info loss_info{.loss = NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR, + .reduction_type = NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE}; + /** optimizer type */ + NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_SGD; +} nnfw_train_info; /** - * @brief Set training input - * @note This function should be called after {@link om_train_compile} + * @brief Create a new session instance. * - * @param[in] context The context to be set training inputs and expected model outputs - * @param[in] index The index of training input - * @param[in] input The input buffers for training - * @param[in] size The byte size of input buffer - * If it is nullptr, it will not change shape and batch size - * @return @c OM_STATUS_NO_ERROR if successful + *

This only creates a session. + * Model is loaded after {@link nnfw_load_model_from_file} is invoked. + * And inference is performed after {@link nnfw_run} is invoked. + * + *

{@link nnfw_close_session} should be called once + * if session is no longer needed + * + * @param[out] session The session to be created + * @return NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_set_input(om_context *ctx, uint32_t index, const void *input, int size); +NNFW_STATUS nnfw_create_session(nnfw_session **session); /** - * @brief Set training expected output - * @note This function should be called after {@link om_train_compile} + * @brief Close a session instance * - * @param context The context to be set training inputs and expected model outputs - * @param index The index of training expected output - * @param expected The expected buffers for training - * @param[in] size The byte size of input buffer + * After called, access to closed session by application will be invalid * - * @return @c OM_STATUS_NO_ERROR if successful + * @param[in] session The session to be closed + * @return @c NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_set_expected(om_context *ctx, uint32_t index, const void *expected, int size); +NNFW_STATUS nnfw_close_session(nnfw_session *session); /** - * @brief Perform a inference with current state - * @note This function should be called after {@link om_train_set_input} + * @brief Load model from nnpackage file or directory * - * @param[in] ctx The context to be inferenced + * The length of \p package_file_path must not exceed 1024 bytes including zero at the end. * - * @return @c OM_STATUS_NO_ERROR if successful + * @param[in] session nnfw_session loading the given nnpackage file/dir + * @param[in] package_file_path Path to the nnpackage file or unzipped directory to be loaded + * + * @return @c NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_inference(om_context *context); - +NNFW_STATUS nnfw_load_model_from_file(nnfw_session *session, const char *package_file_path); /** - * @brief Get output + * @brief Prepare session to be ready for training + * @note The session will be entered into training mode * - * @param[in] context Context from inference output is to be extracted - * @param[in] index Index of output to be set (0-indexed) - * @param[out] buffer output buffera pointer + * If training info is NOT set in session, this function returns @c NNFW_STATUS_ERROR . + * You should set training info using {@link nnfw_train_set_traininfo}. * - * @return @c OM_STATUS_NO_ERROR if successful + * @param[in] session The session to be prepared for training + * + * @return @c NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_get_output(om_context *ctx, uint32_t index, void *buffer); +NNFW_STATUS nnfw_train_prepare(nnfw_session *session); /** - * @brief Perform an training step - * @note This function should be called after {@link om_train_set_input} and - * {@link om_train_set_expected} for each input and expected output + * @brief Train the model + * @note This function should be called after {@link nnfw_train_set_input} and + * {@link nnfw_train_set_expected} for each input and expected output * - * @param[in] ctx The context to be trained + * In order to use \p update_weights as false, it should be called after + * {@link nnfw_train_set_output}. * - * @return @c OM_STATUS_NO_ERROR if successful + * @param[in] session The session to be trained + * @param[in] update_weights If true, update weights of the model + * If false, do not update weights of the model (for validation) + * @return @c NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_single_step(om_context *context); +NNFW_STATUS nnfw_train(nnfw_session *session, bool update_weights); /** - * @brief Get loss value for expected output - * @note This function should be called after {@link om_train} + * @brief Export current training model into circle model + * @note This function should be called on training mode + * This function should be called after {@link nnfw_train} * - * @param[in] ctx The context to get loss value - * @param[in] index The output index for loss value - * @param[out] loss The loss value - * - * @return @c OM_STATUS_NO_ERROR if successful + * @param[in] session The session to export inference model + * @param[in] path The path to export inference model + * @return @c NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_train_get_loss(om_context *context, uint32_t index, float *loss); +NNFW_STATUS nnfw_train_export_circle(nnfw_session *session, const char *path); + +NNFW_STATUS nnfw_train_export_checkpoint(nnfw_session *session, const char *path); +NNFW_STATUS nnfw_train_import_checkpoint(nnfw_session *session, const char *path); /** - * @brief Save current state as checkpoint - * - * @param[in] ctx context for saving the given checkpoint - * @param[in] path Path to the checkpoint dir + * @brief Set training input + * @note This function should be called after {@link nnfw_train_prepare} * - * @return @c OM_STATUS_NO_ERROR if successful + * @param[in] session The session to be set training inputs and expected model outputs + * @param[in] index The index of training input + * @param[in] input The input buffers for training + * @param[in] input_info The shape and type of input buffer + * If it is nullptr, it will not change shape and batch size + * @return @c NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_save_as_checkpoint(om_context *ctx, const char *path); +NNFW_STATUS nnfw_train_set_input(nnfw_session *session, uint32_t index, void *input, + const nnfw_tensorinfo *input_info); /** - * @brief Save current(ctx) state into inference circle model + * @brief Set training expected output + * @note This function should be called after {@link nnfw_train_prepare} * - * @param[in] ctx context for saving - * @param[in] path Path to the model(circle) + * @param session The session to be set training inputs and expected model outputs + * @param index The index of training expected output + * @param expected The expected buffers for training + * @param expected_info The shape and type of expected buffer + * If it is nullptr, it will not change shape and batch size + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_set_expected(nnfw_session *session, uint32_t index, void *expected, + const nnfw_tensorinfo *expected_info); + +/** + * @brief Get loss value for expected output + * @note This function should be called after {@link nnfw_train} * - * @return @c OM_STATUS_NO_ERROR if successful + * @param[in] session The session to get loss value + * @param[in] index The index of loss value [0, number of expected outputs) + * @param[out] loss The loss value + * @return @c NNFW_STATUS_NO_ERROR if successful */ -OM_STATUS om_save_as_inferencemodel(om_context *ctx, const char *path); +NNFW_STATUS nnfw_train_get_loss(nnfw_session *session, uint32_t index, float *loss); #ifdef __cplusplus } #endif -#endif // ONERT-MICRO-TRAIN_H_ +#endif //_ONERT_MICRO_H_