Skip to content

Commit

Permalink
Add new CoveredNetworks option
Browse files Browse the repository at this point in the history
* Search by CIDR rather than just by IP.

Signed-off-by: Rob Adams <[email protected]>
  • Loading branch information
readams committed Dec 19, 2017
1 parent 9a60958 commit 8c56974
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 0 deletions.
18 changes: 18 additions & 0 deletions brute.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ func (b *bruteRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
return results, nil
}

// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers. That is, the networks that are completely subsumed by the
// specified network.
func (b *bruteRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
entries, err := b.getEntriesByVersion(network.IP)
if err != nil {
return nil, err
}
var results []RangerEntry
for _, entry := range entries {
entrynetwork := entry.Network()
if network.Contains(entrynetwork.IP) {
results = append(results, entry)
}
}
return results, nil
}

func (b *bruteRanger) getEntriesByVersion(ip net.IP) (map[string]RangerEntry, error) {
if ip.To4() != nil {
return b.ipV4Entries, nil
Expand Down
31 changes: 31 additions & 0 deletions brute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cidranger

import (
"net"
"sort"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -144,3 +145,33 @@ func TestContainingNetworks(t *testing.T) {
})
}
}

func TestCoveredNetworks(t *testing.T) {
for _, tc := range coveredNetworkTests {
t.Run(tc.name, func(t *testing.T) {
ranger := newBruteRanger()
for _, insert := range tc.inserts {
_, network, _ := net.ParseCIDR(insert)
err := ranger.Insert(NewBasicRangerEntry(*network))
assert.NoError(t, err)
}
var expectedEntries []string
for _, network := range tc.networks {
expectedEntries = append(expectedEntries, network)
}
sort.Strings(expectedEntries)
_, snet, _ := net.ParseCIDR(tc.search)
networks, err := ranger.CoveredNetworks(*snet)
assert.NoError(t, err)

var results []string
for _, result := range networks {
net := result.Network()
results = append(results, net.String())
}
sort.Strings(results)

assert.Equal(t, expectedEntries, results)
})
}
}
1 change: 1 addition & 0 deletions cidranger.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type Ranger interface {
Remove(network net.IPNet) (RangerEntry, error)
Contains(ip net.IP) (bool, error)
ContainingNetworks(ip net.IP) ([]RangerEntry, error)
CoveredNetworks(network net.IPNet) ([]RangerEntry, error)
}

// NewPCTrieRanger returns a versionedRanger that supports both IPv4 and IPv6
Expand Down
33 changes: 33 additions & 0 deletions trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ func (p *prefixTrie) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
return p.containingNetworks(nn)
}

// CoveredNetworks returns the list of RangerEntry(s) the given ipnet
// covers. That is, the networks that are completely subsumed by the
// specified network.
func (p *prefixTrie) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
net := rnet.NewNetwork(network)
return p.coveredNetworks(net)
}

// String returns string representation of trie, mainly for visualization and
// debugging.
func (p *prefixTrie) String() string {
Expand Down Expand Up @@ -176,6 +184,31 @@ func (p *prefixTrie) containingNetworks(number rnet.NetworkNumber) ([]RangerEntr
return results, nil
}

func (p *prefixTrie) coveredNetworks(network rnet.Network) ([]RangerEntry, error) {
var results []RangerEntry
if p.hasEntry() && network.Contains(p.network.Number) {
results = []RangerEntry{p.entry}
}
if p.targetBitPosition() < 0 {
return results, nil
}

masked := network.Masked(int(p.numBitsSkipped))
if !masked.Equal(p.network) {
return results, nil
}
for _, child := range p.children {
if child != nil {
ranges, err := child.coveredNetworks(network)
if err != nil {
return nil, err
}
results = append(results, ranges...)
}
}
return results, nil
}

func (p *prefixTrie) insert(network rnet.Network, entry RangerEntry) error {
if p.network.Equal(network) {
p.entry = entry
Expand Down
86 changes: 86 additions & 0 deletions trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,89 @@ func TestPrefixTrieContainingNetworks(t *testing.T) {
})
}
}

type coveredNetworkTest struct {
version rnet.IPVersion
inserts []string
search string
networks []string
name string
}

var coveredNetworkTests = []coveredNetworkTest{
{
rnet.IPv4,
[]string{"192.168.0.0/24"},
"192.168.0.0/16",
[]string{"192.168.0.0/24"},
"basic covered networks",
},
{
rnet.IPv4,
[]string{"192.168.0.0/24"},
"10.1.0.0/16",
nil,
"nothing",
},
{
rnet.IPv4,
[]string{"192.168.0.0/24", "192.168.0.0/25"},
"192.168.0.0/16",
[]string{"192.168.0.0/24", "192.168.0.0/25"},
"multiple networks",
},
{
rnet.IPv4,
[]string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"},
"192.168.0.0/16",
[]string{"192.168.0.0/24", "192.168.0.0/25", "192.168.0.1/32"},
"multiple networks 2",
},
{
rnet.IPv4,
[]string{"192.168.1.1/32"},
"192.168.0.0/16",
[]string{"192.168.1.1/32"},
"leaf",
},
{
rnet.IPv4,
[]string{"0.0.0.0/0", "192.168.1.1/32"},
"192.168.0.0/16",
[]string{"192.168.1.1/32"},
"leaf with root",
},
{
rnet.IPv4,
[]string{
"0.0.0.0/0", "192.168.0.0/24", "192.168.1.1/32",
"10.1.0.0/16", "10.1.1.0/24",
},
"192.168.0.0/16",
[]string{"192.168.0.0/24", "192.168.1.1/32"},
"path not taken",
},
}

func TestPrefixTrieCoveredNetworks(t *testing.T) {
for _, tc := range coveredNetworkTests {
t.Run(tc.name, func(t *testing.T) {
trie := newPrefixTree(tc.version)
for _, insert := range tc.inserts {
_, network, _ := net.ParseCIDR(insert)
err := trie.Insert(NewBasicRangerEntry(*network))
assert.NoError(t, err)
}
var expectedEntries []RangerEntry
for _, network := range tc.networks {
_, net, _ := net.ParseCIDR(network)
expectedEntries = append(expectedEntries,
NewBasicRangerEntry(*net))
}
_, snet, _ := net.ParseCIDR(tc.search)
networks, err := trie.CoveredNetworks(*snet)
assert.NoError(t, err)
assert.Equal(t, expectedEntries, networks)
})
}
}
8 changes: 8 additions & 0 deletions version.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ func (v *versionedRanger) ContainingNetworks(ip net.IP) ([]RangerEntry, error) {
return ranger.ContainingNetworks(ip)
}

func (v *versionedRanger) CoveredNetworks(network net.IPNet) ([]RangerEntry, error) {
ranger, err := v.getRangerForIP(network.IP)
if err != nil {
return nil, err
}
return ranger.CoveredNetworks(network)
}

func (v *versionedRanger) getRangerForIP(ip net.IP) (Ranger, error) {
if ip.To4() != nil {
return v.ipV4Ranger, nil
Expand Down

0 comments on commit 8c56974

Please sign in to comment.