Skip to content

Commit

Permalink
feat: add protocol wrapper and interface
Browse files Browse the repository at this point in the history
  • Loading branch information
lvlcn-t committed May 20, 2024
1 parent f0e5e9b commit 0994313
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 89 deletions.
67 changes: 36 additions & 31 deletions internal/traceroute/icmp.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
package traceroute

import (
"context"
"errors"
"fmt"
"net"
"os"
"time"

"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)

func (t *tracer) hopICMP(destAddr *net.IPAddr, ttl int) (hop Hop, err error) {
network, icmpType := getNetworkAndICMPType(destAddr)
icmpConn, err := icmp.ListenPacket(network, "")
var _ hopper = (*icmpHopper)(nil)

type icmpHopper struct{ *tracer }

func (h *icmpHopper) Hop(_ context.Context, destAddr *net.IPAddr, _ uint16, ttl int) (hop Hop, err error) {
network, typ := h.resolveType(destAddr)
recvConn, err := icmp.ListenPacket(network, "")
if err != nil {
return hop, fmt.Errorf("error creating ICMP listener: %w", err)
}
defer func() {
if cErr := icmpConn.Close(); cErr != nil {
if cErr := recvConn.Close(); cErr != nil {
err = errors.Join(err, ErrClosingConn{Err: cErr})
}
}()

conn, err := createRawConn(network, destAddr, ttl)
conn, err := h.newConn(network, destAddr, ttl)
if err != nil {
return hop, fmt.Errorf("error creating raw socket: %w", err)
}
Expand All @@ -35,24 +40,24 @@ func (t *tracer) hopICMP(destAddr *net.IPAddr, ttl int) (hop Hop, err error) {
}()

start := time.Now()
if err = sendICMPMessage(conn, icmp.Message{
Type: icmpType,
if err = h.send(conn, icmp.Message{
Type: typ,
Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff, Seq: ttl,
ID: unix.Getpid() & 0xffff, Seq: ttl,
Data: []byte("HELLO-R-U-THERE"),
},
}); err != nil {
return hop, fmt.Errorf("error sending ICMP message: %w", err)
}

recvBuffer := make([]byte, bufferSize)
err = icmpConn.SetReadDeadline(time.Now().Add(t.Timeout))
err = recvConn.SetReadDeadline(time.Now().Add(h.Timeout))
if err != nil {
return hop, fmt.Errorf("error setting read deadline: %w", err)
}

hop, err = receiveICMPResponse(icmpConn, recvBuffer, start)
hop, err = h.receive(recvConn, recvBuffer, start)
hop.Tracepoint = ttl
if err != nil {
return hop, err
Expand All @@ -61,74 +66,74 @@ func (t *tracer) hopICMP(destAddr *net.IPAddr, ttl int) (hop Hop, err error) {
return hop, nil
}

// getNetworkAndICMPType returns the network and ICMP type based on the destination address
func getNetworkAndICMPType(destAddr *net.IPAddr) (string, icmp.Type) {
// resolveType returns the network and ICMP type based on the destination address
func (*icmpHopper) resolveType(destAddr *net.IPAddr) (network string, typ icmp.Type) {
if destAddr.IP.To4() != nil {
return "ip4:icmp", ipv4.ICMPTypeEcho
}
return "ip6:ipv6-icmp", ipv6.ICMPTypeEchoRequest
}

// createRawConn creates a raw connection to the given address with the specified TTL
func createRawConn(network string, destAddr *net.IPAddr, ttl int) (*net.IPConn, error) {
// newConn creates a raw connection to the given address with the specified TTL
func (*icmpHopper) newConn(network string, destAddr *net.IPAddr, ttl int) (*net.IPConn, error) {
conn, err := net.DialIP(network, nil, destAddr)
if err != nil {
return nil, err
}

if network == "ip4:icmp" {
pc := ipv4.NewPacketConn(conn)
if err := pc.SetControlMessage(ipv4.FlagTTL, true); err != nil {
cv4 := ipv4.NewPacketConn(conn)
if err := cv4.SetControlMessage(ipv4.FlagTTL, true); err != nil {
return nil, err
}
if err := pc.SetTTL(ttl); err != nil {
if err := cv4.SetTTL(ttl); err != nil {
return nil, err
}
} else {
pc := ipv6.NewPacketConn(conn)
if err := pc.SetControlMessage(ipv6.FlagHopLimit, true); err != nil {
cv6 := ipv6.NewPacketConn(conn)
if err := cv6.SetControlMessage(ipv6.FlagHopLimit, true); err != nil {
return nil, err
}
if err := pc.SetHopLimit(ttl); err != nil {
if err := cv6.SetHopLimit(ttl); err != nil {
return nil, err
}
}
return conn, nil
}

// sendICMPMessage sends an ICMP message to the given connection
func sendICMPMessage(conn *net.IPConn, wm icmp.Message) error {
wb, err := wm.Marshal(nil)
// send sends an ICMP message to the given connection
func (*icmpHopper) send(conn *net.IPConn, msg icmp.Message) error {
b, err := msg.Marshal(nil)
if err != nil {
return err
}

_, err = conn.Write(wb)
_, err = conn.Write(b)
return err
}

// receiveICMPResponse reads the response from the ICMP connection
func receiveICMPResponse(icmpConn *icmp.PacketConn, recvBuffer []byte, start time.Time) (Hop, error) {
// receive reads the response from the ICMP connection
func (*icmpHopper) receive(conn *icmp.PacketConn, buffer []byte, start time.Time) (Hop, error) {
hop := Hop{}
n, peer, err := icmpConn.ReadFrom(recvBuffer)
n, peer, err := conn.ReadFrom(buffer)
if err != nil {
return hop, fmt.Errorf("error reading from ICMP connection: %w", err)
}
hop.Duration = time.Since(start).Seconds()

rm, err := icmp.ParseMessage(1, recvBuffer[:n])
pm, err := icmp.ParseMessage(1, buffer[:n])
if err != nil {
return hop, fmt.Errorf("error parsing ICMP message: %w", err)
}

switch rm.Type {
switch pm.Type {
case ipv4.ICMPTypeTimeExceeded, ipv6.ICMPTypeTimeExceeded:
hop.IP = peer.(*net.IPAddr).IP
case ipv4.ICMPTypeEchoReply, ipv6.ICMPTypeEchoReply:
hop.IP = peer.(*net.IPAddr).IP
hop.ReachedTarget = true
default:
return hop, fmt.Errorf("unexpected ICMP message type: %v", rm.Type)
return hop, fmt.Errorf("unexpected ICMP message type: %v", pm.Type)
}

return hop, nil
Expand Down
94 changes: 46 additions & 48 deletions internal/traceroute/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,109 +13,107 @@ import (
"golang.org/x/net/ipv6"
)

func (t *tracer) hopTCP(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (hop Hop, err error) {
var _ hopper = (*tcpHopper)(nil)

type tcpHopper struct{ *tracer }

func (h *tcpHopper) Hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (hop Hop, err error) {
log := logger.FromContext(ctx)
conn, err := createTCPConn(ctx, destAddr, port, ttl)
conn, err := h.newConn(ctx, destAddr, port, ttl)
if err != nil {
log.Error("Error creating TCP connection", "error", err)
log.ErrorContext(ctx, "Error creating TCP connection", "error", err)
return hop, fmt.Errorf("error creating TCP connection: %w", err)
}
log.Debug("TCP connection created", "address", destAddr.String(), "port", port, "ttl", ttl)
log.DebugContext(ctx, "TCP connection created", "address", destAddr.String(), "port", port, "ttl", ttl)
defer func() {
if cErr := conn.Close(); cErr != nil {
log.Error("Error closing TCP connection", "error", cErr)
log.ErrorContext(ctx, "Error closing TCP connection", "error", cErr)
err = errors.Join(err, ErrClosingConn{Err: cErr})
}
}()

start := time.Now()
if err = sendTCPPacket(ctx, conn); err != nil {
log.Error("Error sending TCP packet", "error", err)
if err = h.sendSYN(ctx, conn); err != nil {
log.ErrorContext(ctx, "Error sending TCP packet", "error", err)
return hop, fmt.Errorf("error sending TCP packet: %w", err)
}
log.Debug("TCP packet sent", "address", destAddr.String(), "port", port, "ttl", ttl)
log.DebugContext(ctx, "TCP packet sent", "address", destAddr.String(), "port", port, "ttl", ttl)

hop, err = receiveTCPResponse(ctx, conn, t.Timeout, start)
hop, err = h.receive(ctx, conn, h.Timeout, start)
hop.Tracepoint = ttl
log.Debug("TCP response received", "address", destAddr.String(), "port", port, "ttl", ttl, "hop", hop, "error", err)
log.DebugContext(ctx, "TCP response received", "address", destAddr.String(), "port", port, "ttl", ttl, "hop", hop, "error", err)
return hop, err
}

// createTCPConn creates a TCP connection to the given address with the specified TTL
func createTCPConn(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (*net.TCPConn, error) {
// newConn creates a TCP connection to the given address with the specified TTL
func (*tcpHopper) newConn(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (*net.TCPConn, error) {
log := logger.FromContext(ctx)
// Unfortunately, the net package does not provide a context-aware DialTCP function
// TODO: Switch to the net.DialTCPContext function as soon as https://github.com/golang/go/issues/49097 is implemented
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: destAddr.IP, Port: int(port)})
if err != nil {
log.Error("Error dialing TCP connection", "error", err)
log.ErrorContext(ctx, "Error dialing TCP connection", "error", err)
return nil, err
}

if destAddr.IP.To4() != nil {
pc := ipv4.NewConn(conn)
if err := pc.SetTTL(ttl); err != nil {
log.Error("Error setting TTL on IPv4 connection", "error", err)
cv4 := ipv4.NewConn(conn)
if err := cv4.SetTTL(ttl); err != nil {
log.ErrorContext(ctx, "Error setting TTL on IPv4 connection", "error", err)
return nil, err
}
} else {
pc := ipv6.NewConn(conn)
if err := pc.SetHopLimit(ttl); err != nil {
log.Error("Error setting hop limit on IPv6 connection", "error", err)
cv6 := ipv6.NewConn(conn)
if err := cv6.SetHopLimit(ttl); err != nil {
log.ErrorContext(ctx, "Error setting hop limit on IPv6 connection", "error", err)
return nil, err
}
}
return conn, nil
}

// sendTCPPacket sends a TCP SYN packet to the given destination
func sendTCPPacket(ctx context.Context, conn *net.TCPConn) error {
// sendSYN writes a TCP SYN packet to the given connection's file descriptor
func (*tcpHopper) sendSYN(ctx context.Context, conn *net.TCPConn) error {
log := logger.FromContext(ctx)
err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second))
if err != nil {
log.Error("Error setting write deadline", "error", err)
log.ErrorContext(ctx, "Error setting write deadline", "error", err)
return fmt.Errorf("error setting write deadline: %w", err)
}

_, err = conn.Write([]byte("HELLO-R-U-THERE"))
// To initiate a TCP connection, we need to send a SYN packet
// In this case we want this to be as small as possible, so we send a single byte
_, err = conn.Write([]byte{0})
if err != nil {
log.Error("Error writing TCP packet", "error", err)
log.ErrorContext(ctx, "Error writing TCP packet", "error", err)
return fmt.Errorf("error writing TCP packet: %w", err)
}

return nil
}

// receiveTCPResponse waits for a TCP response to the sent SYN packet
func receiveTCPResponse(ctx context.Context, conn *net.TCPConn, timeout time.Duration, start time.Time) (Hop, error) {
hop := Hop{}
// receive waits for a TCP response to the sent SYN packet
func (*tcpHopper) receive(ctx context.Context, conn *net.TCPConn, timeout time.Duration, start time.Time) (Hop, error) {
log := logger.FromContext(ctx)
err := conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
log.Error("Error setting read deadline", "error", err)
return hop, fmt.Errorf("error setting read deadline: %w", err)
log.ErrorContext(ctx, "Error setting read deadline", "error", err)
return Hop{}, fmt.Errorf("error setting read deadline: %w", err)
}

buf := make([]byte, 1)
_, err = conn.Read(buf)
_, err = io.ReadAll(conn)
if err != nil {
// Timeout means the TTL expired
// Timeout means the TTL expired and the packet was dropped by a router
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
hop.Duration = time.Since(start).Seconds()
hop.IP = conn.RemoteAddr().(*net.TCPAddr).IP
return hop, nil
}
// EOF means the target sent a TCP RST response
if err == io.EOF {
hop.Duration = time.Since(start).Seconds()
hop.IP = conn.RemoteAddr().(*net.TCPAddr).IP
hop.ReachedTarget = true
return hop, nil
return Hop{Duration: time.Since(start).Seconds()}, nil
}
log.Error("Error reading TCP response", "error", err)
return hop, fmt.Errorf("error reading TCP response: %w", err)
log.ErrorContext(ctx, "Error reading TCP response", "error", err)
return Hop{}, fmt.Errorf("error reading TCP response: %w", err)
}

hop.Duration = time.Since(start).Seconds()
hop.IP = conn.RemoteAddr().(*net.TCPAddr).IP
hop.ReachedTarget = true
return hop, nil
return Hop{
Duration: time.Since(start).Seconds(),
IP: conn.RemoteAddr().(*net.TCPAddr).IP,
ReachedTarget: true,
}, nil
}
33 changes: 23 additions & 10 deletions internal/traceroute/traceroute.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package traceroute

import (
"context"
"errors"
"fmt"
"net"
"sort"
Expand Down Expand Up @@ -202,6 +201,11 @@ func (t *tracer) filterHops(hops []Hop) []Hop {
return filtered
}

// hopper represents an interface for a hopper
type hopper interface {
Hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error)
}

// hop performs a single hop in the traceroute to the given address with the specified TTL.
func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl int) (Hop, error) {
ctx, cancel := context.WithTimeout(ctx, t.Timeout)
Expand All @@ -214,15 +218,24 @@ func (t *tracer) hop(ctx context.Context, destAddr *net.IPAddr, port uint16, ttl
Error: fmt.Sprintf("timeout after %fs", t.Timeout.Seconds()),
}, ctx.Err()
default:
switch t.Protocol {
case ICMP:
return t.hopICMP(destAddr, ttl)
case UDP:
return Hop{}, errors.New("UDP not supported yet")
case TCP:
return t.hopTCP(ctx, destAddr, port, ttl)
default:
return Hop{}, errors.New("protocol not supported")
h := t.newHopper()
if h == nil {
return Hop{}, fmt.Errorf("unsupported protocol: %d", t.Protocol)
}
return h.Hop(ctx, destAddr, port, ttl)
}
}

// newHopper returns the hopper based on the protocol
func (t *tracer) newHopper() hopper {
switch t.Protocol {
case ICMP:
return &icmpHopper{tracer: t}
case UDP:
return &udpHopper{tracer: t}
case TCP:
return &tcpHopper{tracer: t}
default:
return nil
}
}
14 changes: 14 additions & 0 deletions internal/traceroute/udp.go
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
package traceroute

import (
"context"
"errors"
"net"
)

var _ hopper = (*udpHopper)(nil)

type udpHopper struct{ *tracer }

func (h *udpHopper) Hop(_ context.Context, _ *net.IPAddr, _ uint16, _ int) (hop Hop, err error) {
return hop, errors.New("udp protocol is not supported yet")
}

0 comments on commit 0994313

Please sign in to comment.