Skip to content

Commit

Permalink
feat: add the support for getting prefix list file from third bucket …
Browse files Browse the repository at this point in the history
…and support configu sse encrypthion during putting object
  • Loading branch information
YikaiHu committed Jan 5, 2024
1 parent 18b7f76 commit 3553479
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 21 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ENV SINGLE_PART_TABLE_NAME ''
ENV SRC_BUCKET ''
ENV SRC_PREFIX ''
ENV SRC_PREFIX_LIST ''
ENV SRC_PREFIX_LIST_BUCKET ''
ENV SRC_REGION ''
ENV SRC_ENDPOINT ''
ENV SRC_CREDENTIALS ''
Expand Down
9 changes: 9 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ func initConfig() {
viper.SetDefault("destStorageClass", "STANDARD")
viper.SetDefault("srcPrefix", "")
viper.SetDefault("srcPrefixList", "")
viper.SetDefault("srcPrefixListBucket", "")
viper.SetDefault("srcCredential", "")
viper.SetDefault("srcEndpoint", "")
viper.SetDefault("destPrefix", "")
viper.SetDefault("destCredential", "")
viper.SetDefault("destAcl", "bucket-owner-full-control")
viper.SetDefault("destSSEType", "None")
viper.SetDefault("destSSEKMSKeyId", "")

viper.SetDefault("options.chunkSize", dth.DefaultChunkSize)
viper.SetDefault("options.multipartThreshold", dth.DefaultMultipartThreshold)
Expand All @@ -92,6 +95,7 @@ func initConfig() {
viper.BindEnv("srcBucket", "SRC_BUCKET")
viper.BindEnv("srcPrefix", "SRC_PREFIX")
viper.BindEnv("srcPrefixList", "SRC_PREFIX_LIST")
viper.BindEnv("srcPrefixListBucket", "SRC_PREFIX_LIST_BUCKET")
viper.BindEnv("srcRegion", "SRC_REGION")
viper.BindEnv("srcEndpoint", "SRC_ENDPOINT")
viper.BindEnv("srcCredential", "SRC_CREDENTIALS")
Expand All @@ -106,6 +110,8 @@ func initConfig() {
viper.BindEnv("destInCurrentAccount", "DEST_IN_CURRENT_ACCOUNT")
viper.BindEnv("destStorageClass", "DEST_STORAGE_CLASS")
viper.BindEnv("destAcl", "DEST_ACL")
viper.BindEnv("destSSEType", "DEST_SSE_TYPE")
viper.BindEnv("destSSEKMSKeyId", "DEST_SSE_KMS_KEY_ID")

viper.BindEnv("jobTableName", "JOB_TABLE_NAME")
viper.BindEnv("jobQueueName", "JOB_QUEUE_NAME")
Expand Down Expand Up @@ -156,6 +162,7 @@ func initConfig() {
SrcBucket: viper.GetString("srcBucket"),
SrcPrefix: viper.GetString("srcPrefix"),
SrcPrefixList: viper.GetString("srcPrefixList"),
SrcPrefixListBucket: viper.GetString("srcPrefixListBucket"),
SrcRegion: viper.GetString("srcRegion"),
SrcEndpoint: viper.GetString("srcEndpoint"),
SrcCredential: viper.GetString("srcCredential"),
Expand All @@ -168,6 +175,8 @@ func initConfig() {
DestCredential: viper.GetString("destCredential"),
DestStorageClass: viper.GetString("destStorageClass"),
DestAcl: viper.GetString("destAcl"),
DestSSEType: viper.GetString("destSSEType"),
DestSSEKMSKeyId: viper.GetString("destSSEKMSKeyId"),
DestInCurrentAccount: viper.GetBool("destInCurrentAccount"),
JobTableName: viper.GetString("jobTableName"),
JobQueueName: viper.GetString("jobQueueName"),
Expand Down
3 changes: 3 additions & 0 deletions config-example.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
srcType: Amazon_S3
srcBucket: src-bucket
srcPrefix:
srcPrefixListBucket:
srcRegion: us-west-2
srcEndpoint:
srcCredential: src
Expand All @@ -15,6 +16,8 @@ destCredential:
destStorageClass: STANDARD
destInCurrentAccount: true
destAcl: bucket-owner-full-control
destSSEType: None
destSSEKMSKeyId:

jobTableName: test-table
jobQueueName: test-queue
Expand Down
74 changes: 70 additions & 4 deletions dth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ type Client interface {
ListParts(ctx context.Context, key, uploadID *string) (parts map[int]*Part)
GetUploadID(ctx context.Context, key *string) (uploadID *string)
ListSelectedPrefixes(ctx context.Context, key *string) (prefixes []*string)
ListSelectedPrefixesFromThirdBucket(ctx context.Context, bucket *string, key *string) (prefixes []*string)

// WRITE
PutObject(ctx context.Context, key *string, body []byte, storageClass, acl *string, meta *Metadata) (etag *string, err error)
CreateMultipartUpload(ctx context.Context, key, storageClass, acl *string, meta *Metadata) (uploadID *string, err error)
PutObject(ctx context.Context, key *string, body []byte, storageClass, acl *string, sseType *string, sseKMSKeyId *string, meta *Metadata) (etag *string, err error)
CreateMultipartUpload(ctx context.Context, key, storageClass, acl *string, sseType *string, sseKMSKeyId *string, meta *Metadata) (uploadID *string, err error)
CompleteMultipartUpload(ctx context.Context, key, uploadID *string, parts []*Part) (etag *string, err error)
UploadPart(ctx context.Context, key *string, body []byte, uploadID *string, partNumber int) (etag *string, err error)
AbortMultipartUpload(ctx context.Context, key, uploadID *string) (err error)
Expand Down Expand Up @@ -166,6 +167,19 @@ func NewS3Client(ctx context.Context, bucket, prefix, prefixList, endpoint, regi

}

// NewS3ClientWithEC2Role creates a S3Client instance which uses EC2 Role to access S3
func NewS3ClientWithEC2Role(ctx context.Context, bucket, prefixList string) *S3Client {
cfg := loadDefaultConfig(ctx)

client := s3.NewFromConfig(cfg)

return &S3Client{
bucket: bucket,
prefixList: prefixList,
client: client,
}
}

// GetObject is a function to get (download) object from Amazon S3
func (c *S3Client) GetObject(ctx context.Context, key *string, size, start, chunkSize int64, version string) ([]byte, error) {
// log.Printf("S3> Downloading %s with %d bytes start from %d\n", key, size, start)
Expand Down Expand Up @@ -459,8 +473,45 @@ func (c *S3Client) ListSelectedPrefixes(ctx context.Context, key *string) (prefi
return
}

// ListSelectedPrefixesFromThirdBucket is a function to list prefixes from a list file in a specific bucket.
func (c *S3Client) ListSelectedPrefixesFromThirdBucket(ctx context.Context, bucket *string, key *string) (prefixes []*string) {
downloader := manager.NewDownloader(c.client)
getBuf := manager.NewWriteAtBuffer([]byte{})

input := &s3.GetObjectInput{
Bucket: bucket,
Key: key,
}

downloadStart := time.Now()
log.Printf("Start downloading the Prefix List File from bucket: %s", *bucket)
_, err := downloader.Download(ctx, getBuf, input)
downloadEnd := time.Since(downloadStart)
if err != nil {
log.Printf("Error downloading the Prefix List File: %s", err)
return nil
} else {
log.Printf("Download the Prefix List File Completed in %v\n", downloadEnd)
}

start := time.Now()
prefixesValue := make([]string, 0)

for i, line := range strings.Split(string(getBuf.Bytes()), "\n") {
if len(line) > 0 {
prefixesValue = append(prefixesValue, line)
prefixes = append(prefixes, &prefixesValue[i])
}
}

end := time.Since(start)
log.Printf("Got %d prefixes from the customized list file in %v", len(prefixes), end)
return
}


// PutObject is a function to put (upload) an object to Amazon S3
func (c *S3Client) PutObject(ctx context.Context, key *string, body []byte, storageClass, acl *string, meta *Metadata) (etag *string, err error) {
func (c *S3Client) PutObject(ctx context.Context, key *string, body []byte, storageClass, acl *string, sseType *string, sseKMSKeyId *string, meta *Metadata) (etag *string, err error) {
// log.Printf("S3> Uploading object %s to bucket %s\n", key, c.bucket)

md5Bytes := md5.Sum(body)
Expand All @@ -481,6 +532,14 @@ func (c *S3Client) PutObject(ctx context.Context, key *string, body []byte, stor
StorageClass: types.StorageClass(*storageClass),
ACL: types.ObjectCannedACL(*acl),
}
switch *sseType {
case "AES256":
input.ServerSideEncryption = types.ServerSideEncryptionAes256
case "AWS_KMS":
input.ServerSideEncryption = types.ServerSideEncryptionAwsKms
input.SSEKMSKeyId = sseKMSKeyId
}

if meta != nil {
input.ContentType = meta.ContentType
input.ContentEncoding = meta.ContentEncoding
Expand Down Expand Up @@ -522,7 +581,7 @@ func (c *S3Client) DeleteObject(ctx context.Context, key *string) (err error) {
// CreateMultipartUpload is a function to initilize a multipart upload process.
// This func returns an upload ID used to indicate the multipart upload.
// All parts will be uploaded with this upload ID, after that, all parts by this ID will be combined to create the full object.
func (c *S3Client) CreateMultipartUpload(ctx context.Context, key, storageClass, acl *string, meta *Metadata) (uploadID *string, err error) {
func (c *S3Client) CreateMultipartUpload(ctx context.Context, key, storageClass, acl *string, sseType *string, sseKMSKeyId *string, meta *Metadata) (uploadID *string, err error) {
// log.Printf("S3> Create Multipart Upload for %s\n", *key)
if *acl == "" {
*acl = string(types.ObjectCannedACLBucketOwnerFullControl)
Expand All @@ -534,6 +593,13 @@ func (c *S3Client) CreateMultipartUpload(ctx context.Context, key, storageClass,
StorageClass: types.StorageClass(*storageClass),
ACL: types.ObjectCannedACL(*acl),
}
switch *sseType {
case "AES256":
input.ServerSideEncryption = types.ServerSideEncryptionAes256
case "AWS_KMS":
input.ServerSideEncryption = types.ServerSideEncryptionAwsKms
input.SSEKMSKeyId = sseKMSKeyId
}
if meta != nil {
input.ContentType = meta.ContentType
input.ContentEncoding = meta.ContentEncoding
Expand Down
8 changes: 4 additions & 4 deletions dth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ type JobOptions struct {

// JobConfig is General Job Info
type JobConfig struct {
SrcType, SrcBucket, SrcPrefix, SrcPrefixList, SrcRegion, SrcEndpoint, SrcCredential string
DestBucket, DestPrefix, DestRegion, DestCredential, DestStorageClass, DestAcl string
JobTableName, JobQueueName, SinglePartTableName, SfnArn string
SrcInCurrentAccount, DestInCurrentAccount, SkipCompare, PayerRequest bool
SrcType, SrcBucket, SrcPrefix, SrcPrefixList, SrcPrefixListBucket, SrcRegion, SrcEndpoint, SrcCredential string
DestBucket, DestPrefix, DestRegion, DestCredential, DestStorageClass, DestAcl, DestSSEType, DestSSEKMSKeyId string
JobTableName, JobQueueName, SinglePartTableName, SfnArn string
SrcInCurrentAccount, DestInCurrentAccount, SkipCompare, PayerRequest bool
*JobOptions
}

Expand Down
38 changes: 25 additions & 13 deletions dth/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ type Job interface {
// Finder is an implemenation of Job interface
// Finder compares the differences of source and destination and sends the delta to SQS
type Finder struct {
srcClient, desClient Client
sqs *SqsService
cfg *JobConfig
sfn *SfnService
srcClient, desClient, ec2RoleS3Client Client
sqs *SqsService
cfg *JobConfig
sfn *SfnService
}

// Worker is an implemenation of Job interface
Expand Down Expand Up @@ -148,6 +148,8 @@ func NewFinder(ctx context.Context, cfg *JobConfig) (f *Finder) {
srcClient := NewS3Client(ctx, cfg.SrcBucket, cfg.SrcPrefix, cfg.SrcPrefixList, cfg.SrcEndpoint, cfg.SrcRegion, cfg.SrcType, srcCred)
desClient := NewS3Client(ctx, cfg.DestBucket, cfg.DestPrefix, "", "", cfg.DestRegion, "Amazon_S3", desCred)

ec2RoleS3Client := NewS3ClientWithEC2Role(ctx, cfg.SrcPrefixListBucket, cfg.SrcPrefixList)

if srcClient != nil {
srcClient.isSrcClient = true
}
Expand All @@ -160,11 +162,12 @@ func NewFinder(ctx context.Context, cfg *JobConfig) (f *Finder) {
DST_CRED = desCred

f = &Finder{
srcClient: srcClient,
desClient: desClient,
sfn: sfn,
sqs: sqs,
cfg: cfg,
srcClient: srcClient,
desClient: desClient,
ec2RoleS3Client: ec2RoleS3Client,
sfn: sfn,
sqs: sqs,
cfg: cfg,
}
return
}
Expand Down Expand Up @@ -227,7 +230,9 @@ func (f *Finder) Run(ctx context.Context) {
log.Printf("Enable Payer Request Mode")
}

if len(f.cfg.SrcPrefixList) > 0 {
if f.cfg.SrcPrefixListBucket != "" && len(f.cfg.SrcPrefixList) > 0 {
prefixes = f.ec2RoleS3Client.ListSelectedPrefixesFromThirdBucket(ctx, &f.cfg.SrcPrefixListBucket, &f.cfg.SrcPrefixList)
} else if len(f.cfg.SrcPrefixList) > 0 {
prefixes = f.srcClient.ListSelectedPrefixes(ctx, &f.cfg.SrcPrefixList)
} else {
prefixes = f.srcClient.ListCommonPrefixes(ctx, f.cfg.FinderDepth, f.cfg.MaxKeys)
Expand Down Expand Up @@ -843,7 +848,9 @@ func (w *Worker) generateMultiPartTransferJobs(ctx context.Context, obj *Object,
meta = w.srcClient.HeadObject(ctx, &obj.Key)
}

uploadID, err = w.desClient.CreateMultipartUpload(ctx, destKey, &w.cfg.DestStorageClass, &w.cfg.DestAcl, meta)
uploadID, err = w.desClient.CreateMultipartUpload(
ctx, destKey, &w.cfg.DestStorageClass, &w.cfg.DestAcl, &w.cfg.DestSSEType, &w.cfg.DestSSEKMSKeyId, meta,
)
if err != nil {
log.Printf("Failed to create upload ID - %s for %s\n", err.Error(), *destKey)
return 0, err
Expand Down Expand Up @@ -949,7 +956,9 @@ func (w *Worker) migrateBigFile(ctx context.Context, obj *Object, destKey *strin
meta = w.srcClient.HeadObject(ctx, &obj.Key)
}

uploadID, err = w.desClient.CreateMultipartUpload(ctx, destKey, &w.cfg.DestStorageClass, &w.cfg.DestAcl, meta)
uploadID, err = w.desClient.CreateMultipartUpload(
ctx, destKey, &w.cfg.DestStorageClass, &w.cfg.DestAcl, &w.cfg.DestSSEType, &w.cfg.DestSSEKMSKeyId, meta,
)
if err != nil {
log.Printf("Failed to create upload ID - %s for %s\n", err.Error(), *destKey)
return &TransferResult{
Expand Down Expand Up @@ -1100,7 +1109,10 @@ func (w *Worker) transfer(ctx context.Context, obj *Object, destKey *string, sta

} else {
log.Printf("----->Uploading %d Bytes to %s/%s\n", chunkSize, w.cfg.DestBucket, *destKey)
etag, err = w.desClient.PutObject(ctx, destKey, body, &w.cfg.DestStorageClass, &w.cfg.DestAcl, meta)
etag, err = w.desClient.PutObject(
ctx, destKey, body, &w.cfg.DestStorageClass, &w.cfg.DestAcl,
&w.cfg.DestSSEType, &w.cfg.DestSSEKMSKeyId, meta,
)
}

body = nil // release memory
Expand Down

0 comments on commit 3553479

Please sign in to comment.