Skip to content

Commit

Permalink
replace torch.Tensor with torch.empty (NVIDIA#1578)
Browse files Browse the repository at this point in the history
* replace torch.Tensor with torch.empty

* nit

* nit
  • Loading branch information
NouamaneTazi authored Feb 13, 2023
1 parent 0145d69 commit ba027dd
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 43 deletions.
10 changes: 5 additions & 5 deletions apex/RNN/RNNBackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,17 @@ def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_stat
self.gate_size = gate_multiplier * self.hidden_size
self.n_hidden_states = n_hidden_states

self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))
self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size))
self.w_ih = nn.Parameter(torch.empty(self.gate_size, self.input_size))
self.w_hh = nn.Parameter(torch.empty(self.gate_size, self.output_size))

#Check if there's recurrent projection
if(self.output_size != self.hidden_size):
self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))
self.w_ho = nn.Parameter(torch.empty(self.output_size, self.hidden_size))

self.b_ih = self.b_hh = None
if self.bias:
self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))
self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))
self.b_ih = nn.Parameter(torch.empty(self.gate_size))
self.b_hh = nn.Parameter(torch.empty(self.gate_size))

#hidden states for forward
self.hidden = [ None for states in range(self.n_hidden_states)]
Expand Down
4 changes: 2 additions & 2 deletions apex/RNN/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self, input_size, hidden_size, bias = False, output_size = None):
gate_multiplier = 4
super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)

self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size))
self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size))
self.w_mih = nn.Parameter(torch.empty(self.output_size, self.input_size))
self.w_mhh = nn.Parameter(torch.empty(self.output_size, self.output_size))

self.reset_parameters()

Expand Down
4 changes: 2 additions & 2 deletions apex/amp/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ def variable_is_tensor():
return isinstance(v, torch.Tensor)

def tensor_is_variable():
x = torch.Tensor()
x = torch.empty()
return type(x) == torch.autograd.Variable

# False for post-0.4
def tensor_is_float_tensor():
x = torch.Tensor()
x = torch.empty()
return type(x) == torch.FloatTensor

# Akin to `torch.is_tensor`, but returns True for Variable
Expand Down
4 changes: 2 additions & 2 deletions apex/contrib/layer_norm/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class FastLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.epsilon = eps
self.weight = torch.nn.Parameter(torch.Tensor(hidden_size))
self.bias = torch.nn.Parameter(torch.Tensor(hidden_size))
self.weight = torch.nn.Parameter(torch.empty(hidden_size))
self.bias = torch.nn.Parameter(torch.empty(hidden_size))
self.reset_parameters()

def reset_parameters(self):
Expand Down
16 changes: 8 additions & 8 deletions apex/contrib/multihead_attn/encdec_multihead_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_a
self.impl = impl
self.scaling = self.head_dim ** -0.5

self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_kv = Parameter(torch.Tensor(2 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_q = Parameter(torch.empty(embed_dim, embed_dim))
self.in_proj_weight_kv = Parameter(torch.empty(2 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
if self.bias:
assert impl != "fast", "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_q = Parameter(torch.empty(embed_dim))
self.in_proj_bias_kv = Parameter(torch.empty(2 * embed_dim))
self.out_proj_bias = Parameter(torch.empty(embed_dim))
else:
self.register_parameter("in_proj_bias_q", None)
self.register_parameter("in_proj_bias_kv", None)
Expand All @@ -52,8 +52,8 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_a
self.out_proj_bias = None
if self.include_norm_add:
if impl == "fast":
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter("lyr_norm_gamma_weights", None)
Expand Down
24 changes: 12 additions & 12 deletions apex/contrib/multihead_attn/self_multihead_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,20 @@ def __init__(
impl == "fast" and bias
), "additive mask not supported for fast mode without bias"
if separate_qkv_params:
self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.q_weight = Parameter(torch.empty(embed_dim, embed_dim))
self.k_weight = Parameter(torch.empty(embed_dim, embed_dim))
self.v_weight = Parameter(torch.empty(embed_dim, embed_dim))
else:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
if self.bias:
if separate_qkv_params:
self.q_bias = Parameter(torch.Tensor(embed_dim))
self.k_bias = Parameter(torch.Tensor(embed_dim))
self.v_bias = Parameter(torch.Tensor(embed_dim))
self.q_bias = Parameter(torch.empty(embed_dim))
self.k_bias = Parameter(torch.empty(embed_dim))
self.v_bias = Parameter(torch.empty(embed_dim))
else:
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
self.out_proj_bias = Parameter(torch.empty(embed_dim))
else:
if separate_qkv_params:
self.register_parameter("q_bias", None)
Expand All @@ -82,8 +82,8 @@ def __init__(
self.out_proj_bias = None
if self.include_norm_add:
if impl == "fast":
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter("lyr_norm_gamma_weights", None)
Expand Down
6 changes: 3 additions & 3 deletions apex/contrib/sparsity/sparse_masklib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def compute_valid_1d_patterns(m,n):
if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist()))))
valid_patterns = torch.empty(list(set(permutations(patterns.tolist()))))
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
return valid_patterns

Expand Down Expand Up @@ -109,10 +109,10 @@ def compute_valid_2d_patterns(m,n):
patterns[:n] = 1
patterns = list(set(permutations(patterns.tolist())))
patterns = patterns + patterns
patterns = torch.Tensor(list(set(permutations(patterns,m))))
patterns = torch.empty(list(set(permutations(patterns,m))))

valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1)
valid_patterns = torch.Tensor(valid.shape[0],m,m)
valid_patterns = torch.empty(valid.shape[0],m,m)
valid_patterns[:] = patterns[valid[:]]

if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns
Expand Down
12 changes: 6 additions & 6 deletions apex/fused_dense/fused_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__(self, in_features, out_features, bias=True):
super(FusedDense, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.weight = nn.Parameter(torch.empty(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
self.bias = nn.Parameter(torch.empty(out_features))
else:
#assert False, "no-bias option not added yet"
self.register_parameter('bias', None)
Expand All @@ -86,10 +86,10 @@ def __init__(self, in_features, intermediate_features, out_features, bias=True):
self.in_features = in_features
self.intermediate_features = intermediate_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.Tensor(intermediate_features, in_features))
self.bias1 = nn.Parameter(torch.Tensor(intermediate_features))
self.weight2 = nn.Parameter(torch.Tensor(out_features, intermediate_features))
self.bias2 = nn.Parameter(torch.Tensor(out_features))
self.weight1 = nn.Parameter(torch.empty(intermediate_features, in_features))
self.bias1 = nn.Parameter(torch.empty(intermediate_features))
self.weight2 = nn.Parameter(torch.empty(out_features, intermediate_features))
self.bias2 = nn.Parameter(torch.empty(out_features))

def forward(self, input):
return _fused_dense_gelu_dense(input, self.weight1, self.bias1, self.weight2, self.bias2)
6 changes: 3 additions & 3 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.empty(*normalized_shape))
self.bias = Parameter(torch.empty(*normalized_shape))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
Expand Down Expand Up @@ -369,7 +369,7 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.empty(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()
Expand Down

0 comments on commit ba027dd

Please sign in to comment.