Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send push timings to the server #2152

Merged
merged 17 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module github.com/replicate/cog

go 1.23
go 1.23.0

toolchain go1.23.2

require (
Expand All @@ -18,6 +19,7 @@ require (
github.com/mattn/go-isatty v0.0.20
github.com/mitchellh/go-homedir v1.1.0
github.com/moby/term v0.5.0
github.com/replicate/go v0.0.0-20250205165008-b772d7cd506b
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.6
github.com/stretchr/testify v1.10.0
Expand Down Expand Up @@ -115,7 +117,6 @@ require (
github.com/gobwas/glob v0.2.3 // indirect
github.com/gofrs/flock v0.12.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect
github.com/golangci/go-printf-func-name v0.1.0 // indirect
github.com/golangci/gofmt v0.0.0-20250106114630-d62b90e6713d // indirect
Expand Down Expand Up @@ -144,7 +145,7 @@ require (
github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect
github.com/kisielk/errcheck v1.8.0 // indirect
github.com/kkHAIKE/contextcheck v1.1.5 // indirect
github.com/klauspost/compress v1.16.5 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/kulti/thelper v0.6.3 // indirect
github.com/kunwardeep/paralleltest v1.0.10 // indirect
github.com/lasiar/canonicalheader v1.1.2 // indirect
Expand All @@ -162,12 +163,12 @@ require (
github.com/matoous/godox v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mgechev/revive v1.6.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/moricho/tparallel v0.3.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nakabonne/nestif v0.3.1 // indirect
github.com/nishanths/exhaustive v0.12.0 // indirect
github.com/nishanths/predeclared v0.2.2 // indirect
Expand All @@ -181,10 +182,10 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/polyfloyd/go-errorlint v1.7.1 // indirect
github.com/prometheus/client_golang v1.12.1 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.32.1 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
github.com/prometheus/client_golang v1.20.5 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.61.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/quasilyte/go-ruleguard v0.4.3-0.20240823090925-0fe6f58b47b1 // indirect
github.com/quasilyte/go-ruleguard/dsl v0.3.22 // indirect
github.com/quasilyte/gogrep v0.5.0 // indirect
Expand Down Expand Up @@ -233,10 +234,9 @@ require (
gitlab.com/bosi/decorder v0.4.2 // indirect
go-simpler.org/musttag v0.13.0 // indirect
go-simpler.org/sloglint v0.9.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
go.uber.org/zap v1.24.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
golang.org/x/exp/typeparams v0.0.0-20241108190413-2d47ceb2692f // indirect
golang.org/x/mod v0.23.0 // indirect
golang.org/x/net v0.35.0 // indirect
Expand Down
441 changes: 20 additions & 421 deletions go.sum

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion pkg/cli/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
"fmt"
"strings"
"time"

"github.com/spf13/cobra"

Expand Down Expand Up @@ -59,17 +60,21 @@ func push(cmd *cobra.Command, args []string) error {
}
}

startBuildTime := time.Now()

if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile, buildDockerfileFile, DetermineUseCogBaseImage(cmd), buildStrip, buildPrecompile, buildFast); err != nil {
return err
}

buildDuration := time.Since(startBuildTime)

console.Infof("\nPushing image '%s'...", imageName)
if buildFast {
console.Info("Fast push enabled.")
}

command := docker.NewDockerCommand()
err = docker.Push(imageName, buildFast, projectDir, command)
err = docker.Push(imageName, buildFast, projectDir, command, buildDuration)
if err != nil {
if strings.Contains(err.Error(), "404") {
return fmt.Errorf("Unable to find existing Replicate model for %s. "+
Expand Down
1 change: 1 addition & 0 deletions pkg/docker/command/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type Config struct {

type Manifest struct {
Config Config `json:"Config"`
ID string `json:"Id"`
}

const UvPythonInstallDirEnvVarName = "UV_PYTHON_INSTALL_DIR"
Expand Down
4 changes: 2 additions & 2 deletions pkg/docker/fast_push.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const weightsObjectType = "weights"
const filesObjectType = "files"
const requirementsTarFile = "requirements.tar.zst"

func FastPush(ctx context.Context, image string, projectDir string, command command.Command, webClient *web.Client, monobeamClient *monobeam.Client) error {
func FastPush(ctx context.Context, image string, projectDir string, command command.Command, webClient *web.Client, monobeamClient *monobeam.Client, uploadID string) error {
g, _ := errgroup.WithContext(ctx)
p := mpb.New(
mpb.WithRefreshRate(180 * time.Millisecond),
Expand Down Expand Up @@ -126,7 +126,7 @@ func FastPush(ctx context.Context, image string, projectDir string, command comm
}

// Tell replicate about our new version
return webClient.PostNewVersion(ctx, image, createWeightsFilesFromWeightsManifest(weights), files)
return webClient.PostNewVersion(ctx, image, createWeightsFilesFromWeightsManifest(weights), files, uploadID)
}

func createPythonPackagesTarFile(image string, tmpDir string, command command.Command) (string, error) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/docker/fast_push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestFastPush(t *testing.T) {
monobeamClient := monobeam.NewClient(client)

// Run fast push
err = FastPush(context.Background(), "r8.im/username/modelname", dir, command, webClient, monobeamClient)
err = FastPush(context.Background(), "r8.im/username/modelname", dir, command, webClient, monobeamClient, "")
require.NoError(t, err)
}

Expand Down Expand Up @@ -118,6 +118,6 @@ func TestFastPushWithWeight(t *testing.T) {
monobeamClient := monobeam.NewClient(client)

// Run fast push
err = FastPush(context.Background(), "r8.im/username/modelname", dir, command, webClient, monobeamClient)
err = FastPush(context.Background(), "r8.im/username/modelname", dir, command, webClient, monobeamClient, "")
require.NoError(t, err)
}
51 changes: 46 additions & 5 deletions pkg/docker/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,63 @@ package docker

import (
"context"
"fmt"
"math/rand/v2"
"strings"
"time"

"github.com/replicate/cog/pkg/docker/command"
"github.com/replicate/cog/pkg/http"
"github.com/replicate/cog/pkg/monobeam"
"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/web"
)

func Push(image string, fast bool, projectDir string, command command.Command) error {
func Push(image string, fast bool, projectDir string, command command.Command, buildTime time.Duration) error {
ctx := context.Background()
client, err := http.ProvideHTTPClient(command)
if err != nil {
return err
}
webClient := web.NewClient(command, client)

// For the timing flow, on error we will just log and continue since
// this is just a loss of push timing information
imageID := ""
if fast {
client, err := http.ProvideHTTPClient(command)
imageID = buildRandomHash256()
} else {
imageMeta, err := command.Inspect(image)
if err != nil {
return err
console.Warnf("Failed to inspect image: %v", err)
}
_, hash, ok := strings.Cut(imageMeta.ID, ":")
if !ok {
console.Warn("Image ID was not of the form sha:hash")
} else {
imageID = hash
}
webClient := web.NewClient(command, client)
}

if err := webClient.PostBuildStart(ctx, imageID, buildTime); err != nil {
console.Warnf("Failed to send build timings to server: %v", err)
}

if fast {
monobeamClient := monobeam.NewClient(client)
return FastPush(context.Background(), image, projectDir, command, webClient, monobeamClient)
return FastPush(ctx, image, projectDir, command, webClient, monobeamClient, imageID)
}
return StandardPush(image, command)
}

func buildRandomHash256() string {
out := ""
// Generate 256 bit random hash (4x64 bits) to use as an upload ID
for i := 0; i < 4; i++ {
// Ignoring the linter warning about math/rand/v2 not being cryptographically secure
// because this just needs to be a "unique enough" ID for a cache between when the
// push starts and ends, which should only be ~a week max.
out = fmt.Sprintf("%s%x", out, rand.Int64()) //nolint:gosec
}
return out
}
5 changes: 3 additions & 2 deletions pkg/docker/push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -54,7 +55,7 @@ func TestPush(t *testing.T) {
command := dockertest.NewMockCommand()

// Run fast push
err = Push("r8.im/username/modelname", true, dir, command)
err = Push("r8.im/username/modelname", true, dir, command, time.Duration(0))
require.NoError(t, err)
}

Expand Down Expand Up @@ -108,6 +109,6 @@ func TestPushWithWeight(t *testing.T) {
command := dockertest.NewMockCommand()

// Run fast push
err = Push("r8.im/username/modelname", true, dir, command)
err = Push("r8.im/username/modelname", true, dir, command, time.Duration(0))
require.NoError(t, err)
}
58 changes: 54 additions & 4 deletions pkg/web/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"os"
"strconv"
"strings"
"time"

"github.com/replicate/go/types"

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker/command"
Expand All @@ -18,6 +21,17 @@ import (
"github.com/replicate/cog/pkg/util"
)

const (
buildTimingURLPath = "/api/models/build-start"
)

var (
ErrorBadResponseNewVersionEndpoint = errors.New("Bad response from new version endpoint")
ErrorBadResponseBuildStartEndpoint = errors.New("Bad response from build start endpoint")
ErrorBadRegistryURL = errors.New("The image URL must have 3 components in the format of " + global.ReplicateRegistryHost + "/your-username/your-model")
ErrorBadRegistryHost = errors.New("The image name must have the " + global.ReplicateRegistryHost + " prefix when using --x-fast.")
)

type Client struct {
dockerCommand command.Command
client *http.Client
Expand Down Expand Up @@ -56,6 +70,7 @@ type Version struct {
OpenAPISchema map[string]any `json:"openapi_schema"`
RuntimeConfig RuntimeConfig `json:"runtime_config"`
Virtual bool `json:"virtual"`
UploadID string `json:"upload_id"`
}

func NewClient(dockerCommand command.Command, client *http.Client) *Client {
Expand All @@ -65,12 +80,47 @@ func NewClient(dockerCommand command.Command, client *http.Client) *Client {
}
}

func (c *Client) PostNewVersion(ctx context.Context, image string, weights []File, files []File) error {
func (c *Client) PostBuildStart(ctx context.Context, imageHash string, buildTime time.Duration) error {
jsonBody := map[string]any{
"image_hash": imageHash,
"build_duration": types.Duration(buildTime).String(),
"push_start_time": time.Now().UTC(),
}

jsonData, err := json.Marshal(jsonBody)
if err != nil {
return util.WrapError(err, "failed to marshal JSON for build start")
}

url := webBaseURL()
url.Path = buildTimingURLPath

req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bytes.NewReader(jsonData))
if err != nil {
return err
}

resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return util.WrapError(ErrorBadResponseBuildStartEndpoint, strconv.Itoa(resp.StatusCode))
}

return nil
}

func (c *Client) PostNewVersion(ctx context.Context, image string, weights []File, files []File, uploadID string) error {
version, err := c.versionFromManifest(image, weights, files)
if err != nil {
return util.WrapError(err, "failed to build new version from manifest")
}

version.UploadID = uploadID

jsonData, err := json.Marshal(version)
if err != nil {
return util.WrapError(err, "failed to marshal JSON for new version")
Expand All @@ -93,7 +143,7 @@ func (c *Client) PostNewVersion(ctx context.Context, image string, weights []Fil
defer resp.Body.Close()

if resp.StatusCode != http.StatusCreated {
return errors.New("Bad response from new version endpoint: " + strconv.Itoa(resp.StatusCode))
return util.WrapError(ErrorBadResponseNewVersionEndpoint, strconv.Itoa(resp.StatusCode))
}

return nil
Expand Down Expand Up @@ -216,10 +266,10 @@ func newVersionURL(image string) (url.URL, error) {
imageComponents := strings.Split(image, "/")
newVersionUrl := webBaseURL()
if len(imageComponents) != 3 {
return newVersionUrl, errors.New("The image URL must have 3 components in the format of " + global.ReplicateRegistryHost + "/your-username/your-model")
return newVersionUrl, ErrorBadRegistryURL
}
if imageComponents[0] != global.ReplicateRegistryHost {
return newVersionUrl, errors.New("The image name must have the " + global.ReplicateRegistryHost + " prefix when using --x-fast.")
return newVersionUrl, ErrorBadRegistryHost
}
newVersionUrl.Path = strings.Join([]string{"", "api", "models", imageComponents[1], imageComponents[2], "versions"}, "/")
return newVersionUrl, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/web/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestPostNewVersion(t *testing.T) {

client := NewClient(command, http.DefaultClient)
ctx := context.Background()
err = client.PostNewVersion(ctx, "r8.im/user/test", []File{}, []File{})
err = client.PostNewVersion(ctx, "r8.im/user/test", []File{}, []File{}, "")
require.NoError(t, err)
}

Expand Down
Loading