Skip to content

Commit

Permalink
solved linear
Browse files Browse the repository at this point in the history
  • Loading branch information
kev committed Jun 13, 2024
1 parent 5952968 commit b3c818a
Show file tree
Hide file tree
Showing 2 changed files with 315 additions and 0 deletions.
221 changes: 221 additions & 0 deletions rl_equation_solver/dev_kev/env_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import logging
from operator import add, sub, mul, truediv
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from sympy import symbols, simplify, expand, Expr, pretty, Basic, Integer, Rational

logger = logging.getLogger(__name__)

class Env(gym.Env):
"""Environment for solving algebraic equations using RL."""

metadata = {"render_modes": ["human"]}

def __init__(self) -> None:
super().__init__()
self._setup()
self.state_sympy = symbols("0")
self.max_length = 10
self.reward_when_solved = 100

self.action_space = spaces.Discrete(len(self.actions))
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(self.max_length,), dtype=np.float32
)

def _setup(self):
# Define all the terms and operations you need
operations = [add, sub, mul, truediv]
a, b, x = symbols("a b x")
terms = [a, b]

# Define main equation to solve
self.eqn_main = a * x + b

# Create feature dictionary: this will be used to cast state representation as a vector
unique_symbols = ['a', 'b', 'Add', 'Pow', 'Mul', '-1']
self.feature_dict = {symbol: -(i + 2) for i, symbol in enumerate(unique_symbols)}

# Add PAD as the largest negative number
self.feature_dict['PAD'] = -(len(self.feature_dict) + 2)

# Reverse feature dictionary for from_vec method
self.reverse_feature_dict = {v: k for k, v in self.feature_dict.items()}

illegal_actions = [(truediv, symbols("0"))]
self.actions = [[op, term] for op in operations for term in terms if (op, term) not in illegal_actions]

def get_expr_tree(self, expr):
if isinstance(expr, Basic):
if expr.is_Atom:
# Handle negative integers by breaking them down into -1 * positive part
if isinstance(expr, Integer) and expr < 0:
return ['Mul', '-1', str(-expr)]
# Handle fractions by breaking them down into products of powers
elif expr.is_Rational and not expr.is_Integer:
numer, denom = expr.as_numer_denom()
return ['Mul', str(numer), 'Pow', str(denom), '-1']
return [str(expr)]
else:
# Include the function of the expression as a string
nodes = [expr.func.__name__]
for arg in expr.args:
nodes.append(self.get_expr_tree(arg)) # Recursively get the tree for each argument
return nodes
else:
return [str(expr)] # Base case for symbols and numbers

def flatten_expr_tree(self, tree):
flat_list = []
for item in tree:
if isinstance(item, list):
flat_list.extend(self.flatten_expr_tree(item))
else:
flat_list.append(item)
return flat_list

def convert_to_vector(self, expr_list):
vector = []
for item in expr_list:
if item in self.feature_dict:
vector.append(self.feature_dict[item])
else:
try:
# Map integers directly to their values
vector.append(int(item))
except ValueError:
raise ValueError(f"Unrecognized symbol: {item}")
return vector

def from_vec(self, vector):
# Remove padding
vector = [v for v in vector if v != self.feature_dict['PAD']]

# Convert back to expression list
expr_list = [self.reverse_feature_dict[v] if v in self.reverse_feature_dict else str(v) for v in vector]

# Convert list back to SymPy expression
expr_stack = []
while expr_list:
token = expr_list.pop(0)
if token == 'Add':
right = expr_stack.pop()
left = expr_stack.pop()
expr_stack.append(left + right)
elif token == 'Mul':
right = expr_stack.pop()
left = expr_stack.pop()
expr_stack.append(left * right)
elif token == 'Pow':
right = expr_stack.pop()
left = expr_stack.pop()
expr_stack.append(left ** right)
else:
expr_stack.append(symbols(token))

return expr_stack[0]

def reset(self, seed=None, options=None):
super().reset(seed=seed)
self.state_sympy = 0
self.state_vec = self.to_vec(self.state_sympy)
return self.state_vec, {}

def get_complexity_of_expression(self, expr: Expr) -> int:
"""
Calculate the complexity of a SymPy expression based on the number of non-padded elements.
Parameters:
expr (Expr): The SymPy expression whose complexity is to be calculated.
Returns:
int: The complexity of the expression.
"""
vec = self.to_vec(expr)
return len([x for x in vec if x != self.feature_dict['PAD']])

def get_complexity_of_guess(self, guess: Expr) -> int:
substituted_eq = self.eqn_main.subs(symbols('x'), guess)
simplified_eq = simplify(substituted_eq)
return self.get_complexity_of_expression(simplified_eq)

def find_reward(self, state_old: Expr, state_new: Expr, is_solved) -> float:
"""
Calculate the reward based on the reduction in complexity of the equation
with the current guess substituted.
Parameters:
state_old (Expr): The old state of the current guess.
state_new (Expr): The new state of the current guess after taking an action.
Returns:
float: The reward, calculated as the reduction in complexity of the equation
with the guess substituted.
"""
if is_solved:
reward = self.reward_when_solved
else:
old_complexity = self.get_complexity_of_guess(state_old)
new_complexity = self.get_complexity_of_guess(state_new)
reward = old_complexity - new_complexity
return reward

def to_vec(self, state_sympy: Expr) -> np.ndarray:
"""
Convert a SymPy expression to a numerical vector using the defined feature dictionary.
Parameters:
state_sympy (Expr): The SymPy expression to convert.
Returns:
np.ndarray: The numerical vector representation of the expression.
"""
expr_tree = self.get_expr_tree(state_sympy)
flat_nodes = self.flatten_expr_tree(expr_tree)
vector = self.convert_to_vector(flat_nodes)

# Pad the vector to the max_length
vector += [self.feature_dict['PAD']] * (self.max_length - len(vector))

return np.array(vector[:self.max_length], dtype=np.float32)

def is_solved(self, state_sympy):
""" Env is solved when state, which is the current guess, is 0.
"""
eqn_after_sub = self.eqn_main.subs(symbols('x'), state_sympy)
is_solved = eqn_after_sub == 0
return is_solved

def is_too_long(self, next_state_as_vector):
""" Checks if the current guess / solution is too long.
"""
padding_int = self.feature_dict['PAD']
return next_state_as_vector[-1] != padding_int

def step(self, action: int):
# Take step
operation, term = self.actions[action]
state_sympy = self.state_sympy
next_state_sympy = simplify(operation(state_sympy, term))
next_state_as_vector = self.to_vec(next_state_sympy)

is_solved = self.is_solved(next_state_sympy)
too_long = self.is_too_long(next_state_as_vector)

reward = self.find_reward(state_sympy, next_state_sympy, is_solved)

# Check terminal condition
terminated = bool(is_solved or too_long)
truncated = False
info = {}

# Update
self.state_sympy = next_state_sympy

return next_state_as_vector, reward, terminated, truncated, info

def render(self, mode: str = "human"):
state = self.state_sympy
print(f'x_guess = {state}')

94 changes: 94 additions & 0 deletions rl_equation_solver/dev_kev/linear_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import gymnasium as gym
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Import the custom environment
from env_linear import Env

class RewardLoggingCallback(BaseCallback):
def __init__(self, log_interval: int = 100, save_dir: str = '.', verbose: int = 1):
super(RewardLoggingCallback, self).__init__(verbose)
self.log_interval = log_interval
self.rewards = []
self.save_dir = save_dir

def _on_step(self) -> bool:
reward = self.locals["rewards"][0]
self.rewards.append(reward)
return True

def _on_training_end(self) -> None:
np.save(os.path.join(self.save_dir, 'rewards.npy'), np.array(self.rewards))

def plot_rewards(rewards, save_dir, window=100):
rewards_df = pd.DataFrame(rewards, columns=['reward'])
rolling_avg = rewards_df['reward'].rolling(window=window).mean()

plt.figure(figsize=(12, 6))
plt.plot(rewards, 'b.', alpha=0.3, label='Rewards')
plt.plot(rolling_avg, 'r-', linewidth=2, label=f'Rolling Average (window={window})')
plt.xlabel('Timesteps')
plt.ylabel('Reward')
plt.title('Reward over Time')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(save_dir, 'reward_plot.png'))
plt.show()

def main():
# Parameters
Ntrain = 10**4

save_dir = 'data_linear'
os.makedirs(save_dir, exist_ok=True)

# Create the environment
env = Env()

# Wrap the environment with Monitor wrapper
env = Monitor(env, filename=os.path.join(save_dir, 'monitor.csv'))

# Check the environment
check_env(env)

# Custom callback
reward_logging_callback = RewardLoggingCallback(save_dir=save_dir)

# Create the A2C model
model = A2C("MlpPolicy", DummyVecEnv([lambda: env]), verbose=1)

# Train the model
model.learn(total_timesteps=Ntrain, callback=reward_logging_callback)

# Save the model
model.save(os.path.join(save_dir, "a2c_solver"))

# Load rewards and plot them
rewards = np.load(os.path.join(save_dir, 'rewards.npy'))
plot_rewards(rewards, save_dir, window=int(0.1*Ntrain))

# Evaluate the model
eval_env = DummyVecEnv([lambda: Monitor(Env(), filename=os.path.join(save_dir, 'eval_monitor.csv'))])
#mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10)
#print(f"Mean reward: {mean_reward} +/- {std_reward}")

# Run the trained model
obs, _ = env.reset()
for i in range(100):
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
env.render()
if done or truncated:
print("Episode finished")
break

if __name__ == "__main__":
main()

0 comments on commit b3c818a

Please sign in to comment.