Skip to content

Commit

Permalink
Authenticate guard-service requests (#77)
Browse files Browse the repository at this point in the history
* get sid and ns from jwt

* get sid and ns from jwt

* get sid and ns from jwt

* get sid and ns from jwt

* remvoe race on log initialization during tests

* support service auth config param

* support service auth config param

* merge

* merge

* log cleanups

* log cleanups

* use ServiceAudience const
  • Loading branch information
davidhadas authored Oct 21, 2022
1 parent 4459d49 commit 7be60b6
Show file tree
Hide file tree
Showing 43 changed files with 2,816 additions and 147 deletions.
22 changes: 10 additions & 12 deletions cmd/guard-rproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"runtime/debug"

"github.com/kelseyhightower/envconfig"
"go.uber.org/zap"

"knative.dev/pkg/signals"
_ "knative.dev/security-guard/pkg/guard-gate"
Expand Down Expand Up @@ -84,7 +83,7 @@ func (p *GuardGate) Transport(t http.RoundTripper) http.RoundTripper {
return p
}

func preMain() (guardGate *GuardGate, mux *http.ServeMux, target string, plugConfig map[string]string, sid string, ns string, log *zap.SugaredLogger) {
func preMain() (guardGate *GuardGate, mux *http.ServeMux, target string, plugConfig map[string]string, sid string, ns string) {
var env config
if err := envconfig.Process("", &env); err != nil {
fmt.Fprintf(os.Stderr, "Failed to process environment: %s\n", err.Error())
Expand All @@ -94,9 +93,7 @@ func preMain() (guardGate *GuardGate, mux *http.ServeMux, target string, plugCon
plugConfig = make(map[string]string)
guardGate = new(GuardGate)

log = utils.CreateLogger(env.LogLevel)
defer log.Sync()
pi.Log = log
utils.CreateLogger(env.LogLevel)

if env.GuardUrl == "" {
// use default
Expand Down Expand Up @@ -124,13 +121,13 @@ func preMain() (guardGate *GuardGate, mux *http.ServeMux, target string, plugCon
sid = env.ServiceName
ns = env.Namespace

log.Infof("guard-proxy serving serviceName: %s, namespace: %s, serviceUrl: %s", sid, ns, env.ServiceUrl)
pi.Log.Infof("guard-proxy serving serviceName: %s, namespace: %s, serviceUrl: %s", sid, ns, env.ServiceUrl)
parsedUrl, err := url.Parse(env.ServiceUrl)
if err != nil {
log.Errorf("Failed to parse serviceUrl: %s", err.Error())
pi.Log.Errorf("Failed to parse serviceUrl: %s", err.Error())
return
}
log.Infof("guard-proxy parsedUrl: %v", parsedUrl)
pi.Log.Infof("guard-proxy parsedUrl: %v", parsedUrl)

proxy := httputil.NewSingleHostReverseProxy(parsedUrl)

Expand All @@ -148,19 +145,20 @@ func preMain() (guardGate *GuardGate, mux *http.ServeMux, target string, plugCon

mux = http.NewServeMux()
mux.Handle("/", proxy)
log.Infof("Starting Reverse Proxy on port %s", target)
pi.Log.Infof("Starting Reverse Proxy on port %s", target)
return
}

func main() {
guardGate, mux, target, plugConfig, sid, ns, log := preMain()
guardGate, mux, target, plugConfig, sid, ns := preMain()
if mux == nil {
os.Exit(1)
}
defer utils.SyncLogger()

guardGate.securityPlug.Init(signals.NewContext(), plugConfig, sid, ns, log)
guardGate.securityPlug.Init(signals.NewContext(), plugConfig, sid, ns, pi.Log)
defer guardGate.securityPlug.Shutdown()

err := http.ListenAndServe(target, mux)
log.Fatalf("Failed to open http local service: %s", err.Error())
pi.Log.Errorf("Failed to open http local service: %s", err.Error())
}
2 changes: 1 addition & 1 deletion cmd/guard-rproxy/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func Test_preMain(t *testing.T) {
os.Setenv(k, v)
}
//guardGate, mux, target, plugConfig, sid, ns, log := preMain()
_, mux, target, _, _, _, _ := preMain()
_, mux, target, _, _, _ := preMain()
if (mux != nil) != tt.mux {
t.Errorf("preMain() mux expected %t, received %t", tt.mux, mux != nil)
}
Expand Down
125 changes: 89 additions & 36 deletions cmd/guard-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,117 +25,171 @@ import (
"strings"
"time"

"go.uber.org/zap"
"github.com/kelseyhightower/envconfig"
spec "knative.dev/security-guard/pkg/apis/guard/v1alpha1"
utils "knative.dev/security-guard/pkg/guard-utils"

"github.com/kelseyhightower/envconfig"
pi "knative.dev/security-guard/pkg/pluginterfaces"
)

var log *zap.SugaredLogger

const (
serviceIntervalDefault = 5 * time.Minute
)

type config struct {
GuardServiceLogLevel string `split_words:"true" required:"false"`
GuardServiceInterval string `split_words:"true" required:"false"`
GuardServiceAuth bool `split_words:"true" required:"false"`
}

type learner struct {
services *services
pileLearnTicker *utils.Ticker
}

var env config

func (l *learner) authenticate(req *http.Request) (sid string, ns string, err error) {
token := req.Header.Get("Authorization")
if !strings.HasPrefix(token, "Bearer ") {
err = fmt.Errorf("missing token")
return
}
token = token[7:]
sid, ns, err = l.services.kmgr.TokenData(token)
if err != nil {
err = fmt.Errorf("cant verify token %w", err)
return
}
if sid == "ns-"+ns {
err = fmt.Errorf("token of a service with illegal name %s", sid)
return
}
return
}

// Common method used for parsing ns, sid, cmFlag from all requests
func (l *learner) baseHandler(query url.Values) (record *serviceRecord, err error) {
func (l *learner) queryData(query url.Values) (cmFlag bool, sid string, ns string, err error) {
cmFlagSlice := query["cm"]
sidSlice := query["sid"]
nsSlice := query["ns"]

if len(sidSlice) != 1 || len(nsSlice) != 1 || len(cmFlagSlice) > 1 {
err = fmt.Errorf("wrong data sid %d ns %d cmflag %d", len(sidSlice), len(nsSlice), len(cmFlagSlice))
err = fmt.Errorf("query has wrong cmflag/sid/ns length")
return
}

// extract and sanitize sid and ns
sid := utils.Sanitize(sidSlice[0])
ns := utils.Sanitize(nsSlice[0])
sid = utils.Sanitize(sidSlice[0])
ns = utils.Sanitize(nsSlice[0])

if strings.HasPrefix(sid, "ns-") {
sid = ""
err = fmt.Errorf("illegal sid %s", sid)
if sid == "ns-"+ns {
err = fmt.Errorf("query sid of a service with illegal name that starts with ns-")
return
}

if len(sid) < 1 {
err = fmt.Errorf("wrong sid %s", sidSlice[0])
err = fmt.Errorf("query missing sid")
return
}

if len(ns) < 1 {
err = fmt.Errorf("wrong ns %s", nsSlice[0])
err = fmt.Errorf("query missing ns")
return
}

// extract and sanitize cmFlag
var cmFlag bool
if len(cmFlagSlice) > 0 {
cmFlag = (cmFlagSlice[0] == "true")
}

return
}

func (l *learner) baseHandler(w http.ResponseWriter, req *http.Request) (record *serviceRecord, err error) {
var sid, ns, querySid, queryNs string
var cmFlag bool

cmFlag, querySid, queryNs, err = l.queryData(req.URL.Query())
if err != nil {
pi.Log.Infof("baseHandler queryData failed with %v", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pi.Log.Debugf("queryData ns %s, sid %s cmFlag %t", queryNs, querySid, cmFlag)

if env.GuardServiceAuth {
sid, ns, err = l.authenticate(req)
if err != nil {
pi.Log.Infof("baseHandler authenticate failed with %v", err)
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
pi.Log.Debugf("Authorized ns %s, sid %s", ns, sid)
} else {
sid = querySid
ns = queryNs
pi.Log.Debugf("Authorization skipped ns %s, sid %s", ns, sid)
}

// get session record, create one if does not exist
log.Debugf("** baseHandler ** ns %s, sid %s, cmFlag %t", ns, sid, cmFlag)
record = l.services.get(ns, sid, cmFlag)
if record == nil {
// should never happen
err = fmt.Errorf("internal error no record created")
err = fmt.Errorf("no record created")
pi.Log.Infof("internal error %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
pi.Log.Debugf("record found for ns %s, sid %s", ns, sid)
return
}

func (l *learner) fetchConfig(w http.ResponseWriter, req *http.Request) {
if req.Method != "GET" || req.URL.Path != "/config" {
http.Error(w, "404 not found.", http.StatusNotFound)
}

record, err := l.baseHandler(req.URL.Query())
record, err := l.baseHandler(w, req)
if err != nil {
log.Infof("fetchConfig Missing data %v", err)
http.Error(w, "Missing data", http.StatusBadRequest)
return
}

if req.Method != "GET" || req.URL.Path != "/config" {
http.Error(w, "404 not found.", http.StatusNotFound)
}

buf, err := json.Marshal(record.guardianSpec)
if err != nil {
// should never happen
log.Infof("Servicing fetchConfig error while JSON Marshal %v", err)
pi.Log.Infof("Servicing fetchConfig error while JSON Marshal %v", err)
http.Error(w, "Failed to marshal data", http.StatusInternalServerError)
return
}
pi.Log.Debugf("Servicing fetchConfig success")
w.Write(buf)
}

func (l *learner) processPile(w http.ResponseWriter, req *http.Request) {
var pile spec.SessionDataPile
var err error
record, err := l.baseHandler(req.URL.Query())
record, err := l.baseHandler(w, req)
if err != nil {
log.Infof("fetchConfig Missing data %v", err)
http.Error(w, "processPile Missing data", http.StatusBadRequest)
return
}
if req.Method != "POST" || req.URL.Path != "/pile" {
http.Error(w, "404 not found.", http.StatusNotFound)
return
}

if req.ContentLength == 0 || req.Body == nil {
http.Error(w, "400 not found.", http.StatusBadRequest)
return
}

var pile spec.SessionDataPile
err = json.NewDecoder(req.Body).Decode(&pile)
if err != nil {
log.Infof("processPile error: %s", err.Error())
pi.Log.Infof("processPile error: %v", err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
l.services.merge(record, &pile)

log.Debugf("Successful setting record.wsgate")
pi.Log.Debugf("Successful merging pile")

w.Write([]byte{})
}
Expand All @@ -146,20 +200,19 @@ func (l *learner) mainEventLoop(quit chan string) {
case <-l.pileLearnTicker.Ch():
l.services.tick()
case reason := <-quit:
log.Infof("mainEventLoop was asked to quit! - Reason: %s", reason)
pi.Log.Infof("mainEventLoop was asked to quit! - Reason: %s", reason)
return
}
}
}

// Set network policies to ensure that only pods in your trust domain can use the service!
func preMain(minimumInterval time.Duration) (*learner, *http.ServeMux, string, chan string) {
var env config
if err := envconfig.Process("", &env); err != nil {
fmt.Fprintf(os.Stderr, "Failed to process environment: %s\n", err.Error())
os.Exit(1)
}
log = utils.CreateLogger(env.GuardServiceLogLevel)
utils.CreateLogger(env.GuardServiceLogLevel)

l := new(learner)
l.pileLearnTicker = utils.NewTicker(minimumInterval)
Expand All @@ -176,7 +229,7 @@ func preMain(minimumInterval time.Duration) (*learner, *http.ServeMux, string, c

quit := make(chan string)

log.Infof("Starting guard-service on %s", target)
pi.Log.Infof("Starting guard-service on %s", target)
return l, mux, target, quit
}

Expand All @@ -189,6 +242,6 @@ func main() {
go l.mainEventLoop(quit)

err := http.ListenAndServe(target, mux)
log.Infof("Using target: %s - Failed to start %v", target, err)
pi.Log.Infof("Using target: %s - Failed to start %v", target, err)
quit <- "ListenAndServe failed"
}
Loading

0 comments on commit 7be60b6

Please sign in to comment.