Skip to content

Commit

Permalink
add save train as h5
Browse files Browse the repository at this point in the history
  • Loading branch information
nargesr committed Dec 18, 2023
1 parent 5b71f78 commit c370c7b
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 23 deletions.
52 changes: 42 additions & 10 deletions Topyfic/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from itertools import repeat
import pickle
from sklearn.decomposition import LatentDirichletAllocation
import h5py

from Topyfic.topModel import TopModel

Expand Down Expand Up @@ -195,20 +196,51 @@ def make_LDA_models_attributes(self):

return all_components, all_exp_dirichlet_component, all_others

def save_train(self, name=None, save_path=""):
def save_train(self, name=None, save_path="", file_format='pickle'):
"""
save Train class as a pickle file
save Train class as a pickle file
:param name: name of the pickle file (default is train_Train.name)
:type name: str
:param save_path: directory you want to use to save pickle file (default is saving near script)
:type save_path: str
:param name: name of the pickle file (default is train_Train.name)
:type name: str
:param save_path: directory you want to use to save pickle file (default is saving near script)
:type save_path: str
"""
if file_format not in ['pickle', 'HDF5']:
sys.exit(f"{file_format} is not correct! It should be 'pickle' or 'HDF5'.")
if name is None:
name = f"train_{self.name}"

print(f"Saving train class as {name}.p")
if file_format == "pickle":
print(f"Saving train as {name}.p")

picklefile = open(f"{save_path}{name}.p", "wb")
pickle.dump(self, picklefile)
picklefile.close()

if file_format == "HDF5":
print(f"Saving train as {name}.h5")

f = h5py.File(f"{name}.h5", "w")

# models
models = f.create_group("models")
for i in range(len(self.top_models)):
model = models.create_group(str(i))

self.top_models[i].model = self.top_models[i].rLDA

model['components_'] = self.top_models[i].model.components_
model['exp_dirichlet_component_'] = self.top_models[i].model.exp_dirichlet_component_
model['n_batch_iter_'] = np.int_(self.top_models[i].model.n_batch_iter_)
model['n_features_in_'] = self.top_models[i].model.n_features_in_
model['n_iter_'] = np.int_(self.top_models[i].model.n_iter_)
model['bound_'] = np.float_(self.top_models[i].model.bound_)
model['doc_topic_prior_'] = np.float_(self.top_models[i].model.doc_topic_prior_)
model['topic_word_prior_'] = np.float_(self.top_models[i].model.topic_word_prior_)

f['name'] = np.string_(self.name)
f['k'] = np.int_(self.k)
f['n_runs'] = np.int_(self.n_runs)
f['random_state_range'] = np.array(list(self.random_state_range))

picklefile = open(f"{save_path}{name}.p", "wb")
pickle.dump(self, picklefile)
picklefile.close()
f.close()
67 changes: 54 additions & 13 deletions Topyfic/utilsMakeModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,51 @@ def read_train(file):
:rtype: Train class
"""
if not os.path.isfile(file):
raise ValueError('Train object not found at given path!')
raise ValueError('Train file not found at given path!')
if not file.endswith('.p') and not file.endswith('.h5'):
raise ValueError('Train file type is not correct!')

picklefile = open(file, 'rb')
train = pickle.load(picklefile)
if file.endswith('.p'):
picklefile = open(file, 'rb')
train = pickle.load(picklefile)

if file.endswith('.h5'):
f = h5py.File(file, 'r')

name = np.string_(f['name']).decode('ascii')
k = np.int_(f['k'])
n_runs = np.int_(f['n_runs'])
random_state_range = list(f['random_state_range'])

# models
top_models = []
for random_state in random_state_range:
components = pd.DataFrame(np.array(f[f"models/{random_state}/components_"]))
exp_dirichlet_component = pd.DataFrame(np.array(f[f"models/{random_state}/exp_dirichlet_component_"]))

others = pd.DataFrame()
others.loc[0, 'n_batch_iter'] = np.int_(f[f"models/{random_state}/n_batch_iter_"])
others.loc[0, 'n_features_in'] = np.array(f[f"models/{random_state}/n_features_in_"])
others.loc[0, 'n_iter'] = np.int_(f[f"models/{random_state}/n_iter_"])
others.loc[0, 'bound'] = np.float_(f[f"models/{random_state}/bound_"])
others.loc[0, 'doc_topic_prior'] = np.array(f[f"models/{random_state}/doc_topic_prior_"])
others.loc[0, 'topic_word_prior'] = np.array(f[f"models/{random_state}/topic_word_prior_"])

model = initialize_lda_model(components, exp_dirichlet_component, others)

top_model = TopModel(name=f"{name}_{random_state}",
N=k,
gene_weights=components,
model=model)
top_models.append(top_model)

train = Train(name=name,
k=k,
n_runs=n_runs,
random_state_range=random_state_range)
train.top_models = top_models

f.close()

print(f"Reading Train done!")
return train
Expand Down Expand Up @@ -553,11 +594,11 @@ def read_topModel(file):
gene_weights.index = gene_information.index.tolist()
gene_weights.columns = topic_information.index.tolist()

topic = Topyfic.Topic(topic_id=topic_id,
topic_name=topic_name,
topic_gene_weights=gene_weights,
gene_information=gene_information,
topic_information=topic_information)
topic = Topic(topic_id=topic_id,
topic_name=topic_name,
topic_gene_weights=gene_weights,
gene_information=gene_information,
topic_information=topic_information)
topics[topic_id] = topic

# model
Expand All @@ -572,12 +613,12 @@ def read_topModel(file):
others.loc[0, 'doc_topic_prior'] = np.array(f['model']['doc_topic_prior_'])
others.loc[0, 'topic_word_prior'] = np.array(f['model']['topic_word_prior_'])

model = Topyfic.initialize_lda_model(components, exp_dirichlet_component, others)
model = initialize_lda_model(components, exp_dirichlet_component, others)

top_model = Topyfic.TopModel(name=name,
N=N,
topics=topics,
model=model)
top_model = TopModel(name=name,
N=N,
topics=topics,
model=model)

f.close()

Expand Down

0 comments on commit c370c7b

Please sign in to comment.