diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 927bf89d8..6abf4121c 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -1,5 +1,6 @@ from collections import Counter, namedtuple import copy +import itertools import logging import os import random @@ -689,13 +690,13 @@ def build_losses(con_values, tree): subtrees = [x for x in tree.children if not x.is_preterminal()] for subtree in subtrees: build_losses(con_values, subtree) - for subtree_idx in range(len(subtrees)-1): - left = str(subtrees[subtree_idx]) - right = str(subtrees[subtree_idx+1]) + for left, right in itertools.combinations(subtrees, 2): + left = str(left) + right = str(right) if left in con_values and right in con_values: left_value = con_values[left].squeeze(0) right_value = con_values[right].squeeze(0) - mse = torch.dot(left_value, right_value) + mse = torch.dot(left_value, right_value) / (len(subtrees) - 1) orthogonal_losses.append(mse) for result in gold_results: gold_constituents = result.constituents