diff --git a/interceptor_test.go b/interceptor_test.go index 624024f..4d27610 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -37,11 +37,14 @@ import ( "net" "runtime" "strings" + "sync" "testing" "github.com/containerd/otelttrpc/internal" "github.com/containerd/ttrpc" "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" "go.opentelemetry.io/otel/trace" @@ -88,9 +91,71 @@ func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*i return tp, nil } +func TestClientCallServerConcurrent(t *testing.T) { + var ( + ctx = ttrpc.WithMetadata(context.Background(), ttrpc.MD{"test-key": []string{"test-val"}}) + exp, tp = newTracerProvider() + server = mustServer(t)(newServerWithTTRPCInterceptor(tp)) + testImpl = &testingServer{} + addr, listener = newTestListener(t) + payload = &internal.TestPayload{ + Foo: "bar", + } + ) + + concurrency := 30 + testClients := make([]*testingClient, 0, concurrency) + for i := 0; i < concurrency; i++ { + client, cleanup := newTestClient(t, addr, tp) + testClients = append(testClients, newTestingClient(client)) + defer cleanup() + } + defer listener.Close() + defer func() { _ = tp.Shutdown(ctx) }() + + registerTestingService(server, testImpl) + + go func() { + _ = server.Serve(ctx, listener) + }() + defer func() { + _ = server.Shutdown(ctx) + }() + + var wg sync.WaitGroup + var errs []error + var mu sync.Mutex + + for _, testClient := range testClients { + // capture range variable + // TODO: we can remove this once we upgrade golang to >= 1.22 + testClient := testClient + wg.Add(1) + go func() { + defer wg.Done() + if _, err := testClient.Test(ctx, payload); err != nil { + mu.Lock() + defer mu.Unlock() + errs = append(errs, err) + } + }() + } + + wg.Wait() + if len(errs) > 0 { + t.Fatalf("unexpected errors: %v", errs) + } + + // get exported spans + snapshots := exp.GetSpans().Snapshots() + // we should capture `concurrency * 2` spans, one each from client and server side + // TODO: validate individual spans and their attributes + assert.Equal(t, concurrency*2, len(snapshots), "Number of spans mismatched") +} + func TestClientCallServer(t *testing.T) { var ( - ctx = context.Background() + ctx = ttrpc.WithMetadata(context.Background(), ttrpc.MD{"test-key": []string{"test-val"}}) exp, tp = newTracerProvider() server = mustServer(t)(newServerWithTTRPCInterceptor(tp)) testImpl = &testingServer{} @@ -153,6 +218,8 @@ func newTracerProvider() (*tracetest.InMemoryExporter, *sdktrace.TracerProvider) tp := sdktrace.NewTracerProvider( sdktrace.WithSyncer(exp), ) + + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) return exp, tp }