From eba9f4f83e1264473ae31307ee0e43a13aefd9eb Mon Sep 17 00:00:00 2001 From: steviestevepy Date: Tue, 9 Feb 2021 15:10:28 +0000 Subject: [PATCH] Fixed for cuda implementation --- examples/adaptive/main.py | 9 +++++---- pytorch_neat/adaptive_linear_net.py | 6 ++++-- pytorch_neat/aggregations.py | 18 ++++++++++++++++-- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/examples/adaptive/main.py b/examples/adaptive/main.py index fba6e04..3970e52 100644 --- a/examples/adaptive/main.py +++ b/examples/adaptive/main.py @@ -18,7 +18,7 @@ import click import neat -# import torch +import torch import numpy as np from pytorch_neat import t_maze @@ -29,6 +29,7 @@ batch_size = 4 DEBUG = True +DEVICE = "cuda:0" def make_net(genome, config, _batch_size): @@ -43,7 +44,7 @@ def make_net(genome, config, _batch_size): batch_size=batch_size, activation=tanh_activation, output_activation=tanh_activation, - device="cpu", + device=DEVICE, ) @@ -52,13 +53,13 @@ def activate_net(net, states, debug=False, step_num=0): print("\n" + "=" * 20 + " DEBUG " + "=" * 20) print(net.delta_w_node) print("W init: ", net.input_to_output[0]) - outputs = net.activate(states).numpy() + outputs = net.activate(states).numpy() if DEVICE == "cpu" else net.activate(states) if debug and (step_num - 1) % 100 == 0: print("\nStep {}".format(step_num - 1)) print("Outputs: ", outputs[0]) print("Delta W: ", net.delta_w[0]) print("W: ", net.input_to_output[0]) - return np.argmax(outputs, axis=1) + return np.argmax(outputs, axis=1) if DEVICE == "cpu" else torch.argmax(outputs, dim=1) @click.command() diff --git a/pytorch_neat/adaptive_linear_net.py b/pytorch_neat/adaptive_linear_net.py index 0d38203..442fed3 100644 --- a/pytorch_neat/adaptive_linear_net.py +++ b/pytorch_neat/adaptive_linear_net.py @@ -104,6 +104,8 @@ def activate(self, inputs): inputs, dtype=torch.float32, device=self.device ).unsqueeze(2) + if device == "cuda:0" : self.input_to_output = self.input_to_output.to(device) #additional for CUDA + outputs = self.activation(self.input_to_output.matmul(inputs)) input_activs = inputs.transpose(1, 2).expand( @@ -128,8 +130,8 @@ def activate(self, inputs): ) self.delta_w = delta_w - - self.input_to_output[self.w_expressed] += delta_w[self.w_expressed] + + self.input_to_output[self.w_expressed] += delta_w[self.w_expressed].to(device) clamp_weights_( self.input_to_output, weight_threshold=0.0, weight_max=self.weight_max ) diff --git a/pytorch_neat/aggregations.py b/pytorch_neat/aggregations.py index d9660d9..ca19ca3 100644 --- a/pytorch_neat/aggregations.py +++ b/pytorch_neat/aggregations.py @@ -17,11 +17,25 @@ def sum_aggregation(inputs): - return sum(inputs) + validatedInputs = [] + try: + for tens in inputs: + tens.to("cuda:0") + validatedInputs.append(tens) + except Exception as e: + print(f"The following exeption occured: {str(e)}") + return sum(validatedInputs) def prod_aggregation(inputs): - return reduce(mul, inputs, 1) + validatedInputs = [] + try: + for tens in inputs: + tens.to("cuda:0") + validatedInputs.append(tens) + except Exception as e: + print(f"The following exeption occured: {str(e)}") + return reduce(mul, validatedInputs, 1) str_to_aggregation = {