Skip to content

Commit

Permalink
compressed session for microsoft online provider
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Pristas authored and markbates committed Jan 23, 2018
1 parent a13db52 commit 953d7c3
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 22 deletions.
39 changes: 24 additions & 15 deletions providers/microsoftonline/microsoftonline.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,25 @@
package microsoftonline

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/markbates/going/defaults"
"github.com/markbates/goth"
"golang.org/x/oauth2"
)

const (
authURL string = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
tokenURL string = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
endpointProfile string = "https://graph.windows.net/v1.0/me"
endpointProfile string = "https://graph.microsoft.com/v1.0/me"
)

var defaultScopes = []string{"openid", "offline_access", "user.read"}

// New creates a new microsoftonline provider, and sets up important connection details.
// You should always call `microsoftonline.New` to get a new Provider. Never try to create
// one manually.
Expand Down Expand Up @@ -101,6 +105,8 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
}

user.AccessToken = msSession.AccessToken

err = userFromReader(response.Body, &user)
return user, err
}
Expand Down Expand Up @@ -138,40 +144,43 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config {
Scopes: []string{},
}

if len(scopes) > 0 {
for _, scope := range scopes {
c.Scopes = append(c.Scopes, scope)
}
} else {
c.Scopes = append(c.Scopes,
"openid",
"offline_access",
"user.read")
c.Scopes = append(c.Scopes, scopes...)
if len(scopes) == 0 {
c.Scopes = append(c.Scopes, defaultScopes...)
}

return c
}

func userFromReader(r io.Reader, user *goth.User) error {
buf := &bytes.Buffer{}
tee := io.TeeReader(r, buf)

u := struct {
ID string `json:"id"`
Name string `json:"displayName"`
Email string `json:"mail"`
FirstName string `json:"givenName"`
LastName string `json:"surname"`
UserPrincipalName string `json:"userPrincipalName"`
}{}

err := json.NewDecoder(r).Decode(&u)
if err != nil {
if err := json.NewDecoder(tee).Decode(&u); err != nil {
return err
}

user.Email = u.UserPrincipalName
raw := map[string]interface{}{}
if err := json.NewDecoder(buf).Decode(&raw); err != nil {
return err
}

user.UserID = u.ID
user.Email = defaults.String(u.Email, u.UserPrincipalName)
user.Name = u.Name
user.NickName = u.Name
user.FirstName = u.FirstName
user.LastName = u.LastName
user.NickName = u.Name
user.UserID = u.ID
user.RawData = raw

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion providers/microsoftonline/microsoftonline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func Test_SessionFromJSON(t *testing.T) {
a := assert.New(t)

provider := microsoftonlineProvider()
session, err := provider.UnmarshalSession(`{"AuthURL":"https://login.microsoftonline.com/common/oauth2/v2.0/authorize","AccessToken":"1234567890"}`)
session, err := provider.UnmarshalSession(string(compressedSession()))
a.NoError(err)

s := session.(*microsoftonline.Session)
Expand Down
32 changes: 29 additions & 3 deletions providers/microsoftonline/session.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package microsoftonline

import (
"bytes"
"compress/gzip"
"encoding/json"
"errors"
"io/ioutil"
"strings"
"time"

Expand Down Expand Up @@ -46,8 +49,20 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string,

// Marshal the session into a string
func (s Session) Marshal() string {
b, _ := json.Marshal(s)
return string(b)
data, _ := json.Marshal(s)

var b bytes.Buffer
gz := gzip.NewWriter(&b)
if _, err := gz.Write([]byte(data)); err != nil {
panic(err)
}
if err := gz.Flush(); err != nil {
panic(err)
}
if err := gz.Close(); err != nil {
panic(err)
}
return b.String()
}

func (s Session) String() string {
Expand All @@ -57,6 +72,17 @@ func (s Session) String() string {
// UnmarshalSession wil unmarshal a JSON string into a session.
func (p *Provider) UnmarshalSession(data string) (goth.Session, error) {
session := &Session{}
err := json.NewDecoder(strings.NewReader(data)).Decode(session)

rdata := strings.NewReader(data)
r, err := gzip.NewReader(rdata)
if err != nil {
return session, err
}
s, err := ioutil.ReadAll(r)
if err != nil {
return session, err
}

err = json.NewDecoder(bytes.NewReader(s)).Decode(session)
return session, err
}
29 changes: 26 additions & 3 deletions providers/microsoftonline/session_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
package microsoftonline_test

import (
"fmt"
"runtime"
"testing"

"github.com/markbates/goth"
"github.com/markbates/goth/providers/microsoftonline"
"github.com/stretchr/testify/assert"
)

// compressed session of: Session{AuthURL:"https://login.microsoftonline.com/common/oauth2/v2.0/authorize",AccessToken: "1234567890"}
var compressedSession18 = []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 28, 198, 191, 10, 194, 48, 16, 7, 224, 119, 249, 205, 181, 119, 173, 255, 179, 117, 112, 115, 146, 186, 184, 73, 136, 38, 216, 228, 74, 114, 21, 81, 124, 119, 177, 240, 13, 223, 7, 221, 164, 254, 124, 58, 194, 192, 171, 142, 197, 16, 13, 114, 15, 169, 142, 193, 102, 41, 114, 83, 73, 67, 72, 174, 182, 18, 201, 74, 140, 146, 72, 174, 147, 250, 150, 158, 109, 205, 244, 175, 228, 240, 118, 168, 208, 89, 235, 74, 233, 229, 225, 18, 12, 154, 118, 185, 90, 111, 182, 187, 61, 163, 194, 225, 53, 134, 236, 74, 167, 48, 96, 230, 102, 49, 235, 153, 205, 236, 130, 239, 15, 0, 0, 255, 255, 1, 0, 0, 255, 255, 123, 236, 131, 18, 138, 0, 0, 0}
var compressedSession17 = []byte{31, 139, 8, 0, 0, 9, 110, 136, 0, 255, 28, 198, 191, 10, 194, 48, 16, 7, 224, 119, 249, 205, 181, 119, 173, 255, 179, 117, 112, 115, 146, 186, 184, 73, 136, 38, 216, 228, 74, 114, 21, 81, 124, 119, 177, 240, 13, 223, 7, 221, 164, 254, 124, 58, 194, 192, 171, 142, 197, 16, 13, 114, 15, 169, 142, 193, 102, 41, 114, 83, 73, 67, 72, 174, 182, 18, 201, 74, 140, 146, 72, 174, 147, 250, 150, 158, 109, 205, 244, 175, 228, 240, 118, 168, 208, 89, 235, 74, 233, 229, 225, 18, 12, 154, 118, 185, 90, 111, 182, 187, 61, 163, 194, 225, 53, 134, 236, 74, 167, 48, 96, 230, 102, 49, 235, 153, 205, 236, 130, 239, 15, 0, 0, 255, 255, 1, 0, 0, 255, 255, 123, 236, 131, 18, 138, 0, 0, 0}

// retrieves session based on runtime version as gziped values differs in some bytes between versions
func compressedSession() []byte {
var minor int
fmt.Sscanf(runtime.Version(), "go1.%d", &minor)

if minor <= 7 {
return compressedSession17
}
return compressedSession18
}

func Test_Implements_Session(t *testing.T) {
t.Parallel()
a := assert.New(t)
Expand All @@ -33,16 +50,22 @@ func Test_GetAuthURL(t *testing.T) {
func Test_ToJSON(t *testing.T) {
t.Parallel()
a := assert.New(t)
s := &microsoftonline.Session{}
s := &microsoftonline.Session{
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
AccessToken: "1234567890",
}

data := s.Marshal()
a.Equal(`{"AuthURL":"","AccessToken":"","ExpiresAt":"0001-01-01T00:00:00Z"}`, data)
a.Equal(compressedSession(), []byte(data))
}

func Test_String(t *testing.T) {
t.Parallel()
a := assert.New(t)
s := &microsoftonline.Session{}
s := &microsoftonline.Session{
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
AccessToken: "1234567890",
}

a.Equal(s.String(), s.Marshal())
}

0 comments on commit 953d7c3

Please sign in to comment.