-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor kopt Module #62
Conversation
procrustes/kopt.py
Outdated
for comb in it.combinations(np.arange(num_row), r=kopt_k): | ||
for comb_perm in it.permutations(comb, r=kopt_k): | ||
for comb in it.combinations(np.arange(num_row), r=k): | ||
for comb_perm in it.permutations(comb, r=k): | ||
if comb_perm != comb: | ||
perm_kopt = deepcopy(perm) | ||
perm_kopt[comb, :] = perm_kopt[comb_perm, :] | ||
e_kopt_new = compute_error(array_a, array_b, perm_kopt, perm_kopt) | ||
if e_kopt_new < kopt_error: | ||
perm = perm_kopt | ||
kopt_error = e_kopt_new | ||
if kopt_error <= kopt_tol: | ||
p_new = deepcopy(p) | ||
p_new[comb, :] = p_new[comb_perm, :] | ||
error_new = compute_error(a, b, p_new, p_new) | ||
if error_new < error: | ||
p, error = p_new, error_new | ||
if error <= tol: | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fwmeng88 & @PaulWAyers, I have two questions:
- Upon finding a better permutation matrix, shouldn't we start the k-opt from the beginning? The current code doesn't start the
comb
andcomb_perm
swaps for the newp
matrix. - Isn't this a brute force algorithm? Then, I don't understand the goal of having a
tol
argument. Based on my understanding, we should just try all k-permutations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- For the previous implementation version, we didn't run the k-opt from the beginning again once we find an improvement. But I think about it again and it seems we should.
- This is a greedy algorithm. Using
tol
was an attempt to set up an early stop ask-opt
is computationally expensive. We settol
to be very small (1.e-8
) and in most cases, it won't affect anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FarnazH you are correct. After finding an improved permutation matrix, you should start over at the beginning.* I don't see the point in having a *The only possible exception would be where k = dim(matrix), where this just restarts an exhaustive search, which is inefficient. |
procrustes/kopt.py
Outdated
# compute 2-sided permutation error of the initial p matrix | ||
error = compute_error(a, b, p, p) | ||
# swap rows and columns until the permutation matrix is not improved | ||
search = True | ||
while search: | ||
search = False | ||
for comb in it.combinations(np.arange(p.shape[0]), r=k): | ||
for comb_perm in it.permutations(comb, r=k): | ||
if comb_perm != comb: | ||
p_new = deepcopy(p) | ||
p_new[comb, :] = p_new[comb_perm, :] | ||
error_new = compute_error(a, b, p_new, p_new) | ||
if error_new < error: | ||
search = True | ||
p, error = p_new, error_new | ||
# check whether perfect permutation matrix is found | ||
# TODO: smarter threshold based on norm of matrix | ||
if error <= 1.0e-8: | ||
return p, error | ||
return p, error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this sound right to you @PaulWAyers & @fwmeng88?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I would choose the error threshold as something like 1e-8*sqrt(norm(A)*norm(B)). The problem would be revealed by a test matrix that is very ill-scaled, e.g., where A and B are filled with elements of order 1e-30 or something like that.
This isn't quite what I'd intended, but it works very well and is even better (more accurate) on average than what I'd intended. The (cheaper) thing I'd intended would break the loops as soon as search=True. This sort of break-a-double-loop issue is something that requires crazy work-arounds in Python.
https://note.nkmk.me/en/python-break-nested-loops/
But why do we have a double loop at all? Based on my reading of iter-tools, you can generate all the permutations directly. But that requires a lot more thought to do....not totally sure how to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, why make a permutation matrix (mostly zeros and ones) and multiply by A (cost of O(n^3)). You can directly compute A with permuted rows/columns, then compute the error in (permuted) A. This has cost O(n^2) (computing the error) instead of O(n^3). IF the error is reduced, you can return the permutation, or even the permutation matrix (which can be constructed at the very end).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point regarding computation cost. Not worth sticking to using the compute_error
function, when the cost can be reduced by a factor of O(n). I will implement it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had some "fun" trying to generate permutations that were inversions up to some order, but that's a HARD problem it seems.
Information Processing Letters; Volume 86, Issue 2, 30 April 2003, Pages 107-112
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can replace the double-for loop with:
for comb_perm in it.permutations(np.arange(5), r=3):
comb = tuple(sorted(comb_perm))
print(comb, comb_perm)
There is some other low-hanging fruit here (the same comb
tuple shows up k! times in a row) but the bottleneck is going to be evaluating the error, so I see no reason to worry about the fact you are sorting a small vector of k entries repeatedly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion. I fixed the nested for loops.
procrustes/kopt.py
Outdated
for comb_left in it.combinations(np.arange(num_row_left), r=kopt_k): | ||
for comb_perm_left in it.permutations(comb_left, r=kopt_k): | ||
for comb_left in it.combinations(np.arange(num_row_left), r=k): | ||
for comb_perm_left in it.permutations(comb_left, r=k): | ||
if comb_perm_left != comb_left: | ||
perm_kopt_left = deepcopy(perm_p) | ||
perm_kopt_left = deepcopy(p) | ||
# the right hand side permutation | ||
for comb_right in it.combinations(np.arange(num_row_right), r=kopt_k): | ||
for comb_perm_right in it.permutations(comb_right, r=kopt_k): | ||
for comb_right in it.combinations(np.arange(num_row_right), r=k): | ||
for comb_perm_right in it.permutations(comb_right, r=k): | ||
if comb_perm_right != comb_right: | ||
perm_kopt_right = deepcopy(perm_q) | ||
perm_kopt_right = deepcopy(q) | ||
perm_kopt_right[comb_right, :] = perm_kopt_right[comb_perm_right, :] | ||
e_kopt_new_right = compute_error(array_n, array_m, perm_p.T, | ||
e_kopt_new_right = compute_error(b, a, p.T, | ||
perm_kopt_right) | ||
if e_kopt_new_right < kopt_error: | ||
perm_q = perm_kopt_right | ||
q = perm_kopt_right | ||
kopt_error = e_kopt_new_right | ||
if kopt_error <= kopt_tol: | ||
# check whether perfect permutation matrix is found | ||
# TODO: smarter threshold based on norm of matrix | ||
if kopt_error <= 1.0e-8: | ||
break | ||
|
||
perm_kopt_left[comb_left, :] = perm_kopt_left[comb_perm_left, :] | ||
e_kopt_new_left = compute_error(array_n, array_m, perm_kopt_left.T, perm_q) | ||
e_kopt_new_left = compute_error(b, a, perm_kopt_left.T, q) | ||
if e_kopt_new_left < kopt_error: | ||
perm_p = perm_kopt_left | ||
p = perm_kopt_left | ||
kopt_error = e_kopt_new_left | ||
if kopt_error <= kopt_tol: | ||
# check whether perfect permutation matrix is found | ||
# TODO: smarter threshold based on norm of matrix | ||
if kopt_error <= 1.0e-8: | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused regarding the permutations tried in the kopt_heuristic_double
. It seems to me that the swaps of right-hand-side permutation matrix P and left-hand-side permutation matrix Q are tested separately (to check whether it lowers the error). What about swapping both P & Q, and checking how the error would change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be tested together. You want to generate all P and Q and then check for all possible combinations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, it is addressed in a1cfd3b.
procrustes/kopt.py
Outdated
num_row_left = p.shape[0] | ||
num_row_right = q.shape[0] | ||
|
||
for comb_p in it.combinations(np.arange(num_row_left), r=k): | ||
for perm_p in it.permutations(comb_p, r=k): | ||
for comb_q in it.combinations(np.arange(num_row_right), r=k): | ||
for perm_q in it.permutations(comb_q, r=k): | ||
# permute rows of matrix P | ||
p_new = deepcopy(p) | ||
p_new[comb_p, :] = p_new[perm_p, :] | ||
# permute rows of matrix Q | ||
q_new = deepcopy(q) | ||
q_new[comb_q, :] = q_new[perm_q, :] | ||
# compute error with new matrices & compare | ||
error_new = compute_error(b, a, p_new, q_new) | ||
if error_new < error: | ||
p, q, error = p_new, q_new, error_new | ||
# check whether perfect permutation matrix is found | ||
# TODO: smarter threshold based on norm of matrix | ||
if error <= 1.0e-8: | ||
break | ||
return p, q, error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this sound to you @PaulWAyers & @fwmeng88?
@PaulWAyers I have in mind that your earlier comment regarding the 1.0e-8
and double for loop applies here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My comment on not explicitly constructing permutation matrices also applies here.
The rest of this looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little curious that we do a row-permutation for both p_new and q_new. Since the error is |QAP - B| (didn't we decide to use P1 and P2?) one could argue that row-permutations of P should be paired with column-permutations of Q.
This isn't a big-deal because a k-fold row-permutation of Q is also a k-fold column permutation of Q, so looping over all k-fold row-permutations of Q is equivalent to looping over all k-fold column-permutations of Q. So it doesn't really matter what's being done, but it's more intuitive (to me) to use column permutations of Q.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just adding mathematical formulas to the docstrings and making the argument names consistent. I have changed the paper to use P1 and P2 and will use the same in the documentation.
The kopt_heuristic_double only worked for square A array, so it was fixed.
The current code is doing what it was supposed to do with clear documentation. This PR is becoming very long, so I defer implementing the trick for changing the cost from O(n^3) to O(n^2) to an issue to be taken care of later. |
procrustes/kopt.py
Outdated
for perm in it.permutations(np.arange(n), r=k): | ||
comb = sorted(perm) | ||
if perm != comb: | ||
# row-swap P matrix & compute error | ||
perm_p = np.copy(p) | ||
perm_p[:, comb] = perm_p[:, perm] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, perm
is a tuple and comb
is a list which means that they are always not equal. Line 91 is not working which results in a lot of redundant calculations. For example, if we use it.combinations(np.arange(5), r=3)
, there are 10 redundant index used for generating the permutation. So, we need to convert it back to a tuple?
comb = tuple(sorted(perm))
This code snippet shows how the comparison operation can help save time.
# the way it's implemented for now
idx = 0
for perm in it.permutations(np.arange(5), 3):
comb = sorted(perm)
if perm != comb:
# print(perm, "<--", comb)
idx += 1
print(idx)
# 60
Once we implement the comparison in the right way,
idx = 0
for perm in it.permutations(np.arange(5), 3):
comb = tuple(sorted(perm))
if perm != comb:
print(perm, "<--", comb)
idx += 1
print(idx)
# 50
Thanks for refactoring the code, making it nicer and more scientifically correct. Hope this helps.
@FarnazH
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch @fwmeng88. It is fixed in a57de40.
procrustes/kopt.py
Outdated
for perm1 in it.permutations(np.arange(n), r=k): | ||
comb1 = sorted(perm1) | ||
for perm2 in it.permutations(np.arange(m), r=k): | ||
comb2 = sorted(perm2) | ||
# permute rows of matrix P1 | ||
perm_p1 = np.copy(p1) | ||
perm_p1[comb1, :] = perm_p1[perm1, :] | ||
# permute rows of matrix P2 | ||
perm_p2 = np.copy(p2) | ||
perm_p2[comb2, :] = perm_p2[perm2, :] | ||
# compute error with new matrices & compare | ||
perm_error = compute_error(b, a, perm_p1, perm_p2) | ||
if perm_error < error: | ||
p1, p2, error = perm_p1, perm_p2, perm_error | ||
# check whether perfect permutation matrix is found | ||
# TODO: smarter threshold based on norm of matrix | ||
if error <= 1.0e-8: | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the updated codes! I only have a minor comment here.
We are not comparing or filtering the generated index here, which I think we can save some time (as I did in the old version) just by if not (comb1 == perm1 and comb2 == perm2)
. As far as I tested, if we are in the following situation,
all = 0
equal_part = 0
p1 = np.identity(5)
p2 = np.identity(5)
for perm1 in it.permutations(np.arange(n), r=k):
comb1 = tuple(sorted(perm1))
for perm2 in it.permutations(np.arange(m), r=k):
comb2 = tuple(sorted(perm2))
if comb1 == perm1 and comb2 == perm2:
equal_part += 1
all += 1
print(all) # 3600
print(equal_part) # 100
where we will have 100 operations that are not swapping the columns/rows of left and right permutation matrix as the index are the same. That's saying we can save ~1/36 computation time if the comparison operation is not so expensive.
I don't know if this necessary, but I was assuming it's going to decrease the computation time. Do you have any comments on this? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points. No reason to do the extra work. Saving a factor of 1/k! (or 1/k!)**2 is always good, especially for k=2.
The bigger issue is re-writing the norm. I'd also like to be able to short-circuit the loop and do a "greedier" algorithm but re-loops through the permutations as soon as we are close. But .... I think those are new issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactors the
kopt
module. I removed theref_error
argument because the initial permutation matrix is given, so it can be used to compute the initial error. It is easier to compute it internally than clarify it in the documentation and have the user pre-compute the initial error. Also, the user can erroneously provide an error that does not match the initial permutation matrix. This is how I understand/see things, so please feel free to let me know if anything is missing.I had a few questions that I will raise in this PR.