Skip to content

Commit

Permalink
remove tensorinfo and tensor type
Browse files Browse the repository at this point in the history
  • Loading branch information
chunseoklee committed May 16, 2024
1 parent 332cabb commit d09603b
Showing 1 changed file with 12 additions and 82 deletions.
94 changes: 12 additions & 82 deletions onert-micro/onert-micro/include/onert-micro-train.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ om_create_context(&ctx);
// 1. load model
om_load_model_from_file(ctx, MODEL_PATH);
// 1-1. (optional) configure training options
// 1-1. (optional, TBD) configure training options
om_train_compile(ctx);
Expand All @@ -59,19 +59,21 @@ for(int epoch=0; epoch < NUM_EPOCHS; epoch++) {
memcpy(training_output, train_output_data + THIS_BATCH_OFFSET, BATCH_SIZE*OUTPUT_SIZE);
// 2. feed training input / output
om_train_set_input(ctx, 0 , training_input, input_tensorinfo);
om_train_set_expected(ctx, 0 , training_input, output_tensorinfo);
om_train_set_input(ctx, 0 , training_input, BATCH_SIZE*INPUT_SIZE);
om_train_set_expected(ctx, 0 , training_input,BATCH_SIZE*INPUT_SIZE);
// 3. train a step
om_train_single_step(ctx);
}
// 4. check loss
float loss;
om_train_get_loss(ctx, 0, &loss);
if(loss < TARGET_LOSS) {
if(loss > TARGET_LOSS) {
om_train_save_as_checkpoint(ctx, PATH);
}
else {
om_train_save_as_inferencemodel(ctx, PATH);
}
}
*/

Expand All @@ -92,76 +94,6 @@ typedef enum
OM_STATUS_ERROR = 1,
} OM_STATUS;

/**
* @brief Tensor types
*
* The type of tensor represented in {@link om_tensorinfo}
*/
typedef enum
{
/** A tensor of 32 bit floating point */
OM_TYPE_TENSOR_FLOAT32 = 0,
/** A tensor of 32 bit signed integer */
OM_TYPE_TENSOR_INT32 = 1,
/**
* A tensor of 8 bit unsigned integers that represent real numbers.
*
* real_value = (integer_value - zeroPoint) * scale.
*/
OM_TYPE_TENSOR_QUANT8_ASYMM = 2,
/** A tensor of boolean */
OM_TYPE_TENSOR_BOOL = 3,

/** A tensor of 8 bit unsigned integer */
OM_TYPE_TENSOR_UINT8 = 4,

/** A tensor of 64 bit signed integer */
OM_TYPE_TENSOR_INT64 = 5,

/**
* A tensor of 8 bit signed integers that represent real numbers.
*
* real_value = (integer_value - zeroPoint) * scale.
*/
OM_TYPE_TENSOR_QUANT8_ASYMM_SIGNED = 6,

/**
* A tensor of 16 bit signed integers that represent real numbers.
*
* real_value = (integer_value - zeroPoint) * scale.
*
* Forced to have zeroPoint equal to 0.
*/
OM_TYPE_TENSOR_QUANT16_SYMM_SIGNED = 7,

} OM_TYPE;

/**
* @brief tensor info describes the type and shape of tensors
*
* <p>This structure is used to describe input and output tensors.
* Application can get input and output tensor type and shape described in model by using
* {@link om_input_tensorinfo} and {@link om_output_tensorinfo}
*
* <p>Maximum rank is 6 (OM_MAX_RANK). And tensor's dimension value is filled in 'dims' field from
* index 0.
* For example, if tensor's rank is 4,
* application can get dimension value from dims[0], dims[1], dims[2], and dims[3]
*/
typedef struct om_tensorinfo
{
/** The data type */
OM_TYPE dtype;
/** The number of dimensions (rank) */
int32_t rank;
/**
* The dimension of tensor.
* Maximum rank is 6 (OM_MAX_RANK).
*/
int32_t dims[OM_MAX_RANK];
} om_tensorinfo;


struct om_context;

/**
Expand Down Expand Up @@ -211,12 +143,11 @@ OM_STATUS om_train_compile(om_context *ctx);
* @param[in] context The context to be set training inputs and expected model outputs
* @param[in] index The index of training input
* @param[in] input The input buffers for training
* @param[in] input_info The shape and type of input buffer
* @param[in] size The byte size of input buffer
* If it is nullptr, it will not change shape and batch size
* @return @c OM_STATUS_NO_ERROR if successful
*/
OM_STATUS om_train_set_input(om_context *ctx, uint32_t index, const void *input,
const om_tensorinfo *input_info);
OM_STATUS om_train_set_input(om_context *ctx, uint32_t index, const void *input, int size);

/**
* @brief Set training expected output
Expand All @@ -225,12 +156,11 @@ OM_STATUS om_train_set_input(om_context *ctx, uint32_t index, const void *input,
* @param context The context to be set training inputs and expected model outputs
* @param index The index of training expected output
* @param expected The expected buffers for training
* @param expected_info The shape and type of expected buffer
* If it is nullptr, it will not change shape and batch size
* @param[in] size The byte size of input buffer
*
* @return @c OM_STATUS_NO_ERROR if successful
*/
OM_STATUS om_train_set_expected(om_context *ctx, uint32_t index, const void *expected,
const om_tensorinfo *expected_info);
OM_STATUS om_train_set_expected(om_context *ctx, uint32_t index, const void *expected, int size);

/**
* @brief Perform a inference with current state
Expand Down

0 comments on commit d09603b

Please sign in to comment.