Skip to content

Commit

Permalink
test: use go api in integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasten committed Jul 16, 2024
1 parent 193065f commit 3ea7b8c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 237 deletions.
210 changes: 40 additions & 170 deletions test/framework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package framework

import (
"bufio"
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
Expand All @@ -19,22 +18,20 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"syscall"
"testing"
"time"

"github.com/edgelesssys/marblerun/api"
"github.com/edgelesssys/marblerun/coordinator/constants"
"github.com/edgelesssys/marblerun/coordinator/manifest"
"github.com/edgelesssys/marblerun/coordinator/store/stdstore"
mconfig "github.com/edgelesssys/marblerun/marble/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)

// IntegrationTest is a testing framework for MarbleRun tests.
Expand All @@ -43,16 +40,15 @@ type IntegrationTest struct {
assert *assert.Assertions
require *require.Assertions

Ctx context.Context
TestManifest manifest.Manifest
UpdatedManifest manifest.Manifest
BuildDir string
SimulationFlag string
NoEnclave bool
MeshServerAddr string
ClientServerAddr string
MarbleTestAddr string
transportSkipVerify http.RoundTripper
Ctx context.Context
TestManifest manifest.Manifest
UpdatedManifest manifest.Manifest
BuildDir string
SimulationFlag string
NoEnclave bool
MeshServerAddr string
ClientServerAddr string
MarbleTestAddr string
}

// New creates a new IntegrationTest.
Expand All @@ -68,14 +64,13 @@ func New(t *testing.T, buildDir, simulation string, noenclave bool,
assert: assert.New(t),
require: require.New(t),

Ctx: ctx,
BuildDir: buildDir,
SimulationFlag: simulation,
NoEnclave: noenclave,
MeshServerAddr: meshServerAddr,
ClientServerAddr: clientServerAddr,
MarbleTestAddr: marbleTestAddr,
transportSkipVerify: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}},
Ctx: ctx,
BuildDir: buildDir,
SimulationFlag: simulation,
NoEnclave: noenclave,
MeshServerAddr: meshServerAddr,
ClientServerAddr: clientServerAddr,
MarbleTestAddr: marbleTestAddr,
}

i.require.NoError(json.Unmarshal([]byte(testManifest), &i.TestManifest))
Expand Down Expand Up @@ -154,9 +149,6 @@ func (i IntegrationTest) StartCoordinator(ctx context.Context, cfg CoordinatorCo
cmd.Env = append(cmd.Env, cfg.extraEnv...)
cmdErr := i.StartCommand("coor", cmd)

client := http.Client{Transport: i.transportSkipVerify}
url := url.URL{Scheme: "https", Host: i.ClientServerAddr, Path: "status"}

i.t.Log("Coordinator starting...")
for {
time.Sleep(10 * time.Millisecond)
Expand All @@ -167,14 +159,8 @@ func (i IntegrationTest) StartCoordinator(ctx context.Context, cfg CoordinatorCo
default:
}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url.String(), http.NoBody)
i.require.NoError(err)

resp, err := client.Do(req)
if err == nil {
if _, _, err := api.GetStatus(context.Background(), i.ClientServerAddr, nil); err == nil {
i.t.Log("Coordinator started")
resp.Body.Close()
i.require.Equal(http.StatusOK, resp.StatusCode)
return func() {
_ = cmd.Cancel()
<-cmdErr
Expand Down Expand Up @@ -214,131 +200,37 @@ func (i IntegrationTest) StartCommand(friendlyName string, cmd *exec.Cmd) chan e
}

// SetManifest sets the manifest of the Coordinator.
func (i IntegrationTest) SetManifest(manifest manifest.Manifest) ([]byte, error) {
// Use ClientAPI to set Manifest
client := http.Client{Transport: i.transportSkipVerify}
clientAPIURL := url.URL{
Scheme: "https",
Host: i.ClientServerAddr,
Path: "manifest",
}

func (i IntegrationTest) SetManifest(manifest manifest.Manifest) (map[string][]byte, error) {
manifestRaw, err := json.Marshal(manifest)
i.require.NoError(err)

req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, clientAPIURL.String(), bytes.NewReader(manifestRaw))
i.require.NoError(err)
req.Header.Set("Content-Type", "application/json")

resp, err := client.Do(req)
i.require.NoError(err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
i.require.NoError(err)

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("expected %v, but /manifest returned %v: %v", http.StatusOK, resp.Status, string(body))
}

return body, nil
return api.ManifestSet(context.Background(), i.ClientServerAddr, nil, manifestRaw)
}

// SetUpdateManifest sets performs a manifest update for the Coordinator.
func (i IntegrationTest) SetUpdateManifest(manifest manifest.Manifest, certPEM []byte, key *rsa.PrivateKey) ([]byte, error) {
func (i IntegrationTest) SetUpdateManifest(manifest manifest.Manifest, certPEM []byte, key *rsa.PrivateKey) error {
// Setup requied client certificate for authentication
privk, err := x509.MarshalPKCS8PrivateKey(key)
i.require.NoError(err)

cert, err := tls.X509KeyPair(certPEM, pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privk}))
i.require.NoError(err)

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}

transport := &http.Transport{TLSClientConfig: tlsConfig}

// Use ClientAPI to set Manifest
client := http.Client{Transport: transport}
clientAPIURL := url.URL{
Scheme: "https",
Host: i.ClientServerAddr,
Path: "update",
}

manifestRaw, err := json.Marshal(manifest)
i.require.NoError(err)

req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, clientAPIURL.String(), bytes.NewReader(manifestRaw))
i.require.NoError(err)
req.Header.Set("Content-Type", "application/json")

resp, err := client.Do(req)
i.require.NoError(err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
i.require.NoError(err)

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("expected %v, but /manifest returned %v: %v", http.StatusOK, resp.Status, string(body))
}

return body, nil
return api.ManifestUpdateApply(context.Background(), i.ClientServerAddr, nil, manifestRaw, &cert)
}

// SetRecover sets the recovery key of the Coordinator.
func (i IntegrationTest) SetRecover(recoveryKey []byte) error {
client := http.Client{Transport: i.transportSkipVerify}
clientAPIURL := url.URL{
Scheme: "https",
Host: i.ClientServerAddr,
Path: "recover",
}

req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, clientAPIURL.String(), bytes.NewReader(recoveryKey))
i.require.NoError(err)
req.Header.Set("Content-Type", "application/octet-stream")

resp, err := client.Do(req)
i.require.NoError(err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
i.require.NoError(err)

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("expected %v, but /recover returned %v: %v", http.StatusOK, resp.Status, string(body))
}

return nil
_, _, err := api.Recover(context.Background(), i.ClientServerAddr, api.VerifyOptions{InsecureSkipVerify: true}, recoveryKey)
return err
}

// GetStatus returns the status of the Coordinator.
func (i IntegrationTest) GetStatus() (string, error) {
client := http.Client{Transport: i.transportSkipVerify}
clientAPIURL := url.URL{
Scheme: "https",
Host: i.ClientServerAddr,
Path: "status",
}

req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, clientAPIURL.String(), http.NoBody)
i.require.NoError(err)
resp, err := client.Do(req)
i.require.NoError(err)
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
i.require.NoError(err)

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("expected %v, but /status returned %v: %v", http.StatusOK, resp.Status, string(body))
}

return string(body), nil
func (i IntegrationTest) GetStatus() (int, error) {
code, _, err := api.GetStatus(context.Background(), i.ClientServerAddr, nil)
return code, err
}

// MarbleConfig contains the configuration for a Marble.
Expand Down Expand Up @@ -438,21 +330,11 @@ func (i IntegrationTest) StartMarbleClient(ctx context.Context, cfg MarbleConfig
}

// TriggerRecovery triggers a recovery.
func (i IntegrationTest) TriggerRecovery(coordinatorCfg CoordinatorConfig, cancelCoordinator func()) (func(), string) {
func (i IntegrationTest) TriggerRecovery(coordinatorCfg CoordinatorConfig, cancelCoordinator func()) (func(), *x509.Certificate) {
// get certificate
i.t.Log("Save certificate before we try to recover.")
client := http.Client{Transport: i.transportSkipVerify}
clientAPIURL := url.URL{Scheme: "https", Host: i.ClientServerAddr, Path: "quote"}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, clientAPIURL.String(), http.NoBody)
i.require.NoError(err)
resp, err := client.Do(req)
cert, _, _, err := api.VerifyCoordinator(context.Background(), i.ClientServerAddr, api.VerifyOptions{InsecureSkipVerify: true})
i.require.NoError(err)
i.require.Equal(http.StatusOK, resp.StatusCode)
quote, err := io.ReadAll(resp.Body)
resp.Body.Close()
i.require.NoError(err)
cert := gjson.GetBytes(quote, "data.Cert").String()
i.require.NotEmpty(cert)

// simulate restart of coordinator
i.t.Log("Simulating a restart of the coordinator enclave...")
Expand All @@ -469,27 +351,19 @@ func (i IntegrationTest) TriggerRecovery(coordinatorCfg CoordinatorConfig, cance

// Query status API, check if status response begins with Code 1 (recovery state)
i.t.Log("Checking status...")
statusResponse, err := i.GetStatus()
statusCode, err := i.GetStatus()
i.require.NoError(err)
i.assert.EqualValues(1, gjson.Get(statusResponse, "data.StatusCode").Int(), "Server is not in recovery state, but should be.")
i.assert.EqualValues(1, statusCode, "Server is not in recovery state, but should be.")

return cancelCoordinator, cert
}

// VerifyCertAfterRecovery verifies the certificate after a recovery.
func (i IntegrationTest) VerifyCertAfterRecovery(cert string, cancelCoordinator func(), cfg CoordinatorConfig) func() {
func (i IntegrationTest) VerifyCertAfterRecovery(cert *x509.Certificate, cancelCoordinator func(), cfg CoordinatorConfig) func() {
// Test with certificate
i.t.Log("Verifying certificate after recovery, without a restart.")
pool := x509.NewCertPool()
i.require.True(pool.AppendCertsFromPEM([]byte(cert)))
client := http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: pool}}}
clientAPIURL := url.URL{Scheme: "https", Host: i.ClientServerAddr, Path: "status"}
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, clientAPIURL.String(), http.NoBody)
i.require.NoError(err)
resp, err := client.Do(req)
_, _, err := api.GetStatus(context.Background(), i.ClientServerAddr, cert)
i.require.NoError(err)
resp.Body.Close()
i.require.Equal(http.StatusOK, resp.StatusCode)

// Simulate restart of coordinator
i.t.Log("Simulating a restart of the coordinator enclave...")
Expand All @@ -502,18 +376,14 @@ func (i IntegrationTest) VerifyCertAfterRecovery(cert string, cancelCoordinator

// Finally, check if we survive a restart.
i.t.Log("Restarted instance, now let's see if the state can be restored again successfully.")
statusResponse, err := i.GetStatus()
statusCode, err := i.GetStatus()
i.require.NoError(err)
i.assert.EqualValues(3, gjson.Get(statusResponse, "data.StatusCode").Int(), "Server is in wrong status after recovery.")
i.assert.EqualValues(3, statusCode, "Server is in wrong status after recovery.")

// test with certificate
i.t.Log("Verifying certificate after restart.")
req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, clientAPIURL.String(), http.NoBody)
i.require.NoError(err)
resp, err = client.Do(req)
_, _, err = api.GetStatus(context.Background(), i.ClientServerAddr, cert)
i.require.NoError(err)
resp.Body.Close()
i.require.Equal(http.StatusOK, resp.StatusCode)

return cancelCoordinator
}
Expand All @@ -522,9 +392,9 @@ func (i IntegrationTest) VerifyCertAfterRecovery(cert string, cancelCoordinator
func (i IntegrationTest) VerifyResetAfterRecovery(cancelCoordinator func(), cfg CoordinatorConfig) func() {
// Check status after setting a new manifest, we should be able
i.t.Log("Check if the manifest was accepted and we are ready to accept Marbles")
statusResponse, err := i.GetStatus()
statusCode, err := i.GetStatus()
i.require.NoError(err)
i.assert.EqualValues(3, gjson.Get(statusResponse, "data.StatusCode").Int(), "Server is in wrong status after recovery.")
i.assert.EqualValues(3, statusCode, "Server is in wrong status after recovery.")

// simulate restart of coordinator
i.t.Log("Simulating a restart of the coordinator enclave...")
Expand All @@ -537,9 +407,9 @@ func (i IntegrationTest) VerifyResetAfterRecovery(cancelCoordinator func(), cfg

// Finally, check if we survive a restart.
i.t.Log("Restarted instance, now let's see if the new state can be decrypted successfully...")
statusResponse, err = i.GetStatus()
statusCode, err = i.GetStatus()
i.require.NoError(err)
i.assert.EqualValues(3, gjson.Get(statusResponse, "data.StatusCode").Int(), "Server is in wrong status after recovery.")
i.assert.EqualValues(3, statusCode, "Server is in wrong status after recovery.")

return cancelCoordinator
}
Expand Down
Loading

0 comments on commit 3ea7b8c

Please sign in to comment.