-
Notifications
You must be signed in to change notification settings - Fork 1
/
security.go
173 lines (151 loc) · 4.37 KB
/
security.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package netjail
import (
"context"
"errors"
"net"
"net/netip"
"slices"
)
var (
// ErrDenied is an error returned by a dial function when the address is
// denied by security rules.
ErrDenied = errors.New("address not allowed")
)
// DialFunc is a type of function used to establish network connections.
//
// The function matches the signatures of standard functions like
// net.(*Dialer).DialContext or http.(*Transport).DialContext.
type DialFunc func(context.Context, string, string) (net.Conn, error)
// Resolver is an interface used to abstract the name resolver used by
// security rules to convert logical hostnames to IP addresses.
type Resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// Rules is a set of rules used to determine whether a network address can be
// accessed.
//
// By default, the rules denies all addresses. Rules can be added to open
// networks, and further block subsets of the open address space.
type Rules struct {
Allow []netip.Prefix
Block []netip.Prefix
}
// Clone returns a deep copy of the security rules.
func (rules *Rules) Clone() *Rules {
if rules != nil {
rules = &Rules{
Allow: slices.Clone(rules.Allow),
Block: slices.Clone(rules.Block),
}
}
return rules
}
// String returns a string representation of the security rules.
func (rules *Rules) String() string {
return string(rules.AppendTo(nil))
}
// AppendTo appends a string representation of the security rules to the given
// byte slice and returns the resulting slice.
func (rules *Rules) AppendTo(data []byte) []byte {
if rules != nil {
data = appendPrefixes(data, "ALLOW ", rules.Allow)
data = appendPrefixes(data, "BLOCK ", rules.Block)
}
return data
}
func appendPrefixes(data []byte, title string, prefixes []netip.Prefix) []byte {
if len(prefixes) > 0 {
if len(data) > 0 {
data = append(data, ',', ' ')
}
data = append(data, title...)
for _, prefix := range prefixes {
data = prefix.AppendTo(data)
data = append(data, ' ')
}
data = data[:len(data)-1]
}
return data
}
// Accept returns true if the given address is allowed by the security rules.
func (rules *Rules) Accept(addr netip.Addr) bool {
if rules != nil {
for _, allow := range rules.Allow {
if allow.Contains(addr) {
for _, block := range rules.Block {
if block.Contains(addr) {
return false
}
}
return true
}
}
}
return false
}
// DialFunc returns a dial function using the given resolver and dialer to
// establish connections to addresses that are allowed by the security rules.
//
// The resolver is used to convert logical hostnames to IP addresses before
// applying the security rules.
//
// If the resolver is nil, net.DefaultResolver is used.
//
// If the dialer is nil, a new dialer is created with the default options.
func (rules *Rules) DialFunc(rslv Resolver, dial DialFunc) DialFunc {
if rslv == nil {
rslv = net.DefaultResolver
}
if dial == nil {
dial = new(net.Dialer).DialContext
}
// Clone the rules so we're resistant to buggy applications that would
// modify the lists after the dial function has been created.
rules = rules.Clone()
return func(ctx context.Context, network, address string) (net.Conn, error) {
dialError := func(err error, addr net.Addr) error {
return &net.OpError{Op: "dial", Net: network, Addr: addr, Err: err}
}
denyError := func(addr netip.Addr) error {
return dialError(ErrDenied, &net.IPAddr{IP: net.IP(addr.AsSlice())})
}
dnsError := func(host string) error {
return dialError(&net.DNSError{Err: "no such host", Name: host, IsNotFound: true}, nil)
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, dialError(err, nil)
}
if addr, _ := netip.ParseAddr(host); addr.IsValid() {
if !rules.Accept(addr) {
return nil, denyError(addr)
}
return dial(ctx, network, address)
}
addrs, err := rslv.LookupNetIP(ctx, ipnet(network), host)
if err != nil {
return nil, dialError(err, nil)
}
if len(addrs) == 0 {
return nil, dnsError(host)
}
for _, addr := range addrs {
if rules.Accept(addr) {
return dial(ctx, network, net.JoinHostPort(addr.String(), port))
}
}
return nil, denyError(addrs[0])
}
}
func ipnet(network string) string {
switch network {
case "tcp", "udp":
return "ip"
case "tcp4", "udp4":
return "ip4"
case "tcp6", "udp6":
return "ip6"
default:
return network
}
}