Skip to content

Commit

Permalink
COCOS-132 - Add progress bar for algo and data uploads (ultravioletrs…
Browse files Browse the repository at this point in the history
…#162)

* add progress bar to CLI

* fix error handling

* fix comments errors

* add header

* add wraper for AlgoClient and DataClient

* add compile time check for wrapper structs

* refactor code
  • Loading branch information
danko-miladinovic authored Jul 9, 2024
1 parent 006897a commit 654be60
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 39 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ require (
go.opentelemetry.io/proto/otlp v1.0.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.20.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/term v0.21.0
golang.org/x/text v0.14.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240108191215-35c7eff3a6b1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240108191215-35c7eff3a6b1 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,12 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
Expand Down
237 changes: 237 additions & 0 deletions pkg/progressbar/progressbar.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package progressbar

import (
"bytes"
"fmt"
"io"
"os"
"strings"

"github.com/ultravioletrs/cocos/agent"
"golang.org/x/term"
)

const (
progressBarDots = "... "
leftBracket = "["
rightBracket = "]"
head = ">"
body = "="
bodyPadding = "."
bufferSize = 1024 * 1024
)

var (
_ streamSender = (*algoClientWrapper)(nil)
_ streamSender = (*dataClientWrapper)(nil)
)

type streamSender interface {
Send(interface{}) error
CloseAndRecv() (interface{}, error)
}

type algoClientWrapper struct {
client *agent.AgentService_AlgoClient
}

func (a *algoClientWrapper) Send(req interface{}) error {
algoReq, ok := req.(*agent.AlgoRequest)
if !ok {
return fmt.Errorf("expected *AlgoRequest, got %T", req)
}

return (*a.client).Send(algoReq)
}

func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) {
return (*a.client).CloseAndRecv()
}

type dataClientWrapper struct {
client *agent.AgentService_DataClient
}

func (a *dataClientWrapper) Send(req interface{}) error {
dataReq, ok := req.(*agent.DataRequest)
if !ok {
return fmt.Errorf("expected *DataRequest, got %T", req)
}

return (*a.client).Send(dataReq)
}

func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) {
return (*a.client).CloseAndRecv()
}

type ProgressBar struct {
numberOfBytes int
currentUploadedBytes int
currentUploadPercentage int
description string
maxWidth int
}

func New() *ProgressBar {
return &ProgressBar{}
}

func (p *ProgressBar) SendAlgorithm(description string, buffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error {
return p.sendData(description, buffer, &algoClientWrapper{client: stream}, func(data []byte) interface{} {
return &agent.AlgoRequest{Algorithm: data}
})
}

func (p *ProgressBar) SendData(description string, buffer *bytes.Buffer, stream *agent.AgentService_DataClient) error {
return p.sendData(description, buffer, &dataClientWrapper{client: stream}, func(data []byte) interface{} {
return &agent.DataRequest{Dataset: data}
})
}

func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error {
p.reset(description, buffer.Len())

buf := make([]byte, bufferSize)

for {
n, err := buffer.Read(buf)
if err == io.EOF {
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
return err
}
break
}
if err != nil {
return err
}

p.updateProgress(n)

if err := stream.Send(createRequest(buf[:n])); err != nil {
return err
}

if err := p.renderProgressBar(); err != nil {
return err
}
}

_, err := stream.CloseAndRecv()
return err
}

func (p *ProgressBar) reset(description string, totalBytes int) {
p.currentUploadedBytes = 0
p.currentUploadPercentage = 0
p.numberOfBytes = totalBytes
p.description = description
}

func (p *ProgressBar) updateProgress(bytesRead int) {
if p.currentUploadedBytes < p.numberOfBytes {
p.currentUploadedBytes += bytesRead
p.currentUploadPercentage = p.currentUploadedBytes * 100 / p.numberOfBytes
}
}

// Progress bar example: Uploading algorithm... 25% [==> ].
func (p *ProgressBar) renderProgressBar() error {
var builder strings.Builder

// Get terminal width.
width, err := terminalWidth()
if err != nil {
return err
}

if p.maxWidth < width {
p.maxWidth = width
}

if err := p.clearProgressBar(); err != nil {
return err
}

// The progress bar starts with the description.
if _, err := builder.WriteString(p.description); err != nil {
return err
}

// Add dots to progress bar.
if _, err := builder.WriteString(progressBarDots); err != nil {
return err
}

// Add uploaded percentage.
strCurrentUploadPercentage := fmt.Sprintf("%4d%% ", p.currentUploadPercentage)
if _, err := builder.WriteString(strCurrentUploadPercentage); err != nil {
return err
}

// Add letf bracket and space to progress bar.
if _, err := builder.WriteString(leftBracket); err != nil {
return err
}

progressWidth := width - builder.Len() - len(rightBracket+" ")
numOfCharactersBody := progressWidth * p.currentUploadPercentage / 100
if numOfCharactersBody == 0 {
numOfCharactersBody = 1
}

numOfCharactersPadding := progressWidth - numOfCharactersBody

// Add body which represents the percentage.
progress := strings.Repeat(body, numOfCharactersBody-1)

// Add progress to the progress bar.
if _, err := builder.WriteString(progress); err != nil {
return err
}

// Add head to progress bar.
if _, err := builder.WriteString(head); err != nil {
return err
}

// Add padding to end of bar.
padding := strings.Repeat(bodyPadding, numOfCharactersPadding)

// Add padding to progress bar.
if _, err := builder.WriteString(padding); err != nil {
return err
}

// Add right bracket to progress bar.
if _, err := builder.WriteString(rightBracket); err != nil {
return err
}

// Write progress bar.
if _, err := io.WriteString(os.Stdout, builder.String()); err != nil {
return err
}

return nil
}

func terminalWidth() (int, error) {
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err == nil {
return width, nil
}

return 0, err
}

func (p *ProgressBar) clearProgressBar() error {
emptySpace := fmt.Sprintf("\r%s\r", strings.Repeat(" ", p.maxWidth))
if _, err := io.WriteString(os.Stdout, emptySpace); err != nil {
return err
}

return nil
}
49 changes: 11 additions & 38 deletions pkg/sdk/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
"crypto/sha256"
"encoding/base64"
"errors"
"io"
"log/slog"

"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/pkg/progressbar"
"google.golang.org/grpc/metadata"
)

Expand All @@ -29,8 +29,9 @@ type SDK interface {
}

const (
size64 = 64
bufferSize = 1024 * 1024
size64 = 64
algoProgressBarDescription = "Uploading algorithm"
dataProgressBarDescription = "Uploading data"
)

type agentSDK struct {
Expand Down Expand Up @@ -60,23 +61,9 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
}
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)

buf := make([]byte, bufferSize)
for {
n, err := algoBuffer.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return err
}

err = stream.Send(&agent.AlgoRequest{Algorithm: buf[:n]})
if err != nil {
return err
}
}

if _, err := stream.CloseAndRecv(); err != nil {
progressbar := progressbar.New()
if err := progressbar.SendAlgorithm(algoProgressBarDescription, algoBuffer, &stream); err != nil {
sdk.logger.Error("Failed to send Algorithm")
return err
}

Expand All @@ -93,28 +80,14 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
ctx = metadata.NewOutgoingContext(ctx, md)
stream, err := sdk.client.Data(ctx)
if err != nil {
sdk.logger.Error("Failed to call Algo RPC")
sdk.logger.Error("Failed to call Data RPC")
return err
}
dataBuffer := bytes.NewBuffer(dataset.Dataset)

buf := make([]byte, bufferSize)
for {
n, err := dataBuffer.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return err
}

err = stream.Send(&agent.DataRequest{Dataset: buf[:n]})
if err != nil {
return err
}
}

if _, err := stream.CloseAndRecv(); err != nil {
progressbar := progressbar.New()
if err := progressbar.SendData(dataProgressBarDescription, dataBuffer, &stream); err != nil {
sdk.logger.Error("Failed to send Data")
return err
}

Expand Down

0 comments on commit 654be60

Please sign in to comment.