Skip to content

Commit

Permalink
[onert/nnfw_api] Update DataSet name of TrainCaseData (#13269)
Browse files Browse the repository at this point in the history
This commit changes dataset name from `outputs` to `expects`.
And it removes unused `expects` variable from TrainCaseData.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun authored Jun 24, 2024
1 parent dadb8ba commit a4e5883
Showing 1 changed file with 9 additions and 26 deletions.
35 changes: 9 additions & 26 deletions tests/nnfw_api/lib/GenModelTrain.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,11 @@ struct TrainCaseData
using OneStepData =
std::vector<std::vector<uint8_t>>; // An input/output's data list to be used in one step
using DataSet = std::tuple<std::vector<OneStepData>,
std::vector<OneStepData>>; // inputs dataset, outputs dataset
std::vector<OneStepData>>; // inputs dataset, expects dataset

public:
/**
* @brief A vector of expects buffers
*/
std::vector<std::vector<uint8_t>> expects;

/**
* @brief A dataset of inputs/outputs
* @brief A dataset of inputs/expects
*/
DataSet dataset;

Expand All @@ -69,29 +64,17 @@ struct TrainCaseData
}

/**
* @brief Append vector data list of outputs that are used in one step
* @brief Append vector data list of expects that are used in one step
*
* @tparam T Data type of outputs
* @param outputs Outputs that are used in one step
* @tparam T Data type of expects
* @param expects Expects that are used in one step
*/
template <typename T> TrainCaseData &addOutputs(const std::vector<std::vector<T>> &outputs)
template <typename T> TrainCaseData &addExpects(const std::vector<std::vector<T>> &expects)
{
auto &[unused, outputs_dataset] = dataset;
auto &[unused, expects_dataset] = dataset;
(void)unused;
addData(outputs_dataset.emplace_back(), outputs);

return *this;
}
addData(expects_dataset.emplace_back(), expects);

/**
* @brief Append vector data to expects
*
* @tparam T Data type
* @param data vector data array
*/
template <typename T> TrainCaseData &addExpects(const std::vector<T> &data)
{
addData(expects, data);
return *this;
}

Expand Down Expand Up @@ -186,7 +169,7 @@ static TrainCaseData uniformTCD(const std::vector<std::vector<std::vector<T>>> &
for (const auto &data : expects_dataset)
{
assert(data.size() == losses.size());
ret.addOutputs<T>(data);
ret.addExpects<T>(data);
}
ret.setLosses(losses);
return ret;
Expand Down

0 comments on commit a4e5883

Please sign in to comment.