diff --git a/src/PersistentOrderedSet.mo b/src/PersistentOrderedSet.mo index d28be44e..a65bef5a 100644 --- a/src/PersistentOrderedSet.mo +++ b/src/PersistentOrderedSet.mo @@ -32,8 +32,10 @@ module { /// Red-black tree of nodes with ordered set elements. /// Leaves are considered implicitly black. + /// Nat represents the black height, + /// the black height of red nodes is equal to the parent node. public type Set = { - #node : (Color, Set, T, Set); + #node : (Color, Nat, Set, T, Set); #leaf }; @@ -181,7 +183,7 @@ module { switch (rbSet1, rbSet2) { case (#leaf, rbSet) { rbSet }; case (rbSet, #leaf) { rbSet }; - case (#node (_, l1, x, r1), _) { + case (#node (_, _, l1, x, r1), _) { let (l2, _, r2) = Internal.split(x, rbSet2, compare); Internal.join(union(l1, l2), x, union(r1, r2)) }; @@ -212,7 +214,7 @@ module { switch (rbSet1, rbSet2) { case (#leaf, _) { #leaf }; case (_, #leaf) { #leaf }; - case (#node (_, l1, x, r1), _) { + case (#node (_, _, l1, x, r1), _) { let (l2, b2, r2) = Internal.split(x, rbSet2, compare); let l = intersect(l1, l2); let r = intersect(r1, r2); @@ -246,7 +248,7 @@ module { switch (rbSet1, rbSet2) { case (#leaf, _) { #leaf }; case (rbSet, #leaf) { rbSet }; - case (_, (#node(_, l2, x, r2))) { + case (_, (#node(_, _, l2, x, r2))) { let (l1, _, r1) = Internal.split(x, rbSet1, compare); Internal.join2(diff(l1, l2), diff(r1, r2)); } @@ -437,11 +439,11 @@ module { trees := ts; ?x }; // TODO: Let's float-out case on direction - case (#fwd, ?(#tr(#node(_, l, x, r)), ts)) { + case (#fwd, ?(#tr(#node(_, _, l, x, r)), ts)) { trees := ?(#tr(l), ?(#x(x), ?(#tr(r), ts))); next() }; - case (#bwd, ?(#tr(#node(_, l, x, r)), ts)) { + case (#bwd, ?(#tr(#node(_, _, l, x, r)), ts)) { trees := ?(#tr(r), ?(#x(x), ?(#tr(l), ts))); next() } @@ -515,7 +517,7 @@ module { public func size(t : Set) : Nat { switch t { case (#leaf) { 0 }; - case (#node(_, l, _, r)) { + case (#node(_, _, l, _, r)) { size(l) + size(r) + 1 } } @@ -625,7 +627,7 @@ module { public func contains(t : Set, compare : (T, T) -> O.Order, x : T) : Bool { switch t { case (#leaf) { false }; - case (#node(_c, l, x1, r)) { + case (#node(_c, _, l, x1, r)) { switch (compare(x, x1)) { case (#less) { contains(l, compare, x) }; case (#equal) { true }; @@ -637,8 +639,8 @@ module { func redden(t : Set) : Set { switch t { - case (#node (#B, l, x, r)) { - (#node (#R, l, x, r)) + case (#node (#B, bh, l, x, r)) { + (#node (#R, bh, l, x, r)) }; case _ { Debug.trap "RBTree.red" @@ -646,46 +648,50 @@ module { } }; - func lbalance(left : Set, x : T, right : Set) : Set { + func lbalance(bh : Nat, left : Set, x : T, right : Set) : Set { switch (left, right) { - case (#node(#R, #node(#R, l1, x1, r1), x2, r2), r) { + case (#node(#R, _, #node(#R, _, l1, x1, r1), x2, r2), r) { #node( #R, - #node(#B, l1, x1, r1), + bh + 1, + #node(#B, bh, l1, x1, r1), x2, - #node(#B, r2, x, r)) + #node(#B, bh, r2, x, r)) }; - case (#node(#R, l1, x1, #node(#R, l2, x2, r2)), r) { + case (#node(#R, _, l1, x1, #node(#R, _, l2, x2, r2)), r) { #node( #R, - #node(#B, l1, x1, l2), + bh + 1, + #node(#B, bh, l1, x1, l2), x2, - #node(#B, r2, x, r)) + #node(#B, bh, r2, x, r)) }; case _ { - #node(#B, left, x, right) + #node(#B, bh, left, x, right) } } }; - func rbalance(left : Set, x : T, right : Set) : Set { + func rbalance(bh : Nat, left : Set, x : T, right : Set) : Set { switch (left, right) { - case (l, #node(#R, l1, x1, #node(#R, l2, x2, r2))) { + case (l, #node(#R, _, l1, x1, #node(#R, _, l2, x2, r2))) { #node( #R, - #node(#B, l, x, l1), + bh + 1, + #node(#B, bh, l, x, l1), x1, - #node(#B, l2, x2, r2)) + #node(#B, bh, l2, x2, r2)) }; - case (l, #node(#R, #node(#R, l1, x1, r1), x2, r2)) { + case (l, #node(#R, _, #node(#R, _, l1, x1, r1), x2, r2)) { #node( #R, - #node(#B, l, x, l1), + bh + 1, + #node(#B, bh, l, x, l1), x1, - #node(#B, r1, x2, r2)) + #node(#B, bh, r1, x2, r2)) }; case _ { - #node(#B, left, x, right) + #node(#B, bh, left, x, right) }; } }; @@ -699,118 +705,125 @@ module { func ins(tree : Set) : Set { switch tree { case (#leaf) { - #node(#R, #leaf, elem, #leaf) + #node(#R, 1, #leaf, elem, #leaf) }; - case (#node(#B, left, x, right)) { + case (#node(#B, bh, left, x, right)) { switch (compare (elem, x)) { case (#less) { - lbalance(ins left, x, right) + lbalance(bh, ins left, x, right) }; case (#greater) { - rbalance(left, x, ins right) + rbalance(bh, left, x, ins right) }; case (#equal) { - #node(#B, left, x, right) + #node(#B, bh, left, x, right) } } }; - case (#node(#R, left, x, right)) { + case (#node(#R, bh, left, x, right)) { switch (compare (elem, x)) { case (#less) { - #node(#R, ins left, x, right) + #node(#R, bh, ins left, x, right) }; case (#greater) { - #node(#R, left, x, ins right) + #node(#R, bh, left, x, ins right) }; case (#equal) { - #node(#R, left, x, right) + #node(#R, bh, left, x, right) } } } }; }; switch (ins s) { - case (#node(#R, left, x, right)) { - #node(#B, left, x, right); + case (#node(#R, bh, left, x, right)) { + #node(#B, bh, left, x, right); }; case other { other }; }; }; - func balLeft(left : Set, x : T, right : Set) : Set { + func balLeft(bh : Nat, left : Set, x : T, right : Set) : Set { switch (left, right) { - case (#node(#R, l1, x1, r1), r) { - #node(#R, #node(#B, l1, x1, r1), x, r) + case (#node(#R, _, l1, x1, r1), r) { + #node(#R, bh + 1, #node(#B, bh, l1, x1, r1), x, r) }; - case (_, #node(#B, l2, x2, r2)) { - rbalance(left, x, #node(#R, l2, x2, r2)) + case (_, #node(#B, rbh, l2, x2, r2)) { + rbalance(rbh, left, x, #node(#R, bh, l2, x2, r2)) }; - case (_, #node(#R, #node(#B, l2, x2, r2), x3, r3)) { - #node(#R, - #node(#B, left, x, l2), + case (_, #node(#R, _, #node(#B, lbh, l2, x2, r2), x3, r3)) { + #node( + #R, + bh, + #node(#B, lbh, left, x, l2), x2, - rbalance(r2, x3, redden r3)) + rbalance(bh, r2, x3, redden r3)) }; case _ { Debug.trap "balLeft" }; } }; - func balRight(left : Set, x : T, right : Set) : Set { + func balRight(bh : Nat, left : Set, x : T, right : Set) : Set { switch (left, right) { - case (l, #node(#R, l1, x1, r1)) { - #node(#R, l, x, #node(#B, l1, x1, r1)) + case (l, #node(#R, _, l1, x1, r1)) { + #node(#R, bh + 1, l, x, #node(#B, bh, l1, x1, r1)) }; - case (#node(#B, l1, x1, r1), r) { - lbalance(#node(#R, l1, x1, r1), x, r); + case (#node(#B, lbh, l1, x1, r1), r) { + lbalance(lbh, #node(#R, bh, l1, x1, r1), x, r); }; - case (#node(#R, l1, x1, #node(#B, l2, x2, r2)), r3) { - #node(#R, - lbalance(redden l1, x1, l2), + case (#node(#R, _, l1, x1, #node(#B, rbh, l2, x2, r2)), r3) { + #node( + #R, + bh, + lbalance(bh, redden l1, x1, l2), x2, - #node(#B, r2, x, r3)) + #node(#B, rbh, r2, x, r3)) }; case _ { Debug.trap "balRight" }; } }; - func append(left : Set, right: Set) : Set { + func append(bh : Nat, left : Set, right: Set) : Set { switch (left, right) { case (#leaf, _) { right }; case (_, #leaf) { left }; - case (#node (#R, l1, x1, r1), - #node (#R, l2, x2, r2)) { - switch (append (r1, l2)) { - case (#node (#R, l3, x3, r3)) { + case (#node (#R, lbh, l1, x1, r1), + #node (#R, rbh, l2, x2, r2)) { + switch (append (bh, r1, l2)) { + case (#node (#R, _, l3, x3, r3)) { #node( #R, - #node(#R, l1, x1, l3), + bh, + #node(#R, lbh, l1, x1, l3), x3, - #node(#R, r3, x2, r2)) + #node(#R, rbh, r3, x2, r2)) }; case r1l2 { - #node(#R, l1, x1, #node(#R, r1l2, x2, r2)) + #node(#R, bh - 1 : Nat, l1, x1, #node(#R, rbh, r1l2, x2, r2)) } } }; - case (t1, #node(#R, l2, x2, r2)) { - #node(#R, append(t1, l2), x2, r2) + case (t1, #node(#R, rbh, l2, x2, r2)) { + #node(#R, bh, append(rbh, t1, l2), x2, r2) }; - case (#node(#R, l1, x1, r1), t2) { - #node(#R, l1, x1, append(r1, t2)) + case (#node(#R, lbh, l1, x1, r1), t2) { + #node(#R, bh, l1, x1, append(lbh, r1, t2)) }; - case (#node(#B, l1, x1, r1), #node (#B, l2, x2, r2)) { - switch (append (r1, l2)) { - case (#node (#R, l3, x3, r3)) { - #node(#R, - #node(#B, l1, x1, l3), + case (#node(#B, lbh, l1, x1, r1), #node (#B, rbh, l2, x2, r2)) { + switch (append (bh, r1, l2)) { + case (#node (#R, _, l3, x3, r3)) { + #node(#R, + bh, + #node(#B, lbh, l1, x1, l3), x3, - #node(#B, r3, x2, r2)) + #node(#B, rbh, r3, x2, r2)) }; - case r1l2 { - balLeft ( + case r1l2 { + balLeft ( + bh - 1 : Nat, l1, x1, - #node(#B, r1l2, x2, r2) + #node(#B, rbh, r1l2, x2, r2) ) } } @@ -819,32 +832,32 @@ module { }; public func delete(tree : Set, compare : (T, T) -> O.Order, x : T) : Set { - func delNode(left : Set, x1 : T, right : Set) : Set { + func delNode(bh : Nat, left : Set, x1 : T, right : Set) : Set { switch (compare (x, x1)) { case (#less) { let newLeft = del left; switch left { - case (#node(#B, _, _, _)) { - balLeft(newLeft, x1, right) + case (#node(#B, _, _, _, _)) { + balLeft(bh - 1 : Nat, newLeft, x1, right) }; case _ { - #node(#R, newLeft, x1, right) + #node(#R, bh, newLeft, x1, right) } } }; case (#greater) { let newRight = del right; switch right { - case (#node(#B, _, _, _)) { - balRight(left, x1, newRight) + case (#node(#B, _, _, _, _)) { + balRight(bh - 1 : Nat, left, x1, newRight) }; case _ { - #node(#R, left, x1, newRight) + #node(#R, bh, left, x1, newRight) } } }; case (#equal) { - append(left, right) + append(bh, left, right) }; } }; @@ -853,48 +866,43 @@ module { case (#leaf) { tree }; - case (#node(_, left, x1, right)) { - delNode(left, x1, right) + case (#node(_, bh, left, x1, right)) { + delNode(bh, left, x1, right) } }; }; switch (del(tree)) { - case (#node(#R, left, x1, right)) { - #node(#B, left, x1, right); + case (#node(#R, bh, left, x1, right)) { + #node(#B, bh, left, x1, right); }; case other { other }; }; }; - // TODO: Instead, consider storing the black height in the node constructor public func blackHeight (rbSet : Set) : Nat { - func f (node : Set, acc : Nat) : Nat { - switch node { - case (#leaf) { acc }; - case (#node (#R, l1, _, _)) { f(l1, acc) }; - case (#node (#B, l1, _, _)) { f(l1, acc + 1) } + switch rbSet { + case (#leaf) { 0 }; + case (#node (_, bh, _, _, _)) { bh }; } - }; - f (rbSet, 0) }; - public func joinL(l : Set, x : T, r : Set) : Set { - if (blackHeight r <= blackHeight l) { (#node (#R, l, x, r)) } + public func joinL(bh: Nat, l : Set, x : T, r : Set) : Set { + if (blackHeight r <= blackHeight l) { (#node (#R, bh + 1, l, x, r)) } else { switch r { - case (#node (#R, rl, rx, rr)) { (#node (#R, joinL(l, x, rl) , rx, rr)) }; - case (#node (#B, rl, rx, rr)) { balLeft (joinL(l, x, rl), rx, rr) }; + case (#node (#R, rbh, rl, rx, rr)) { (#node (#R, rbh , joinL(bh, l, x, rl) , rx, rr)) }; + case (#node (#B, rbh, rl, rx, rr)) { lbalance (rbh, joinL(bh, l, x, rl), rx, rr) }; case _ { Debug.trap "joinL" }; } } }; - public func joinR(l : Set, x : T, r : Set) : Set { - if (blackHeight l <= blackHeight r) { (#node (#R, l, x, r)) } + public func joinR(bh : Nat, l : Set, x : T, r : Set) : Set { + if (blackHeight l <= blackHeight r) { (#node (#R, bh + 1, l, x, r)) } else { switch l { - case (#node (#R, ll, lx, lr)) { (#node (#R, ll , lx, joinR (lr, x, r))) }; - case (#node (#B, ll, lx, lr)) { balRight (ll, lx, joinR (lr, x, r)) }; + case (#node (#R, lbh, ll, lx, lr)) { (#node (#R, lbh, ll , lx, joinR (bh, lr, x, r))) }; + case (#node (#B, lbh, ll, lx, lr)) { rbalance (lbh, ll, lx, joinR (bh, lr, x, r)) }; case _ { Debug.trap "joinR" }; } } @@ -903,15 +911,15 @@ module { public func paint(color : Color, rbMap : Set) : Set { switch rbMap { case (#leaf) { #leaf }; - case (#node (_, l, x, r)) { (#node (color, l, x, r)) }; + case (#node (_, bh, l, x, r)) { (#node (color, bh, l, x, r)) }; } }; public func splitMin (rbSet : Set) : (T, Set) { switch rbSet { case (#leaf) { Debug.trap "splitMin" }; - case (#node(_, #leaf, x, r)) { (x, r) }; - case (#node(_, l, x, r)) { + case (#node(_, _, #leaf, x, r)) { (x, r) }; + case (#node(_, _, l, x, r)) { let (m, l2) = splitMin l; (m, join(l2, x, r)) }; @@ -921,13 +929,15 @@ module { // Joins an element and two trees. // See Tobias Nipkow's "Functional Data Structures and Algorithms", 117 public func join(l : Set, x : T, r : Set) : Set { - if (Internal.blackHeight r < Internal.blackHeight l) { - return Internal.paint(#B, Internal.joinR(l, x, r)) + let rbh = Internal.blackHeight r; + let lbh = Internal.blackHeight l; + if (rbh < lbh) { + return Internal.paint(#B, Internal.joinR(rbh, l, x, r)) }; - if (Internal.blackHeight l < Internal.blackHeight r) { - return Internal.paint(#B, Internal.joinL(l, x, r)) + if (lbh < rbh) { + return Internal.paint(#B, Internal.joinL(lbh, l, x, r)) }; - return (#node (#B, l, x, r)) + return (#node (#B, lbh + 1, l, x, r)) }; // Joins two trees. @@ -949,7 +959,7 @@ module { public func split(x : T, rbSet : Set, compare : (T, T) -> O.Order) : (Set, Bool, Set) { switch rbSet { case (#leaf) { (#leaf, false, #leaf)}; - case (#node (_, l, x1, r)) { + case (#node (_, _, l, x1, r)) { switch (compare(x, x1)) { case (#less) { let (l1, b, l2) = split(x, l, compare); diff --git a/test/PersistentOrderedSet.test.mo b/test/PersistentOrderedSet.test.mo index 756dc2d2..b2979523 100644 --- a/test/PersistentOrderedSet.test.mo +++ b/test/PersistentOrderedSet.test.mo @@ -32,7 +32,7 @@ func checkSet(rbSet : Set.Set) { func blackDepth(node : Set.Set) : Nat { switch node { case (#leaf) 0; - case (#node(color, left, x1, right)) { + case (#node(color, _, left, x1, right)) { checkElem(left, func(x) { x < x1 }); checkElem(right, func(x) { x > x1 }); let leftBlacks = blackDepth(left); @@ -56,14 +56,14 @@ func blackDepth(node : Set.Set) : Nat { func isRed(node : Set.Set) : Bool { switch node { case (#leaf) false; - case (#node(color, _, _, _)) color == #R + case (#node(color, _, _, _, _)) color == #R } }; func checkElem(node : Set.Set, isValid : Nat -> Bool) { switch node { case (#leaf) {}; - case (#node(_, _, elem, _)) { + case (#node(_, _, _, elem, _)) { assert (isValid(elem)) } }