-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
kev
committed
Jun 13, 2024
1 parent
5952968
commit b3c818a
Showing
2 changed files
with
315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
import logging | ||
from operator import add, sub, mul, truediv | ||
import gymnasium as gym | ||
import numpy as np | ||
from gymnasium import spaces | ||
from sympy import symbols, simplify, expand, Expr, pretty, Basic, Integer, Rational | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class Env(gym.Env): | ||
"""Environment for solving algebraic equations using RL.""" | ||
|
||
metadata = {"render_modes": ["human"]} | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self._setup() | ||
self.state_sympy = symbols("0") | ||
self.max_length = 10 | ||
self.reward_when_solved = 100 | ||
|
||
self.action_space = spaces.Discrete(len(self.actions)) | ||
self.observation_space = spaces.Box( | ||
low=-np.inf, high=np.inf, shape=(self.max_length,), dtype=np.float32 | ||
) | ||
|
||
def _setup(self): | ||
# Define all the terms and operations you need | ||
operations = [add, sub, mul, truediv] | ||
a, b, x = symbols("a b x") | ||
terms = [a, b] | ||
|
||
# Define main equation to solve | ||
self.eqn_main = a * x + b | ||
|
||
# Create feature dictionary: this will be used to cast state representation as a vector | ||
unique_symbols = ['a', 'b', 'Add', 'Pow', 'Mul', '-1'] | ||
self.feature_dict = {symbol: -(i + 2) for i, symbol in enumerate(unique_symbols)} | ||
|
||
# Add PAD as the largest negative number | ||
self.feature_dict['PAD'] = -(len(self.feature_dict) + 2) | ||
|
||
# Reverse feature dictionary for from_vec method | ||
self.reverse_feature_dict = {v: k for k, v in self.feature_dict.items()} | ||
|
||
illegal_actions = [(truediv, symbols("0"))] | ||
self.actions = [[op, term] for op in operations for term in terms if (op, term) not in illegal_actions] | ||
|
||
def get_expr_tree(self, expr): | ||
if isinstance(expr, Basic): | ||
if expr.is_Atom: | ||
# Handle negative integers by breaking them down into -1 * positive part | ||
if isinstance(expr, Integer) and expr < 0: | ||
return ['Mul', '-1', str(-expr)] | ||
# Handle fractions by breaking them down into products of powers | ||
elif expr.is_Rational and not expr.is_Integer: | ||
numer, denom = expr.as_numer_denom() | ||
return ['Mul', str(numer), 'Pow', str(denom), '-1'] | ||
return [str(expr)] | ||
else: | ||
# Include the function of the expression as a string | ||
nodes = [expr.func.__name__] | ||
for arg in expr.args: | ||
nodes.append(self.get_expr_tree(arg)) # Recursively get the tree for each argument | ||
return nodes | ||
else: | ||
return [str(expr)] # Base case for symbols and numbers | ||
|
||
def flatten_expr_tree(self, tree): | ||
flat_list = [] | ||
for item in tree: | ||
if isinstance(item, list): | ||
flat_list.extend(self.flatten_expr_tree(item)) | ||
else: | ||
flat_list.append(item) | ||
return flat_list | ||
|
||
def convert_to_vector(self, expr_list): | ||
vector = [] | ||
for item in expr_list: | ||
if item in self.feature_dict: | ||
vector.append(self.feature_dict[item]) | ||
else: | ||
try: | ||
# Map integers directly to their values | ||
vector.append(int(item)) | ||
except ValueError: | ||
raise ValueError(f"Unrecognized symbol: {item}") | ||
return vector | ||
|
||
def from_vec(self, vector): | ||
# Remove padding | ||
vector = [v for v in vector if v != self.feature_dict['PAD']] | ||
|
||
# Convert back to expression list | ||
expr_list = [self.reverse_feature_dict[v] if v in self.reverse_feature_dict else str(v) for v in vector] | ||
|
||
# Convert list back to SymPy expression | ||
expr_stack = [] | ||
while expr_list: | ||
token = expr_list.pop(0) | ||
if token == 'Add': | ||
right = expr_stack.pop() | ||
left = expr_stack.pop() | ||
expr_stack.append(left + right) | ||
elif token == 'Mul': | ||
right = expr_stack.pop() | ||
left = expr_stack.pop() | ||
expr_stack.append(left * right) | ||
elif token == 'Pow': | ||
right = expr_stack.pop() | ||
left = expr_stack.pop() | ||
expr_stack.append(left ** right) | ||
else: | ||
expr_stack.append(symbols(token)) | ||
|
||
return expr_stack[0] | ||
|
||
def reset(self, seed=None, options=None): | ||
super().reset(seed=seed) | ||
self.state_sympy = 0 | ||
self.state_vec = self.to_vec(self.state_sympy) | ||
return self.state_vec, {} | ||
|
||
def get_complexity_of_expression(self, expr: Expr) -> int: | ||
""" | ||
Calculate the complexity of a SymPy expression based on the number of non-padded elements. | ||
Parameters: | ||
expr (Expr): The SymPy expression whose complexity is to be calculated. | ||
Returns: | ||
int: The complexity of the expression. | ||
""" | ||
vec = self.to_vec(expr) | ||
return len([x for x in vec if x != self.feature_dict['PAD']]) | ||
|
||
def get_complexity_of_guess(self, guess: Expr) -> int: | ||
substituted_eq = self.eqn_main.subs(symbols('x'), guess) | ||
simplified_eq = simplify(substituted_eq) | ||
return self.get_complexity_of_expression(simplified_eq) | ||
|
||
def find_reward(self, state_old: Expr, state_new: Expr, is_solved) -> float: | ||
""" | ||
Calculate the reward based on the reduction in complexity of the equation | ||
with the current guess substituted. | ||
Parameters: | ||
state_old (Expr): The old state of the current guess. | ||
state_new (Expr): The new state of the current guess after taking an action. | ||
Returns: | ||
float: The reward, calculated as the reduction in complexity of the equation | ||
with the guess substituted. | ||
""" | ||
if is_solved: | ||
reward = self.reward_when_solved | ||
else: | ||
old_complexity = self.get_complexity_of_guess(state_old) | ||
new_complexity = self.get_complexity_of_guess(state_new) | ||
reward = old_complexity - new_complexity | ||
return reward | ||
|
||
def to_vec(self, state_sympy: Expr) -> np.ndarray: | ||
""" | ||
Convert a SymPy expression to a numerical vector using the defined feature dictionary. | ||
Parameters: | ||
state_sympy (Expr): The SymPy expression to convert. | ||
Returns: | ||
np.ndarray: The numerical vector representation of the expression. | ||
""" | ||
expr_tree = self.get_expr_tree(state_sympy) | ||
flat_nodes = self.flatten_expr_tree(expr_tree) | ||
vector = self.convert_to_vector(flat_nodes) | ||
|
||
# Pad the vector to the max_length | ||
vector += [self.feature_dict['PAD']] * (self.max_length - len(vector)) | ||
|
||
return np.array(vector[:self.max_length], dtype=np.float32) | ||
|
||
def is_solved(self, state_sympy): | ||
""" Env is solved when state, which is the current guess, is 0. | ||
""" | ||
eqn_after_sub = self.eqn_main.subs(symbols('x'), state_sympy) | ||
is_solved = eqn_after_sub == 0 | ||
return is_solved | ||
|
||
def is_too_long(self, next_state_as_vector): | ||
""" Checks if the current guess / solution is too long. | ||
""" | ||
padding_int = self.feature_dict['PAD'] | ||
return next_state_as_vector[-1] != padding_int | ||
|
||
def step(self, action: int): | ||
# Take step | ||
operation, term = self.actions[action] | ||
state_sympy = self.state_sympy | ||
next_state_sympy = simplify(operation(state_sympy, term)) | ||
next_state_as_vector = self.to_vec(next_state_sympy) | ||
|
||
is_solved = self.is_solved(next_state_sympy) | ||
too_long = self.is_too_long(next_state_as_vector) | ||
|
||
reward = self.find_reward(state_sympy, next_state_sympy, is_solved) | ||
|
||
# Check terminal condition | ||
terminated = bool(is_solved or too_long) | ||
truncated = False | ||
info = {} | ||
|
||
# Update | ||
self.state_sympy = next_state_sympy | ||
|
||
return next_state_as_vector, reward, terminated, truncated, info | ||
|
||
def render(self, mode: str = "human"): | ||
state = self.state_sympy | ||
print(f'x_guess = {state}') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import os | ||
import gymnasium as gym | ||
from stable_baselines3 import A2C | ||
from stable_baselines3.common.env_checker import check_env | ||
from stable_baselines3.common.evaluation import evaluate_policy | ||
from stable_baselines3.common.monitor import Monitor | ||
from stable_baselines3.common.vec_env import DummyVecEnv | ||
from stable_baselines3.common.callbacks import BaseCallback | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
# Import the custom environment | ||
from env_linear import Env | ||
|
||
class RewardLoggingCallback(BaseCallback): | ||
def __init__(self, log_interval: int = 100, save_dir: str = '.', verbose: int = 1): | ||
super(RewardLoggingCallback, self).__init__(verbose) | ||
self.log_interval = log_interval | ||
self.rewards = [] | ||
self.save_dir = save_dir | ||
|
||
def _on_step(self) -> bool: | ||
reward = self.locals["rewards"][0] | ||
self.rewards.append(reward) | ||
return True | ||
|
||
def _on_training_end(self) -> None: | ||
np.save(os.path.join(self.save_dir, 'rewards.npy'), np.array(self.rewards)) | ||
|
||
def plot_rewards(rewards, save_dir, window=100): | ||
rewards_df = pd.DataFrame(rewards, columns=['reward']) | ||
rolling_avg = rewards_df['reward'].rolling(window=window).mean() | ||
|
||
plt.figure(figsize=(12, 6)) | ||
plt.plot(rewards, 'b.', alpha=0.3, label='Rewards') | ||
plt.plot(rolling_avg, 'r-', linewidth=2, label=f'Rolling Average (window={window})') | ||
plt.xlabel('Timesteps') | ||
plt.ylabel('Reward') | ||
plt.title('Reward over Time') | ||
plt.legend() | ||
plt.grid(True) | ||
plt.savefig(os.path.join(save_dir, 'reward_plot.png')) | ||
plt.show() | ||
|
||
def main(): | ||
# Parameters | ||
Ntrain = 10**4 | ||
|
||
save_dir = 'data_linear' | ||
os.makedirs(save_dir, exist_ok=True) | ||
|
||
# Create the environment | ||
env = Env() | ||
|
||
# Wrap the environment with Monitor wrapper | ||
env = Monitor(env, filename=os.path.join(save_dir, 'monitor.csv')) | ||
|
||
# Check the environment | ||
check_env(env) | ||
|
||
# Custom callback | ||
reward_logging_callback = RewardLoggingCallback(save_dir=save_dir) | ||
|
||
# Create the A2C model | ||
model = A2C("MlpPolicy", DummyVecEnv([lambda: env]), verbose=1) | ||
|
||
# Train the model | ||
model.learn(total_timesteps=Ntrain, callback=reward_logging_callback) | ||
|
||
# Save the model | ||
model.save(os.path.join(save_dir, "a2c_solver")) | ||
|
||
# Load rewards and plot them | ||
rewards = np.load(os.path.join(save_dir, 'rewards.npy')) | ||
plot_rewards(rewards, save_dir, window=int(0.1*Ntrain)) | ||
|
||
# Evaluate the model | ||
eval_env = DummyVecEnv([lambda: Monitor(Env(), filename=os.path.join(save_dir, 'eval_monitor.csv'))]) | ||
#mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10) | ||
#print(f"Mean reward: {mean_reward} +/- {std_reward}") | ||
|
||
# Run the trained model | ||
obs, _ = env.reset() | ||
for i in range(100): | ||
action, _states = model.predict(obs, deterministic=True) | ||
obs, reward, done, truncated, info = env.step(action) | ||
env.render() | ||
if done or truncated: | ||
print("Episode finished") | ||
break | ||
|
||
if __name__ == "__main__": | ||
main() |