Skip to content

Commit

Permalink
Change function according to huawei-noah/HEBO#61 (comment) but leave …
Browse files Browse the repository at this point in the history
…the function `default_kern_rd` in the file as well.
  • Loading branch information
Dimitri Rusin committed Nov 16, 2023
1 parent 124dc21 commit 0f3a19a
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions HEBO/hebo/models/gp/gp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import numpy as np
import torch
import torch.nn as nn
from gpytorch.kernels import (AdditiveKernel, MaternKernel, ProductKernel,
ScaleKernel)
from gpytorch.priors import GammaPrior
from torch import FloatTensor, LongTensor

from ..layers import EmbTransform
from ..util import get_random_graph

from gpytorch.kernels import MaternKernel, ScaleKernel, ProductKernel
from gpytorch.priors import GammaPrior
from gpytorch.constraints.constraints import LessThan

from ..layers import EmbTransform

class DummyFeatureExtractor(nn.Module):
def __init__(self, num_cont, num_enum, num_uniqs = None, emb_sizes = None):
Expand All @@ -43,7 +43,8 @@ def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x
kerns = []
if has_num:
ard_num_dims = x.shape[1] if ard_kernel else None
kernel = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]))
kernel = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]),
lengthscale_constraint=LessThan(5))
if ard_kernel:
lscales = kernel.lengthscale.detach().clone().view(1, -1)
for i in range(x.shape[1]):
Expand All @@ -52,19 +53,21 @@ def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x
kernel.lengthscale = lscales
kerns.append(kernel)
if has_enum:
kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim))
kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim),
lengthscale_constraint=LessThan(5))
kerns.append(kernel)
final_kern = ScaleKernel(ProductKernel(*kerns), outputscale_prior = GammaPrior(0.5, 0.5))
final_kern.outputscale = y[torch.isfinite(y)].var()
return final_kern
else:
if ard_kernel:
kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim))
kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim,
lengthscale_constraint=LessThan(5)))
else:
kernel = ScaleKernel(MaternKernel(nu = 1.5))
kernel.outputscale = y[torch.isfinite(y)].var()
return kernel

def default_kern_rd(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000, E=0.2):
'''
Get a default kernel with random decompositons. 0 <= E <=1 specifies random tree conectivity.
Expand Down

0 comments on commit 0f3a19a

Please sign in to comment.