Skip to content

Commit

Permalink
[BUG] Fix gradient_boosting.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyxu committed Nov 11, 2020
1 parent 2c850a9 commit 474d50f
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions torchensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@ def __init__(self, estimator, n_estimators, output_dim,
# Base estimators
self.estimators_ = nn.ModuleList()
for _ in range(self.n_estimators):
self.estimators_.append(
estimator(output_dim=output_dim).to(self.device))
self.estimators_.append(estimator().to(self.device))

def forward(self, X):
batch_size = X.size()[0]
y_pred = torch.zeros(batch_size, self.output_dim).to(self.device)

# The output of `GradientBoostingClassifier` is the summation of output
# from all base estimators, with each of them multipled by the shrinkage
# rate.
# from all base estimators, with each of them multiplied by the
# shrinkage rate.
for estimator in self.estimators_:
y_pred += self.shrinkage_rate * estimator(X)

Expand Down Expand Up @@ -189,8 +188,7 @@ def __init__(self, estimator, n_estimators, output_dim,
# Base estimators
self.estimators_ = nn.ModuleList()
for _ in range(self.n_estimators):
self.estimators_.append(
estimator(output_dim=output_dim).to(self.device))
self.estimators_.append(estimator().to(self.device))

def forward(self, X):
batch_size = X.size()[0]
Expand Down

0 comments on commit 474d50f

Please sign in to comment.