Skip to content

Commit

Permalink
enh: add normaliztion and standardization to workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Oct 3, 2024
1 parent 24c9512 commit 1b09650
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions pdf_agents/scientific_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from botorch.acquisition import UpperConfidenceBound, qUpperConfidenceBound
from botorch.models import SingleTaskGP
from botorch.optim import optimize_acqf # noqa: F401
from botorch.utils.transforms import normalize, standardize, unnormalize
from gpytorch.mlls import ExactMarginalLogLikelihood
from scipy.spatial import distance_matrix

Expand Down Expand Up @@ -132,7 +133,9 @@ def ask(self, batch_size: int = 1):
train_x = torch.tensor(self.independent_cache, dtype=torch.double, device=self.device)
if train_x.dim() == 1:
train_x = train_x.view(-1, 1)
train_y = torch.tensor(value, dtype=torch.double, device=self.device)
norm_bounds = torch.stack([train_x.min(dim=0).values, train_x.max(dim=0).values])
train_x = normalize(train_x, norm_bounds)
train_y = standardize(torch.tensor(value, dtype=torch.double, device=self.device))
gp = SingleTaskGP(train_x, train_y).to(self.device)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp).to(self.device)
fit_gpytorch_mll(mll)
Expand All @@ -146,12 +149,15 @@ def ask(self, batch_size: int = 1):
# candidates, acq_value = optimize_acqf(
# acq, bounds=self.bounds, q=batch_size, num_restarts=self.num_restarts, raw_samples=self.raw_samples
# )
grid = torch.tensor(make_wafer_grid_list(*self.bounds.cpu().numpy().ravel(), step=self.motor_resolution))[
:, None, :
]
grid = normalize(
torch.tensor(make_wafer_grid_list(*self.bounds.cpu().numpy().ravel(), step=self.motor_resolution))[
:, None, :
],
norm_bounds,
)
acq_grid = acq(grid)
top_indicies = torch.argsort(acq_grid, descending=True, dim=0)[:batch_size]
candidates = grid[top_indicies].squeeze(1)
candidates = unnormalize(grid, norm_bounds)[top_indicies].squeeze(1)
acq_value = acq_grid[top_indicies]

if batch_size == 1:
Expand Down

0 comments on commit 1b09650

Please sign in to comment.