From 0629733fe83993d930852e68e17058bbf20a70b9 Mon Sep 17 00:00:00 2001 From: wangchen Date: Wed, 15 Jan 2025 03:46:36 +0800 Subject: [PATCH] feat: add lark app oauth support (#550) Co-authored-by: camdenwang --- providers/lark/lark.go | 307 +++++++++++++++++++++++++++++++++ providers/lark/lark_test.go | 185 ++++++++++++++++++++ providers/lark/session.go | 71 ++++++++ providers/lark/session_test.go | 112 ++++++++++++ 4 files changed, 675 insertions(+) create mode 100644 providers/lark/lark.go create mode 100644 providers/lark/lark_test.go create mode 100644 providers/lark/session.go create mode 100644 providers/lark/session_test.go diff --git a/providers/lark/lark.go b/providers/lark/lark.go new file mode 100644 index 00000000..d9900b9c --- /dev/null +++ b/providers/lark/lark.go @@ -0,0 +1,307 @@ +package lark + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +const ( + appAccessTokenURL string = "https://open.feishu.cn/open-apis/auth/v3/app_access_token/internal/" // get app_access_token + + authURL string = "https://open.feishu.cn/open-apis/authen/v1/authorize" // obtain authorization code + tokenURL string = "https://open.feishu.cn/open-apis/authen/v1/oidc/access_token" // get user_access_token + refreshTokenURL string = "https://open.feishu.cn/open-apis/authen/v1/oidc/refresh_access_token" // refresh user_access_token + endpointProfile string = "https://open.feishu.cn/open-apis/authen/v1/user_info" // get user info +) + +// Lark is the implementation of `goth.Provider` for accessing Lark +type Lark interface { + GetAppAccessToken() error // get app access token +} + +// Provider is the implementation of `goth.Provider` for accessing Lark +type Provider struct { + ClientKey string + Secret string + CallbackURL string + HTTPClient *http.Client + config *oauth2.Config + providerName string + + appAccessToken *appAccessToken +} + +// New creates a new Lark provider and sets up important connection details. +func New(clientKey, secret, callbackURL string, scopes ...string) *Provider { + p := &Provider{ + ClientKey: clientKey, + Secret: secret, + CallbackURL: callbackURL, + providerName: "lark", + appAccessToken: &appAccessToken{}, + } + p.config = newConfig(p, authURL, tokenURL, scopes) + return p +} + +func newConfig(provider *Provider, authURL, tokenURL string, scopes []string) *oauth2.Config { + c := &oauth2.Config{ + ClientID: provider.ClientKey, + ClientSecret: provider.Secret, + RedirectURL: provider.CallbackURL, + Endpoint: oauth2.Endpoint{ + AuthURL: authURL, + TokenURL: tokenURL, + }, + Scopes: []string{}, + } + + if len(scopes) > 0 { + c.Scopes = append(c.Scopes, scopes...) + } + return c +} + +func (p *Provider) Client() *http.Client { + return goth.HTTPClientWithFallBack(p.HTTPClient) +} + +func (p *Provider) Name() string { + return p.providerName +} + +func (p *Provider) SetName(name string) { + p.providerName = name +} + +type appAccessToken struct { + Token string + ExpiresAt time.Time + rMutex sync.RWMutex +} + +type appAccessTokenReq struct { + AppID string `json:"app_id"` // 自建应用的 app_id + AppSecret string `json:"app_secret"` // 自建应用的 app_secret +} + +type appAccessTokenResp struct { + Code int `json:"code"` // 错误码 + Msg string `json:"msg"` // 错误信息 + AppAccessToken string `json:"app_access_token"` // 用于调用应用级接口的 app_access_token + Expire int64 `json:"expire"` // app_access_token 的过期时间 +} + +// GetAppAccessToken get lark app access token +func (p *Provider) GetAppAccessToken() error { + // get from cache app access token + p.appAccessToken.rMutex.RLock() + if time.Now().Before(p.appAccessToken.ExpiresAt) { + p.appAccessToken.rMutex.RUnlock() + return nil + } + p.appAccessToken.rMutex.RUnlock() + + reqBody, err := json.Marshal(&appAccessTokenReq{ + AppID: p.ClientKey, + AppSecret: p.Secret, + }) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, appAccessTokenURL, bytes.NewBuffer(reqBody)) + if err != nil { + return fmt.Errorf("failed to create app access token request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.Client().Do(req) + if err != nil { + return fmt.Errorf("failed to send app access token request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code while fetching app access token: %d", resp.StatusCode) + } + + tokenResp := new(appAccessTokenResp) + if err = json.NewDecoder(resp.Body).Decode(tokenResp); err != nil { + return fmt.Errorf("failed to decode app access token response: %w", err) + } + + if tokenResp.Code != 0 { + return fmt.Errorf("failed to get app access token: code:%v msg: %s", tokenResp.Code, tokenResp.Msg) + } + + // update local cache + expirationDuration := time.Duration(tokenResp.Expire) * time.Second + p.appAccessToken.rMutex.Lock() + p.appAccessToken.Token = tokenResp.AppAccessToken + p.appAccessToken.ExpiresAt = time.Now().Add(expirationDuration) + p.appAccessToken.rMutex.Unlock() + + return nil +} + +func (p *Provider) BeginAuth(state string) (goth.Session, error) { + // build lark auth url + u, err := url.Parse(p.config.AuthCodeURL(state)) + if err != nil { + panic(err) + } + query := u.Query() + query.Del("response_type") + query.Del("client_id") + query.Add("app_id", p.ClientKey) + u.RawQuery = query.Encode() + + return &Session{ + AuthURL: u.String(), + }, nil +} + +func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { + s := &Session{} + err := json.NewDecoder(strings.NewReader(data)).Decode(s) + return s, err +} + +func (p *Provider) Debug(b bool) { +} + +type getUserAccessTokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshExpiresIn int `json:"refresh_expires_in"` + Scope string `json:"scope"` +} + +func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { + if err := p.GetAppAccessToken(); err != nil { + return nil, fmt.Errorf("failed to get app access token: %w", err) + } + reqBody := strings.NewReader(`{"grant_type":"refresh_token","refresh_token":"` + refreshToken + `"}`) + + req, err := http.NewRequest(http.MethodPost, refreshTokenURL, reqBody) + if err != nil { + return nil, fmt.Errorf("failed to create refresh token request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.appAccessToken.Token)) + + resp, err := p.Client().Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send refresh token request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code while refreshing token: %d", resp.StatusCode) + } + + var oauthResp commResponse[getUserAccessTokenResp] + err = json.NewDecoder(resp.Body).Decode(&oauthResp) + if err != nil { + return nil, fmt.Errorf("failed to decode refreshed token: %w", err) + } + if oauthResp.Code != 0 { + return nil, fmt.Errorf("failed to refresh token: code:%v msg: %s", oauthResp.Code, oauthResp.Msg) + } + + token := oauth2.Token{ + AccessToken: oauthResp.Data.AccessToken, + RefreshToken: oauthResp.Data.RefreshToken, + Expiry: time.Now().Add(time.Duration(oauthResp.Data.ExpiresIn) * time.Second), + } + + return &token, nil +} + +func (p *Provider) RefreshTokenAvailable() bool { + return true +} + +type commResponse[T any] struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data T `json:"data"` +} + +type larkUser struct { + OpenID string `json:"open_id"` + UnionID string `json:"union_id"` + UserID string `json:"user_id"` + Name string `json:"name"` + Email string `json:"enterprise_email"` + AvatarURL string `json:"avatar_url"` + Mobile string `json:"mobile,omitempty"` +} + +// FetchUser will go to Lark and access basic information about the user. +func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { + sess := session.(*Session) + user := goth.User{ + AccessToken: sess.AccessToken, + Provider: p.Name(), + RefreshToken: sess.RefreshToken, + ExpiresAt: sess.ExpiresAt, + } + if user.AccessToken == "" { + return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName) + } + + req, err := http.NewRequest("GET", endpointProfile, nil) + if err != nil { + return user, fmt.Errorf("%s failed to create request: %w", p.providerName, err) + } + req.Header.Set("Authorization", "Bearer "+user.AccessToken) + + resp, err := p.Client().Do(req) + if err != nil { + return user, fmt.Errorf("%s failed to get user information: %w", p.providerName, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, resp.StatusCode) + } + + responseBytes, err := io.ReadAll(resp.Body) + if err != nil { + return user, fmt.Errorf("failed to read response body: %w", err) + } + + var oauthResp commResponse[larkUser] + if err = json.Unmarshal(responseBytes, &oauthResp); err != nil { + return user, fmt.Errorf("failed to decode user info: %w", err) + } + if oauthResp.Code != 0 { + return user, fmt.Errorf("failed to get user info: code:%v msg: %s", oauthResp.Code, oauthResp.Msg) + } + + u := oauthResp.Data + user.UserID = u.UserID + user.Name = u.Name + user.Email = u.Email + user.AvatarURL = u.AvatarURL + user.NickName = u.Name + + if err = json.Unmarshal(responseBytes, &user.RawData); err != nil { + return user, err + } + return user, nil +} diff --git a/providers/lark/lark_test.go b/providers/lark/lark_test.go new file mode 100644 index 00000000..cda49e52 --- /dev/null +++ b/providers/lark/lark_test.go @@ -0,0 +1,185 @@ +package lark_test + +import ( + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "strings" + "testing" + + "github.com/markbates/goth/providers/lark" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type MockedHTTPClient struct { + mock.Mock +} + +func (m *MockedHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Mock.Called(req) + return args.Get(0).(*http.Response), args.Error(1) +} + +func Test_New(t *testing.T) { + t.Parallel() + a := assert.New(t) + p := larkProvider() + + a.Equal(p.ClientKey, os.Getenv("LARK_APP_ID")) + a.Equal(p.Secret, os.Getenv("LARK_APP_SECRET")) + a.Equal(p.CallbackURL, "/foo") +} + +func Test_BeginAuth(t *testing.T) { + t.Parallel() + a := assert.New(t) + p := larkProvider() + session, err := p.BeginAuth("test_state") + s := session.(*lark.Session) + a.NoError(err) + a.Contains(s.AuthURL, "https://open.feishu.cn/open-apis/authen/v1/authorize") + a.Contains(s.AuthURL, "app_id="+os.Getenv("LARK_APP_ID")) + a.Contains(s.AuthURL, "state=test_state") + a.Contains(s.AuthURL, fmt.Sprintf("redirect_uri=%s", url.QueryEscape("/foo"))) +} + +func Test_GetAppAccessToken(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"msg":"ok","app_access_token":"test_token","expire":3600}`)), + }, nil) + + err := p.GetAppAccessToken() + assert.NoError(t, err) + }) + + t.Run("error on request", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error")) + + err := p.GetAppAccessToken() + assert.Error(t, err) + }) + + t.Run("non-200 status code", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusForbidden, + Body: ioutil.NopCloser(strings.NewReader(``)), + }, nil) + + err := p.GetAppAccessToken() + assert.Error(t, err) + }) + + t.Run("error on response decode", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`not a json`)), + }, nil) + + err := p.GetAppAccessToken() + assert.Error(t, err) + }) + + t.Run("error code in response", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)), + }, nil) + + err := p.GetAppAccessToken() + assert.Error(t, err) + }) +} + +func Test_FetchUser(t *testing.T) { + session := &lark.Session{ + AccessToken: "user_access_token", + } + + t.Run("happy path", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"msg":"ok","data":{"user_id":"test_user_id","name":"test_name","avatar_url":"test_avatar_url","enterprise_email":"test_email"}}`)), + }, nil) + user, err := p.FetchUser(session) + require.NoError(t, err) + assert.Equal(t, user.UserID, "test_user_id") + assert.Equal(t, user.Name, "test_name") + assert.Equal(t, user.AvatarURL, "test_avatar_url") + assert.Equal(t, user.Email, "test_email") + }) + t.Run("error on request", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error")) + _, err := p.FetchUser(session) + require.Error(t, err) + }) + t.Run("non-200 status code", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusForbidden, + Body: ioutil.NopCloser(strings.NewReader(``)), + }, nil) + _, err := p.FetchUser(session) + require.Error(t, err) + }) + t.Run("error on response decode", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`not a json`)), + }, nil) + _, err := p.FetchUser(session) + require.Error(t, err) + }) + t.Run("error code in response", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)), + }, nil) + _, err := p.FetchUser(session) + require.Error(t, err) + }) +} + +func larkProvider() *lark.Provider { + return lark.New(os.Getenv("LARK_APP_ID"), os.Getenv("LARK_APP_SECRET"), "/foo") +} diff --git a/providers/lark/session.go b/providers/lark/session.go new file mode 100644 index 00000000..2fdf260c --- /dev/null +++ b/providers/lark/session.go @@ -0,0 +1,71 @@ +package lark + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/markbates/goth" +) + +type Session struct { + AuthURL string + AccessToken string + RefreshToken string + ExpiresAt time.Time + RefreshTokenExpiresAt time.Time +} + +func (s *Session) GetAuthURL() (string, error) { + if s.AuthURL == "" { + return "", errors.New("lark: missing AuthURL") + } + return s.AuthURL, nil +} + +func (s *Session) Marshal() string { + b, _ := json.Marshal(s) + return string(b) +} + +func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { + p := provider.(*Provider) + reqBody := strings.NewReader(`{"grant_type":"authorization_code","code":"` + params.Get("code") + `"}`) + req, err := http.NewRequest(http.MethodPost, tokenURL, reqBody) + if err != nil { + return "", fmt.Errorf("failed to create refresh token request: %w", err) + } + if err = p.GetAppAccessToken(); err != nil { + return "", fmt.Errorf("failed to get app access token: %w", err) + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", p.appAccessToken.Token)) + req.Header.Add("Content-Type", "application/json; charset=utf-8") + + resp, err := p.Client().Do(req) + if err != nil { + return "", fmt.Errorf("failed to send refresh token request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code while authorizing: %d", resp.StatusCode) + } + + var larkCommResp commResponse[getUserAccessTokenResp] + err = json.NewDecoder(resp.Body).Decode(&larkCommResp) + if err != nil { + return "", fmt.Errorf("failed to decode commResponse: %w", err) + } + if larkCommResp.Code != 0 { + return "", fmt.Errorf("failed to get accessToken: code:%v msg: %s", larkCommResp.Code, larkCommResp.Msg) + } + + s.AccessToken = larkCommResp.Data.AccessToken + s.RefreshToken = larkCommResp.Data.RefreshToken + s.ExpiresAt = time.Now().Add(time.Duration(larkCommResp.Data.ExpiresIn) * time.Second) + s.RefreshTokenExpiresAt = time.Now().Add(time.Duration(larkCommResp.Data.RefreshExpiresIn) * time.Second) + return s.AccessToken, nil +} diff --git a/providers/lark/session_test.go b/providers/lark/session_test.go new file mode 100644 index 00000000..59dc53f2 --- /dev/null +++ b/providers/lark/session_test.go @@ -0,0 +1,112 @@ +package lark_test + +import ( + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/lark" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type MockParams struct { + params map[string]string +} + +func (m *MockParams) Get(key string) string { + return m.params[key] +} + +func Test_Implements_Session(t *testing.T) { + t.Parallel() + a := assert.New(t) + s := &lark.Session{} + + a.Implements((*goth.Session)(nil), s) +} + +func Test_GetAuthURL(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + session := &lark.Session{ + AuthURL: "https://auth.url", + } + url, err := session.GetAuthURL() + assert.NoError(t, err) + assert.Equal(t, "https://auth.url", url) + }) + + t.Run("missing AuthURL", func(t *testing.T) { + session := &lark.Session{} + _, err := session.GetAuthURL() + assert.Error(t, err) + }) +} + +func Test_Marshal(t *testing.T) { + session := &lark.Session{ + AuthURL: "https://auth.url", + AccessToken: "access_token", + } + marshaled := session.Marshal() + assert.Contains(t, marshaled, "https://auth.url") + assert.Contains(t, marshaled, "access_token") +} + +func Test_Authorize(t *testing.T) { + session := &lark.Session{} + params := &MockParams{ + params: map[string]string{ + "code": "authorization_code", + }, + } + + t.Run("error on request", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{}, errors.New("request error")) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) + + t.Run("non-200 status code", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusForbidden, + Body: ioutil.NopCloser(strings.NewReader(``)), + }, nil) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) + + t.Run("error on response decode", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`not a json`)), + }, nil) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) + + t.Run("error code in response", func(t *testing.T) { + mockClient := new(MockedHTTPClient) + p := larkProvider() + p.HTTPClient = &http.Client{Transport: mockClient} + mockClient.On("RoundTrip", mock.Anything).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(strings.NewReader(`{"code":1,"msg":"error message"}`)), + }, nil) + _, err := session.Authorize(p, params) + require.Error(t, err) + }) +}