Skip to content

Commit

Permalink
feat: switch protocol base type to string
Browse files Browse the repository at this point in the history
  • Loading branch information
lvlcn-t committed May 20, 2024
1 parent 4af2e83 commit 5bc24af
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 56 deletions.
17 changes: 16 additions & 1 deletion internal/traceroute/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"time"

"github.com/caas-team/sparrow/internal/logger"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
Expand All @@ -17,12 +18,17 @@ var _ hopper = (*icmpHopper)(nil)

type icmpHopper struct{ *tracer }

func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl int) (hop Hop, err error) {
func (h *icmpHopper) Hop(ctx context.Context, destAddr *net.IPAddr, _ uint16, ttl int) (hop Hop, err error) {
log := logger.FromContext(ctx)
network, typ := h.resolveType(destAddr)
log.DebugContext(ctx, "Resolved network and ICMP type", "network", network, "type", typ)

recvConn, err := icmp.ListenPacket(network, "")
if err != nil {
log.ErrorContext(ctx, "Error creating ICMP listener", "error", err)
return hop, fmt.Errorf("error creating ICMP listener: %w", err)
}
log.DebugContext(ctx, "ICMP listener created", "address", recvConn.LocalAddr().String())
defer func() {
if cErr := recvConn.Close(); cErr != nil {
err = errors.Join(err, ErrClosingConn{Err: cErr})
Expand All @@ -31,8 +37,10 @@ func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl

conn, err := h.newConn(network, destAddr, ttl)
if err != nil {
log.ErrorContext(ctx, "Error creating raw socket", "error", err)
return hop, fmt.Errorf("error creating raw socket: %w", err)
}
log.DebugContext(ctx, "Raw socket created", "address", destAddr.String(), "ttl", ttl)
defer func() {
if cErr := conn.Close(); cErr != nil {
err = errors.Join(err, ErrClosingConn{Err: cErr})
Expand All @@ -48,20 +56,25 @@ func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl
Data: []byte("HELLO-R-U-THERE"),
},
}); err != nil {
log.ErrorContext(ctx, "Error sending ICMP message", "error", err)
return hop, fmt.Errorf("error sending ICMP message: %w", err)
}
log.DebugContext(ctx, "ICMP message sent", "address", destAddr.String(), "ttl", ttl)

recvBuffer := make([]byte, bufferSize)
err = recvConn.SetReadDeadline(time.Now().Add(h.Timeout))
if err != nil {
log.ErrorContext(ctx, "Error setting read deadline", "error", err)
return hop, fmt.Errorf("error setting read deadline: %w", err)
}

hop, err = h.receive(recvConn, recvBuffer, start)
hop.Tracepoint = ttl
if err != nil {
log.ErrorContext(ctx, "Error receiving ICMP message", "error", err)
return hop, err
}
log.DebugContext(ctx, "ICMP message received", "address", destAddr.String(), "ttl", ttl, "hop", hop)

return hop, nil
}
Expand All @@ -76,6 +89,8 @@ func (*icmpHopper) resolveType(destAddr *net.IPAddr) (network string, typ icmp.T

// newConn creates a raw connection to the given address with the specified TTL
func (*icmpHopper) newConn(network string, destAddr *net.IPAddr, ttl int) (*net.IPConn, error) {
// Unfortunately, the net package does not provide a context-aware DialIP function
// TODO: Switch to the net.DialIPContext function as soon as https://github.com/golang/go/issues/49097 is implemented
conn, err := net.DialIP(network, nil, destAddr)
if err != nil {
return nil, err
Expand Down
77 changes: 58 additions & 19 deletions internal/traceroute/traceroute.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,59 @@ import (
"sort"
"sync"
"time"

"github.com/caas-team/sparrow/internal/helper"
)

var _ Tracer = (*tracer)(nil)

// Protocol defines the protocol used for the traceroute
type Protocol int
type Protocol string

// String returns the string representation of the protocol
func (p Protocol) String() string {
return string(p)
}

// Validate validates the protocol
func (p Protocol) Validate() error {
if p == "" {
return fmt.Errorf("protocol cannot be empty")
}

isValid := false
for _, proto := range Protocols() {
if p == proto {
isValid = true
break
}
}

if !isValid {
return fmt.Errorf("invalid protocol %q, must be one of %v", p, Protocols())
}

if p == ICMP || p == UDP {
if !helper.HasCapabilities(helper.CAP_NET_RAW) {
return fmt.Errorf("protocol %q requires either elevated capabilities (CAP_NET_RAW) or running as root", p)
}
}

return nil
}

// Protocols returns the list of supported protocols
func Protocols() []Protocol {
return []Protocol{ICMP, UDP, TCP}
}

const (
// ICMP represents the ICMP protocol
ICMP Protocol = iota
ICMP Protocol = "icmp"
// UDP represents the UDP protocol
UDP
UDP Protocol = "udp"
// TCP represents the TCP protocol
TCP
TCP Protocol = "tcp"
// bufferSize represents the buffer size for the received data
bufferSize = 1500
)
Expand Down Expand Up @@ -206,6 +245,20 @@ type hopper interface {
Hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error)
}

// newHopper returns the hopper based on the protocol
func (t *tracer) newHopper() hopper {
switch t.Protocol {
case ICMP:
return &icmpHopper{tracer: t}
case UDP:
return &udpHopper{icmpHopper: &icmpHopper{tracer: t}}
case TCP:
return &tcpHopper{tracer: t}
default:
return nil
}
}

// hop performs a single hop in the traceroute to the given address with the specified TTL.
func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error) {
ctx, cancel := context.WithTimeout(ctx, t.Timeout)
Expand All @@ -220,22 +273,8 @@ func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl
default:
h := t.newHopper()
if h == nil {
return Hop{}, fmt.Errorf("unsupported protocol: %d", t.Protocol)
return Hop{}, fmt.Errorf("unsupported protocol: %q", t.Protocol)
}
return h.Hop(ctx, destAddr, port, ttl)
}
}

// newHopper returns the hopper based on the protocol
func (t *tracer) newHopper() hopper {
switch t.Protocol {
case ICMP:
return &icmpHopper{tracer: t}
case UDP:
return &udpHopper{icmpHopper: &icmpHopper{tracer: t}}
case TCP:
return &tcpHopper{tracer: t}
default:
return nil
}
}
45 changes: 10 additions & 35 deletions pkg/checks/traceroute/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,12 @@ import (
"github.com/caas-team/sparrow/pkg/checks"
)

// ProtocolType is the type for the protocol to use for the traceroute
type ProtocolType string

// String returns the string representation of the protocol type
func (p ProtocolType) String() string {
return string(p)
}

// ToProtocol converts the protocol type to a traceroute protocol
func (p ProtocolType) ToProtocol() traceroute.Protocol {
switch p {
case "icmp":
return traceroute.ICMP
case "udp":
return traceroute.UDP
case "tcp":
return traceroute.TCP
default:
return -1
}
}

// Config is the configuration for the traceroute check
type Config struct {
// Targets is a list of targets to traceroute to
Targets []Target `json:"targets" yaml:"targets" mapstructure:"targets"`
// Protocol is the protocol to use for the traceroute
Protocol ProtocolType `json:"protocol" yaml:"protocol" mapstructure:"protocol"`
Protocol traceroute.Protocol `json:"protocol" yaml:"protocol" mapstructure:"protocol"`
// Interval is the time to wait between check iterations
Interval time.Duration `json:"interval" yaml:"interval" mapstructure:"interval"`
// Timeout is the maximum time to wait for a response from a hop
Expand All @@ -62,14 +40,8 @@ func (c *Config) For() string {
}

func (c *Config) Validate() error {
switch c.Protocol {
case "tcp":
case "icmp", "udp":
if !helper.HasCapabilities(helper.CAP_NET_RAW) {
return checks.ErrInvalidConfig{CheckName: CheckName, Field: "traceroute.protocol", Reason: fmt.Sprintf("protocol %q requires either elevated capabilities (CAP_NET_RAW) or running as root", c.Protocol)}
}
default:
return checks.ErrInvalidConfig{CheckName: CheckName, Field: "traceroute.protocol", Reason: "must be one of 'icmp', 'udp', 'tcp'"}
if err := c.Protocol.Validate(); err != nil {
return checks.ErrInvalidConfig{CheckName: CheckName, Field: "traceroute.protocol", Reason: err.Error()}
}

if c.Timeout <= 0 {
Expand All @@ -80,15 +52,18 @@ func (c *Config) Validate() error {
}

for i, t := range c.Targets {
ip := net.ParseIP(t.Addr)
if ip != nil {
if ip := net.ParseIP(t.Addr); ip != nil {
continue
}

_, err := url.Parse(t.Addr)
if err != nil {
if _, err := url.Parse(t.Addr); err != nil {
return checks.ErrInvalidConfig{CheckName: CheckName, Field: fmt.Sprintf("traceroute.targets[%d].addr", i), Reason: "invalid url or ip"}
}

if t.Port == 0 {
c.Targets[i].Port = 80
}
}

return nil
}
2 changes: 1 addition & 1 deletion pkg/checks/traceroute/traceroute.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (tr *Traceroute) SetConfig(config checks.Runtime) error {
if err != nil {
return err
}
tr.tracer = traceroute.New(tr.Config.MaxHops, tr.Config.Timeout, tr.Config.Protocol.ToProtocol())
tr.tracer = traceroute.New(tr.Config.MaxHops, tr.Config.Timeout, tr.Config.Protocol)
return nil
}

Expand Down

0 comments on commit 5bc24af

Please sign in to comment.