Skip to content

Commit

Permalink
gzipping each session inside cookie
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Pristas authored and markbates committed Jan 23, 2018
1 parent 953d7c3 commit 12866fa
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 54 deletions.
48 changes: 45 additions & 3 deletions gothic/gothic.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ See https://github.com/markbates/goth/examples/main.go to see this in action.
package gothic

import (
"bytes"
"compress/gzip"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"os"
"strings"
"time"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -283,17 +287,55 @@ func getProviderName(req *http.Request) (string, error) {
func storeInSession(key string, value string, req *http.Request, res http.ResponseWriter) error {
session, _ := Store.Get(req, SessionName)

session.Values[key] = value
if err := updateSessionValue(session, key, value); err != nil {
return err
}

return session.Save(req, res)
}

func getFromSession(key string, req *http.Request) (string, error) {
session, _ := Store.Get(req, SessionName)
value, err := getSessionValue(session, key)
if err != nil {
return "", errors.New("could not find a matching session for this request")
}

return value, nil
}

func getSessionValue(session *sessions.Session, key string) (string, error) {
value := session.Values[key]
if value == nil {
return "", errors.New("could not find a matching session for this request")
return "", fmt.Errorf("could not find a matching session for this request")
}

rdata := strings.NewReader(value.(string))
r, err := gzip.NewReader(rdata)
if err != nil {
return "", err
}
s, err := ioutil.ReadAll(r)
if err != nil {
return "", err
}

return value.(string), nil
return string(s), nil
}

func updateSessionValue(session *sessions.Session, key, value string) error {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
if _, err := gz.Write([]byte(value)); err != nil {
return err
}
if err := gz.Flush(); err != nil {
return err
}
if err := gz.Close(); err != nil {
return err
}

session.Values[key] = b.String()
return nil
}
41 changes: 38 additions & 3 deletions gothic/gothic_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package gothic_test

import (
"bytes"
"compress/gzip"
"fmt"
"html"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/gorilla/sessions"
Expand Down Expand Up @@ -68,12 +72,13 @@ func Test_BeginAuthHandler(t *testing.T) {
if err != nil {
t.Fatalf("error getting faux Gothic session: %v", err)
}

sessStr, ok := sess.Values["faux"].(string)
if !ok {
t.Fatalf("Gothic session not stored as marshalled string; was %T (value %v)",
sess.Values["faux"], sess.Values["faux"])
}
gothSession, err := fauxProvider.UnmarshalSession(sessStr)
gothSession, err := fauxProvider.UnmarshalSession(ungzipString(sessStr))
if err != nil {
t.Fatalf("error unmarshalling faux Gothic session: %v", err)
}
Expand Down Expand Up @@ -124,7 +129,7 @@ func Test_CompleteUserAuth(t *testing.T) {

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, "faux"+SessionName)
session.Values["faux"] = sess.Marshal()
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)

Expand All @@ -144,7 +149,7 @@ func Test_Logout(t *testing.T) {

sess := faux.Session{Name: "Homer Simpson", Email: "[email protected]"}
session, _ := Store.Get(req, "faux"+SessionName)
session.Values["faux"] = sess.Marshal()
session.Values["faux"] = gzipString(sess.Marshal())
err = session.Save(req, res)
a.NoError(err)

Expand Down Expand Up @@ -197,3 +202,33 @@ func Test_StateValidation(t *testing.T) {
_, err = CompleteUserAuth(res, req)
a.Error(err)
}

func gzipString(value string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
if _, err := gz.Write([]byte(value)); err != nil {
return "err"
}
if err := gz.Flush(); err != nil {
return "err"
}
if err := gz.Close(); err != nil {
return "err"
}

return b.String()
}

func ungzipString(value string) string {
rdata := strings.NewReader(value)
r, err := gzip.NewReader(rdata)
if err != nil {
return "err"
}
s, err := ioutil.ReadAll(r)
if err != nil {
return "err"
}

return string(s)
}
2 changes: 1 addition & 1 deletion providers/microsoftonline/microsoftonline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func Test_SessionFromJSON(t *testing.T) {
a := assert.New(t)

provider := microsoftonlineProvider()
session, err := provider.UnmarshalSession(string(compressedSession()))
session, err := provider.UnmarshalSession(`{"AuthURL":"https://login.microsoftonline.com/common/oauth2/v2.0/authorize","AccessToken":"1234567890","ExpiresAt":"0001-01-01T00:00:00Z"}`)
a.NoError(err)

s := session.(*microsoftonline.Session)
Expand Down
32 changes: 3 additions & 29 deletions providers/microsoftonline/session.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package microsoftonline

import (
"bytes"
"compress/gzip"
"encoding/json"
"errors"
"io/ioutil"
"strings"
"time"

Expand Down Expand Up @@ -49,20 +46,8 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string,

// Marshal the session into a string
func (s Session) Marshal() string {
data, _ := json.Marshal(s)

var b bytes.Buffer
gz := gzip.NewWriter(&b)
if _, err := gz.Write([]byte(data)); err != nil {
panic(err)
}
if err := gz.Flush(); err != nil {
panic(err)
}
if err := gz.Close(); err != nil {
panic(err)
}
return b.String()
b, _ := json.Marshal(s)
return string(b)
}

func (s Session) String() string {
Expand All @@ -72,17 +57,6 @@ func (s Session) String() string {
// UnmarshalSession wil unmarshal a JSON string into a session.
func (p *Provider) UnmarshalSession(data string) (goth.Session, error) {
session := &Session{}

rdata := strings.NewReader(data)
r, err := gzip.NewReader(rdata)
if err != nil {
return session, err
}
s, err := ioutil.ReadAll(r)
if err != nil {
return session, err
}

err = json.NewDecoder(bytes.NewReader(s)).Decode(session)
err := json.NewDecoder(strings.NewReader(data)).Decode(session)
return session, err
}
19 changes: 1 addition & 18 deletions providers/microsoftonline/session_test.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,13 @@
package microsoftonline_test

import (
"fmt"
"runtime"
"testing"

"github.com/markbates/goth"
"github.com/markbates/goth/providers/microsoftonline"
"github.com/stretchr/testify/assert"
)

// compressed session of: Session{AuthURL:"https://login.microsoftonline.com/common/oauth2/v2.0/authorize",AccessToken: "1234567890"}
var compressedSession18 = []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 28, 198, 191, 10, 194, 48, 16, 7, 224, 119, 249, 205, 181, 119, 173, 255, 179, 117, 112, 115, 146, 186, 184, 73, 136, 38, 216, 228, 74, 114, 21, 81, 124, 119, 177, 240, 13, 223, 7, 221, 164, 254, 124, 58, 194, 192, 171, 142, 197, 16, 13, 114, 15, 169, 142, 193, 102, 41, 114, 83, 73, 67, 72, 174, 182, 18, 201, 74, 140, 146, 72, 174, 147, 250, 150, 158, 109, 205, 244, 175, 228, 240, 118, 168, 208, 89, 235, 74, 233, 229, 225, 18, 12, 154, 118, 185, 90, 111, 182, 187, 61, 163, 194, 225, 53, 134, 236, 74, 167, 48, 96, 230, 102, 49, 235, 153, 205, 236, 130, 239, 15, 0, 0, 255, 255, 1, 0, 0, 255, 255, 123, 236, 131, 18, 138, 0, 0, 0}
var compressedSession17 = []byte{31, 139, 8, 0, 0, 9, 110, 136, 0, 255, 28, 198, 191, 10, 194, 48, 16, 7, 224, 119, 249, 205, 181, 119, 173, 255, 179, 117, 112, 115, 146, 186, 184, 73, 136, 38, 216, 228, 74, 114, 21, 81, 124, 119, 177, 240, 13, 223, 7, 221, 164, 254, 124, 58, 194, 192, 171, 142, 197, 16, 13, 114, 15, 169, 142, 193, 102, 41, 114, 83, 73, 67, 72, 174, 182, 18, 201, 74, 140, 146, 72, 174, 147, 250, 150, 158, 109, 205, 244, 175, 228, 240, 118, 168, 208, 89, 235, 74, 233, 229, 225, 18, 12, 154, 118, 185, 90, 111, 182, 187, 61, 163, 194, 225, 53, 134, 236, 74, 167, 48, 96, 230, 102, 49, 235, 153, 205, 236, 130, 239, 15, 0, 0, 255, 255, 1, 0, 0, 255, 255, 123, 236, 131, 18, 138, 0, 0, 0}

// retrieves session based on runtime version as gziped values differs in some bytes between versions
func compressedSession() []byte {
var minor int
fmt.Sscanf(runtime.Version(), "go1.%d", &minor)

if minor <= 7 {
return compressedSession17
}
return compressedSession18
}

func Test_Implements_Session(t *testing.T) {
t.Parallel()
a := assert.New(t)
Expand Down Expand Up @@ -56,7 +39,7 @@ func Test_ToJSON(t *testing.T) {
}

data := s.Marshal()
a.Equal(compressedSession(), []byte(data))
a.Equal(`{"AuthURL":"https://login.microsoftonline.com/common/oauth2/v2.0/authorize","AccessToken":"1234567890","ExpiresAt":"0001-01-01T00:00:00Z"}`, data)
}

func Test_String(t *testing.T) {
Expand Down

0 comments on commit 12866fa

Please sign in to comment.