diff --git a/pdu/pdu.go b/pdu/pdu.go index 90222b8..cc57faa 100644 --- a/pdu/pdu.go +++ b/pdu/pdu.go @@ -310,36 +310,40 @@ func NewDecoder(r io.Reader) *Decoder { // Decode reads data from reader and populates PDU. func (d *Decoder) Decode() (Header, PDU, error) { // Read header first. - h := make([]byte, 16) - n, err := d.r.Read(h) - if err != nil { + var headerBytes [16]byte + if _, err := io.ReadFull(d.r, headerBytes[:]); err != nil { return nil, nil, err } - if n != 16 { - return nil, nil, errors.New("smpp: invalid pdu header byte length") + + header := &header{} + if err := header.UnmarshalBinary(headerBytes[:]); err != nil { + return header, nil, err } - he := &header{} - if err := he.UnmarshalBinary(h); err != nil { - return nil, nil, err + + if header.length < 16 { + return header, nil, fmt.Errorf("smpp: invalid pdu header byte length: %d", header.length) } - p := NewPDU(he.commandID) - if he.length == 16 { - return he, p, nil + + pdu := NewPDU(header.commandID) + if header.length == 16 { + // not expecting body to read - we're done. + return header, pdu, nil } // Read rest of the PDU. - buf := make([]byte, he.length-16) - n, err = d.r.Read(buf) - if err != nil { - return he, nil, err - } - if n != int(he.length-16) { - return he, nil, fmt.Errorf("smpp: pdu length doesn't match read body length %d != %d", he.length, n) + bodyBytes := make([]byte, header.length-16) + if len(bodyBytes) > 0 { + if _, err := io.ReadFull(d.r, bodyBytes); err != nil { + return header, pdu, fmt.Errorf("smpp: pdu length doesn't match read body length %d != %d", header.length, len(bodyBytes)) + } } - if err := p.UnmarshalBinary(buf); err != nil { - return he, nil, err + + // Unmarshal binary + if err := pdu.UnmarshalBinary(bodyBytes); err != nil { + return header, pdu, err } - return he, p, nil + + return header, pdu, nil } // NewPDU creates new PDU from CommandID. diff --git a/pdu/pdu_test.go b/pdu/pdu_test.go index 781a2b2..3a23f47 100644 --- a/pdu/pdu_test.go +++ b/pdu/pdu_test.go @@ -3,9 +3,11 @@ package pdu import ( "bytes" "encoding/hex" + "net" "reflect" "strings" "testing" + "time" ) var pduTT = []struct { @@ -300,3 +302,60 @@ func TestPDUDecoding(t *testing.T) { }) } } + +func TestPDUDecodingIncompleteBuffers(t *testing.T) { + + var pdus []byte + mtu := 8 // Likely e.g. 1500 in the real world + + for _, row := range codingTT { + pduBytes, _ := hex.DecodeString(toHexStr(row.headerHex + pduTT[row.pduIndex].hexStr)) + pdus = append(pdus, pduBytes...) + } + + buf, wr := net.Pipe() + dec := NewDecoder(buf) + + go func() { + for i := 0; i < len(pdus); { + + j := i + mtu + if j > len(pdus) { + j = len(pdus) + } + time.Sleep(time.Millisecond * 10) // Similulate some network (rather than coordinate with .Decode() for the purpose of this test) + _, err := wr.Write(pdus[i:j]) + if err != nil { + panic("error writing to net.Pipe") + } + i = j + } + + wr.Close() + }() + + for _, row := range codingTT { + t.Run(row.desc, func(t *testing.T) { + + h, p, err := dec.Decode() + + if err != nil { + if !row.err { + t.Fatalf("unexpected error %s", err) + } + return + } + + if h.Sequence() != row.seq { + t.Errorf("Decode() => seq %d expected %d", h.Sequence(), row.seq) + } + if h.Status() != row.status { + t.Errorf("Decode() => status %d expected %d", h.Status(), row.status) + } + if !reflect.DeepEqual(p, pduTT[row.pduIndex].pdu) { + t.Errorf("Decode() => pdu\n%+v\nexpected \n%+v", p, pduTT[row.pduIndex].pdu) + } + + }) + } +}