diff --git a/kronfluence/score/dot_product.py b/kronfluence/score/dot_product.py index 8467e00..cd96f3e 100644 --- a/kronfluence/score/dot_product.py +++ b/kronfluence/score/dot_product.py @@ -114,7 +114,8 @@ def compute_dot_products_with_loader( if score_args.compute_per_token_scores: raise RuntimeError(DIMENSION_NOT_MATCH_ERROR_MSG) from exc raise - score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) + pairwise_scores = pairwise_scores.cpu() + score_chunks[ALL_MODULE_NAME].append(pairwise_scores) accumulate_iterations(model=model, tracked_module_names=tracked_module_names) if state.use_distributed and total_steps % DISTRIBUTED_SYNC_INTERVAL == 0: