Skip to content

Commit

Permalink
Small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ejhusom committed Jun 21, 2024
1 parent 8b8e4e1 commit 2a70764
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 14 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ Start the server by running:
python3 src/api.py
```



## Parameters


Expand Down
1 change: 0 additions & 1 deletion src/featurize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import pandas as pd
import pycatch22

# import tsfresh
import yaml
from pandas.api.types import is_numeric_dtype
Expand Down
10 changes: 6 additions & 4 deletions src/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def postprocess(model, cluster_centers, feature_vectors, labels):
"""

PLOTS_PATH.mkdir(parents=True, exist_ok=True)

with open("params.yaml", "r") as params_file:
params = yaml.safe_load(params_file)

Expand Down Expand Up @@ -354,15 +356,15 @@ def postprocess(model, cluster_centers, feature_vectors, labels):
# Create and save cluster names
cluster_names = generate_cluster_names(model, cluster_centers)

# Read predefined centroids from file
with open(PREDEFINED_CENTROIDS_PATH, "r") as f:
predefined_centroids_dict = json.load(f)

# Use cluster names from annotated data, if the number of clusters still
# matches the number of unique annotation label (the number of clusters
# might change when using cluster algorithms that automatically decide on a
# suitable number of clusters.
if use_predefined_centroids:
# Read predefined centroids from file
with open(PREDEFINED_CENTROIDS_PATH, "r") as f:
predefined_centroids_dict = json.load(f)

if len(predefined_centroids_dict) == n_clusters:
for i, key in enumerate(predefined_centroids_dict):
# cluster_names["cluster_name"][i] = (
Expand Down
10 changes: 3 additions & 7 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@
import json
import sys

import joblib
import numpy as np
import pandas as pd
import yaml
import joblib
from sklearn.cluster import (
DBSCAN,
AffinityPropagation,
MeanShift,
MiniBatchKMeans,
)
from sklearn.cluster import (DBSCAN, AffinityPropagation, MeanShift,
MiniBatchKMeans)

from annotations import *
from config import *
Expand Down

0 comments on commit 2a70764

Please sign in to comment.