Skip to content

Commit

Permalink
preffer public IP if any forwwarded
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Nov 26, 2023
1 parent 28c47e1 commit 6fdaabb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
53 changes: 51 additions & 2 deletions app/proxy/only_from.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"bytes"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -69,8 +70,10 @@ func (o *OnlyFrom) realIP(ipLookups []OFLookup, r *http.Request) string {

if lookup == OFForwarded && forwardedFor != "" {
// X-Forwarded-For is potentially a list of addresses separated with ","
parts := strings.Split(forwardedFor, ",")
return strings.TrimSpace(parts[len(parts)-1])
// The left-most being the original client, and each successive proxy that passed the request
// adding the IP address where it received the request from.
// In case if the original IP is a private behind a proxy, we need to get the first public IP from the list
return preferPublicIP(strings.Split(forwardedFor, ","))
}

if lookup == OFRealIP && realIP != "" {
Expand Down Expand Up @@ -98,3 +101,49 @@ func (o *OnlyFrom) matchRemoteIP(remoteIP string, allowedIPs []string) bool {
}
return false
}

// preferPublicIP returns first public IP from the list of IPs
// if no public IP found, returns first IP from the list
func preferPublicIP(ips []string) string {
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if net.ParseIP(ip).IsGlobalUnicast() && !isPrivateSubnet(net.ParseIP(ip)) {
return ip
}
}
return strings.TrimSpace(ips[0])
}

type ipRange struct {
start net.IP
end net.IP
}

var privateRanges = []ipRange{
{start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")},
{start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")},
{start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")},
{start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")},
{start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")},
{start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")},
{start: net.ParseIP("::1"), end: net.ParseIP("::1")},
{start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
{start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}

// isPrivateSubnet - check to see if this ip is in a private subnet
func isPrivateSubnet(ipAddress net.IP) bool {
inRange := func(r ipRange, ipAddress net.IP) bool {
// ensure the IPs are in the same format for comparison
ipAddress = ipAddress.To16()
r.start = r.start.To16()
r.end = r.end.To16()
return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0
}
for _, r := range privateRanges {
if inRange(r, ipAddress) {
return true
}
}
return false
}
7 changes: 7 additions & 0 deletions app/proxy/only_from_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ func TestOnlyFrom_Handler(t *testing.T) {
forwardedFor: "192.168.1.1",
expectedStatusCode: http.StatusOK,
},
{
name: "allowed IP with Forwarded lookup, mix private and public IPs",
lookups: []OFLookup{OFForwarded},
allowedIPs: []string{"8.8.8.8"},
forwardedFor: "192.168.1.1, 10.0.0.5, 8.8.8.8, 10.10.10.10",
expectedStatusCode: http.StatusOK,
},
{
name: "disallowed IP with Forwarded lookup",
lookups: []OFLookup{OFForwarded},
Expand Down
14 changes: 7 additions & 7 deletions app/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,22 +401,22 @@ func (h *Http) makeHTTPServer(addr string, router http.Handler) *http.Server {
}

func (h *Http) setXRealIP(r *http.Request) {

remoteIP := r.Header.Get("X-Forwarded-For")
if remoteIP == "" {
remoteIP = r.RemoteAddr
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// use the left-most non-private client IP address
// if there is no any non-private IP address, use the left-most address
r.Header.Set("X-Real-IP", preferPublicIP(strings.Split(forwarded, ",")))
return
}

ip, _, err := net.SplitHostPort(remoteIP)
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return
}

userIP := net.ParseIP(ip)
if userIP == nil {
return
}
r.Header.Add("X-Real-IP", ip)
r.Header.Set("X-Real-IP", ip)
}

// discoveredServers gets the list of servers discovered by providers.
Expand Down

0 comments on commit 6fdaabb

Please sign in to comment.