Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/chain #172

Merged
merged 22 commits into from
Jan 10, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor and update util.py
AHReccese committed Jan 4, 2025
commit 79ffd9dedd7a7e14e5f7982e00e6a5b28c680986
66 changes: 34 additions & 32 deletions pymilo/chains/util.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
# -*- coding: utf-8 -*-
"""useful utilities for chains."""
from .linear_model_chain import transport_linear_model, is_linear_model
from .neural_network_chain import transport_neural_network, is_neural_network
from .decision_tree_chain import transport_decision_tree, is_decision_tree
from .clustering_chain import transport_clusterer, is_clusterer
from .naive_bayes_chain import transport_naive_bayes, is_naive_bayes
from .svm_chain import transport_svm, is_svm
from .neighbours_chain import transport_neighbor, is_neighbors
from .cross_decomposition_chain import transport_cross_decomposition, is_cross_decomposition

from .linear_model_chain import linear_chain
from .neural_network_chain import neural_network_chain
from .decision_tree_chain import decision_trees_chain
from .clustering_chain import clustering_chain
from .naive_bayes_chain import naive_bayes_chain
from .svm_chain import svm_chain
from .neighbours_chain import neighbors_chain
from .cross_decomposition_chain import cross_decomposition_chain


MODEL_TYPE_TRANSPORTER = {
"LINEAR_MODEL": transport_linear_model,
"NEURAL_NETWORK": transport_neural_network,
"DECISION_TREE": transport_decision_tree,
"CLUSTERING": transport_clusterer,
"NAIVE_BAYES": transport_naive_bayes,
"SVM": transport_svm,
"NEIGHBORS": transport_neighbor,
"CROSS_DECOMPOSITION": transport_cross_decomposition,
"LINEAR_MODEL": linear_chain.transport,
"NEURAL_NETWORK": neural_network_chain.transport,
"DECISION_TREE": decision_trees_chain.transport,
"CLUSTERING": clustering_chain.transport,
"NAIVE_BAYES": naive_bayes_chain.transport,
"SVM": svm_chain.transport,
"NEIGHBORS": neighbors_chain.transport,
"CROSS_DECOMPOSITION": cross_decomposition_chain.transport,
}


@@ -34,21 +36,21 @@ def get_concrete_transporter(model):
if upper_model in MODEL_TYPE_TRANSPORTER.keys():
return upper_model, MODEL_TYPE_TRANSPORTER[upper_model]

if is_linear_model(model):
return "LINEAR_MODEL", transport_linear_model
elif is_neural_network(model):
return "NEURAL_NETWORK", transport_neural_network
elif is_decision_tree(model):
return "DECISION_TREE", transport_decision_tree
elif is_clusterer(model):
return "CLUSTERING", transport_clusterer
elif is_naive_bayes(model):
return "NAIVE_BAYES", transport_naive_bayes
elif is_svm(model):
return "SVM", transport_svm
elif is_neighbors(model):
return "NEIGHBORS", transport_neighbor
elif is_cross_decomposition(model):
return "CROSS_DECOMPOSITION", transport_cross_decomposition
if linear_chain.is_supported(model):
return "LINEAR_MODEL", linear_chain.transport
elif neural_network_chain.is_supported(model):
return "NEURAL_NETWORK", neural_network_chain.transport
elif decision_trees_chain.is_supported(model):
return "DECISION_TREE", decision_trees_chain.transport
elif clustering_chain.is_supported(model):
return "CLUSTERING", clustering_chain.transport
elif naive_bayes_chain.is_supported(model):
return "NAIVE_BAYES", naive_bayes_chain.transport
elif svm_chain.is_supported(model):
return "SVM", svm_chain.transport
elif neighbors_chain.is_supported(model):
return "NEIGHBORS", neighbors_chain.transport
elif cross_decomposition_chain.is_supported(model):
return "CROSS_DECOMPOSITION", cross_decomposition_chain.transport
else:
return None, None