From 5bc24af3ab7760fb7985e4e145df971d2cf78b26 Mon Sep 17 00:00:00 2001 From: lvlcn-t <75443136+lvlcn-t@users.noreply.github.com> Date: Mon, 20 May 2024 22:06:46 +0200 Subject: [PATCH] feat: switch protocol base type to string --- internal/traceroute/icmp.go | 17 ++++++- internal/traceroute/traceroute.go | 77 ++++++++++++++++++++++------- pkg/checks/traceroute/config.go | 45 ++++------------- pkg/checks/traceroute/traceroute.go | 2 +- 4 files changed, 85 insertions(+), 56 deletions(-) diff --git a/internal/traceroute/icmp.go b/internal/traceroute/icmp.go index fa67e6dd..ad057beb 100644 --- a/internal/traceroute/icmp.go +++ b/internal/traceroute/icmp.go @@ -7,6 +7,7 @@ import ( "net" "time" + "github.com/caas-team/sparrow/internal/logger" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -17,12 +18,17 @@ var _ hopper = (*icmpHopper)(nil) type icmpHopper struct{ *tracer } -func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl int) (hop Hop, err error) { +func (h *icmpHopper) Hop(ctx context.Context, destAddr *net.IPAddr, _ uint16, ttl int) (hop Hop, err error) { + log := logger.FromContext(ctx) network, typ := h.resolveType(destAddr) + log.DebugContext(ctx, "Resolved network and ICMP type", "network", network, "type", typ) + recvConn, err := icmp.ListenPacket(network, "") if err != nil { + log.ErrorContext(ctx, "Error creating ICMP listener", "error", err) return hop, fmt.Errorf("error creating ICMP listener: %w", err) } + log.DebugContext(ctx, "ICMP listener created", "address", recvConn.LocalAddr().String()) defer func() { if cErr := recvConn.Close(); cErr != nil { err = errors.Join(err, ErrClosingConn{Err: cErr}) @@ -31,8 +37,10 @@ func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl conn, err := h.newConn(network, destAddr, ttl) if err != nil { + log.ErrorContext(ctx, "Error creating raw socket", "error", err) return hop, fmt.Errorf("error creating raw socket: %w", err) } + log.DebugContext(ctx, "Raw socket created", "address", destAddr.String(), "ttl", ttl) defer func() { if cErr := conn.Close(); cErr != nil { err = errors.Join(err, ErrClosingConn{Err: cErr}) @@ -48,20 +56,25 @@ func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl Data: []byte("HELLO-R-U-THERE"), }, }); err != nil { + log.ErrorContext(ctx, "Error sending ICMP message", "error", err) return hop, fmt.Errorf("error sending ICMP message: %w", err) } + log.DebugContext(ctx, "ICMP message sent", "address", destAddr.String(), "ttl", ttl) recvBuffer := make([]byte, bufferSize) err = recvConn.SetReadDeadline(time.Now().Add(h.Timeout)) if err != nil { + log.ErrorContext(ctx, "Error setting read deadline", "error", err) return hop, fmt.Errorf("error setting read deadline: %w", err) } hop, err = h.receive(recvConn, recvBuffer, start) hop.Tracepoint = ttl if err != nil { + log.ErrorContext(ctx, "Error receiving ICMP message", "error", err) return hop, err } + log.DebugContext(ctx, "ICMP message received", "address", destAddr.String(), "ttl", ttl, "hop", hop) return hop, nil } @@ -76,6 +89,8 @@ func (*icmpHopper) resolveType(destAddr *net.IPAddr) (network string, typ icmp.T // newConn creates a raw connection to the given address with the specified TTL func (*icmpHopper) newConn(network string, destAddr *net.IPAddr, ttl int) (*net.IPConn, error) { + // Unfortunately, the net package does not provide a context-aware DialIP function + // TODO: Switch to the net.DialIPContext function as soon as https://github.com/golang/go/issues/49097 is implemented conn, err := net.DialIP(network, nil, destAddr) if err != nil { return nil, err diff --git a/internal/traceroute/traceroute.go b/internal/traceroute/traceroute.go index 2c489b53..e9749873 100644 --- a/internal/traceroute/traceroute.go +++ b/internal/traceroute/traceroute.go @@ -7,20 +7,59 @@ import ( "sort" "sync" "time" + + "github.com/caas-team/sparrow/internal/helper" ) var _ Tracer = (*tracer)(nil) // Protocol defines the protocol used for the traceroute -type Protocol int +type Protocol string + +// String returns the string representation of the protocol +func (p Protocol) String() string { + return string(p) +} + +// Validate validates the protocol +func (p Protocol) Validate() error { + if p == "" { + return fmt.Errorf("protocol cannot be empty") + } + + isValid := false + for _, proto := range Protocols() { + if p == proto { + isValid = true + break + } + } + + if !isValid { + return fmt.Errorf("invalid protocol %q, must be one of %v", p, Protocols()) + } + + if p == ICMP || p == UDP { + if !helper.HasCapabilities(helper.CAP_NET_RAW) { + return fmt.Errorf("protocol %q requires either elevated capabilities (CAP_NET_RAW) or running as root", p) + } + } + + return nil +} + +// Protocols returns the list of supported protocols +func Protocols() []Protocol { + return []Protocol{ICMP, UDP, TCP} +} const ( // ICMP represents the ICMP protocol - ICMP Protocol = iota + ICMP Protocol = "icmp" // UDP represents the UDP protocol - UDP + UDP Protocol = "udp" // TCP represents the TCP protocol - TCP + TCP Protocol = "tcp" // bufferSize represents the buffer size for the received data bufferSize = 1500 ) @@ -206,6 +245,20 @@ type hopper interface { Hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error) } +// newHopper returns the hopper based on the protocol +func (t *tracer) newHopper() hopper { + switch t.Protocol { + case ICMP: + return &icmpHopper{tracer: t} + case UDP: + return &udpHopper{icmpHopper: &icmpHopper{tracer: t}} + case TCP: + return &tcpHopper{tracer: t} + default: + return nil + } +} + // hop performs a single hop in the traceroute to the given address with the specified TTL. func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error) { ctx, cancel := context.WithTimeout(ctx, t.Timeout) @@ -220,22 +273,8 @@ func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl default: h := t.newHopper() if h == nil { - return Hop{}, fmt.Errorf("unsupported protocol: %d", t.Protocol) + return Hop{}, fmt.Errorf("unsupported protocol: %q", t.Protocol) } return h.Hop(ctx, destAddr, port, ttl) } } - -// newHopper returns the hopper based on the protocol -func (t *tracer) newHopper() hopper { - switch t.Protocol { - case ICMP: - return &icmpHopper{tracer: t} - case UDP: - return &udpHopper{icmpHopper: &icmpHopper{tracer: t}} - case TCP: - return &tcpHopper{tracer: t} - default: - return nil - } -} diff --git a/pkg/checks/traceroute/config.go b/pkg/checks/traceroute/config.go index d455d58f..b6269a23 100644 --- a/pkg/checks/traceroute/config.go +++ b/pkg/checks/traceroute/config.go @@ -11,34 +11,12 @@ import ( "github.com/caas-team/sparrow/pkg/checks" ) -// ProtocolType is the type for the protocol to use for the traceroute -type ProtocolType string - -// String returns the string representation of the protocol type -func (p ProtocolType) String() string { - return string(p) -} - -// ToProtocol converts the protocol type to a traceroute protocol -func (p ProtocolType) ToProtocol() traceroute.Protocol { - switch p { - case "icmp": - return traceroute.ICMP - case "udp": - return traceroute.UDP - case "tcp": - return traceroute.TCP - default: - return -1 - } -} - // Config is the configuration for the traceroute check type Config struct { // Targets is a list of targets to traceroute to Targets []Target `json:"targets" yaml:"targets" mapstructure:"targets"` // Protocol is the protocol to use for the traceroute - Protocol ProtocolType `json:"protocol" yaml:"protocol" mapstructure:"protocol"` + Protocol traceroute.Protocol `json:"protocol" yaml:"protocol" mapstructure:"protocol"` // Interval is the time to wait between check iterations Interval time.Duration `json:"interval" yaml:"interval" mapstructure:"interval"` // Timeout is the maximum time to wait for a response from a hop @@ -62,14 +40,8 @@ func (c *Config) For() string { } func (c *Config) Validate() error { - switch c.Protocol { - case "tcp": - case "icmp", "udp": - if !helper.HasCapabilities(helper.CAP_NET_RAW) { - return checks.ErrInvalidConfig{CheckName: CheckName, Field: "traceroute.protocol", Reason: fmt.Sprintf("protocol %q requires either elevated capabilities (CAP_NET_RAW) or running as root", c.Protocol)} - } - default: - return checks.ErrInvalidConfig{CheckName: CheckName, Field: "traceroute.protocol", Reason: "must be one of 'icmp', 'udp', 'tcp'"} + if err := c.Protocol.Validate(); err != nil { + return checks.ErrInvalidConfig{CheckName: CheckName, Field: "traceroute.protocol", Reason: err.Error()} } if c.Timeout <= 0 { @@ -80,15 +52,18 @@ func (c *Config) Validate() error { } for i, t := range c.Targets { - ip := net.ParseIP(t.Addr) - if ip != nil { + if ip := net.ParseIP(t.Addr); ip != nil { continue } - _, err := url.Parse(t.Addr) - if err != nil { + if _, err := url.Parse(t.Addr); err != nil { return checks.ErrInvalidConfig{CheckName: CheckName, Field: fmt.Sprintf("traceroute.targets[%d].addr", i), Reason: "invalid url or ip"} } + + if t.Port == 0 { + c.Targets[i].Port = 80 + } } + return nil } diff --git a/pkg/checks/traceroute/traceroute.go b/pkg/checks/traceroute/traceroute.go index ed3455ca..30115dc2 100644 --- a/pkg/checks/traceroute/traceroute.go +++ b/pkg/checks/traceroute/traceroute.go @@ -57,7 +57,7 @@ func (tr *Traceroute) SetConfig(config checks.Runtime) error { if err != nil { return err } - tr.tracer = traceroute.New(tr.Config.MaxHops, tr.Config.Timeout, tr.Config.Protocol.ToProtocol()) + tr.tracer = traceroute.New(tr.Config.MaxHops, tr.Config.Timeout, tr.Config.Protocol) return nil }