From 1769260f3675fe1ed6e4c4a517e2ce6ee10ad349 Mon Sep 17 00:00:00 2001 From: Adrian Preston Date: Tue, 1 Aug 2023 22:31:49 +0100 Subject: [PATCH] feat: support up to v4 of the ListGroups API Signed-off-by: Adrian Preston --- broker.go | 1 + list_groups_request.go | 48 ++++++++++++- list_groups_request_test.go | 30 ++++++++ list_groups_response.go | 134 +++++++++++++++++++++++++++++------ list_groups_response_test.go | 26 +++++++ 5 files changed, 213 insertions(+), 26 deletions(-) diff --git a/broker.go b/broker.go index dcd62b9b1..9a6297347 100644 --- a/broker.go +++ b/broker.go @@ -587,6 +587,7 @@ func (b *Broker) Heartbeat(request *HeartbeatRequest) (*HeartbeatResponse, error // ListGroups return a list group response or error func (b *Broker) ListGroups(request *ListGroupsRequest) (*ListGroupsResponse, error) { response := new(ListGroupsResponse) + response.Version = request.Version // Required to ensure use of the correct response header version err := b.sendAndReceive(request, response) if err != nil { diff --git a/list_groups_request.go b/list_groups_request.go index 68b3c8f34..4d5f9e40d 100644 --- a/list_groups_request.go +++ b/list_groups_request.go @@ -1,14 +1,47 @@ package sarama type ListGroupsRequest struct { - Version int16 + Version int16 + StatesFilter []string // version 4 or later } func (r *ListGroupsRequest) encode(pe packetEncoder) error { + if r.Version >= 4 { + pe.putCompactArrayLength(len(r.StatesFilter)) + for _, filter := range r.StatesFilter { + err := pe.putCompactString(filter) + if err != nil { + return err + } + } + } + if r.Version >= 3 { + pe.putEmptyTaggedFieldArray() + } return nil } func (r *ListGroupsRequest) decode(pd packetDecoder, version int16) (err error) { + r.Version = version + if r.Version >= 4 { + filterLen, err := pd.getCompactArrayLength() + if err != nil { + return err + } + if filterLen > 0 { + r.StatesFilter = make([]string, filterLen) + for i := 0; i < filterLen; i++ { + if r.StatesFilter[i], err = pd.getCompactString(); err != nil { + return err + } + } + } + } + if r.Version >= 3 { + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } return nil } @@ -21,20 +54,29 @@ func (r *ListGroupsRequest) version() int16 { } func (r *ListGroupsRequest) headerVersion() int16 { + if r.Version >= 3 { + return 2 + } return 1 } func (r *ListGroupsRequest) isValidVersion() bool { - return r.Version >= 0 && r.Version <= 2 + return r.Version >= 0 && r.Version <= 4 } func (r *ListGroupsRequest) requiredVersion() KafkaVersion { switch r.Version { + case 4: + return V2_6_0_0 + case 3: + return V2_4_0_0 case 2: return V2_0_0_0 case 1: return V0_11_0_0 - default: + case 0: return V0_9_0_0 + default: + return V2_6_0_0 } } diff --git a/list_groups_request_test.go b/list_groups_request_test.go index 2e977d9a5..fbb3bbd42 100644 --- a/list_groups_request_test.go +++ b/list_groups_request_test.go @@ -4,4 +4,34 @@ import "testing" func TestListGroupsRequest(t *testing.T) { testRequest(t, "ListGroupsRequest", &ListGroupsRequest{}, []byte{}) + + testRequest(t, "ListGroupsRequest", &ListGroupsRequest{ + Version: 1, + }, []byte{}) + + testRequest(t, "ListGroupsRequest", &ListGroupsRequest{ + Version: 2, + }, []byte{}) + + testRequest(t, "ListGroupsRequest", &ListGroupsRequest{ + Version: 3, + }, []byte{ + 0, // 0, // empty tag buffer + }) + + testRequest(t, "ListGroupsRequest", &ListGroupsRequest{ + Version: 4, + }, []byte{ + 1, // compact array length (0) + 0, // empty tag buffer + }) + + testRequest(t, "ListGroupsRequest", &ListGroupsRequest{ + Version: 4, + StatesFilter: []string{"Empty"}, + }, []byte{ + 2, // compact array length (1) + 6, 'E', 'm', 'p', 't', 'y', // compact string + 0, // empty tag buffer + }) } diff --git a/list_groups_response.go b/list_groups_response.go index a4fd15a34..62948c31f 100644 --- a/list_groups_response.go +++ b/list_groups_response.go @@ -1,23 +1,52 @@ package sarama type ListGroupsResponse struct { - Version int16 - Err KError - Groups map[string]string + Version int16 + ThrottleTime int32 + Err KError + Groups map[string]string + GroupsData map[string]GroupData // version 4 or later +} + +type GroupData struct { + GroupState string // version 4 or later } func (r *ListGroupsResponse) encode(pe packetEncoder) error { + if r.Version >= 1 { + pe.putInt32(r.ThrottleTime) + } + pe.putInt16(int16(r.Err)) - if err := pe.putArrayLength(len(r.Groups)); err != nil { - return err - } - for groupId, protocolType := range r.Groups { - if err := pe.putString(groupId); err != nil { + if r.Version <= 2 { + if err := pe.putArrayLength(len(r.Groups)); err != nil { return err } - if err := pe.putString(protocolType); err != nil { - return err + for groupId, protocolType := range r.Groups { + if err := pe.putString(groupId); err != nil { + return err + } + if err := pe.putString(protocolType); err != nil { + return err + } + } + } else { + pe.putCompactArrayLength(len(r.Groups)) + for groupId, protocolType := range r.Groups { + if err := pe.putCompactString(groupId); err != nil { + return err + } + if err := pe.putCompactString(protocolType); err != nil { + return err + } + + if r.Version >= 4 { + groupData := r.GroupsData[groupId] + if err := pe.putCompactString(groupData.GroupState); err != nil { + return err + } + } } } @@ -25,6 +54,14 @@ func (r *ListGroupsResponse) encode(pe packetEncoder) error { } func (r *ListGroupsResponse) decode(pd packetDecoder, version int16) error { + r.Version = version + if r.Version >= 1 { + var err error + if r.ThrottleTime, err = pd.getInt32(); err != nil { + return err + } + } + kerr, err := pd.getInt16() if err != nil { return err @@ -32,26 +69,68 @@ func (r *ListGroupsResponse) decode(pd packetDecoder, version int16) error { r.Err = KError(kerr) - n, err := pd.getArrayLength() + var n int + if r.Version <= 2 { + n, err = pd.getArrayLength() + } else { + n, err = pd.getCompactArrayLength() + } if err != nil { return err } - if n == 0 { - return nil - } - r.Groups = make(map[string]string) for i := 0; i < n; i++ { - groupId, err := pd.getString() - if err != nil { - return err + if i == 0 { + r.Groups = make(map[string]string) + if r.Version >= 4 { + r.GroupsData = make(map[string]GroupData) + } } - protocolType, err := pd.getString() - if err != nil { - return err + + var groupId, protocolType string + if r.Version <= 2 { + groupId, err = pd.getString() + if err != nil { + return err + } + protocolType, err = pd.getString() + if err != nil { + return err + } + } else { + groupId, err = pd.getCompactString() + if err != nil { + return err + } + protocolType, err = pd.getCompactString() + if err != nil { + return err + } } r.Groups[groupId] = protocolType + + if r.Version >= 4 { + groupState, err := pd.getCompactString() + if err != nil { + return err + } + r.GroupsData[groupId] = GroupData{ + GroupState: groupState, + } + } + + if r.Version >= 3 { + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } + } + } + + if r.Version >= 3 { + if _, err = pd.getEmptyTaggedFieldArray(); err != nil { + return err + } } return nil @@ -66,20 +145,29 @@ func (r *ListGroupsResponse) version() int16 { } func (r *ListGroupsResponse) headerVersion() int16 { + if r.Version >= 3 { + return 1 + } return 0 } func (r *ListGroupsResponse) isValidVersion() bool { - return r.Version >= 0 && r.Version <= 2 + return r.Version >= 0 && r.Version <= 4 } func (r *ListGroupsResponse) requiredVersion() KafkaVersion { switch r.Version { + case 4: + return V2_6_0_0 + case 3: + return V2_4_0_0 case 2: return V2_0_0_0 case 1: return V0_11_0_0 - default: + case 0: return V0_9_0_0 + default: + return V2_6_0_0 } } diff --git a/list_groups_response_test.go b/list_groups_response_test.go index 29b57c2d5..a7205f854 100644 --- a/list_groups_response_test.go +++ b/list_groups_response_test.go @@ -22,6 +22,17 @@ var ( 0, 3, 'f', 'o', 'o', // group name 0, 8, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // protocol type } + + listGroupResponseV4 = []byte{ + 0, 0, 0, 0, // no throttle time + 0, 0, // no error + 2, // compact array length (1) + 4, 'f', 'o', 'o', // group name (compact string) + 9, 'c', 'o', 'n', 's', 'u', 'm', 'e', 'r', // protocol type (compact string) + 6, 'E', 'm', 'p', 't', 'y', // state (compact string) + 0, // Empty tag buffer + 0, // Empty tag buffer + } ) func TestListGroupsResponse(t *testing.T) { @@ -56,4 +67,19 @@ func TestListGroupsResponse(t *testing.T) { if response.Groups["foo"] != "consumer" { t.Error("Expected foo group to use consumer protocol") } + + response = new(ListGroupsResponse) + testVersionDecodable(t, "no error", response, listGroupResponseV4, 4) + if !errors.Is(response.Err, ErrNoError) { + t.Error("Expected no gerror, found:", response.Err) + } + if len(response.Groups) != 1 { + t.Error("Expected one group") + } + if response.Groups["foo"] != "consumer" { + t.Error("Expected foo group to use consumer protocol") + } + if response.GroupsData["foo"].GroupState != "Empty" { + t.Error("Expected foo grup to have empty state") + } }