-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathproxy.go
208 lines (173 loc) · 5.47 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
package main
// Note: This was built off of an example reverse proxy created by Ben Church
// https://github.com/bechurch/reverse-proxy-demo
import (
"encoding/json"
"io/ioutil"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"time"
)
var (
auth_endpoint_url string
auth_client_id string
auth_client_secret string
auth_scope string
proxy_downstream_url string
proxy_port string
access_token string
token_type string
token_refresh_time time.Time
api_key string
api_key_header string
)
// Structure for storing results from a
type AuthReponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// Proxies the incoming request to the downstream, adding Authorization
// header and optional API Key header
func handleRequestAndRedirect(res http.ResponseWriter, req *http.Request) {
url, err := url.Parse(proxy_downstream_url)
if err != nil {
log.Print(err)
http.Error(res, "Failure calling downstream proxy", http.StatusInternalServerError)
return
}
proxy := httputil.NewSingleHostReverseProxy(url)
req.URL.Host = url.Host
req.URL.Scheme = url.Scheme
req.Host = url.Host
req.Header.Set("Authorization", token_type+" "+access_token)
if api_key != "" {
req.Header.Set(api_key_header, api_key)
}
log.Printf("Proxy %s %s", req.Method, req.URL)
proxy.ServeHTTP(res, req)
}
// Gets (or refreshes) the access token using a jittered backed-off retry
func getOuath2AuthAccessToken() {
request_body := url.Values{
"grant_type": {"client_credentials"},
"client_id": {auth_client_id},
"client_secret": {auth_client_secret},
}
if auth_scope != "" {
request_body.Set("scope", auth_scope)
}
retry_number := -1
for {
retry_number++
if retry_number > 5 {
log.Print("Failed to acquire access token; exiting")
break
} else if retry_number > 0 {
seconds_to_wait := retry_number*retry_number + 1
log.Printf("Failed to aquired token; awaiting retry #%v in %v seconds", retry_number, seconds_to_wait)
retry_time := time.Duration(seconds_to_wait) * time.Second
time.Sleep(retry_time)
log.Printf("Retry #%v", retry_number)
}
log.Printf("Sending authentication request via POST to %s", auth_endpoint_url)
resp, err := http.PostForm(auth_endpoint_url, request_body)
if err != nil {
log.Print(err)
continue
}
if resp.StatusCode != 200 && resp.StatusCode != 201 {
log.Printf("Received non-200 status code: %s", resp.Status)
continue
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Print(err)
continue
}
//TODO: Error handling on unmarshalling the JSON payload
auth_response := AuthReponse{}
err = json.Unmarshal(body, &auth_response)
if err != nil {
log.Print(err)
continue
}
if auth_response.AccessToken == "" || auth_response.TokenType == "" || auth_response.ExpiresIn == 0 {
log.Print("Returned JSON document did not contain required fields")
continue
}
access_token = auth_response.AccessToken
token_type = auth_response.TokenType
expires := auth_response.ExpiresIn - (60 * 5)
token_refresh_time = time.Now().UTC().Add(time.Second * time.Duration(expires))
log.Print("Access token updated")
log.Printf("Token refresh scheduled at %s", token_refresh_time)
break
}
}
// Go routine to handle token refresh on a loop
func handleTokenRefresh() {
for {
current_time := time.Now().UTC()
if current_time.After(token_refresh_time) {
log.Print("Refreshing access token")
getOuath2AuthAccessToken()
}
time.Sleep(30 * time.Second)
}
}
// Retrieves a named environment variable. validates that required
// variables are supplied, and supplies defaults for missing values
func getEnvironmentVariable(key string, required bool, secret bool, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
if secret {
log.Printf("%s=**************", key)
} else {
log.Printf("%s=%s", key, value)
}
return value
}
if required {
log.Fatalf("Environment variable %s must be supplied", key)
}
if fallback != "" {
log.Printf("%s=%s (Default Value)", key, fallback)
}
return fallback
}
// Initialize variables from environment
func initVariables() {
auth_endpoint_url = getEnvironmentVariable("AUTH_ENDPOINT_URL", true, false, "")
auth_client_id = getEnvironmentVariable("AUTH_CLIENT_ID", true, true, "")
auth_client_secret = getEnvironmentVariable("AUTH_CLIENT_SECRET", true, true, "")
auth_scope = getEnvironmentVariable("AUTH_SCOPE", false, false, "")
proxy_downstream_url = getEnvironmentVariable("PROXY_DOWNSTREAM_URL", true, false, "")
proxy_port = getEnvironmentVariable("PROXY_PORT", false, false, "10801")
api_key = getEnvironmentVariable("PROXY_API_KEY", false, true, "")
if api_key != "" {
api_key_header = getEnvironmentVariable("PROXY_API_KEY_HEADER", false, false, "x-api-key")
}
}
// Main program entrypoint
func main() {
log.SetFlags(log.LstdFlags | log.LUTC)
log.Print("Initializing proxy")
initVariables()
log.Print("Getting initial access token")
getOuath2AuthAccessToken()
if access_token == "" {
log.Fatal("Failed to acquire initial access token - terminating")
}
log.Print("Starting access token refresh background routine")
go handleTokenRefresh()
listen_address := ":" + proxy_port
http.HandleFunc("/", handleRequestAndRedirect)
log.Printf("Listening to path / on %s", listen_address)
if err := http.ListenAndServe(listen_address, nil); err != nil {
panic(err)
}
}