diff --git a/rl_equation_solver/dev_kev/env_linear.py b/rl_equation_solver/dev_kev/env_linear.py new file mode 100644 index 0000000..e4d26d2 --- /dev/null +++ b/rl_equation_solver/dev_kev/env_linear.py @@ -0,0 +1,221 @@ +import logging +from operator import add, sub, mul, truediv +import gymnasium as gym +import numpy as np +from gymnasium import spaces +from sympy import symbols, simplify, expand, Expr, pretty, Basic, Integer, Rational + +logger = logging.getLogger(__name__) + +class Env(gym.Env): + """Environment for solving algebraic equations using RL.""" + + metadata = {"render_modes": ["human"]} + + def __init__(self) -> None: + super().__init__() + self._setup() + self.state_sympy = symbols("0") + self.max_length = 10 + self.reward_when_solved = 100 + + self.action_space = spaces.Discrete(len(self.actions)) + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(self.max_length,), dtype=np.float32 + ) + + def _setup(self): + # Define all the terms and operations you need + operations = [add, sub, mul, truediv] + a, b, x = symbols("a b x") + terms = [a, b] + + # Define main equation to solve + self.eqn_main = a * x + b + + # Create feature dictionary: this will be used to cast state representation as a vector + unique_symbols = ['a', 'b', 'Add', 'Pow', 'Mul', '-1'] + self.feature_dict = {symbol: -(i + 2) for i, symbol in enumerate(unique_symbols)} + + # Add PAD as the largest negative number + self.feature_dict['PAD'] = -(len(self.feature_dict) + 2) + + # Reverse feature dictionary for from_vec method + self.reverse_feature_dict = {v: k for k, v in self.feature_dict.items()} + + illegal_actions = [(truediv, symbols("0"))] + self.actions = [[op, term] for op in operations for term in terms if (op, term) not in illegal_actions] + + def get_expr_tree(self, expr): + if isinstance(expr, Basic): + if expr.is_Atom: + # Handle negative integers by breaking them down into -1 * positive part + if isinstance(expr, Integer) and expr < 0: + return ['Mul', '-1', str(-expr)] + # Handle fractions by breaking them down into products of powers + elif expr.is_Rational and not expr.is_Integer: + numer, denom = expr.as_numer_denom() + return ['Mul', str(numer), 'Pow', str(denom), '-1'] + return [str(expr)] + else: + # Include the function of the expression as a string + nodes = [expr.func.__name__] + for arg in expr.args: + nodes.append(self.get_expr_tree(arg)) # Recursively get the tree for each argument + return nodes + else: + return [str(expr)] # Base case for symbols and numbers + + def flatten_expr_tree(self, tree): + flat_list = [] + for item in tree: + if isinstance(item, list): + flat_list.extend(self.flatten_expr_tree(item)) + else: + flat_list.append(item) + return flat_list + + def convert_to_vector(self, expr_list): + vector = [] + for item in expr_list: + if item in self.feature_dict: + vector.append(self.feature_dict[item]) + else: + try: + # Map integers directly to their values + vector.append(int(item)) + except ValueError: + raise ValueError(f"Unrecognized symbol: {item}") + return vector + + def from_vec(self, vector): + # Remove padding + vector = [v for v in vector if v != self.feature_dict['PAD']] + + # Convert back to expression list + expr_list = [self.reverse_feature_dict[v] if v in self.reverse_feature_dict else str(v) for v in vector] + + # Convert list back to SymPy expression + expr_stack = [] + while expr_list: + token = expr_list.pop(0) + if token == 'Add': + right = expr_stack.pop() + left = expr_stack.pop() + expr_stack.append(left + right) + elif token == 'Mul': + right = expr_stack.pop() + left = expr_stack.pop() + expr_stack.append(left * right) + elif token == 'Pow': + right = expr_stack.pop() + left = expr_stack.pop() + expr_stack.append(left ** right) + else: + expr_stack.append(symbols(token)) + + return expr_stack[0] + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + self.state_sympy = 0 + self.state_vec = self.to_vec(self.state_sympy) + return self.state_vec, {} + + def get_complexity_of_expression(self, expr: Expr) -> int: + """ + Calculate the complexity of a SymPy expression based on the number of non-padded elements. + + Parameters: + expr (Expr): The SymPy expression whose complexity is to be calculated. + + Returns: + int: The complexity of the expression. + """ + vec = self.to_vec(expr) + return len([x for x in vec if x != self.feature_dict['PAD']]) + + def get_complexity_of_guess(self, guess: Expr) -> int: + substituted_eq = self.eqn_main.subs(symbols('x'), guess) + simplified_eq = simplify(substituted_eq) + return self.get_complexity_of_expression(simplified_eq) + + def find_reward(self, state_old: Expr, state_new: Expr, is_solved) -> float: + """ + Calculate the reward based on the reduction in complexity of the equation + with the current guess substituted. + + Parameters: + state_old (Expr): The old state of the current guess. + state_new (Expr): The new state of the current guess after taking an action. + + Returns: + float: The reward, calculated as the reduction in complexity of the equation + with the guess substituted. + """ + if is_solved: + reward = self.reward_when_solved + else: + old_complexity = self.get_complexity_of_guess(state_old) + new_complexity = self.get_complexity_of_guess(state_new) + reward = old_complexity - new_complexity + return reward + + def to_vec(self, state_sympy: Expr) -> np.ndarray: + """ + Convert a SymPy expression to a numerical vector using the defined feature dictionary. + + Parameters: + state_sympy (Expr): The SymPy expression to convert. + + Returns: + np.ndarray: The numerical vector representation of the expression. + """ + expr_tree = self.get_expr_tree(state_sympy) + flat_nodes = self.flatten_expr_tree(expr_tree) + vector = self.convert_to_vector(flat_nodes) + + # Pad the vector to the max_length + vector += [self.feature_dict['PAD']] * (self.max_length - len(vector)) + + return np.array(vector[:self.max_length], dtype=np.float32) + + def is_solved(self, state_sympy): + """ Env is solved when state, which is the current guess, is 0. + """ + eqn_after_sub = self.eqn_main.subs(symbols('x'), state_sympy) + is_solved = eqn_after_sub == 0 + return is_solved + + def is_too_long(self, next_state_as_vector): + """ Checks if the current guess / solution is too long. + """ + padding_int = self.feature_dict['PAD'] + return next_state_as_vector[-1] != padding_int + + def step(self, action: int): + # Take step + operation, term = self.actions[action] + state_sympy = self.state_sympy + next_state_sympy = simplify(operation(state_sympy, term)) + next_state_as_vector = self.to_vec(next_state_sympy) + + is_solved = self.is_solved(next_state_sympy) + too_long = self.is_too_long(next_state_as_vector) + + reward = self.find_reward(state_sympy, next_state_sympy, is_solved) + + # Check terminal condition + terminated = bool(is_solved or too_long) + truncated = False + info = {} + + # Update + self.state_sympy = next_state_sympy + + return next_state_as_vector, reward, terminated, truncated, info + + def render(self, mode: str = "human"): + state = self.state_sympy + print(f'x_guess = {state}') + diff --git a/rl_equation_solver/dev_kev/linear_train.py b/rl_equation_solver/dev_kev/linear_train.py new file mode 100644 index 0000000..e002bc3 --- /dev/null +++ b/rl_equation_solver/dev_kev/linear_train.py @@ -0,0 +1,94 @@ +import os +import gymnasium as gym +from stable_baselines3 import A2C +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.callbacks import BaseCallback +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +# Import the custom environment +from env_linear import Env + +class RewardLoggingCallback(BaseCallback): + def __init__(self, log_interval: int = 100, save_dir: str = '.', verbose: int = 1): + super(RewardLoggingCallback, self).__init__(verbose) + self.log_interval = log_interval + self.rewards = [] + self.save_dir = save_dir + + def _on_step(self) -> bool: + reward = self.locals["rewards"][0] + self.rewards.append(reward) + return True + + def _on_training_end(self) -> None: + np.save(os.path.join(self.save_dir, 'rewards.npy'), np.array(self.rewards)) + +def plot_rewards(rewards, save_dir, window=100): + rewards_df = pd.DataFrame(rewards, columns=['reward']) + rolling_avg = rewards_df['reward'].rolling(window=window).mean() + + plt.figure(figsize=(12, 6)) + plt.plot(rewards, 'b.', alpha=0.3, label='Rewards') + plt.plot(rolling_avg, 'r-', linewidth=2, label=f'Rolling Average (window={window})') + plt.xlabel('Timesteps') + plt.ylabel('Reward') + plt.title('Reward over Time') + plt.legend() + plt.grid(True) + plt.savefig(os.path.join(save_dir, 'reward_plot.png')) + plt.show() + +def main(): + # Parameters + Ntrain = 10**4 + + save_dir = 'data_linear' + os.makedirs(save_dir, exist_ok=True) + + # Create the environment + env = Env() + + # Wrap the environment with Monitor wrapper + env = Monitor(env, filename=os.path.join(save_dir, 'monitor.csv')) + + # Check the environment + check_env(env) + + # Custom callback + reward_logging_callback = RewardLoggingCallback(save_dir=save_dir) + + # Create the A2C model + model = A2C("MlpPolicy", DummyVecEnv([lambda: env]), verbose=1) + + # Train the model + model.learn(total_timesteps=Ntrain, callback=reward_logging_callback) + + # Save the model + model.save(os.path.join(save_dir, "a2c_solver")) + + # Load rewards and plot them + rewards = np.load(os.path.join(save_dir, 'rewards.npy')) + plot_rewards(rewards, save_dir, window=int(0.1*Ntrain)) + + # Evaluate the model + eval_env = DummyVecEnv([lambda: Monitor(Env(), filename=os.path.join(save_dir, 'eval_monitor.csv'))]) + #mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10) + #print(f"Mean reward: {mean_reward} +/- {std_reward}") + + # Run the trained model + obs, _ = env.reset() + for i in range(100): + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, truncated, info = env.step(action) + env.render() + if done or truncated: + print("Episode finished") + break + +if __name__ == "__main__": + main()