-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdt-learn.cpp
47 lines (41 loc) · 1.92 KB
/
dt-learn.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <iostream>
#include <iomanip>
#include <memory>
#include <chrono>
#include <random>
#include <algorithm>
#include "DecisionTree.hpp"
int main(int argc, const char * argv[]) {
if (argc < 4) {
cout << "usage: ./dt-learn train-set-file test-set-file m [percentage-of-train-set]" << endl;
} else {
string trainSetFile = argv[1];
string testSetFile = argv[2];
int stopThreshold = atoi(argv[3]);
int percentageOfTrainSet = argc == 5 ? atoi(argv[4]) : 100;
shared_ptr<Dataset> dataset(Dataset::loadDataset(trainSetFile, testSetFile));
const DatasetMetadata* metadata = dataset->getMetadata();
vector<Instance*> trainSet(dataset->getTrainSet().begin(), dataset->getTrainSet().end());
if (percentageOfTrainSet < 100) {
unsigned int seed = (unsigned int)chrono::system_clock::now().time_since_epoch().count();
shuffle (trainSet.begin(), trainSet.end(), default_random_engine(seed));
int newSize = (int)trainSet.size() * percentageOfTrainSet / 100;
trainSet.resize(newSize);
}
DecisionTree tree(metadata, trainSet, stopThreshold);
cout << tree.toString();
const vector<Instance*>& testSet = dataset->getTestSet();
int correctCount = 0;
cout << "<Predictions for the Test Set Instances>" << endl;
for (int i = 0; i < testSet.size(); ++i) {
Instance* inst = testSet[i];
string predicted = tree.predict(inst);
string actual = inst->toString(metadata, true);
if (predicted == actual)
correctCount++;
cout << setfill(' ') << setw(3) << (i + 1) << ": ";
cout << "Actual: " << actual << " Predicted: " << predicted<< endl;
}
cout << "Number of correctly classified: " << correctCount << " Total number of test instances: " << testSet.size() << endl;
}
}