Created
November 1, 2022 15:07
-
-
Save jxsl13/70c237f84adaa637fe0591a6a128afc3 to your computer and use it in GitHub Desktop.
oauth2 resty wrapper (auto token refresh)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package client | |
import ( | |
"sync" | |
"github.com/go-resty/resty/v2" | |
) | |
var ( | |
factoryOnce sync.Once | |
factory *ClientFactory | |
) | |
type Config struct { | |
Insecure bool | |
TokenUrl string | |
ClientId string | |
ClientSecret string | |
} | |
func Init(config Config) (err error) { | |
factoryOnce.Do(func() { | |
factory = NewClientFactory(config) | |
// directly fetch a token in order to see | |
// whether the initialization process was successful | |
// as well as to see whether the credentials are correct | |
_, err = factory.tokenProvider.UpdateToken() | |
}) | |
return err | |
} | |
func newFactory() *ClientFactory { | |
factoryOnce.Do(func() { | |
panic("client factory not initialized") | |
}) | |
return factory | |
} | |
func New() *resty.Client { | |
return newFactory().New() | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package client | |
func NewUserCredentials(username, password string) map[string]string { | |
return map[string]string{ | |
"grant_type": "password", | |
"client_id": "public", | |
"username": username, | |
"password": password, | |
} | |
} | |
func NewClientCredentials(clientID, clientSecret string) map[string]string { | |
return map[string]string{ | |
"grant_type": "client_credentials", | |
"client_id": clientID, | |
"client_secret": clientSecret, | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package client | |
// UnexpectedTokenResponse is returned by the auth/client package | |
// when the client fails to fetch a new JWT. | |
type UnexpectedTokenResponse struct { | |
Msg string | |
Err string | |
} | |
func (e UnexpectedTokenResponse) Error() string { | |
return e.Msg + ": " + e.Err | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package client | |
import ( | |
"crypto/tls" | |
"time" | |
"github.com/go-resty/resty/v2" | |
) | |
func NewClientFactory(config Config) *ClientFactory { | |
return &ClientFactory{ | |
tokenProvider: getTokenProvider(config), | |
insecure: config.Insecure, | |
} | |
} | |
type ClientFactory struct { | |
tokenProvider *TokenProvider | |
insecure bool | |
} | |
func (cf *ClientFactory) UpdateToken() (*JWT, error) { | |
return cf.tokenProvider.UpdateToken() | |
} | |
// Ping can be use din order to check if a connection to the configured | |
// api can be properly established or not | |
func (cf *ClientFactory) Ping() error { | |
_, err := cf.tokenProvider.UpdateToken() | |
return err | |
} | |
func getTokenProvider(config Config) *TokenProvider { | |
service := NewTokenProvider( | |
config.TokenUrl, | |
NewClientCredentials( | |
config.ClientId, | |
config.ClientSecret, | |
), | |
) | |
service.WithTLSClientConfig( | |
&tls.Config{ | |
InsecureSkipVerify: config.Insecure, | |
}, | |
) | |
return service | |
} | |
func (cs *ClientFactory) New() *resty.Client { | |
return resty.New(). | |
AddRetryCondition(NewUnauthorizedCondition(cs.tokenProvider)). | |
OnBeforeRequest(NewBearerTokenMiddleware(cs.tokenProvider)). | |
SetRetryCount(3). | |
SetRetryAfter(NewRetryAfter(10 * time.Second)). | |
SetTLSClientConfig( | |
&tls.Config{ | |
InsecureSkipVerify: cs.insecure, | |
}, | |
) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package client | |
import ( | |
"net/http" | |
"strconv" | |
"time" | |
"github.com/go-resty/resty/v2" | |
) | |
func NewBearerTokenMiddleware(t *TokenProvider) resty.RequestMiddleware { | |
return func(_ *resty.Client, req *resty.Request) error { | |
jwt, err := t.Token() | |
if jwt != nil { | |
req.SetAuthToken(jwt.AccessToken) | |
} | |
return err | |
} | |
} | |
func NewUnauthorizedCondition(t *TokenProvider) resty.RetryConditionFunc { | |
return func(r *resty.Response, err error) bool { | |
if err != nil { | |
return true | |
} | |
if r != nil { | |
if r.StatusCode() == http.StatusUnauthorized { | |
t.UpdateToken() | |
} | |
return r.IsError() && r.StatusCode() != http.StatusForbidden | |
} | |
return true | |
} | |
} | |
func NewServiceUnavailableCondition() resty.RetryConditionFunc { | |
return func(r *resty.Response, err error) bool { | |
return r.StatusCode() == http.StatusServiceUnavailable | |
} | |
} | |
func NewRetryAfter(defaultSleep time.Duration) resty.RetryAfterFunc { | |
return func(_ *resty.Client, resp *resty.Response) (time.Duration, error) { | |
// see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After | |
retryAfter := resp.Header().Get("Retry-After") | |
// check if header is present | |
if retryAfter != "" { | |
// try int convertion for "seconds only" format | |
retrySec, errRetrySec := strconv.Atoi(retryAfter) | |
if errRetrySec == nil { | |
if retrySec > -1 { | |
return time.Duration(retrySec) * time.Second, nil | |
} | |
} else { | |
// try absolute time conversion | |
nowTS := time.Now() | |
retryTS, errParse := time.Parse(time.RFC1123, retryAfter) | |
if errParse == nil && retryTS.After(nowTS) { | |
dur := retryTS.Sub(nowTS) | |
return dur, nil | |
} | |
} | |
} | |
// return default | |
return defaultSleep, nil | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package client | |
import ( | |
"crypto/tls" | |
"errors" | |
"sync" | |
"time" | |
"stoxdog/internal/logging" | |
"github.com/go-resty/resty/v2" | |
) | |
type JWT struct { | |
AccessToken string `json:"access_token"` | |
IDToken string `json:"id_token"` | |
ExpiresIn int `json:"expires_in"` | |
RefreshExpiresIn int `json:"refresh_expires_in"` | |
RefreshToken string `json:"refresh_token"` | |
TokenType string `json:"token_type"` | |
NotBeforePolicy int64 `json:"not-before-policy"` | |
SessionState string `json:"session_state"` | |
Scope string `json:"scope"` | |
} | |
type TokenProvider struct { | |
restyClient *resty.Client | |
tokenUrl string | |
credentials map[string]string | |
currentJWT *JWT | |
tokenExpireTime time.Time | |
mu sync.Mutex | |
logger logging.Logger | |
} | |
func NewTokenProvider(tokenUrl string, credentials map[string]string) *TokenProvider { | |
return &TokenProvider{ | |
restyClient: resty.New(), | |
tokenUrl: tokenUrl, | |
credentials: credentials, | |
logger: &logging.NoOpLogger{}, | |
} | |
} | |
func (t *TokenProvider) Token() (*JWT, error) { | |
if t == nil { | |
return nil, errors.New("Token service not initialized") | |
} | |
t.mu.Lock() | |
defer t.mu.Unlock() | |
if t.currentJWT == nil { | |
return t.updateToken() | |
} | |
if !t.tokenExpireTime.IsZero() && time.Now().After(t.tokenExpireTime) { | |
t.logger.Debug("Token expired. Updating token.") | |
return t.updateToken() | |
} | |
return t.currentJWT, nil | |
} | |
func (t *TokenProvider) UpdateToken() (*JWT, error) { | |
if t == nil { | |
return nil, errors.New("Token service not initialized") | |
} | |
t.mu.Lock() | |
defer t.mu.Unlock() | |
return t.updateToken() | |
} | |
func (t *TokenProvider) updateToken() (*JWT, error) { | |
resultJWT := JWT{} | |
resp, err := t.restyClient.R(). | |
SetFormData(t.credentials). | |
SetHeader("Cache-Control", "no-cache"). | |
SetResult(&resultJWT). | |
Post(t.tokenUrl) | |
if err != nil { | |
return nil, err | |
} | |
if resp.StatusCode()/100 != 2 { | |
return nil, UnexpectedTokenResponse{ | |
Msg: "unexpected authentication error", | |
Err: string(resp.Body()), | |
} | |
} | |
t.currentJWT = &resultJWT | |
if t.currentJWT.ExpiresIn > 0 { | |
t.tokenExpireTime = time.Now().Add(time.Duration(t.currentJWT.ExpiresIn-30) * time.Second) | |
} else { | |
t.tokenExpireTime = time.Time{} | |
} | |
if t.currentJWT.NotBeforePolicy > 0 { | |
wait := time.Until(time.Unix(t.currentJWT.NotBeforePolicy, 0)) | |
time.Sleep(wait) | |
} | |
return t.currentJWT, nil | |
} | |
func (t *TokenProvider) WithTLSClientConfig(config *tls.Config) *TokenProvider { | |
t.mu.Lock() | |
defer t.mu.Unlock() | |
t.restyClient.SetTLSClientConfig(config) | |
return t | |
} | |
func (t *TokenProvider) WithLogger(logger logging.Logger) *TokenProvider { | |
t.mu.Lock() | |
defer t.mu.Unlock() | |
t.logger = logger | |
t.restyClient.SetLogger(logger) | |
return t | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment