diff --git a/internal/traceroute/icmp.go b/internal/traceroute/icmp.go index 14864c87..199d6729 100644 --- a/internal/traceroute/icmp.go +++ b/internal/traceroute/icmp.go @@ -1,30 +1,35 @@ package traceroute import ( + "context" "errors" "fmt" "net" - "os" "time" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" ) -func (t *tracer) hopICMP(destAddr *net.IPAddr, ttl int) (hop Hop, err error) { - network, icmpType := getNetworkAndICMPType(destAddr) - icmpConn, err := icmp.ListenPacket(network, "") +var _ hopper = (*icmpHopper)(nil) + +type icmpHopper struct{ *tracer } + +func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl int) (hop Hop, err error) { + network, typ := h.resolveType(destAddr) + recvConn, err := icmp.ListenPacket(network, "") if err != nil { return hop, fmt.Errorf("error creating ICMP listener: %w", err) } defer func() { - if cErr := icmpConn.Close(); cErr != nil { + if cErr := recvConn.Close(); cErr != nil { err = errors.Join(err, ErrClosingConn{Err: cErr}) } }() - conn, err := createRawConn(network, destAddr, ttl) + conn, err := h.newConn(network, destAddr, ttl) if err != nil { return hop, fmt.Errorf("error creating raw socket: %w", err) } @@ -35,11 +40,11 @@ func (t *tracer) hopICMP(destAddr *net.IPAddr, ttl int) (hop Hop, err error) { }() start := time.Now() - if err = sendICMPMessage(conn, icmp.Message{ - Type: icmpType, + if err = h.send(conn, icmp.Message{ + Type: typ, Code: 0, Body: &icmp.Echo{ - ID: os.Getpid() & 0xffff, Seq: ttl, + ID: unix.Getpid() & 0xffff, Seq: ttl, Data: []byte("HELLO-R-U-THERE"), }, }); err != nil { @@ -47,12 +52,12 @@ func (t *tracer) hopICMP(destAddr *net.IPAddr, ttl int) (hop Hop, err error) { } recvBuffer := make([]byte, bufferSize) - err = icmpConn.SetReadDeadline(time.Now().Add(t.Timeout)) + err = recvConn.SetReadDeadline(time.Now().Add(h.Timeout)) if err != nil { return hop, fmt.Errorf("error setting read deadline: %w", err) } - hop, err = receiveICMPResponse(icmpConn, recvBuffer, start) + hop, err = h.receive(recvConn, recvBuffer, start) hop.Tracepoint = ttl if err != nil { return hop, err @@ -61,74 +66,74 @@ func (t *tracer) hopICMP(destAddr *net.IPAddr, ttl int) (hop Hop, err error) { return hop, nil } -// getNetworkAndICMPType returns the network and ICMP type based on the destination address -func getNetworkAndICMPType(destAddr *net.IPAddr) (string, icmp.Type) { +// resolveType returns the network and ICMP type based on the destination address +func (*icmpHopper) resolveType(destAddr *net.IPAddr) (network string, typ icmp.Type) { if destAddr.IP.To4() != nil { return "ip4:icmp", ipv4.ICMPTypeEcho } return "ip6:ipv6-icmp", ipv6.ICMPTypeEchoRequest } -// createRawConn creates a raw connection to the given address with the specified TTL -func createRawConn(network string, destAddr *net.IPAddr, ttl int) (*net.IPConn, error) { +// newConn creates a raw connection to the given address with the specified TTL +func (*icmpHopper) newConn(network string, destAddr *net.IPAddr, ttl int) (*net.IPConn, error) { conn, err := net.DialIP(network, nil, destAddr) if err != nil { return nil, err } if network == "ip4:icmp" { - pc := ipv4.NewPacketConn(conn) - if err := pc.SetControlMessage(ipv4.FlagTTL, true); err != nil { + cv4 := ipv4.NewPacketConn(conn) + if err := cv4.SetControlMessage(ipv4.FlagTTL, true); err != nil { return nil, err } - if err := pc.SetTTL(ttl); err != nil { + if err := cv4.SetTTL(ttl); err != nil { return nil, err } } else { - pc := ipv6.NewPacketConn(conn) - if err := pc.SetControlMessage(ipv6.FlagHopLimit, true); err != nil { + cv6 := ipv6.NewPacketConn(conn) + if err := cv6.SetControlMessage(ipv6.FlagHopLimit, true); err != nil { return nil, err } - if err := pc.SetHopLimit(ttl); err != nil { + if err := cv6.SetHopLimit(ttl); err != nil { return nil, err } } return conn, nil } -// sendICMPMessage sends an ICMP message to the given connection -func sendICMPMessage(conn *net.IPConn, wm icmp.Message) error { - wb, err := wm.Marshal(nil) +// send sends an ICMP message to the given connection +func (*icmpHopper) send(conn *net.IPConn, msg icmp.Message) error { + b, err := msg.Marshal(nil) if err != nil { return err } - _, err = conn.Write(wb) + _, err = conn.Write(b) return err } -// receiveICMPResponse reads the response from the ICMP connection -func receiveICMPResponse(icmpConn *icmp.PacketConn, recvBuffer []byte, start time.Time) (Hop, error) { +// receive reads the response from the ICMP connection +func (*icmpHopper) receive(conn *icmp.PacketConn, buffer []byte, start time.Time) (Hop, error) { hop := Hop{} - n, peer, err := icmpConn.ReadFrom(recvBuffer) + n, peer, err := conn.ReadFrom(buffer) if err != nil { return hop, fmt.Errorf("error reading from ICMP connection: %w", err) } hop.Duration = time.Since(start).Seconds() - rm, err := icmp.ParseMessage(1, recvBuffer[:n]) + pm, err := icmp.ParseMessage(1, buffer[:n]) if err != nil { return hop, fmt.Errorf("error parsing ICMP message: %w", err) } - switch rm.Type { + switch pm.Type { case ipv4.ICMPTypeTimeExceeded, ipv6.ICMPTypeTimeExceeded: hop.IP = peer.(*net.IPAddr).IP case ipv4.ICMPTypeEchoReply, ipv6.ICMPTypeEchoReply: hop.IP = peer.(*net.IPAddr).IP hop.ReachedTarget = true default: - return hop, fmt.Errorf("unexpected ICMP message type: %v", rm.Type) + return hop, fmt.Errorf("unexpected ICMP message type: %v", pm.Type) } return hop, nil diff --git a/internal/traceroute/tcp.go b/internal/traceroute/tcp.go index 48be684d..825602b7 100644 --- a/internal/traceroute/tcp.go +++ b/internal/traceroute/tcp.go @@ -13,109 +13,107 @@ import ( "golang.org/x/net/ipv6" ) -func (t *tracer) hopTCP(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (hop Hop, err error) { +var _ hopper = (*tcpHopper)(nil) + +type tcpHopper struct{ *tracer } + +func (h *tcpHopper) Hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (hop Hop, err error) { log := logger.FromContext(ctx) - conn, err := createTCPConn(ctx, destAddr, port, ttl) + conn, err := h.newConn(ctx, destAddr, port, ttl) if err != nil { - log.Error("Error creating TCP connection", "error", err) + log.ErrorContext(ctx, "Error creating TCP connection", "error", err) return hop, fmt.Errorf("error creating TCP connection: %w", err) } - log.Debug("TCP connection created", "address", destAddr.String(), "port", port, "ttl", ttl) + log.DebugContext(ctx, "TCP connection created", "address", destAddr.String(), "port", port, "ttl", ttl) defer func() { if cErr := conn.Close(); cErr != nil { - log.Error("Error closing TCP connection", "error", cErr) + log.ErrorContext(ctx, "Error closing TCP connection", "error", cErr) err = errors.Join(err, ErrClosingConn{Err: cErr}) } }() start := time.Now() - if err = sendTCPPacket(ctx, conn); err != nil { - log.Error("Error sending TCP packet", "error", err) + if err = h.sendSYN(ctx, conn); err != nil { + log.ErrorContext(ctx, "Error sending TCP packet", "error", err) return hop, fmt.Errorf("error sending TCP packet: %w", err) } - log.Debug("TCP packet sent", "address", destAddr.String(), "port", port, "ttl", ttl) + log.DebugContext(ctx, "TCP packet sent", "address", destAddr.String(), "port", port, "ttl", ttl) - hop, err = receiveTCPResponse(ctx, conn, t.Timeout, start) + hop, err = h.receive(ctx, conn, h.Timeout, start) hop.Tracepoint = ttl - log.Debug("TCP response received", "address", destAddr.String(), "port", port, "ttl", ttl, "hop", hop, "error", err) + log.DebugContext(ctx, "TCP response received", "address", destAddr.String(), "port", port, "ttl", ttl, "hop", hop, "error", err) return hop, err } -// createTCPConn creates a TCP connection to the given address with the specified TTL -func createTCPConn(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (*net.TCPConn, error) { +// newConn creates a TCP connection to the given address with the specified TTL +func (*tcpHopper) newConn(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (*net.TCPConn, error) { log := logger.FromContext(ctx) + // Unfortunately, the net package does not provide a context-aware DialTCP function + // TODO: Switch to the net.DialTCPContext function as soon as https://github.com/golang/go/issues/49097 is implemented conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: destAddr.IP, Port: int(port)}) if err != nil { - log.Error("Error dialing TCP connection", "error", err) + log.ErrorContext(ctx, "Error dialing TCP connection", "error", err) return nil, err } if destAddr.IP.To4() != nil { - pc := ipv4.NewConn(conn) - if err := pc.SetTTL(ttl); err != nil { - log.Error("Error setting TTL on IPv4 connection", "error", err) + cv4 := ipv4.NewConn(conn) + if err := cv4.SetTTL(ttl); err != nil { + log.ErrorContext(ctx, "Error setting TTL on IPv4 connection", "error", err) return nil, err } } else { - pc := ipv6.NewConn(conn) - if err := pc.SetHopLimit(ttl); err != nil { - log.Error("Error setting hop limit on IPv6 connection", "error", err) + cv6 := ipv6.NewConn(conn) + if err := cv6.SetHopLimit(ttl); err != nil { + log.ErrorContext(ctx, "Error setting hop limit on IPv6 connection", "error", err) return nil, err } } return conn, nil } -// sendTCPPacket sends a TCP SYN packet to the given destination -func sendTCPPacket(ctx context.Context, conn *net.TCPConn) error { +// sendSYN writes a TCP SYN packet to the given connection's file descriptor +func (*tcpHopper) sendSYN(ctx context.Context, conn *net.TCPConn) error { log := logger.FromContext(ctx) err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)) if err != nil { - log.Error("Error setting write deadline", "error", err) + log.ErrorContext(ctx, "Error setting write deadline", "error", err) return fmt.Errorf("error setting write deadline: %w", err) } - _, err = conn.Write([]byte("HELLO-R-U-THERE")) + // To initiate a TCP connection, we need to send a SYN packet + // In this case we want this to be as small as possible, so we send a single byte + _, err = conn.Write([]byte{0}) if err != nil { - log.Error("Error writing TCP packet", "error", err) + log.ErrorContext(ctx, "Error writing TCP packet", "error", err) return fmt.Errorf("error writing TCP packet: %w", err) } return nil } -// receiveTCPResponse waits for a TCP response to the sent SYN packet -func receiveTCPResponse(ctx context.Context, conn *net.TCPConn, timeout time.Duration, start time.Time) (Hop, error) { - hop := Hop{} +// receive waits for a TCP response to the sent SYN packet +func (*tcpHopper) receive(ctx context.Context, conn *net.TCPConn, timeout time.Duration, start time.Time) (Hop, error) { log := logger.FromContext(ctx) err := conn.SetReadDeadline(time.Now().Add(timeout)) if err != nil { - log.Error("Error setting read deadline", "error", err) - return hop, fmt.Errorf("error setting read deadline: %w", err) + log.ErrorContext(ctx, "Error setting read deadline", "error", err) + return Hop{}, fmt.Errorf("error setting read deadline: %w", err) } - buf := make([]byte, 1) - _, err = conn.Read(buf) + _, err = io.ReadAll(conn) if err != nil { - // Timeout means the TTL expired + // Timeout means the TTL expired and the packet was dropped by a router if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - hop.Duration = time.Since(start).Seconds() - hop.IP = conn.RemoteAddr().(*net.TCPAddr).IP - return hop, nil - } - // EOF means the target sent a TCP RST response - if err == io.EOF { - hop.Duration = time.Since(start).Seconds() - hop.IP = conn.RemoteAddr().(*net.TCPAddr).IP - hop.ReachedTarget = true - return hop, nil + return Hop{Duration: time.Since(start).Seconds()}, nil } - log.Error("Error reading TCP response", "error", err) - return hop, fmt.Errorf("error reading TCP response: %w", err) + log.ErrorContext(ctx, "Error reading TCP response", "error", err) + return Hop{}, fmt.Errorf("error reading TCP response: %w", err) } - hop.Duration = time.Since(start).Seconds() - hop.IP = conn.RemoteAddr().(*net.TCPAddr).IP - hop.ReachedTarget = true - return hop, nil + return Hop{ + Duration: time.Since(start).Seconds(), + IP: conn.RemoteAddr().(*net.TCPAddr).IP, + ReachedTarget: true, + }, nil } diff --git a/internal/traceroute/traceroute.go b/internal/traceroute/traceroute.go index ffba68cb..4052cdef 100644 --- a/internal/traceroute/traceroute.go +++ b/internal/traceroute/traceroute.go @@ -2,7 +2,6 @@ package traceroute import ( "context" - "errors" "fmt" "net" "sort" @@ -202,6 +201,11 @@ func (t *tracer) filterHops(hops []Hop) []Hop { return filtered } +// hopper represents an interface for a hopper +type hopper interface { + Hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error) +} + // hop performs a single hop in the traceroute to the given address with the specified TTL. func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error) { ctx, cancel := context.WithTimeout(ctx, t.Timeout) @@ -214,15 +218,24 @@ func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl Error: fmt.Sprintf("timeout after %fs", t.Timeout.Seconds()), }, ctx.Err() default: - switch t.Protocol { - case ICMP: - return t.hopICMP(destAddr, ttl) - case UDP: - return Hop{}, errors.New("UDP not supported yet") - case TCP: - return t.hopTCP(ctx, destAddr, port, ttl) - default: - return Hop{}, errors.New("protocol not supported") + h := t.newHopper() + if h == nil { + return Hop{}, fmt.Errorf("unsupported protocol: %d", t.Protocol) } + return h.Hop(ctx, destAddr, port, ttl) + } +} + +// newHopper returns the hopper based on the protocol +func (t *tracer) newHopper() hopper { + switch t.Protocol { + case ICMP: + return &icmpHopper{tracer: t} + case UDP: + return &udpHopper{tracer: t} + case TCP: + return &tcpHopper{tracer: t} + default: + return nil } } diff --git a/internal/traceroute/udp.go b/internal/traceroute/udp.go index d4ec93c3..22359b0e 100644 --- a/internal/traceroute/udp.go +++ b/internal/traceroute/udp.go @@ -1 +1,15 @@ package traceroute + +import ( + "context" + "errors" + "net" +) + +var _ hopper = (*udpHopper)(nil) + +type udpHopper struct{ *tracer } + +func (h *udpHopper) Hop(_ context.Context, _ *net.IPAddr, _ uint16, _ int) (hop Hop, err error) { + return hop, errors.New("udp protocol is not supported yet") +}