Skip to content

Commit

Permalink
use device arg
Browse files Browse the repository at this point in the history
function has a arg for device which defaults to 'cuda'
  • Loading branch information
hiramf authored Mar 25, 2024
1 parent bc75c9d commit 5f283be
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _mask_predict_by_feat_single(self, mask_feat, kernels, priors):
batch_size = priors.shape[0]
hw = mask_feat.size()[-2:]
coord = self.prior_generator.single_level_grid_priors(
hw, level_idx=0).to(mask_feat.device)
hw, level_idx=0, device=mask_feat.device)
coord = coord.unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1)
priors = priors.unsqueeze(2)
points = priors[..., :2]
Expand Down

0 comments on commit 5f283be

Please sign in to comment.