-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathmain.go
140 lines (112 loc) · 3.34 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"strconv"
"github.com/coreos/go-oidc"
"github.com/gorilla/handlers"
"github.com/julienschmidt/httprouter"
flag "github.com/spf13/pflag"
)
// Claims stores the values we want to extract from the JWT as JSON
type Claims struct {
Email string `json:"email"`
CommonName string `json:"common_name"`
}
func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}
var (
// default flag values
authDomain = ""
address = ""
port = 8080
// jwt signing keys
keySet oidc.KeySet
)
func init() {
authDomain = getEnv("AUTH_DOMAIN", authDomain)
address = getEnv("LISTEN_ADDRESS", address)
port, _ = strconv.Atoi(getEnv("LISTEN_PORT", fmt.Sprintf("%d", port)))
// parse flags
flag.StringVar(&authDomain, "auth-domain", authDomain, "authentication domain (https://foo.cloudflareaccess.com)")
flag.IntVar(&port, "port", port, fmt.Sprintf("http port to listen on (default %d)", port))
flag.StringVar(&address, "address", address, "http address to listen on (leave empty to listen on all interfaces)")
flag.Parse()
// --auth-domain is required
if authDomain == "" {
fmt.Println("ERROR: Please set --auth-domain to the authorization domain you configured on cloudflare. Should be like `https://foo.cloudflareaccess.com`")
flag.Usage()
os.Exit(1)
}
if port <= 0 {
fmt.Printf("ERROR: Invalid port number %d \n", port)
flag.Usage()
os.Exit(1)
}
// configure keyset
certsURL := fmt.Sprintf("%s/cdn-cgi/access/certs", authDomain)
keySet = oidc.NewRemoteKeySet(context.TODO(), certsURL)
}
func main() {
log.Printf("Authentication domain: %s\n", authDomain)
// set up routes
router := httprouter.New()
router.GET("/auth/:audience", authHandler)
// listen
addr := fmt.Sprintf("%s:%d", address, port)
log.Printf("Listening on %s", addr)
log.Fatalln(http.ListenAndServe(addr, handlers.LoggingHandler(os.Stdout, router)))
}
func authHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Get audience from request params
audience := ps.ByName("audience")
// Configure verifier
config := &oidc.Config{
ClientID: audience,
}
verifier := oidc.NewVerifier(authDomain, keySet, config)
// Make sure that the incoming request has our token header
// Could also look in the cookies for CF_AUTHORIZATION
accessJWT := r.Header.Get("Cf-Access-Jwt-Assertion")
if accessJWT == "" {
write(w, http.StatusUnauthorized, "No token on the request")
return
}
// Verify the access token
ctx := r.Context()
idToken, err := verifier.Verify(ctx, accessJWT)
if err != nil {
write(w, http.StatusUnauthorized, fmt.Sprintf("Invalid token: %s", err.Error()))
return
}
// parse the claims
claims := Claims{}
err = idToken.Claims(&claims)
if err != nil {
write(w, http.StatusUnauthorized, fmt.Sprintf("Invalid claims: %s", err.Error()))
return
}
var user string
if claims.Email != "" { // user is authenticated person
user = claims.Email
} else if claims.CommonName != "" { // user is service token
user = claims.CommonName
}
// Request is good to go
w.Header().Set("X-Auth-User", user)
write(w, http.StatusOK, "OK!")
}
func write(w http.ResponseWriter, status int, body string) {
w.WriteHeader(status)
_, err := w.Write([]byte(body))
if err != nil {
log.Printf("Error writing body: %s\n", err)
}
}