Skip to content

Commit

Permalink
We need to ignore sampled learning problem if they lead Top to have 0…
Browse files Browse the repository at this point in the history
… quality. It pertains to #447
  • Loading branch information
Demirrr committed Oct 18, 2024
1 parent 48f9775 commit 4c87ca1
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions ontolearn/learners/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
import torch
from ontolearn.data_struct import PrepareBatchOfTraining, PrepareBatchOfPrediction
from tqdm import tqdm
from owlapy.utils import OWLClassExpressionLengthMetric
from ..utils.static_funcs import make_iterable_verbose
from owlapy.utils import get_expression_length


class Drill(RefinementBasedConceptLearner): # pragma: no cover
Expand Down Expand Up @@ -173,7 +173,8 @@ def initialize_training_class_expression_learning_problem(self,
neg: FrozenSet[OWLNamedIndividual]) -> RL_State:
""" Initialize """
assert isinstance(pos, frozenset) and isinstance(neg, frozenset), "Pos and neg must be sets"
assert 0 < len(pos) and 0 < len(neg)
assert 0 < len(pos) and 0 < len(neg), ("Positive and negative examples must have at least a single item\n"
"fCurrently: Pos:len(pos)\t Neg:len(neg)\n")
# print("Initializing learning problem")
# (2) Obtain embeddings of positive and negative examples.
self.init_embeddings_of_examples(pos_uri=pos, neg_uri=neg)
Expand Down Expand Up @@ -247,7 +248,7 @@ def train(self, dataset: Optional[Iterable[Tuple[str, Set, Set]]] = None,
"""
if isinstance(self.heuristic_func, CeloeBasedReward):
print("No training")
print("No training...")
return self.terminate_training()

if self.verbose > 0:
Expand All @@ -257,6 +258,9 @@ def train(self, dataset: Optional[Iterable[Tuple[str, Set, Set]]] = None,
else:
training_data = self.generate_learning_problems(num_of_target_concepts,
num_learning_problems)
if isinstance(training_data,Iterable) is False:
print(f"We couldn't generate training data on this given knowledge base ({self.kb})")
return self.terminate_training()

for (target_owl_ce, positives, negatives) in training_data:
print(f"\nGoal Concept:\t {target_owl_ce}\tE^+:[{len(positives)}]\t E^-:[{len(negatives)}]")
Expand Down Expand Up @@ -319,7 +323,7 @@ def fit(self, learning_problem: PosNegLPStandard, max_runtime=None):
root_state = self.initialize_training_class_expression_learning_problem(pos=learning_problem.pos,
neg=learning_problem.neg)
self.operator.set_input_examples(pos=learning_problem.pos, neg=learning_problem.neg)
assert root_state.quality > 0, f"Root state {root_state} must have quality >0"
assert root_state.quality > 0, f"Root state {root_state} must have the quality >0"
# (5) Add root state into search tree
root_state.heuristic = root_state.quality
self.search_tree.add(root_state)
Expand All @@ -337,7 +341,7 @@ def fit(self, learning_problem: PosNegLPStandard, max_runtime=None):
for _ in make_iterable_verbose(range(0, self.iter_bound),
verbose=self.verbose,
desc=f"Learning OWL Class Expression at most {self.iter_bound} iteration"):
assert len(self.search_tree) > 0
assert len(self.search_tree) > 0, "Search Tree cannot be empty!"
self.search_tree.show_current_search_tree()
# (6.1) Get the most fitting RL-state.
most_promising = self.next_node_to_expand()
Expand Down Expand Up @@ -419,8 +423,7 @@ def create_rl_state(self, c: OWLClassExpression, parent_node: Optional[RL_State]
is_root: bool = False) -> RL_State:
""" Create an RL_State instance."""
rl_state = RL_State(c, parent_node=parent_node, is_root=is_root)
# TODO: Will be fixed by https://github.com/dice-group/owlapy/issues/35
rl_state.length = OWLClassExpressionLengthMetric.get_default().length(c)
rl_state.length = get_expression_length(c)
return rl_state

def compute_quality_of_class_expression(self, state: RL_State) -> None:
Expand Down Expand Up @@ -455,8 +458,8 @@ def sequence_of_actions(self, root_rl_state: RL_State) \
current_state = root_rl_state
path_of_concepts = []
rewards = []
assert current_state.quality > 0
assert current_state.heuristic is None
assert current_state.quality > 0, f"Root state ({current_state}) must have quality >0. \tCurrently {current_state.quality}"
assert current_state.heuristic is None,f"Root state ({current_state}) must have heuristic value >0 . \tCurrently {current_state.heuristic}"
# (1)
for _ in range(self.num_of_sequential_actions):
assert isinstance(current_state, RL_State)
Expand Down Expand Up @@ -745,11 +748,15 @@ def generate_learning_problems(self,
individuals_j = set(self.kb.individuals(j))
if len(individuals_j) < size_of_examples:
continue

# Generate Learning problems from a single target
for _ in range(num_of_target_concepts):
lp = (str_dl_concept_i,
set(random.sample(individuals_i, size_of_examples)),
set(random.sample(individuals_j, size_of_examples)))
sampled_positives = set(random.sample(individuals_i, size_of_examples))
sampled_negatives = set(random.sample(individuals_j, size_of_examples))
if sampled_negatives== sampled_positives:
print("Sampled Positives and negatives are same. We need to ignore this example")
continue
lp = (str_dl_concept_i,sampled_positives,sampled_negatives)
examples.append(lp)
counter += 1
if counter == num_learning_problems:
Expand Down

0 comments on commit 4c87ca1

Please sign in to comment.