Skip to content

Commit

Permalink
fixed checkpoints saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
josura committed Jun 21, 2024
1 parent 688b4b6 commit e76ee5b
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions src/Checkpoint.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ void Checkpoint::saveState(const std::string type, const int interIteration, con
std::string fileName = this->checkPointFolder + "checkpoint_" + type + "_" + std::to_string(interIteration) + "_" + std::to_string(intraIteration) + ".tsv";
std::ofstream file(fileName);
std::vector<std::string> nodeNames = currentComputation->getAugmentedGraph()->getNodeNames();
std::vector<double> nodeValues = currentComputation->getOutputAugmented();
std::vector<double> nodeValues = currentComputation->getInputAugmented();
if (file.is_open())
{
//header
file << "interIteration\tintraIteration\tnodeName\tnodeValue\n";
// file << "interIteration\tintraIteration\tnodeName\tnodeValue\n";
file << "nodeName\tnodeValue\n";
//body
for(uint i = 0; i < nodeValues.size(); i++){
file<<interIteration<< "\t"<< intraIteration<<"\t"<< nodeNames[i]<<"\t"<<nodeNames[i]<<"\t"<<std::to_string(nodeValues[i]);
// file<<interIteration<< "\t"<< intraIteration<<"\t"<< nodeNames[i]<<"\t"<<std::to_string(nodeValues[i]);
file << nodeNames[i]<<"\t"<<std::to_string(nodeValues[i]);
file << std::endl;
}
file.close();
Expand All @@ -47,17 +49,38 @@ void Checkpoint::cleanCheckpoints(std::string type) {
{
if (file.find("checkpoint_" + type) != std::string::npos)
{
std::string fileName = folder + file;
if (remove(fileName.c_str()) != 0)
// std::string fileName = folder + file;
if (remove(file.c_str()))
{
std::cerr << "[ERROR] Checkpoint::cleanCheckpoints: Unable to delete file " << fileName << std::endl;
std::cerr << "[ERROR] Checkpoint::cleanCheckpoints: Unable to delete file " << file << std::endl;
}
}
}
}

void Checkpoint::loadState(const std::string type, int& interIteration, int& intraIteration, Computation* computation) {
std::string fileName = this->checkPointFolder + "checkpoint_" + type + "_" + std::to_string(interIteration) + "_" + std::to_string(intraIteration) + ".tsv";
std::string fileName = this->checkPointFolder + "checkpoint_" + type + "_";
std::vector<std::string> files = listFiles(this->checkPointFolder);
bool checkPointExists = false;
for (std::string file : files)
{
if (file.find(fileName) != std::string::npos)
{
interIteration = stoi(splitStringIntoVector(file, "_")[2]);
intraIteration = stoi(splitStringIntoVector(file, "_")[3]);
checkPointExists = true;
fileName = file;
break;
}
}

if (!checkPointExists)
{
std::cerr << "[ERROR] Checkpoint::loadState: Checkpoint file not found" << std::endl;
throw std::runtime_error("[ERROR] Checkpoint::loadState: Checkpoint file not found");
}


std::ifstream file(fileName);
if (file.is_open())
{
Expand All @@ -66,20 +89,22 @@ void Checkpoint::loadState(const std::string type, int& interIteration, int& int
while (std::getline(file, line))
{
std::istringstream iss(line);
std::string interIterationStr;
std::string intraIterationStr;
// std::string interIterationStr;
// std::string intraIterationStr;
std::string nodeName;
std::string nodeValueStr;
iss >> interIterationStr >> intraIterationStr >> nodeName >> nodeValueStr;
// iss >> interIterationStr >> intraIterationStr >> nodeName >> nodeValueStr;
iss >> nodeName >> nodeValueStr;
double nodeValue = std::stod(nodeValueStr);
interIteration = std::stoi(interIterationStr);
intraIteration = std::stoi(intraIterationStr);
// interIteration = std::stoi(interIterationStr);
// intraIteration = std::stoi(intraIterationStr);
computation->setInputNodeValue(nodeName, nodeValue);
}
file.close();
}
else
{
std::cerr << "[ERROR] Checkpoint::loadState: Unable to open file " << fileName << std::endl;
throw std::runtime_error("[ERROR] Checkpoint::loadState: Unable to open file " + fileName);
}
}

0 comments on commit e76ee5b

Please sign in to comment.