-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_test_data.py
67 lines (51 loc) · 2.97 KB
/
create_test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import logging
import sys
import yaml
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
from technews_nlp_aggregator.application import Application
from technews_nlp_aggregator.model import FeatureFiller
from technews_nlp_aggregator.common import load_config
import argparse
from datetime import timedelta
import sys
def create_test_data( starting_date, feature_filler, similarArticlesRepo, version, max_skipped=100):
logging.info("Retrieving data starting from {}".format(starting_date))
similarDF = similarArticlesRepo.retrieve_similar_since( starting_date, version)
logging.info("Retrieved {} ".format(len(similarDF )))
testDF = similarArticlesRepo.load_test_set(version=config["version"])
logging.info("Test_df has {} rows".format(len(testDF)))
mergedDF = similarDF[~similarDF.index.isin(testDF.index)]
logging.info("merged_df has {} rows".format(len(mergedDF )))
retrieves_test( mergedDF, feature_filler, similarArticlesRepo, max_skipped=max_skipped)
def retrieves_test( merged_DF, feature_filler, similarArticlesRepo, max_skipped=100 ):
con = similarArticlesRepo.get_connection()
skipped, added = 0, 0
for index, row in merged_DF.iterrows():
article_id1, article_id2 = index[0], index[1]
score = feature_filler.fill_score_map( article_id1, article_id2)
scores_found = similarArticlesRepo.insert_score(score, con)
if not scores_found:
skipped += 1
else:
added += 1
if ((added+skipped) % 50 == 0):
logging.info("Added {}, skipped {} rows".format(added, skipped))
if skipped > max_skipped or added > max_skipped * 3:
logging.info("Skipped {}, Added rows, leaving".format(skipped, added))
break
logging.info("Added {}, skipped {} rows".format(added, skipped))
if __name__ == '__main__':
config_file = sys.argv[1] if (len(sys.argv) > 1) else 'config.yml'
config = yaml.safe_load(open(config_file))
version = config["version"]
max_skipped = config.get("max_skipped", 100)
application = Application(config, True)
latest_article_date = application.articleDatasetRepo.get_latest_article_date()
sincewhen_date = latest_article_date - timedelta(config["score_go_back"])
sincewhen = str(sincewhen_date.year) + '-' + str(sincewhen_date.month) + '-' + str(sincewhen_date.day)
feature_filler = FeatureFiller(articleLoader=application.articleLoader, summaryFacade=application.summaryFacade, tfidfFacade=application.tfidfFacade, doc2VecFacade=application.doc2VecFacade, classifierAggregator=application.classifierAggregator,
tf2wv_mapper=application.tf2wv_mapper,version=version)
similarArticlesRepo = application.similarArticlesRepo
application.gramFacade.load_phrases()
create_test_data(starting_date=sincewhen, feature_filler=feature_filler, similarArticlesRepo=similarArticlesRepo, version=version,
max_skipped=max_skipped)