Skip to content

Commit

Permalink
require user hash
Browse files Browse the repository at this point in the history
cenkalti committed Sep 5, 2018
1 parent 41bc1fd commit 3af7950
Showing 6 changed files with 95 additions and 3 deletions.
2 changes: 2 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
@@ -13,6 +13,8 @@ type Config struct {
ShutdownTimeout uint
// MySQL database DSN.
MySQLDSN string
// Secret for signing user IDs.
Secret string
}

func NewConfig() (*Config, error) {
1 change: 1 addition & 0 deletions internal/pas/event.go
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ import (

type Event struct {
UserID UserID `json:"user_id"`
UserHash string `json:"user_hash"`
Timestamp *time.Time `json:"timestamp"`
Name EventName `json:"name"`
Properties []Property `json:"properties"`
29 changes: 28 additions & 1 deletion internal/pas/handler.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
package pas

import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
)

type Handler struct {
http.Handler
analytics *Analytics
secret []byte
}

func NewHandler(analytics *Analytics) *Handler {
func NewHandler(analytics *Analytics, secret string) *Handler {
h := &Handler{
analytics: analytics,
secret: []byte(secret),
}
mux := http.NewServeMux()
mux.HandleFunc("/api/events", h.handleEvents)
@@ -34,6 +39,17 @@ func (s *Handler) handleEvents(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(s.secret) > 0 {
hash := hmac.New(sha256.New, s.secret)
for _, e := range events.Events {
hash.Write([]byte(e.UserID))
if hex.EncodeToString(hash.Sum(nil)) != e.UserHash {
http.Error(w, "invalid user_hash", http.StatusBadRequest)
return
}
hash.Reset()
}
}
_, err = s.analytics.InsertEvents(events.Events)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -53,6 +69,17 @@ func (s *Handler) handleUsers(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(s.secret) > 0 {
hash := hmac.New(sha256.New, s.secret)
for _, u := range users.Users {
hash.Write([]byte(u.ID))
if hex.EncodeToString(hash.Sum(nil)) != u.Hash {
http.Error(w, "invalid user_hash", http.StatusBadRequest)
return
}
hash.Reset()
}
}
_, err = s.analytics.UpdateUsers(users.Users)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
63 changes: 62 additions & 1 deletion internal/pas/handler_test.go
Original file line number Diff line number Diff line change
@@ -2,7 +2,11 @@ package pas_test

import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"log"
"net/http"
"net/http/httptest"
@@ -23,7 +27,7 @@ func init() {

analytics := pas.NewAnalytics(db)

handler = pas.NewHandler(analytics)
handler = pas.NewHandler(analytics, "")
}

func TestPostEvents(t *testing.T) {
@@ -67,3 +71,60 @@ func TestPostUsers(t *testing.T) {
status, http.StatusOK)
}
}

func TestUserHash(t *testing.T) {
const secret = "foobar"

db, err := sql.Open("mysql", localDSN)
if err != nil {
log.Fatal(err)
}
defer db.Close()

analytics := pas.NewAnalytics(db)

handler := pas.NewHandler(analytics, secret)

s0 := `{
"events": [
{"name": "test_done", "user_id": "1234", "user_hash": "%s", "timestamp": "2000-01-01T01:02:03Z", "properties": [
{"name": "foo", "value": "bar", "type": "string"}
]}]}
`

// Test invalid secret
s := fmt.Sprintf(s0, generateUserHash("1234", "invalid"))
var postBody = bytes.NewBufferString(s)
req, err := http.NewRequest("POST", "/api/events", postBody)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusBadRequest {
t.Log(rr.Body.String())
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}

// Test correct secret
s = fmt.Sprintf(s0, generateUserHash("1234", secret))
postBody = bytes.NewBufferString(s)
req, err = http.NewRequest("POST", "/api/events", postBody)
if err != nil {
t.Fatal(err)
}
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if status := rr.Code; status != http.StatusOK {
t.Log(rr.Body.String())
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
}

func generateUserHash(userID, secret string) string {
hash := hmac.New(sha256.New, []byte(secret))
hash.Write([]byte(userID))
return hex.EncodeToString(hash.Sum(nil))
}
1 change: 1 addition & 0 deletions internal/pas/user.go
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ import (

type User struct {
ID UserID `json:"id"`
Hash string `json:"hash"`
Properties []Property `json:"properties"`
}

2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ func main() {
}()

analytics := pas.NewAnalytics(db)
handler := pas.NewHandler(analytics)
handler := pas.NewHandler(analytics, config.Secret)
server := pas.NewServer(config.ListenAddress, handler)

go server.ListenAndServe()

0 comments on commit 3af7950

Please sign in to comment.