Skip to content

Commit

Permalink
Merge pull request #6883 from dolthub/aaron/aws-table-persister-uses-…
Browse files Browse the repository at this point in the history
…s3manager

Remotes: AWS: Fix a bug where uploading tables files to S3 could have unbounded memory usage.
  • Loading branch information
reltuk authored Oct 26, 2023
2 parents ce375e4 + c3af0c6 commit 9c6c976
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 132 deletions.
122 changes: 20 additions & 102 deletions go/store/nbs/aws_table_persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"

"github.com/dolthub/dolt/go/store/atomicerr"
"github.com/dolthub/dolt/go/store/chunks"
Expand All @@ -55,7 +57,7 @@ const (
)

type awsTablePersister struct {
s3 s3svc
s3 s3iface.S3API
bucket string
rl chan struct{}
ddb *ddbTableStore
Expand Down Expand Up @@ -120,21 +122,20 @@ func (s3p awsTablePersister) CopyTableFile(ctx context.Context, r io.ReadCloser,
}
}()

data, err := io.ReadAll(r)
if err != nil {
return err
}

name, err := parseAddr(fileId)
if err != nil {
return err
}

if s3p.limits.tableFitsInDynamo(name, len(data), chunkCount) {
if s3p.limits.tableFitsInDynamo(name, int(fileSz), chunkCount) {
data, err := io.ReadAll(r)
if err != nil {
return err
}
return s3p.ddb.Write(ctx, name, data)
}

return s3p.multipartUpload(ctx, data, fileId)
return s3p.multipartUpload(ctx, r, fileSz, fileId)
}

func (s3p awsTablePersister) Path() string {
Expand Down Expand Up @@ -174,7 +175,7 @@ func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver ch
return newReaderFromIndexData(ctx, s3p.q, data, name, &dynamoTableReaderAt{ddb: s3p.ddb, h: name}, s3BlockSize)
}

err = s3p.multipartUpload(ctx, data, name.String())
err = s3p.multipartUpload(ctx, bytes.NewReader(data), uint64(len(data)), name.String())

if err != nil {
return emptyChunkSource{}, err
Expand All @@ -184,20 +185,16 @@ func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver ch
return newReaderFromIndexData(ctx, s3p.q, data, name, tra, s3BlockSize)
}

func (s3p awsTablePersister) multipartUpload(ctx context.Context, data []byte, key string) error {
uploadID, err := s3p.startMultipartUpload(ctx, key)

if err != nil {
return err
}

multipartUpload, err := s3p.uploadParts(ctx, data, key, uploadID)
if err != nil {
_ = s3p.abortMultipartUpload(ctx, key, uploadID)
return err
}

return s3p.completeMultipartUpload(ctx, key, uploadID, multipartUpload)
func (s3p awsTablePersister) multipartUpload(ctx context.Context, r io.Reader, sz uint64, key string) error {
uploader := s3manager.NewUploaderWithClient(s3p.s3, func(u *s3manager.Uploader) {
u.PartSize = int64(s3p.limits.partTarget)
})
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
Body: r,
})
return err
}

func (s3p awsTablePersister) startMultipartUpload(ctx context.Context, key string) (string, error) {
Expand Down Expand Up @@ -234,85 +231,6 @@ func (s3p awsTablePersister) completeMultipartUpload(ctx context.Context, key, u
return err
}

func (s3p awsTablePersister) uploadParts(ctx context.Context, data []byte, key, uploadID string) (*s3.CompletedMultipartUpload, error) {
sent, failed, done := make(chan s3UploadedPart), make(chan error), make(chan struct{})

numParts := getNumParts(uint64(len(data)), s3p.limits.partTarget)

if numParts > maxS3Parts {
return nil, errors.New("exceeded maximum parts")
}

var wg sync.WaitGroup
sendPart := func(partNum, start, end uint64) {
if s3p.rl != nil {
s3p.rl <- struct{}{}
defer func() { <-s3p.rl }()
}
defer wg.Done()

// Check if upload has been terminated
select {
case <-done:
return
default:
}
// Upload the desired part
if partNum == numParts { // If this is the last part, make sure it includes any overflow
end = uint64(len(data))
}
etag, err := s3p.uploadPart(ctx, data[start:end], key, uploadID, int64(partNum))
if err != nil {
failed <- err
return
}
// Try to send along part info. In the case that the upload was aborted, reading from done allows this worker to exit correctly.
select {
case sent <- s3UploadedPart{int64(partNum), etag}:
case <-done:
return
}
}
for i := uint64(0); i < numParts; i++ {
wg.Add(1)
partNum := i + 1 // Parts are 1-indexed
start, end := i*s3p.limits.partTarget, (i+1)*s3p.limits.partTarget
go sendPart(partNum, start, end)
}
go func() {
wg.Wait()
close(sent)
close(failed)
}()

multipartUpload := &s3.CompletedMultipartUpload{}
var firstFailure error
for cont := true; cont; {
select {
case sentPart, open := <-sent:
if open {
multipartUpload.Parts = append(multipartUpload.Parts, &s3.CompletedPart{
ETag: aws.String(sentPart.etag),
PartNumber: aws.Int64(sentPart.idx),
})
}
cont = open

case err := <-failed:
if err != nil && firstFailure == nil { // nil err may happen when failed gets closed
firstFailure = err
close(done)
}
}
}

if firstFailure == nil {
close(done)
}
sort.Sort(partsByPartNum(multipartUpload.Parts))
return multipartUpload, firstFailure
}

func getNumParts(dataLen, minPartSize uint64) uint64 {
numParts := dataLen / minPartSize
if numParts == 0 {
Expand Down
61 changes: 48 additions & 13 deletions go/store/nbs/aws_table_persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package nbs

import (
"context"
crand "crypto/rand"
"io"
"math/rand"
"sync"
Expand All @@ -31,30 +32,66 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dolthub/dolt/go/store/util/sizecache"
)

func randomChunks(t *testing.T, r *rand.Rand, sz int) [][]byte {
buf := make([]byte, sz)
_, err := io.ReadFull(crand.Reader, buf)
require.NoError(t, err)

var ret [][]byte
var i int
for i < len(buf) {
j := int(r.NormFloat64()*1024 + 4096)
if i+j >= len(buf) {
ret = append(ret, buf[i:])
} else {
ret = append(ret, buf[i:i+j])
}
i += j
}

return ret
}

func TestRandomChunks(t *testing.T) {
r := rand.New(rand.NewSource(1024))
res := randomChunks(t, r, 10)
assert.Len(t, res, 1)
res = randomChunks(t, r, 4096+2048)
assert.Len(t, res, 2)
res = randomChunks(t, r, 4096+4096)
assert.Len(t, res, 3)
}

func TestAWSTablePersisterPersist(t *testing.T) {
ctx := context.Background()
calcPartSize := func(rdr chunkReader, maxPartNum uint64) uint64 {
return maxTableSize(uint64(mustUint32(rdr.count())), mustUint64(rdr.uncompressedLen())) / maxPartNum
}

mt := newMemTable(testMemTableSize)
r := rand.New(rand.NewSource(1024))
const sz15mb = 1 << 20 * 15
mt := newMemTable(sz15mb)
testChunks := randomChunks(t, r, 1<<20*12)
for _, c := range testChunks {
assert.Equal(t, mt.addChunk(computeAddr(c), c), chunkAdded)
}

var limits5mb = awsLimits{partTarget: 1 << 20 * 5}
var limits64mb = awsLimits{partTarget: 1 << 20 * 64}

t.Run("PersistToS3", func(t *testing.T) {
testIt := func(t *testing.T, ns string) {
t.Run("InMultipleParts", func(t *testing.T) {
assert := assert.New(t)
s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: calcPartSize(mt, 3)}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}

src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
require.NoError(t, err)
Expand All @@ -72,8 +109,7 @@ func TestAWSTablePersisterPersist(t *testing.T) {
assert := assert.New(t)

s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: calcPartSize(mt, 1)}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits64mb, ns: ns, q: &UnlimitedQuotaProvider{}}

src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
require.NoError(t, err)
Expand All @@ -89,17 +125,16 @@ func TestAWSTablePersisterPersist(t *testing.T) {
t.Run("NoNewChunks", func(t *testing.T) {
assert := assert.New(t)

mt := newMemTable(testMemTableSize)
existingTable := newMemTable(testMemTableSize)
mt := newMemTable(sz15mb)
existingTable := newMemTable(sz15mb)

for _, c := range testChunks {
assert.Equal(mt.addChunk(computeAddr(c), c), chunkAdded)
assert.Equal(existingTable.addChunk(computeAddr(c), c), chunkAdded)
}

s3svc, ddb := makeFakeS3(t), makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: 1 << 10}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}

src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{})
require.NoError(t, err)
Expand All @@ -115,8 +150,7 @@ func TestAWSTablePersisterPersist(t *testing.T) {

s3svc := &failingFakeS3{makeFakeS3(t), sync.Mutex{}, 1}
ddb := makeFakeDTS(makeFakeDDB(t), nil)
limits := awsLimits{partTarget: calcPartSize(mt, 4)}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits, ns: ns, q: &UnlimitedQuotaProvider{}}
s3p := awsTablePersister{s3: s3svc, bucket: "bucket", ddb: ddb, limits: limits5mb, ns: ns, q: &UnlimitedQuotaProvider{}}

_, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
assert.Error(err)
Expand Down Expand Up @@ -330,14 +364,15 @@ func TestAWSTablePersisterCalcPartSizes(t *testing.T) {

func TestAWSTablePersisterConjoinAll(t *testing.T) {
ctx := context.Background()
targetPartSize := uint64(1024)
const sz5mb = 1 << 20 * 5
targetPartSize := uint64(sz5mb)
minPartSize, maxPartSize := targetPartSize, 5*targetPartSize
maxItemSize, maxChunkCount := int(targetPartSize/2), uint32(4)

rl := make(chan struct{}, 8)
defer close(rl)

newPersister := func(s3svc s3svc, ddb *ddbTableStore) awsTablePersister {
newPersister := func(s3svc s3iface.S3API, ddb *ddbTableStore) awsTablePersister {
return awsTablePersister{
s3svc,
"bucket",
Expand Down
38 changes: 38 additions & 0 deletions go/store/nbs/s3_fake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/stretchr/testify/assert"

"github.com/dolthub/dolt/go/store/d"
Expand All @@ -58,6 +60,8 @@ func makeFakeS3(t *testing.T) *fakeS3 {
}

type fakeS3 struct {
s3iface.S3API

assert *assert.Assertions

mu sync.Mutex
Expand Down Expand Up @@ -279,3 +283,37 @@ func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput,

return &s3.PutObjectOutput{}, nil
}

func (m *fakeS3) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) {
out := &s3.GetObjectOutput{}
var handlers request.Handlers
handlers.Send.PushBack(func(r *request.Request) {
res, err := m.GetObjectWithContext(r.Context(), input)
r.Error = err
if res != nil {
*(r.Data.(*s3.GetObjectOutput)) = *res
}
})
return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
Name: "GetObject",
HTTPMethod: "GET",
HTTPPath: "/{Bucket}/{Key+}",
}, input, out), out
}

func (m *fakeS3) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
out := &s3.PutObjectOutput{}
var handlers request.Handlers
handlers.Send.PushBack(func(r *request.Request) {
res, err := m.PutObjectWithContext(r.Context(), input)
r.Error = err
if res != nil {
*(r.Data.(*s3.PutObjectOutput)) = *res
}
})
return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
Name: "PutObject",
HTTPMethod: "PUT",
HTTPPath: "/{Bucket}/{Key+}",
}, input, out), out
}
Loading

0 comments on commit 9c6c976

Please sign in to comment.