Skip to content

Commit

Permalink
remove prints
Browse files Browse the repository at this point in the history
  • Loading branch information
root authored and root committed Aug 14, 2024
1 parent 8873663 commit 8d0107a
Showing 1 changed file with 1 addition and 12 deletions.
13 changes: 1 addition & 12 deletions egs/reazonspeech/ASR/zipformer/streaming_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
Usage:
./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192
"""

import pdb
Expand Down Expand Up @@ -456,9 +457,7 @@ def decode_one_chunk(
states.append(stream.states)
processed_lens.append(stream.done_frames)

print(feature_lens)
feature_lens = torch.tensor(feature_lens, device=model.device)
print(feature_lens)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)

# Make sure the length after encoder_embed is at least 1.
Expand Down Expand Up @@ -522,7 +521,6 @@ def decode_one_chunk(
# finished_streams.append(i)
finished_streams.append(i)

print(finished_streams)
return finished_streams


Expand Down Expand Up @@ -618,8 +616,6 @@ def decode_dataset(
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")

print("cuts processed finished")
print(len(decode_streams))
# decode final chunks of last sequences
while len(decode_streams):
# print("INSIDE LEN DECODE STREAMS")
Expand Down Expand Up @@ -691,9 +687,6 @@ def save_results(
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
print("error stats")
print("results")
print(results)
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
Expand Down Expand Up @@ -871,8 +864,6 @@ def main():

test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]
print('test cuts')
print(test_cuts)

for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
Expand All @@ -882,8 +873,6 @@ def main():
sp=sp,
decoding_graph=decoding_graph,
)
print(r"esults_dict")
print(results_dict)
save_results(
params=params,
test_set_name=test_set,
Expand Down

0 comments on commit 8d0107a

Please sign in to comment.