diff --git a/dag.go b/dag.go index 2efbfa7..c75918b 100644 --- a/dag.go +++ b/dag.go @@ -28,6 +28,7 @@ type DAG struct { verticesLocked *dMutex ancestorsCache map[interface{}]map[interface{}]struct{} descendantsCache map[interface{}]map[interface{}]struct{} + options Options } // NewDAG creates / initializes a new DAG. @@ -40,6 +41,7 @@ func NewDAG() *DAG { verticesLocked: newDMutex(), ancestorsCache: make(map[interface{}]map[interface{}]struct{}), descendantsCache: make(map[interface{}]map[interface{}]struct{}), + options: defaultOptions(), } } @@ -79,12 +81,13 @@ func (d *DAG) AddVertexByID(id string, v interface{}) error { } func (d *DAG) addVertexByID(id string, v interface{}) error { + vHash := d.hashVertex(v) // sanity checking if v == nil { return VertexNilError{} } - if _, exists := d.vertices[v]; exists { + if _, exists := d.vertices[vHash]; exists { return VertexDuplicateError{v} } @@ -92,7 +95,7 @@ func (d *DAG) addVertexByID(id string, v interface{}) error { return IDDuplicateError{id} } - d.vertices[v] = id + d.vertices[vHash] = id d.vertexIds[id] = v return nil @@ -128,43 +131,44 @@ func (d *DAG) DeleteVertex(id string) error { } v := d.vertexIds[id] + vHash := d.hashVertex(v) // get descendents and ancestors as they are now - descendants := copyMap(d.getDescendants(v)) - ancestors := copyMap(d.getAncestors(v)) + descendants := copyMap(d.getDescendants(vHash)) + ancestors := copyMap(d.getAncestors(vHash)) // delete v in outbound edges of parents - if _, exists := d.inboundEdge[v]; exists { - for parent := range d.inboundEdge[v] { - delete(d.outboundEdge[parent], v) + if _, exists := d.inboundEdge[vHash]; exists { + for parent := range d.inboundEdge[vHash] { + delete(d.outboundEdge[parent], vHash) } } // delete v in inbound edges of children - if _, exists := d.outboundEdge[v]; exists { - for child := range d.outboundEdge[v] { - delete(d.inboundEdge[child], v) + if _, exists := d.outboundEdge[vHash]; exists { + for child := range d.outboundEdge[vHash] { + delete(d.inboundEdge[child], vHash) } } // delete in- and outbound of v itself - delete(d.inboundEdge, v) - delete(d.outboundEdge, v) + delete(d.inboundEdge, vHash) + delete(d.outboundEdge, vHash) // for v and all its descendants delete cached ancestors for descendant := range descendants { delete(d.ancestorsCache, descendant) } - delete(d.ancestorsCache, v) + delete(d.ancestorsCache, vHash) // for v and all its ancestors delete cached descendants for ancestor := range ancestors { delete(d.descendantsCache, ancestor) } - delete(d.descendantsCache, v) + delete(d.descendantsCache, vHash) // delete v itself - delete(d.vertices, v) + delete(d.vertices, vHash) delete(d.vertexIds, id) return nil @@ -191,48 +195,50 @@ func (d *DAG) AddEdge(srcID, dstID string) error { } src := d.vertexIds[srcID] + srcHash := d.hashVertex(src) dst := d.vertexIds[dstID] + dstHash := d.hashVertex(dst) // if the edge is already known, there is nothing else to do - if d.isEdge(src, dst) { + if d.isEdge(srcHash, dstHash) { return EdgeDuplicateError{srcID, dstID} } // get descendents and ancestors as they are now - descendants := copyMap(d.getDescendants(dst)) - ancestors := copyMap(d.getAncestors(src)) + descendants := copyMap(d.getDescendants(dstHash)) + ancestors := copyMap(d.getAncestors(srcHash)) - if _, exists := descendants[src]; exists { + if _, exists := descendants[srcHash]; exists { return EdgeLoopError{srcID, dstID} } // prepare d.outbound[src], iff needed - if _, exists := d.outboundEdge[src]; !exists { - d.outboundEdge[src] = make(map[interface{}]struct{}) + if _, exists := d.outboundEdge[srcHash]; !exists { + d.outboundEdge[srcHash] = make(map[interface{}]struct{}) } // dst is a child of src - d.outboundEdge[src][dst] = struct{}{} + d.outboundEdge[srcHash][dstHash] = struct{}{} // prepare d.inboundEdge[dst], iff needed - if _, exists := d.inboundEdge[dst]; !exists { - d.inboundEdge[dst] = make(map[interface{}]struct{}) + if _, exists := d.inboundEdge[dstHash]; !exists { + d.inboundEdge[dstHash] = make(map[interface{}]struct{}) } // src is a parent of dst - d.inboundEdge[dst][src] = struct{}{} + d.inboundEdge[dstHash][srcHash] = struct{}{} // for dst and all its descendants delete cached ancestors for descendant := range descendants { delete(d.ancestorsCache, descendant) } - delete(d.ancestorsCache, dst) + delete(d.ancestorsCache, dstHash) // for src and all its ancestors delete cached descendants for ancestor := range ancestors { delete(d.descendantsCache, ancestor) } - delete(d.descendantsCache, src) + delete(d.descendantsCache, srcHash) return nil } @@ -254,21 +260,23 @@ func (d *DAG) IsEdge(srcID, dstID string) (bool, error) { return false, SrcDstEqualError{srcID, dstID} } - return d.isEdge(d.vertexIds[srcID], d.vertexIds[dstID]), nil + src := d.vertexIds[srcID] + dst := d.vertexIds[dstID] + return d.isEdge(d.hashVertex(src), d.hashVertex(dst)), nil } -func (d *DAG) isEdge(src, dst interface{}) bool { +func (d *DAG) isEdge(srcHash, dstHash interface{}) bool { - if _, exists := d.outboundEdge[src]; !exists { + if _, exists := d.outboundEdge[srcHash]; !exists { return false } - if _, exists := d.outboundEdge[src][dst]; !exists { + if _, exists := d.outboundEdge[srcHash][dstHash]; !exists { return false } - if _, exists := d.inboundEdge[dst]; !exists { + if _, exists := d.inboundEdge[dstHash]; !exists { return false } - if _, exists := d.inboundEdge[dst][src]; !exists { + if _, exists := d.inboundEdge[dstHash][srcHash]; !exists { return false } return true @@ -293,31 +301,33 @@ func (d *DAG) DeleteEdge(srcID, dstID string) error { } src := d.vertexIds[srcID] + srcHash := d.hashVertex(src) dst := d.vertexIds[dstID] + dstHash := d.hashVertex(dst) - if !d.isEdge(src, dst) { + if !d.isEdge(srcHash, dstHash) { return EdgeUnknownError{srcID, dstID} } // get descendents and ancestors as they are now - descendants := copyMap(d.getDescendants(src)) - ancestors := copyMap(d.getAncestors(dst)) + descendants := copyMap(d.getDescendants(srcHash)) + ancestors := copyMap(d.getAncestors(dstHash)) // delete outbound and inbound - delete(d.outboundEdge[src], dst) - delete(d.inboundEdge[dst], src) + delete(d.outboundEdge[srcHash], dstHash) + delete(d.inboundEdge[dstHash], srcHash) // for src and all its descendants delete cached ancestors for descendant := range descendants { delete(d.ancestorsCache, descendant) } - delete(d.ancestorsCache, src) + delete(d.ancestorsCache, srcHash) // for dst and all its ancestors delete cached descendants for ancestor := range ancestors { delete(d.descendantsCache, ancestor) } - delete(d.descendantsCache, dst) + delete(d.descendantsCache, dstHash) return nil } @@ -380,7 +390,8 @@ func (d *DAG) IsLeaf(id string) (bool, error) { func (d *DAG) isLeaf(id string) bool { v := d.vertexIds[id] - dstIDs, ok := d.outboundEdge[v] + vHash := d.hashVertex(v) + dstIDs, ok := d.outboundEdge[vHash] if !ok || len(dstIDs) == 0 { return true } @@ -396,11 +407,11 @@ func (d *DAG) GetRoots() map[string]interface{} { func (d *DAG) getRoots() map[string]interface{} { roots := make(map[string]interface{}) - for v := range d.vertices { - srcIDs, ok := d.inboundEdge[v] + for vHash := range d.vertices { + srcIDs, ok := d.inboundEdge[vHash] if !ok || len(srcIDs) == 0 { - id := d.vertices[v] - roots[id] = v + id := d.vertices[vHash] + roots[id] = vHash } } return roots @@ -419,7 +430,8 @@ func (d *DAG) IsRoot(id string) (bool, error) { func (d *DAG) isRoot(id string) bool { v := d.vertexIds[id] - srcIDs, ok := d.inboundEdge[v] + vHash := d.hashVertex(v) + srcIDs, ok := d.inboundEdge[vHash] if !ok || len(srcIDs) == 0 { return true } @@ -446,8 +458,9 @@ func (d *DAG) GetParents(id string) (map[string]interface{}, error) { return nil, err } v := d.vertexIds[id] + vHash := d.hashVertex(v) parents := make(map[string]interface{}) - for pv := range d.inboundEdge[v] { + for pv := range d.inboundEdge[vHash] { pid := d.vertices[pv] parents[pid] = pv } @@ -467,8 +480,9 @@ func (d *DAG) getChildren(id string) (map[string]interface{}, error) { return nil, err } v := d.vertexIds[id] + vHash := d.hashVertex(v) children := make(map[string]interface{}) - for cv := range d.outboundEdge[v] { + for cv := range d.outboundEdge[vHash] { cid := d.vertices[cv] children[cid] = cv } @@ -488,32 +502,33 @@ func (d *DAG) GetAncestors(id string) (map[string]interface{}, error) { return nil, err } v := d.vertexIds[id] + vHash := d.hashVertex(v) ancestors := make(map[string]interface{}) - for av := range d.getAncestors(v) { + for av := range d.getAncestors(vHash) { aid := d.vertices[av] ancestors[aid] = av } return ancestors, nil } -func (d *DAG) getAncestors(v interface{}) map[interface{}]struct{} { +func (d *DAG) getAncestors(vHash interface{}) map[interface{}]struct{} { // in the best case we have already a populated cache d.muCache.RLock() - cache, exists := d.ancestorsCache[v] + cache, exists := d.ancestorsCache[vHash] d.muCache.RUnlock() if exists { return cache } // lock this vertex to work on it exclusively - d.verticesLocked.lock(v) - defer d.verticesLocked.unlock(v) + d.verticesLocked.lock(vHash) + defer d.verticesLocked.unlock(vHash) // now as we have locked this vertex, check (again) that no one has // meanwhile populated the cache d.muCache.RLock() - cache, exists = d.ancestorsCache[v] + cache, exists = d.ancestorsCache[vHash] d.muCache.RUnlock() if exists { return cache @@ -522,7 +537,7 @@ func (d *DAG) getAncestors(v interface{}) map[interface{}]struct{} { // as there is no cache, we start from scratch and collect all ancestors locally cache = make(map[interface{}]struct{}) var mu sync.Mutex - if parents, ok := d.inboundEdge[v]; ok { + if parents, ok := d.inboundEdge[vHash]; ok { // for each parent collect its ancestors for parent := range parents { @@ -538,7 +553,7 @@ func (d *DAG) getAncestors(v interface{}) map[interface{}]struct{} { // remember the collected descendents d.muCache.Lock() - d.ancestorsCache[v] = cache + d.ancestorsCache[vHash] = cache d.muCache.Unlock() return cache } @@ -582,7 +597,8 @@ func (d *DAG) AncestorsWalker(id string) (chan string, chan bool, error) { go func() { d.muDAG.RLock() v := d.vertexIds[id] - d.walkAncestors(v, ids, signal) + vHash := d.hashVertex(v) + d.walkAncestors(vHash, ids, signal) d.muDAG.RUnlock() close(ids) close(signal) @@ -590,11 +606,11 @@ func (d *DAG) AncestorsWalker(id string) (chan string, chan bool, error) { return ids, signal, nil } -func (d *DAG) walkAncestors(v interface{}, ids chan string, signal chan bool) { +func (d *DAG) walkAncestors(vHash interface{}, ids chan string, signal chan bool) { var fifo []interface{} visited := make(map[interface{}]struct{}) - for parent := range d.inboundEdge[v] { + for parent := range d.inboundEdge[vHash] { visited[parent] = struct{}{} fifo = append(fifo, parent) } @@ -634,34 +650,34 @@ func (d *DAG) GetDescendants(id string) (map[string]interface{}, error) { return nil, err } v := d.vertexIds[id] - //return copyMap(d.getAncestors(v)), nil + vHash := d.hashVertex(v) descendants := make(map[string]interface{}) - for dv := range d.getDescendants(v) { + for dv := range d.getDescendants(vHash) { did := d.vertices[dv] descendants[did] = dv } return descendants, nil } -func (d *DAG) getDescendants(v interface{}) map[interface{}]struct{} { +func (d *DAG) getDescendants(vHash interface{}) map[interface{}]struct{} { // in the best case we have already a populated cache d.muCache.RLock() - cache, exists := d.descendantsCache[v] + cache, exists := d.descendantsCache[vHash] d.muCache.RUnlock() if exists { return cache } // lock this vertex to work on it exclusively - d.verticesLocked.lock(v) - defer d.verticesLocked.unlock(v) + d.verticesLocked.lock(vHash) + defer d.verticesLocked.unlock(vHash) // now as we have locked this vertex, check (again) that no one has // meanwhile populated the cache d.muCache.RLock() - cache, exists = d.descendantsCache[v] + cache, exists = d.descendantsCache[vHash] d.muCache.RUnlock() if exists { return cache @@ -671,7 +687,7 @@ func (d *DAG) getDescendants(v interface{}) map[interface{}]struct{} { // locally cache = make(map[interface{}]struct{}) var mu sync.Mutex - if children, ok := d.outboundEdge[v]; ok { + if children, ok := d.outboundEdge[vHash]; ok { // for each child use a goroutine to collect its descendants //var waitGroup sync.WaitGroup @@ -693,7 +709,7 @@ func (d *DAG) getDescendants(v interface{}) map[interface{}]struct{} { // remember the collected descendents d.muCache.Lock() - d.descendantsCache[v] = cache + d.descendantsCache[vHash] = cache d.muCache.Unlock() return cache } @@ -751,6 +767,7 @@ func (d *DAG) getRelativesGraph(id string, asc bool) (*DAG, string, error) { return nil, "", IDEmptyError{} } v, exists := d.vertexIds[id] + vHash := d.hashVertex(v) if !exists { return nil, "", IDUnknownError{id} } @@ -763,27 +780,27 @@ func (d *DAG) getRelativesGraph(id string, asc bool) (*DAG, string, error) { defer d.muDAG.RUnlock() // recursively add the current vertex and all its relatives - newId, err := d.getRelativesGraphRec(v, newDAG, make(map[interface{}]string), asc) + newId, err := d.getRelativesGraphRec(vHash, newDAG, make(map[interface{}]string), asc) return newDAG, newId, err } -func (d *DAG) getRelativesGraphRec(v interface{}, newDAG *DAG, visited map[interface{}]string, asc bool) (newId string, err error) { +func (d *DAG) getRelativesGraphRec(vHash interface{}, newDAG *DAG, visited map[interface{}]string, asc bool) (newId string, err error) { // copy this vertex to the new graph - if newId, err = newDAG.AddVertex(v); err != nil { + if newId, err = newDAG.AddVertex(vHash); err != nil { return } // mark this vertex as visited - visited[v] = newId + visited[vHash] = newId // get the direct relatives (depending on the direction either parents or children) var relatives map[interface{}]struct{} var ok bool if asc { - relatives, ok = d.inboundEdge[v] + relatives, ok = d.inboundEdge[vHash] } else { - relatives, ok = d.outboundEdge[v] + relatives, ok = d.outboundEdge[vHash] } // for all direct relatives in the original graph @@ -817,7 +834,7 @@ func (d *DAG) getRelativesGraphRec(v interface{}, newDAG *DAG, visited map[inter } // DescendantsWalker returns a channel and subsequently returns / walks all -// descendants of the vertex with id id in a breath first order. The second +// descendants of the vertex with id in a breath first order. The second // channel returned may be used to stop further walking. DescendantsWalker // returns an error, if id is empty or unknown. // @@ -834,7 +851,8 @@ func (d *DAG) DescendantsWalker(id string) (chan string, chan bool, error) { go func() { d.muDAG.RLock() v := d.vertexIds[id] - d.walkDescendants(v, ids, signal) + vHash := d.hashVertex(v) + d.walkDescendants(vHash, ids, signal) d.muDAG.RUnlock() close(ids) close(signal) @@ -842,10 +860,10 @@ func (d *DAG) DescendantsWalker(id string) (chan string, chan bool, error) { return ids, signal, nil } -func (d *DAG) walkDescendants(v interface{}, ids chan string, signal chan bool) { +func (d *DAG) walkDescendants(vHash interface{}, ids chan string, signal chan bool) { var fifo []interface{} visited := make(map[interface{}]struct{}) - for child := range d.outboundEdge[v] { + for child := range d.outboundEdge[vHash] { visited[child] = struct{}{} fifo = append(fifo, child) } @@ -1032,13 +1050,13 @@ func (d *DAG) ReduceTransitively() { } // for each vertex - for v := range d.vertices { + for vHash := range d.vertices { // map of descendants of the children of v descendentsOfChildrenOfV := make(map[interface{}]struct{}) // for each child of v - for childOfV := range d.outboundEdge[v] { + for childOfV := range d.outboundEdge[vHash] { // collect child descendants for descendent := range d.descendantsCache[childOfV] { @@ -1047,13 +1065,13 @@ func (d *DAG) ReduceTransitively() { } // for each child of v - for childOfV := range d.outboundEdge[v] { + for childOfV := range d.outboundEdge[vHash] { // remove the edge between v and child, iff child is a // descendant of any of the children of v if _, exists := descendentsOfChildrenOfV[childOfV]; exists { - delete(d.outboundEdge[v], childOfV) - delete(d.inboundEdge[childOfV], v) + delete(d.outboundEdge[vHash], childOfV) + delete(d.inboundEdge[childOfV], vHash) graphChanged = true } } @@ -1132,6 +1150,10 @@ func (d *DAG) saneID(id string) error { return nil } +func (d *DAG) hashVertex(v interface{}) interface{} { + return d.options.VertexHashFunc(v) +} + func copyMap(in map[interface{}]struct{}) map[interface{}]struct{} { out := make(map[interface{}]struct{}) for id, value := range in { diff --git a/marshal.go b/marshal.go index 6f4fa4f..b44ac11 100644 --- a/marshal.go +++ b/marshal.go @@ -42,12 +42,13 @@ func (d *DAG) UnmarshalJSON(_ []byte) error { // } // // For more specific information please read the test code. -func UnmarshalJSON(data []byte, wd StorableDAG) (*DAG, error) { +func UnmarshalJSON(data []byte, wd StorableDAG, options Options) (*DAG, error) { err := json.Unmarshal(data, &wd) if err != nil { return nil, err } dag := NewDAG() + dag.Options(options) for _, v := range wd.Vertices() { errVertex := dag.AddVertexByID(v.Vertex()) if errVertex != nil { diff --git a/marshal_test.go b/marshal_test.go index f2b72c8..2b1c1eb 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -53,7 +53,7 @@ func testMarshalUnmarshalJSON(t *testing.T, d *DAG, expected string) { } var wd testStorableDAG - dag, err := UnmarshalJSON(data, &wd) + dag, err := UnmarshalJSON(data, &wd, defaultOptions()) if err != nil { t.Fatal(err) } diff --git a/options.go b/options.go new file mode 100644 index 0000000..43d12f0 --- /dev/null +++ b/options.go @@ -0,0 +1,27 @@ +package dag + +// Options is the configuration for the DAG. +type Options struct { + // VertexHashFunc is the function that calculates the hash value of a vertex. + // This can be useful when the vertex contains not comparable types such as maps. + // If VertexHashFunc is nil, the defaultVertexHashFunc is used. + VertexHashFunc func(v interface{}) interface{} +} + +// Options sets the options for the DAG. +// Options must be called before any other method of the DAG is called. +func (d *DAG) Options(options Options) { + d.muDAG.Lock() + defer d.muDAG.Unlock() + d.options = options +} + +func defaultOptions() Options { + return Options{ + VertexHashFunc: defaultVertexHashFunc, + } +} + +func defaultVertexHashFunc(v interface{}) interface{} { + return v +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..c78d2fa --- /dev/null +++ b/options_test.go @@ -0,0 +1,207 @@ +package dag + +import ( + "encoding/json" + "testing" +) + +type testNonComparableVertexType struct { + ID string `json:"i"` + NotComparableField map[string]string `json:"v"` +} + +func TestOverrideVertexHashFunOption(t *testing.T) { + dag := NewDAG() + /* 1 4 + * |\ / + * | 2 + * |/ + * 3 + */ + + dag.Options(Options{ + VertexHashFunc: func(v interface{}) interface{} { + return v.(testNonComparableVertexType).ID + }}) + + testVertex1 := testNonComparableVertexType{ + ID: "1", + NotComparableField: map[string]string{"not": "comparable"}, + } + vertexId1, err := dag.addVertex(testVertex1) + if err != nil { + t.Errorf("Should create a vertex with a not comparable field when a correct VertexHashFunc option is set") + } + + testVertex2 := testNonComparableVertexType{ + ID: "2", + NotComparableField: map[string]string{"stillNot": "comparable"}, + } + vertexId2, err := dag.addVertex(testVertex2) + if err != nil { + t.Errorf("Should create a vertex with a not comparable field when a correct VertexHashFunc option is set") + } + err = dag.AddEdge(vertexId1, vertexId2) + if err != nil { + t.Errorf("Should create an edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + + testVertex3 := testNonComparableVertexType{ + ID: "3", + NotComparableField: map[string]string{"stillNot": "comparable"}, + } + vertexId3, err := dag.addVertex(testVertex3) + if err != nil { + t.Errorf("Should create a vertex with a not comparable field when a correct VertexHashFunc option is set") + } + + err = dag.AddEdge(vertexId1, vertexId3) + if err != nil { + t.Errorf("Should create an edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + err = dag.AddEdge(vertexId2, vertexId3) + if err != nil { + t.Errorf("Should create an edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + + testVertex4 := testNonComparableVertexType{ + ID: "4", + NotComparableField: map[string]string{"stillNot": "comparable"}, + } + vertexId4, err := dag.addVertex(testVertex4) + if err != nil { + t.Errorf("Should create a vertex with a not comparable field when a correct VertexHashFunc option is set") + } + + err = dag.AddEdge(vertexId4, vertexId2) + if err != nil { + t.Errorf("Should create an edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + + isEdge, err := dag.IsEdge(vertexId1, vertexId2) + if !isEdge || err != nil { + t.Errorf("Should return true for edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + + err = dag.DeleteEdge(vertexId1, vertexId3) + if err != nil { + t.Errorf("Should delete an edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + + isEdge, err = dag.IsEdge(vertexId1, vertexId3) + if isEdge || err != nil { + t.Errorf("Should return false for edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + + err = dag.DeleteVertex(vertexId2) + if err != nil { + t.Errorf("Should delete a vertex with not comparable fields when a correct VertexHashFunc option is set") + } + + vertexId2, err = dag.addVertex(testVertex2) + if err != nil { + t.Errorf("Should create a vertex with a not comparable field when a correct VertexHashFunc option is set") + } + + _ = dag.AddEdge(vertexId1, vertexId2) + _ = dag.AddEdge(vertexId2, vertexId3) + err = dag.AddEdge(vertexId4, vertexId2) + if err != nil { + t.Errorf("Should create an edge between vertices with not comparable fields when a correct VertexHashFunc option is set") + } + + roots := dag.GetRoots() + if len(roots) != 2 { + t.Errorf("Should return 2 roots") + } + for rootId := range roots { + if isRoot, err := dag.IsRoot(rootId); !isRoot || err != nil { + t.Errorf("Should return true for root") + } + } + + leaves := dag.GetLeaves() + if len(leaves) != 1 { + t.Errorf("Should return 1 leaf") + } + for leafId := range leaves { + if isLeaf, err := dag.IsLeaf(leafId); !isLeaf || err != nil { + t.Errorf("Should return true for leaf") + } + } + + vertex2Parents, err := dag.GetParents(vertexId2) + if len(vertex2Parents) != 2 || err != nil { + t.Errorf("Should return 2 parents for vertex 2") + } + + vertex2Children, err := dag.GetChildren(vertexId2) + if len(vertex2Children) != 1 || err != nil { + t.Errorf("Should return 1 child for vertex 2") + } + + vertex3Ancestors, err := dag.GetAncestors(vertexId3) + if len(vertex3Ancestors) != 3 || err != nil { + t.Errorf("Should return 3 ancestors for vertex 3, received %d", len(vertex3Ancestors)) + } + + vertex3OrderedAncestors, err := dag.GetOrderedAncestors(vertexId3) + if len(vertex3OrderedAncestors) != 3 || err != nil { + t.Errorf("Should return 3 ancestors for vertex 3, received %d", len(vertex3OrderedAncestors)) + } + + vertex4Descendants, err := dag.GetDescendants(vertexId4) + if len(vertex4Descendants) != 2 || err != nil { + t.Errorf("Should return 2 descendants for vertex 4, received %d", len(vertex4Descendants)) + } + + vertex4OrderedDescendants, err := dag.GetOrderedDescendants(vertexId4) + if len(vertex4OrderedDescendants) != 2 || err != nil { + t.Errorf("Should return 2 descendants for vertex 4, received %d", len(vertex4OrderedDescendants)) + } + + _, _, err = dag.GetDescendantsGraph(vertexId1) + if err != nil { + t.Errorf("Should return a string representation of the descendants graph") + } + + _, _, err = dag.GetAncestorsGraph(vertexId1) + if err != nil { + t.Errorf("Should return a string representation of the ancestors graph") + } + + _, err = dag.Copy() + if err != nil { + t.Errorf("Should return a copy of the DAG") + } + + dagString := dag.String() + if dagString == "" { + t.Errorf("Should return a string representation of the DAG") + } + + dag.ReduceTransitively() + dag.FlushCaches() + dag.DescendantsWalker(vertexId1) + + mv := newMarshalVisitor(dag) + dag.DFSWalk(mv) + dag.BFSWalk(mv) + dag.OrderedWalk(mv) + + _, err = dag.MarshalJSON() + if err != nil { + t.Error(err) + } + + data, err := json.Marshal(dag) + if err != nil { + t.Error(err) + } + + var wd testNonComparableStorableDAG + _, err = UnmarshalJSON(data, &wd, dag.options) + if err != nil { + t.Fatal(err) + } +} diff --git a/storage_test.go b/storage_test.go index d268e9e..4d77921 100644 --- a/storage_test.go +++ b/storage_test.go @@ -33,3 +33,33 @@ func (g testStorableDAG) Edges() []Edger { } return l } + +type testNonComparableStorableVertex struct { + Id string `json:"i"` + NotComparableVertex testNonComparableVertexType `json:"v"` +} + +func (tv testNonComparableStorableVertex) Vertex() (id string, value interface{}) { + return tv.Id, tv.NotComparableVertex +} + +type testNonComparableStorableDAG struct { + StorableVertices []testNonComparableStorableVertex `json:"vs"` + StorableEdges []storableEdge `json:"es"` +} + +func (g testNonComparableStorableDAG) Vertices() []Vertexer { + l := make([]Vertexer, 0, len(g.StorableVertices)) + for _, v := range g.StorableVertices { + l = append(l, v) + } + return l +} + +func (g testNonComparableStorableDAG) Edges() []Edger { + l := make([]Edger, 0, len(g.StorableEdges)) + for _, v := range g.StorableEdges { + l = append(l, v) + } + return l +} diff --git a/visitor.go b/visitor.go index f73e64d..613968f 100644 --- a/visitor.go +++ b/visitor.go @@ -25,7 +25,7 @@ func (d *DAG) DFSWalk(visitor Visitor) { vertices := d.getRoots() for _, id := range reversedVertexIDs(vertices) { - v := vertices[id] + v := d.vertexIds[id] sv := storableVertex{WrappedID: id, Value: v} stack.Push(sv) } @@ -43,7 +43,7 @@ func (d *DAG) DFSWalk(visitor Visitor) { vertices, _ := d.getChildren(sv.WrappedID) for _, id := range reversedVertexIDs(vertices) { - v := vertices[id] + v := d.vertexIds[id] sv := storableVertex{WrappedID: id, Value: v} stack.Push(sv) }