Skip to content

Commit

Permalink
update hyper related
Browse files Browse the repository at this point in the history
  • Loading branch information
Min authored and Min committed Nov 24, 2023
1 parent 9561f7b commit 7e43e3a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion easygraph/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def generate_mask_tensor(mask):
"""
assert isinstance(
mask, np.ndarray
), "input for generate_mask_tensorshould be an numpy ndarray"
), "input for generate_mask_tensor should be an numpy ndarray"
return tensor(mask, dtype=data_type_dict()["bool"])


Expand Down
6 changes: 3 additions & 3 deletions easygraph/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class BaseTask:
``data`` (``dict``): The dictionary to store input data that used in the experiment.
``model_builder`` (``Callable``): The function to build a model with a fixed parameter ``trial``.
``train_builder`` (``Callable``): The function to build a training configuration with two fixed parameters ``trial`` and ``model``.
``evaluator`` (``dhg.ml_metrics.BaseEvaluator``): The EasyGraph evaluator object to evaluate performance of the model in the experiment.
``evaluator`` (``eg.ml_metrics.BaseEvaluator``): The EasyGraph evaluator object to evaluate performance of the model in the experiment.
``device`` (``torch.device``): The target device to run the experiment.
``structure_builder`` (``Optional[Callable]``): The function to build a structure with a fixed parameter ``trial``. The structure can be ``dhg.Graph``, ``dhg.DiGraph``, ``dhg.BiGraph``, and ``dhg.Hypergraph``.
``structure_builder`` (``Optional[Callable]``): The function to build a structure with a fixed parameter ``trial``. The structure can be ``eg.Graph``, ``eg.DiGraph``, ``eg.BiGraph``, and ``eg.Hypergraph``.
``study_name`` (``Optional[str]``): The name of this study. If set to ``None``, the study name will be generated automatically according to current time. Defaults to ``None``.
``overwrite`` (``bool``): The flag that whether to overwrite the existing study. Different studies are identified by the ``study_name``. Defaults to ``True``.
"""
Expand Down Expand Up @@ -199,6 +199,6 @@ def test(self, data: Optional[dict] = None, model: Optional[nn.Module] = None):
r"""Test the model.
Args:
``data`` (``dict``, optional): The input data if set to ``None``, the specified ``data`` in the intialization of the experiments will be used. Defaults to ``None``.
``data`` (``dict``, optional): The input data if set to ``None``, the specified ``data`` in the initialization of the experiments will be used. Defaults to ``None``.
``model`` (``nn.Module``, optional): The model if set to ``None``, the trained best model will be used. Defaults to ``None``.
"""
2 changes: 1 addition & 1 deletion easygraph/experiments/vertex_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test(self, data: Optional[dict] = None, model: Optional[nn.Module] = None):
r"""Test the model.
Args:
``data`` (``dict``, optional): The input data if set to ``None``, the specified ``data`` in the intialization of the experiments will be used. Defaults to ``None``.
``data`` (``dict``, optional): The input data if set to ``None``, the specified ``data`` in the initialization of the experiments will be used. Defaults to ``None``.
``model`` (``nn.Module``, optional): The model if set to ``None``, the trained best model will be used. Defaults to ``None``.
"""
if data is None:
Expand Down
6 changes: 3 additions & 3 deletions easygraph/functions/hypergraph/hypergraph_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def hypergraph_local_clustering_coefficient(H):
if len(D1.union(D2)) == 0:
eo = 0
else:
# otherwise we have to look at their neighbours
# the neighbours of D1 and D2, respectively.
# otherwise we have to look at their neighbors
# the neighbors of D1 and D2, respectively.
neighD1 = {i for d in D1 for i in H.neighbor_of_node(d)}
neighD2 = {i for d in D2 for i in H.neighbor_of_node(d)}
# compute extra overlap [len() is used for cardinality of edges]
Expand All @@ -169,7 +169,7 @@ def hypergraph_local_clustering_coefficient(H):
# add it up
total_eo = total_eo + eo

# include normalisation by degree k*(k-1)/2
# include normalization by degree k*(k-1)/2
result[n] = 2 * total_eo / (dv * (dv - 1))
return result

Expand Down

0 comments on commit 7e43e3a

Please sign in to comment.