-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathtest_api.py
64 lines (53 loc) · 2.23 KB
/
test_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import numpy as np
import pandas as pd
import pytest
import requests
from sklearn.metrics import roc_auc_score
def test_train_model_1(host):
'''Train sklearn model'''
url = os.path.join(host, 'train_pipeline')
train_files = {'raw_data': open('data/data_train.json', 'rb'),
'labels' : open('data/label_train.json', 'rb'),
'params' : open('parameters/train_parameters.yml', 'rb')}
r = requests.post(url,
files=train_files)
def test_train_model_2(host):
'''Train TPOT model'''
url = os.path.join(host, 'train_pipeline')
train_files = {'raw_data': open('data/data_train.json', 'rb'),
'labels' : open('data/label_train.json', 'rb'),
'params' : open('parameters/train_parameters_model2.yml', 'rb')}
r = requests.post(url,
files=train_files)
def test_serve_model(host):
serve_url = os.path.join(host, 'serve_prediction')
test_files = {'raw_data': open('data/data_test.json', 'rb'),
'params' : open('parameters/test_parameters.yml', 'rb')}
r = requests.post(serve_url, files=test_files)
# parse result
result = pd.read_json(r.json()).set_index('id')
result.index = result.index.astype(np.int)
label_test = pd.read_json('data/label_test.json')
result = result.loc[label_test.example_id]
auc = roc_auc_score(label_test.label, result.values)
print "Test AUC: {}".format(auc)
assert (auc > 0.9)
def test_serve_model_2(host):
serve_url = os.path.join(host, 'serve_prediction')
test_files = {'raw_data': open('data/data_test.json', 'rb'),
'params' : open('parameters/test_parameters_model2.yml', 'rb')}
r = requests.post(serve_url, files=test_files)
# parse result
result = pd.read_json(r.json()).set_index('id')
result.index = result.index.astype(np.int)
label_test = pd.read_json('data/label_test.json')
result = result.loc[label_test.example_id]
auc = roc_auc_score(label_test.label, result.values)
print "Test AUC: {}".format(auc)
assert (auc > 0.9)
def test_get_models(host):
'''Show all available models'''
url = os.path.join(host, 'models')
r = requests.get(url)
assert r.status_code == 200