From 55ed2c8c64f11f417ee086f5d47cd8d922c9a0e6 Mon Sep 17 00:00:00 2001 From: isaac Date: Mon, 20 May 2019 12:52:15 +0900 Subject: [PATCH] add support for multi param and multi header adds support for multiple parameter for requests adds support for multiple headers in request and response adds tests --- request.go | 10 +++++++++ request_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++++ response.go | 8 +++++-- response_test.go | 13 ++++++++++++ 4 files changed, 84 insertions(+), 2 deletions(-) diff --git a/request.go b/request.go index e4ca424..427d41e 100644 --- a/request.go +++ b/request.go @@ -26,6 +26,12 @@ func NewRequest(ctx context.Context, e events.APIGatewayProxyRequest) (*http.Req for k, v := range e.QueryStringParameters { q.Set(k, v) } + + for k, values := range e.MultiValueQueryStringParameters { + for _, v := range values { + q.Add(k, v) + } + } u.RawQuery = q.Encode() // base64 encoded body @@ -52,6 +58,10 @@ func NewRequest(ctx context.Context, e events.APIGatewayProxyRequest) (*http.Req req.Header.Set(k, v) } + for k, values := range e.MultiValueHeaders { + req.Header[k] = values + } + // content-length if req.Header.Get("Content-Length") == "" && body != "" { req.Header.Set("Content-Length", strconv.Itoa(len(body))) diff --git a/request_test.go b/request_test.go index 91db508..582a24f 100644 --- a/request_test.go +++ b/request_test.go @@ -51,6 +51,28 @@ func TestNewRequest_queryString(t *testing.T) { assert.Equal(t, `desc`, r.URL.Query().Get("order")) } +func TestNewRequest_multiValueQueryString(t *testing.T) { + e := events.APIGatewayProxyRequest{ + HTTPMethod: "GET", + Path: "/pets", + MultiValueQueryStringParameters: map[string][]string{ + "multi_fields": []string{"name", "species"}, + "multi_arr[]": []string{"arr1", "arr2"}, + }, + QueryStringParameters: map[string]string{ + "order": "desc", + "fields": "name,species", + }, + } + + r, err := NewRequest(context.Background(), e) + assert.NoError(t, err) + + assert.Equal(t, `/pets?fields=name%2Cspecies&multi_arr%5B%5D=arr1&multi_arr%5B%5D=arr2&multi_fields=name&multi_fields=species&order=desc`, r.URL.String()) + assert.Equal(t, []string{"name", "species"}, r.URL.Query()["multi_fields"]) + assert.Equal(t, []string{"arr1", "arr2"}, r.URL.Query()["multi_arr[]"]) +} + func TestNewRequest_remoteAddr(t *testing.T) { e := events.APIGatewayProxyRequest{ HTTPMethod: "GET", @@ -95,6 +117,39 @@ func TestNewRequest_header(t *testing.T) { assert.Equal(t, `bar`, r.Header.Get("X-Foo")) } +func TestNewRequest_multiHeader(t *testing.T) { + e := events.APIGatewayProxyRequest{ + HTTPMethod: "POST", + Path: "/pets", + Body: `{ "name": "Tobi" }`, + MultiValueHeaders: map[string][]string{ + "X-APEX": []string{"apex1", "apex2"}, + "X-APEX-2": []string{"apex-1", "apex-2"}, + }, + Headers: map[string]string{ + "Content-Type": "application/json", + "X-Foo": "bar", + "Host": "example.com", + }, + RequestContext: events.APIGatewayProxyRequestContext{ + RequestID: "1234", + Stage: "prod", + }, + } + + r, err := NewRequest(context.Background(), e) + assert.NoError(t, err) + + assert.Equal(t, `example.com`, r.Host) + assert.Equal(t, `prod`, r.Header.Get("X-Stage")) + assert.Equal(t, `1234`, r.Header.Get("X-Request-Id")) + assert.Equal(t, `18`, r.Header.Get("Content-Length")) + assert.Equal(t, `application/json`, r.Header.Get("Content-Type")) + assert.Equal(t, `bar`, r.Header.Get("X-Foo")) + assert.Equal(t, []string{"apex1", "apex2"}, r.Header["X-APEX"]) + assert.Equal(t, []string{"apex-1", "apex-2"}, r.Header["X-APEX-2"]) +} + func TestNewRequest_body(t *testing.T) { e := events.APIGatewayProxyRequest{ HTTPMethod: "POST", diff --git a/response.go b/response.go index a79fa35..8f8b344 100644 --- a/response.go +++ b/response.go @@ -58,14 +58,18 @@ func (w *ResponseWriter) WriteHeader(status int) { w.out.StatusCode = status h := make(map[string]string) + mvh := make(map[string][]string) for k, v := range w.Header() { - if len(v) > 0 { - h[k] = v[len(v)-1] + if len(v) == 1 { + h[k] = v[0] + } else if len(v) > 1 { + mvh[k] = v } } w.out.Headers = h + w.out.MultiValueHeaders = mvh w.wroteHeader = true } diff --git a/response_test.go b/response_test.go index a644a72..47e1357 100644 --- a/response_test.go +++ b/response_test.go @@ -30,6 +30,19 @@ func TestResponseWriter_Header(t *testing.T) { assert.Equal(t, "Bar: baz\r\nFoo: bar\r\n", buf.String()) } +func TestResponseWriter_multiHeader(t *testing.T) { + w := NewResponse() + w.Header().Set("Foo", "bar") + w.Header().Set("Bar", "baz") + w.Header().Add("X-APEX", "apex1") + w.Header().Add("X-APEX", "apex2") + + var buf bytes.Buffer + w.header.Write(&buf) + + assert.Equal(t, "Bar: baz\r\nFoo: bar\r\nX-Apex: apex1\r\nX-Apex: apex2\r\n", buf.String()) +} + func TestResponseWriter_Write_text(t *testing.T) { types := []string{ "text/x-custom",