Skip to content

Commit

Permalink
PR: simplify type checking
Browse files Browse the repository at this point in the history
Signed-off-by: Hamza El-Saawy <[email protected]>
  • Loading branch information
helsaawy committed Apr 9, 2024
1 parent b034afc commit 873ef3e
Showing 1 changed file with 17 additions and 29 deletions.
46 changes: 17 additions & 29 deletions hvsock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,7 @@ func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) {
if err != nil {
return fmt.Errorf("listener accept: %w", err)
}
var ok bool
sv, ok = conn.(*HvsockConn)
if !ok {
return fmt.Errorf("expected connection type %T; got %T", new(HvsockConn), conn)
}
sv = mustBeType[*HvsockConn](u.T, conn)
if err := l.Close(); err != nil {
return err
}
Expand Down Expand Up @@ -113,10 +109,7 @@ func TestHvSockListenerAddresses(t *testing.T) {
u := newUtil(t)
l, addr := serverListen(u)

la, ok := (l.Addr()).(*HvsockAddr)
if !ok {
t.Fatalf("expected type %T; got %T", new(HvsockAddr), l.Addr())
}
la := mustBeType[*HvsockAddr](t, l.Addr())
u.Assert(*la == *addr, fmt.Sprintf("give: %v; want: %v", la, addr))

ra := rawHvsockAddr{}
Expand All @@ -130,22 +123,10 @@ func TestHvSockAddresses(t *testing.T) {
u := newUtil(t)
cl, sv, addr := clientServer(u)

sra, ok := (sv.RemoteAddr()).(*HvsockAddr)
if !ok {
t.Fatalf("expected type %T; got %T", new(HvsockAddr), sv.RemoteAddr())
}
sla, ok := (sv.LocalAddr()).(*HvsockAddr)
if !ok {
t.Fatalf("expected type %T; got %T", new(HvsockAddr), sv.LocalAddr())
}
cra, ok := (cl.RemoteAddr()).(*HvsockAddr)
if !ok {
t.Fatalf("expected type %T; got %T", new(HvsockAddr), cl.RemoteAddr())
}
cla, ok := (cl.LocalAddr()).(*HvsockAddr)
if !ok {
t.Fatalf("expected type %T; got %T", new(HvsockAddr), cl.LocalAddr())
}
sra := mustBeType[*HvsockAddr](t, sv.RemoteAddr())
sla := mustBeType[*HvsockAddr](t, sv.LocalAddr())
cra := mustBeType[*HvsockAddr](t, cl.RemoteAddr())
cla := mustBeType[*HvsockAddr](t, cl.LocalAddr())

t.Run("Info", func(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -341,10 +322,7 @@ func TestHvSockCloseReadWriteListener(t *testing.T) {
}
defer c.Close()

hv, ok := c.(*HvsockConn)
if !ok {
t.Fatalf("expected type %T; got %T", new(HvsockConn), c)
}
hv := mustBeType[*HvsockConn](t, c)
//
// test CloseWrite()
//
Expand Down Expand Up @@ -683,3 +661,13 @@ func (u testUtil) Check() {
func msgJoin(pre []string, s string) string {
return strings.Join(append(pre, s), ": ")
}

func mustBeType[T any](tb testing.TB, v any) T {
tb.Helper()

v2, ok := v.(T)
if !ok {
tb.Fatalf("expected type %T; got %T", *new(T), v)
}
return v2
}

0 comments on commit 873ef3e

Please sign in to comment.