From 6c943036f5e47b1d0205493c9a5a818ca695c39f Mon Sep 17 00:00:00 2001 From: Martin Sucha Date: Fri, 22 Sep 2023 16:56:56 +0200 Subject: [PATCH] Use correct type when observing result rows The type assertion never matched because we use pointer there. Adding a test as well. --- conn_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++++ frame.go | 2 +- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/conn_test.go b/conn_test.go index 84b1b3f0f..ce37ed3e2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -929,6 +929,89 @@ func TestWriteCoalescing_WriteAfterClose(t *testing.T) { } } +type frameObserverFunc func(ctx context.Context, observedFrame ObservedFrame) + +func (fn frameObserverFunc) ObserveFrame(ctx context.Context, observedFrame ObservedFrame) { + fn(ctx, observedFrame) +} + +func TestFrameObserver(t *testing.T) { + srv := NewTestServer(t, protoVersion4, context.Background()) + defer srv.Stop() + + var frameCount int64 + var rowsCount int64 + var rowsBytes int64 + var frameLengthBytes int64 + var frameUncompressedBytes int64 + + cluster := testCluster(protoVersion4, srv.Address) + cluster.FrameObserver = frameObserverFunc(func(ctx context.Context, observedFrame ObservedFrame) { + if observedFrame.Opcode != FrameOpcodeResult { + return + } + atomic.AddInt64(&frameCount, 1) + atomic.AddInt64(&rowsCount, int64(observedFrame.RowCount)) + atomic.AddInt64(&rowsBytes, int64(observedFrame.RowsSize)) + atomic.AddInt64(&frameLengthBytes, int64(observedFrame.Length)) + atomic.AddInt64(&frameUncompressedBytes, int64(observedFrame.UncompressedSize)) + }) + db, err := cluster.CreateSession() + if err != nil { + t.Fatalf("create session: %v", err) + } + + // Reset the counters, we are not interested in session setup. + atomic.SwapInt64(&frameCount, 0) + atomic.SwapInt64(&rowsCount, 0) + atomic.SwapInt64(&rowsBytes, 0) + atomic.SwapInt64(&frameLengthBytes, 0) + atomic.SwapInt64(&frameUncompressedBytes, 0) + + it := db.Query("rows").Iter() + + var items []string + + for { + var column1 string + if !it.Scan(&column1) { + break + } + items = append(items, column1) + } + + if err := it.Close(); err != nil { + t.Fatalf("close: %v", err) + } + + if len(items) != 2 || items[0] != "hello" || items[1] != "world" { + t.Errorf("unexpected items: %+v", items) + } + + gotFrameCount := atomic.LoadInt64(&frameCount) + gotRowsCount := atomic.LoadInt64(&rowsCount) + gotRowsBytes := atomic.LoadInt64(&rowsBytes) + gotFrameLengthBytes := atomic.LoadInt64(&frameLengthBytes) + gotFrameUncompressedBytes := atomic.LoadInt64(&frameUncompressedBytes) + + if gotFrameCount != 1 { + t.Errorf("unexpected frame count, got %d", gotFrameCount) + } + if gotRowsCount != 2 { + t.Errorf("unexpected row count, got %d", gotRowsCount) + } + if gotRowsBytes != 18 { + t.Errorf("unexpected rows bytes, got %d", gotRowsBytes) + } + if gotFrameLengthBytes != 61 { + t.Errorf("unexpected frame length, got %d", gotFrameLengthBytes) + } + // compression was not used. + if gotFrameUncompressedBytes != 0 { + t.Errorf("unexpected frame uncompressed size, got %d", gotFrameUncompressedBytes) + } +} + type recordingFrameHeaderObserver struct { t *testing.T mu sync.Mutex @@ -1270,6 +1353,19 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer, exts map[string] rand.Seed(time.Now().UnixNano()) <-time.After(time.Millisecond * 120) } + case "rows": + // https://martin-sucha.github.io/cqlprotodoc/native_protocol_v4.html#s4.2.5.2 + respFrame.writeHeader(0, FrameOpcodeResult, head.stream) + respFrame.writeInt(resultKindRows) + respFrame.writeInt(int32(flagGlobalTableSpec)) // flags + respFrame.writeInt(1) // column count + respFrame.writeString("keyspace") // global table spec: keyspace name + respFrame.writeString("table") // global table spec: table name + respFrame.writeString("column") // column1 name + respFrame.writeShort(uint16(TypeVarchar)) // column1 type + respFrame.writeInt(2) // rows_count + respFrame.writeBytes([]byte("hello")) // row1 + respFrame.writeBytes([]byte("world")) // row2 default: respFrame.writeHeader(0, FrameOpcodeResult, head.stream) respFrame.writeInt(resultKindVoid) diff --git a/frame.go b/frame.go index 3640cd98f..bcf748b76 100644 --- a/frame.go +++ b/frame.go @@ -420,7 +420,7 @@ func (fpo *frameParseObserver) observeFrame(ff *framer, f frame) { ObservedFrameHeader: fpo.head, UncompressedSize: ff.uncompressedSize, } - if rows, ok := f.(resultRowsFrame); ok { + if rows, ok := f.(*resultRowsFrame); ok { of.IsRowsResult = true of.RowCount = rows.numRows of.RowsSize = rows.rowsContentSize