Skip to content

Commit

Permalink
mark todos
Browse files Browse the repository at this point in the history
  • Loading branch information
justheuristic committed Jan 12, 2024
1 parent 8598f78 commit b03ed91
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 8 deletions.
4 changes: 1 addition & 3 deletions notebooks/aq_simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
"in_group_size = 8\n",
"batch_size = 16384\n",
"beam_size = 1\n",
"rrr_rank = 0\n",
"beam_search_epochs = 100\n",
"sparsity_regularizer = 0\n",
"print_frequency = 10\n",
Expand Down Expand Up @@ -82,7 +81,6 @@
" \"scale_nbits\": scale_nbits,\n",
" \"beam_search_epochs\": beam_search_epochs,\n",
" \"sparsity_regularizer\": sparsity_regularizer,\n",
" \"rrr_rank\": rrr_rank,\n",
" \"init_max_iter\": init_max_iter,\n",
" }\n",
")"
Expand Down Expand Up @@ -119,7 +117,7 @@
"source": [
"quantized_weight = QuantizedWeight(\n",
" XTX=XTX, reference_weight=reference_weight, num_codebooks=num_codebooks,\n",
" nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, rrr_rank=rrr_rank,\n",
" nbits_per_codebook=nbits_per_codebook, scale_nbits=scale_nbits, \n",
" out_group_size=out_group_size, in_group_size=in_group_size,\n",
" verbose=True, max_iter=init_max_iter, # faster init, not tested\n",
")\n",
Expand Down
3 changes: 1 addition & 2 deletions src/aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def __init__(self, quantized_weight, bias: Optional[nn.Parameter]):
self.bias = bias

def forward(self, input: torch.Tensor):
# TODO this can be optimized! (after we're sure the idea works)
# TODO maybe integrate with QuantizedLinear?
# TODO[aqlm] this can be optimized! maybe integrate with QuantizedLinear?
return F.linear(input, self.quantized_weight(), self.bias)


Expand Down
2 changes: 1 addition & 1 deletion src/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def fit_kmeans_1d(
- indices are integers [0, k) in the same shape as data; they denote the index of the nearest centroid
- restored_data is a floating point tensor in the same shape as data; they are dequantized(quantized(data))
:note: to reconstruct clusters manually, call clusters.gather(-1, indices)
:TODO: torch.jit.script / torch.compile
:TODO[aqlm]: torch.jit.script / torch.compile
"""
assert groupwise_data.ndim == 2
assert 0 <= offset_rate < 0.5
Expand Down
2 changes: 1 addition & 1 deletion src/modelutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_model(model_path, load_quantized=None, dtype="auto", model_seqlen=2048):
low_cpu_mem_usage=True,
local_files_only=True,
)
# Please verify correcttess #TODO
# Please verify correcttess #TODO[aqlm]
model.seqlen = model_seqlen
print("Model loaded sucessfully ...")

Expand Down
2 changes: 1 addition & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def using_tf32(enabled: bool):
was_cudnn = torch.backends.cudnn.allow_tf32
was_matmul = torch.backends.cuda.matmul.allow_tf32
torch.backends.cudnn.allow_tf32 = enabled
torch.backends.cuda.matmul.allow_tf32 = enabled # TODO unhardcode
torch.backends.cuda.matmul.allow_tf32 = enabled
yield
torch.backends.cudnn.allow_tf32 = was_cudnn
torch.backends.cuda.matmul.allow_tf32 = was_matmul
Expand Down

0 comments on commit b03ed91

Please sign in to comment.