Skip to content

Commit

Permalink
Re-implement play mode (#80)
Browse files Browse the repository at this point in the history
* play mode

* fix play mode default tree_type to match menu (g)

* make display_list more legible

* Apply suggestions from code review

Co-authored-by: Ezio Melotti <[email protected]>

* fix variable names and isinstance

Co-authored-by: Ezio Melotti <[email protected]>
  • Loading branch information
granawkins and ezio-melotti authored Jul 6, 2022
1 parent db2c47a commit addcdb5
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 17 deletions.
23 changes: 14 additions & 9 deletions karoo-gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
import pandas as pd

from karoo_gp import pause as menu
from karoo_gp import __version__, MultiClassifierGP, RegressorGP, MatchingGP
from karoo_gp import __version__, MultiClassifierGP, RegressorGP, MatchingGP, BaseGP

#++++++++++++++++++++++++++++++++++++++++++
# User Interface for Configuation |
Expand Down Expand Up @@ -102,7 +102,7 @@
try:
query = input('\t Select (f)ull or (g)row (default g): ')
if query in ['f', 'g', '']:
tree_type = query or 'f'
tree_type = query or 'g'
break
else:
raise ValueError()
Expand Down Expand Up @@ -134,7 +134,7 @@
tree_pop_max = 1
gen_max = 1
tourn_size = 0
display = 'm'
display = 's' # for play mode, initialize, print fittest tree and quit
# evolve_repro, evolve_point, evolve_branch, evolve_cross,
# tourn_size, precision, filename are not required

Expand Down Expand Up @@ -560,7 +560,9 @@ def fx_karoo_pause(model):
#++++++++++++++++++++++++++++++++++++++++++

# Select the correct class for kernel
cls = {'c': MultiClassifierGP, 'r': RegressorGP, 'm': MatchingGP}[kernel]
cls = {
'c': MultiClassifierGP, 'r': RegressorGP, 'm': MatchingGP, 'p': BaseGP
}[kernel]

# Initialize the model
gp = cls(
Expand Down Expand Up @@ -590,11 +592,14 @@ def fx_karoo_pause(model):
# Fit to the data
gp.fit(X, y)

# TODO: Relocated from BaseGP.__init__(). Need to test/debug
# if self.kernel == 'p':
# self.fx_data_tree_write(self.population.trees, 'a')
# sys.exit()

if kernel == 'p':
tree = gp.population.trees[0]
print(f'\nTree ID {tree.id}')
print(f' yields (raw): {tree.raw_expression}')
print(f' yields (sym): {tree.expression}\n')
print(gp.population.trees[0].display(method='viz'))
print(gp.population.trees[0].display(method='list'))
# self.fx_data_tree_write(self.population.trees, 'a')

# Save files and exit
gp.fx_karoo_terminate()
8 changes: 5 additions & 3 deletions karoo_gp/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,11 @@ def log(self, msg, display={'i', 'g', 'm', 'db'}):
def pause(self, display={'i', 'g', 'm', 'db'}):
if not self.pause_callback:
self.log('No pause callback function provided')
return
if self.display in display or display is None:
return 0
elif self.display in display:
self.pause_callback(self)
else:
return 0

def error(self, msg, display={'i', 'g', 'm', 'db'}):
self.log(msg, display)
Expand Down Expand Up @@ -343,7 +345,7 @@ def fx_karoo_terminate(self):
used via ContextManager, or manually.
'''
kernel = {
RegressorGP: 'r', MultiClassifierGP: 'c', MatchingGP: 'm'
RegressorGP: 'r', MultiClassifierGP: 'c', MatchingGP: 'm', BaseGP: 'p'
}[type(self)]
self.fx_data_params_write(kernel)
self.fx_data_params_write_json(kernel)
Expand Down
14 changes: 12 additions & 2 deletions karoo_gp/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, node, tree_type, parent=None):
self.parent = parent
self.children = None
self.bfs_ref = None
self.id = None

@classmethod
def load(cls, expr: str, tree_type, parent=None):
Expand Down Expand Up @@ -230,9 +231,18 @@ def display(self, *args, method='viz', **kwargs):
return self.display_viz(*args, **kwargs)

def display_list(self, prefix=''):
output = prefix + repr(self.node) + '\n'
node_type = 'term' if isinstance(self.node, Terminal) else 'func'
symbol = self.node.symbol
parent = '' if self.parent is None else self.parent.id
arity = 0 if node_type == 'term' else self.node.arity
children = [] if not self.children else [c.id for c in self.children]
output = (
f'{prefix}NODE ID: {self.id}\n'
f'{prefix} type: {node_type}\n'
f'{prefix} label: {symbol}\tparent node: {parent}\n'
f'{prefix} arity: {arity}\tchild node(s): {children}\n\n')
if self.children:
output += ''.join(child.display_list(prefix=prefix+' ')
output += ''.join(child.display_list(prefix=prefix+'\t')
for child in self.children)
return output

Expand Down
6 changes: 5 additions & 1 deletion karoo_gp/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def _type_check(self, trees):
#++++++++++++++++++++++++++++

def inf_to_zero_divide(a, b):
return np.where(b==0, 0, a / b)
# This continued to raise 'RuntimeWarning: divide by zero' errors. This
# is a proposed solution. ref: https://stackoverflow.com/a/64747978
with np.errstate(divide='ignore'):
out = np.where(b==0, 0., a / b)
return out

class NumpyEngine(Engine):
def __init__(self, model):
Expand Down
13 changes: 11 additions & 2 deletions karoo_gp/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, id, root, tree_type='g', score=None):
self.root = root # The top Branch (depth = 0)
self.tree_type = tree_type
self.score = score or {}
self.bfs_ref = None
self.renumber()

@classmethod
def load(cls, id, expr, tree_type='f'):
Expand Down Expand Up @@ -105,14 +105,20 @@ def get_child(self, i_child, **kwargs):
# Modify |
#++++++++++++++++++++++++++++

def renumber(self, method='BFS'):
"""Set the id of each branch of the subtree"""
self.root.bfs_ref = None
for i in range(0, self.n_children + 1):
self.get_child(i, method=method).id = i

def set_child(self, i_child, branch, **kwargs):
if i_child == 0:
self.root = branch
n_ch = self.n_children
if i_child > n_ch:
raise ValueError(f'Index "{i_child}" out of range ({n_ch}')
self.root.set_child(i_child, branch, **kwargs)
self.bfs_ref = None
self.renumber()

def point_mutate(self, rng, functions, terminals, log):
"""Replace a node (including root) with random node of same type"""
Expand Down Expand Up @@ -155,8 +161,10 @@ def prune(self, rng, terminals, max_depth):
return
elif max_depth == 0 and type(self.root.node) != Terminal: # Replace the root
self.root = Branch(rng.choice(terminals.get()), self.tree_type)
self.renumber()
elif max_depth == 1: # Prune the root
self.root.prune(rng, terminals)
self.renumber()
else: # Cycle through (BFS order), prune second-to-last depth
last_depth_nodes = [self.root]
for d in range(max_depth - 1, 0, -1):
Expand All @@ -168,6 +176,7 @@ def prune(self, rng, terminals, max_depth):
if d == 1: # second to last row
for node in this_depth_nodes:
node.prune(rng, terminals)
self.renumber()

def crossover(self, i, mate, i_mate, rng, terminals, tree_depth_max,
log, pause):
Expand Down

0 comments on commit addcdb5

Please sign in to comment.