diff --git a/internal/abstractions/container/container.go b/internal/abstractions/container/container.go index 57d5d8917..9b871d2de 100644 --- a/internal/abstractions/container/container.go +++ b/internal/abstractions/container/container.go @@ -143,7 +143,7 @@ func (cs *CmdContainerService) ExecuteCommand(in *pb.CommandExecutionRequest, st return err } - hostname, err := cs.containerRepo.GetContainerWorkerHostname(task.ContainerId) + hostname, err := cs.containerRepo.GetWorkerAddress(task.ContainerId) if err != nil { return err } diff --git a/internal/abstractions/function/function.go b/internal/abstractions/function/function.go index da4f96e6a..f85210ad0 100644 --- a/internal/abstractions/function/function.go +++ b/internal/abstractions/function/function.go @@ -93,7 +93,7 @@ func (fs *RunCFunctionService) FunctionInvoke(in *pb.FunctionInvokeRequest, stre return err } - hostname, err := fs.containerRepo.GetContainerWorkerHostname(task.ContainerId) + hostname, err := fs.containerRepo.GetWorkerAddress(task.ContainerId) if err != nil { return err } diff --git a/internal/abstractions/image/build.go b/internal/abstractions/image/build.go index 8463d5953..4f548ae07 100644 --- a/internal/abstractions/image/build.go +++ b/internal/abstractions/image/build.go @@ -140,11 +140,12 @@ func (b *Builder) Build(ctx context.Context, opts *BuildOpts, outputChan chan co return err } - hostname, err := b.containerRepo.GetContainerWorkerHostname(containerId) + hostname, err := b.containerRepo.GetWorkerAddress(containerId) if err != nil { return err } + log.Println("Retrieved worker address: ", hostname) conn, err := network.ConnectToHost(ctx, hostname, time.Second*30, b.tailscale, b.config.Tailscale) if err != nil { return err diff --git a/internal/common/keys.go b/internal/common/keys.go index 11d530b78..71ea07416 100644 --- a/internal/common/keys.go +++ b/internal/common/keys.go @@ -5,17 +5,17 @@ import ( ) var ( - schedulerPrefix string = "scheduler:" - schedulerContainerRequests string = "scheduler:container_requests" - schedulerWorkerLock string = "scheduler:worker:lock:%s" - schedulerWorkerRequests string = "scheduler:worker:requests:%s" - schedulerWorkerState string = "scheduler:worker:state:%s" - schedulerContainerConfig string = "scheduler:container:config:%s" - schedulerContainerState string = "scheduler:container:state:%s" - schedulerContainerHost string = "scheduler:container:host:%s" - schedulerWorkerContainerHost string = "scheduler:container:worker_host:%s" - schedulerContainerLock string = "scheduler:container:lock:%s" - schedulerContainerExitCode string = "scheduler:container:exit_code:%s" + schedulerPrefix string = "scheduler:" + schedulerContainerRequests string = "scheduler:container_requests" + schedulerWorkerLock string = "scheduler:worker:lock:%s" + schedulerWorkerRequests string = "scheduler:worker:requests:%s" + schedulerWorkerState string = "scheduler:worker:state:%s" + schedulerContainerConfig string = "scheduler:container:config:%s" + schedulerContainerState string = "scheduler:container:state:%s" + schedulerContainerAddress string = "scheduler:container:container_addr:%s" + schedulerWorkerAddress string = "scheduler:container:worker_addr:%s" + schedulerContainerLock string = "scheduler:container:lock:%s" + schedulerContainerExitCode string = "scheduler:container:exit_code:%s" ) var ( @@ -96,12 +96,12 @@ func (rk *redisKeys) SchedulerContainerConfig(containerId string) string { return fmt.Sprintf(schedulerContainerConfig, containerId) } -func (rk *redisKeys) SchedulerContainerHost(containerId string) string { - return fmt.Sprintf(schedulerContainerHost, containerId) +func (rk *redisKeys) SchedulerContainerAddress(containerId string) string { + return fmt.Sprintf(schedulerContainerAddress, containerId) } -func (rk *redisKeys) SchedulerWorkerContainerHost(containerId string) string { - return fmt.Sprintf(schedulerWorkerContainerHost, containerId) +func (rk *redisKeys) SchedulerWorkerAddress(containerId string) string { + return fmt.Sprintf(schedulerWorkerAddress, containerId) } func (rk *redisKeys) SchedulerContainerExitCode(containerId string) string { diff --git a/internal/repository/base.go b/internal/repository/base.go index 98768d901..9ec899175 100644 --- a/internal/repository/base.go +++ b/internal/repository/base.go @@ -37,8 +37,8 @@ type ContainerRepository interface { GetContainerAddress(containerId string) (string, error) UpdateContainerStatus(string, types.ContainerStatus, time.Duration) error DeleteContainerState(*types.ContainerRequest) error - SetContainerWorkerHostname(containerId string, addr string) error - GetContainerWorkerHostname(containerId string) (string, error) + SetWorkerAddress(containerId string, addr string) error + GetWorkerAddress(containerId string) (string, error) GetActiveContainersByPrefix(patternPrefix string) ([]types.ContainerState, error) GetFailedContainerCountByPrefix(patternPrefix string) (int, error) } diff --git a/internal/repository/container_redis.go b/internal/repository/container_redis.go index 0e69e1614..8eef3e0cc 100644 --- a/internal/repository/container_redis.go +++ b/internal/repository/container_redis.go @@ -149,28 +149,28 @@ func (cr *ContainerRedisRepository) DeleteContainerState(request *types.Containe return fmt.Errorf("failed to delete container state <%v>: %w", stateKey, err) } - hostKey := common.RedisKeys.SchedulerContainerHost(containerId) - err = cr.rdb.Del(context.TODO(), hostKey).Err() + addrKey := common.RedisKeys.SchedulerContainerAddress(containerId) + err = cr.rdb.Del(context.TODO(), addrKey).Err() if err != nil { - return fmt.Errorf("failed to delete container host <%v>: %w", hostKey, err) + return fmt.Errorf("failed to delete container addr <%v>: %w", addrKey, err) } return nil } func (cr *ContainerRedisRepository) SetContainerAddress(containerId string, addr string) error { - return cr.rdb.Set(context.TODO(), common.RedisKeys.SchedulerContainerHost(containerId), addr, 0).Err() + return cr.rdb.Set(context.TODO(), common.RedisKeys.SchedulerContainerAddress(containerId), addr, 0).Err() } func (cr *ContainerRedisRepository) GetContainerAddress(containerId string) (string, error) { - return cr.rdb.Get(context.TODO(), common.RedisKeys.SchedulerContainerHost(containerId)).Result() + return cr.rdb.Get(context.TODO(), common.RedisKeys.SchedulerContainerAddress(containerId)).Result() } -func (cr *ContainerRedisRepository) SetContainerWorkerHostname(containerId string, addr string) error { - return cr.rdb.Set(context.TODO(), common.RedisKeys.SchedulerWorkerContainerHost(containerId), addr, 0).Err() +func (cr *ContainerRedisRepository) SetWorkerAddress(containerId string, addr string) error { + return cr.rdb.Set(context.TODO(), common.RedisKeys.SchedulerWorkerAddress(containerId), addr, 0).Err() } -func (cr *ContainerRedisRepository) GetContainerWorkerHostname(containerId string) (string, error) { +func (cr *ContainerRedisRepository) GetWorkerAddress(containerId string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() @@ -183,9 +183,9 @@ func (cr *ContainerRedisRepository) GetContainerWorkerHostname(containerId strin for { select { case <-ctx.Done(): - return "", errors.New("timeout reached while trying to get worker hostname") + return "", errors.New("timeout reached while trying to get worker addr") case <-ticker.C: - hostname, err = cr.rdb.Get(ctx, common.RedisKeys.SchedulerWorkerContainerHost(containerId)).Result() + hostname, err = cr.rdb.Get(ctx, common.RedisKeys.SchedulerWorkerAddress(containerId)).Result() if err == nil { return hostname, nil } diff --git a/internal/worker/network.go b/internal/worker/network.go index c21cb8883..a65c775b6 100644 --- a/internal/worker/network.go +++ b/internal/worker/network.go @@ -70,5 +70,10 @@ func getIPFromEnv(varName string) (string, error) { return "", errors.New("failed to parse ip address") } + // If the parsed IP is an IPv6 address, encapsulate in brackets + if ip.To4() == nil { + return fmt.Sprintf("[%s]", ip.String()), nil + } + return ip.String(), nil } diff --git a/internal/worker/network_test.go b/internal/worker/network_test.go new file mode 100644 index 000000000..79d928581 --- /dev/null +++ b/internal/worker/network_test.go @@ -0,0 +1,63 @@ +package worker + +import ( + "os" + "testing" +) + +func TestGetIPFromEnv(t *testing.T) { + tests := []struct { + name string + envName string + envValue string + want string + expectErr bool + }{ + { + name: "No IP set", + envName: "EMPTY_IP", + envValue: "", + want: "", + expectErr: true, + }, + { + name: "Invalid IP", + envName: "INVALID_IP", + envValue: "invalid", + want: "", + expectErr: true, + }, + { + name: "Valid IPv4", + envName: "VALID_IPV4", + envValue: "192.168.1.1", + want: "192.168.1.1", + expectErr: false, + }, + { + name: "Valid IPv6", + envName: "VALID_IPV6", + envValue: "2001:db8::1", + want: "[2001:db8::1]", + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv(tt.envName, tt.envValue) + defer os.Unsetenv(tt.envName) + + got, err := getIPFromEnv(tt.envName) + + if (err != nil) != tt.expectErr { + t.Errorf("getIPFromEnv() error = %v, expectErr %v", err, tt.expectErr) + return + } + + if got != tt.want { + t.Errorf("getIPFromEnv() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/worker/runc_server.go b/internal/worker/runc_server.go index 493988710..f8a3f92b6 100644 --- a/internal/worker/runc_server.go +++ b/internal/worker/runc_server.go @@ -24,7 +24,7 @@ import ( const ( defaultWorkingDirectory string = "/mnt/code" - defaultWorkerServerPort int = 1000 + defaultWorkerServerPort int = 1989 ) type RunCServer struct { diff --git a/internal/worker/worker.go b/internal/worker/worker.go index a482791f5..38ea12c7b 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -255,7 +255,7 @@ func (s *Worker) RunContainer(request *types.ContainerRequest) error { bundlePath := filepath.Join(s.userImagePath, request.ImageId) hostname := fmt.Sprintf("%s:%d", s.podAddr, defaultWorkerServerPort) - err := s.containerRepo.SetContainerWorkerHostname(request.ContainerId, hostname) + err := s.containerRepo.SetWorkerAddress(request.ContainerId, hostname) if err != nil { return err }