-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.go
328 lines (263 loc) · 8.2 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path"
"strings"
"time"
"github.com/BurntSushi/toml"
"github.com/pkg/errors"
)
type oidcConfig struct {
Provider string
ClientID string
ClientSecret string
HostedDomain string
Alias map[string]alias
}
type alias struct {
Arn string
SourceRole string
RoleChain []string
}
var envFormat = flag.Bool("env", false, "output credentials in format suitable for use with $()")
var loginFormat = flag.Bool("login", false, "generate login link for AWS web console")
var verbose = flag.Bool("verbose", false, "log verbose/debug output")
var sourceRole = flag.String("sourcerole", "", "source role to assume before assuming target role")
var aliasFlag = flag.String("alias", "", "alias configured in ~/.oidc2aws/oidcconfig")
var shell = flag.String("shell", "", "shell type, possible values: bash, zsh, sh, fish, csh, tcsh")
func arnFilename(arn string) string {
arn = strings.Replace(arn, "/", "-", -1)
arn = strings.Replace(arn, ":", "-", -1)
return path.Join(os.Getenv("HOME"), ".oidc2aws", arn)
}
func printCredentials(result *result) error {
if *envFormat {
// Get the name of current user's default shell
default_shell := os.Getenv("SHELL")
current_shell := path.Base(default_shell)
// If the user has specified a shell, use that instead
if *shell != "" {
current_shell = *shell
}
// Check the shell type and print the appropriate command to export the variable
switch current_shell {
case "fish":
// For fish, use the set command
fmt.Printf("set -x AWS_ACCESS_KEY_ID %s\n", *result.Credentials.AccessKeyId)
fmt.Printf("set -x AWS_SECRET_ACCESS_KEY %s\n", *result.Credentials.SecretAccessKey)
fmt.Printf("set -x AWS_SESSION_TOKEN %s\n", *result.Credentials.SessionToken)
case "csh", "tcsh":
// For csh and tcsh, use the setenv command
fmt.Printf("setenv AWS_ACCESS_KEY_ID %s\n", *result.Credentials.AccessKeyId)
fmt.Printf("setenv AWS_SECRET_ACCESS_KEY %s\n", *result.Credentials.SecretAccessKey)
fmt.Printf("setenv AWS_SESSION_TOKEN %s\n", *result.Credentials.SessionToken)
case "bash", "zsh", "sh":
fallthrough
default:
// For bash, zsh, sh and any other shell, use the export command
fmt.Printf("export AWS_ACCESS_KEY_ID=%s\n", *result.Credentials.AccessKeyId)
fmt.Printf("export AWS_SECRET_ACCESS_KEY=%s\n", *result.Credentials.SecretAccessKey)
fmt.Printf("export AWS_SESSION_TOKEN=%s\n", *result.Credentials.SessionToken)
}
return nil
} else if *loginFormat {
return fetchSigninToken(result)
}
b, err := json.Marshal(result)
if err != nil {
return errors.Wrap(err, "error serialising credentials to json")
}
// Write credentials to stdout
_, err = os.Stdout.Write(b)
return err
}
var debug = func(_ ...interface{}) {}
var debugf = func(_ string, _ ...interface{}) {}
func main() {
oc := oidcConfig{}
if _, err := toml.DecodeFile(path.Join(os.Getenv("HOME"), ".oidc2aws", "oidcconfig"), &oc); err != nil {
log.Fatal(errors.Wrap(err, "error loading OIDC config"))
}
flag.Parse()
if *verbose {
debug = log.Println
debugf = log.Printf
}
args := flag.Args()
roles := []string{}
// 1. len(args) == 1 => initialRole = args[0]
// 2. len(args) == 1, sourceRole != nil => roles = [sourceRole, ...args] <= legacy
// 3. alias(arn) => roles = [arn]
// 4. alias(arn, sourceRole) => initialRole = sourceRole, roles = [sourceRole, arn] <= legacy
// 5. alias(rolechain) => roles = rolechain
// 5. len(args) > 1 => roles = args
if *aliasFlag != "" {
alias, ok := oc.Alias[*aliasFlag]
if !ok {
log.Fatalf("unknown alias: %s", *aliasFlag)
}
if len(alias.RoleChain) > 0 {
roles = alias.RoleChain
} else {
if alias.SourceRole != "" {
roles = []string{alias.SourceRole, alias.Arn}
} else {
roles = []string{alias.Arn}
}
}
} else if *sourceRole != "" {
roles = []string{*sourceRole}
roles = append(roles, args...)
} else {
if len(args) > 0 {
roles = args
}
}
if len(roles) == 0 {
log.Fatal("no roles provided, please add roles as args or use -alias")
}
var r fetcher = cache{oidcFetcher{arn: roles[0], oc: oc}}
for _, role := range roles[1:] {
r = cache{assumedRole{arn: role, upstream: r}}
}
result, err := r.fetchCredentials()
if err != nil {
log.Fatal(err)
}
if err := printCredentials(result); err != nil {
log.Fatal(errors.Wrap(err, "error printing credentials"))
}
}
type fetcher interface {
fetchCredentials() (*result, error)
}
type cacheable interface {
key() string
}
type expiring interface {
Expiry() *time.Time
}
type assumedRole struct {
arn string
upstream fetcher
}
func (this assumedRole) fetchCredentials() (*result, error) {
debugf("assuming role %s", this.arn)
result, err := this.upstream.fetchCredentials()
if err != nil {
return nil, errors.Wrap(err, "error fetching credentials from upstream")
} else if result == nil {
return nil, errors.Wrap(err, "upstream didn't return credentials")
}
// return assumeRole(this.arn, idToken.email+","+*result.Credentials.AccessKeyId, result)
return assumeRole(this.arn, *result.Credentials.AccessKeyId, result)
}
func (this assumedRole) key() string {
return this.arn
}
type oidcFetcher struct {
oc oidcConfig
arn string
}
func (this oidcFetcher) fetchCredentials() (*result, error) {
idToken := &idTokenResult{}
key := "id-token"
err := get(key, idToken)
if err != nil && err != errNotFound {
return nil, errors.Wrap(err, "error reading cached file")
}
if err == errNotFound {
debug("no cached id token, fetching...")
del(key)
idToken, err = fetchIDToken(this.oc)
if err != nil {
return nil, errors.Wrap(err, "error fetching id token")
}
put(key, idToken)
if err != nil {
return nil, errors.Wrap(err, "error caching id token")
}
} else {
debug("using cached id token...")
}
debugf("swapping id token for credentials for role %s...", this.arn)
result, err := fetchAWSCredentials(this.arn, idToken.RawToken, idToken.Email)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf(`error fetching aws credentials (is %q an allowed value for "accounts.google.com:sub" in Trust relationship conditions for role?)`, idToken.Sub))
}
return result, err
}
func (this oidcFetcher) key() string {
return this.arn
}
type cacheableFetcher interface {
cacheable
fetcher
}
type cache struct {
f cacheableFetcher
}
func (this cache) fetchCredentials() (*result, error) {
key := this.f.key()
// Check for AWS credentials for role, use them if not expired (cache checks expiry)
result := &result{}
err := get(key, result)
if err == nil {
debugf("have cached credentials for %s", key)
return result, err
} else if err != errNotFound {
return nil, errors.Wrap(err, "error reading cached credentials for role")
}
debugf("no cached credentials for %s, fetching...", key)
result, err = this.f.fetchCredentials()
if err != nil {
return nil, errors.Wrap(err, "error fetching credentials")
}
put(key, result)
if err != nil {
return nil, errors.Wrap(err, "error caching credentials")
}
return result, err
}
func put(key string, val interface{}) error {
debugf("writing credentials for %s to cache", key)
bytes, err := json.Marshal(val)
if err != nil {
return errors.Wrap(err, "error serialising credentials to json")
}
err = os.WriteFile(arnFilename(key), bytes, 0600)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error writing credentials to file %q", key))
}
return nil
}
var errNotFound = errors.New("no value for key")
func get(key string, val interface{}) error {
debugf("fetching cached credentials for %s...", key)
data, err := os.ReadFile(arnFilename(key))
if err != nil {
if os.IsNotExist(err) {
return errNotFound
}
return errors.Wrap(err, fmt.Sprintf("error reading credential cache file %q", key))
}
err = json.Unmarshal(data, &val)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("error decoding credential cache from file %q", key))
}
exp, ok := val.(expiring)
debugf("credentials expire at %v", exp.Expiry())
if expiry := exp.Expiry(); ok && (expiry == nil || exp.Expiry().Add(-5*time.Minute).Before(time.Now())) {
debug("credentials nil or expiring in less than 5 minutes")
del(key)
return errNotFound
}
return nil
}
func del(key string) error {
return os.Remove(arnFilename(key))
}