From ac339205fb8384dbc7991c02ecf23b0f14ef98f3 Mon Sep 17 00:00:00 2001 From: Janusz Marcinkiewicz Date: Tue, 23 Apr 2024 12:24:17 -0400 Subject: [PATCH] s3 multipart upload: read and send next part in parallel * use tee reader Signed-off-by: Alex Aizman --- ais/backend/awsmpt.go | 33 +++++++++++++++++---------------- ais/backend/mock_aws.go | 8 ++++---- ais/tgts3mpt.go | 31 ++++++++++++++++++------------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/ais/backend/awsmpt.go b/ais/backend/awsmpt.go index 7269075e35a..7e3b8bd437b 100644 --- a/ais/backend/awsmpt.go +++ b/ais/backend/awsmpt.go @@ -13,7 +13,6 @@ import ( "io" "net/http" "net/url" - "os" aiss3 "github.com/NVIDIA/aistore/ais/s3" "github.com/NVIDIA/aistore/cmn" @@ -39,7 +38,8 @@ func StartMpt(lom *core.LOM, oreq *http.Request, oq url.Values) (id string, ecod resp, err := pts.Do(core.T.DataClient()) if err != nil { return "", resp.StatusCode, err - } else if resp != nil { + } + if resp != nil { result, err := decodeXML[aiss3.InitiateMptUploadResult](resp.Body) if err != nil { return "", http.StatusBadRequest, err @@ -68,14 +68,15 @@ func StartMpt(lom *core.LOM, oreq *http.Request, oq url.Values) (id string, ecod return id, ecode, err } -func PutMptPart(lom *core.LOM, fh *os.File, oreq *http.Request, uploadID string, partNum int32, size int64) (etag string, ecode int, _ error) { +func PutMptPart(lom *core.LOM, r io.ReadCloser, oreq *http.Request, oq url.Values, uploadID string, size int64, partNum int32) (etag string, + ecode int, _ error) { if lom.IsFeatureSet(feat.PresignedS3Req) && oreq != nil { - q := oreq.URL.Query() // TODO: optimize-out - pts := aiss3.NewPresignedReq(oreq, lom, fh, q) + pts := aiss3.NewPresignedReq(oreq, lom, r, oq) resp, err := pts.Do(core.T.DataClient()) if err != nil { return "", resp.StatusCode, err - } else if resp != nil { + } + if resp != nil { ecode = resp.StatusCode etag = cmn.UnquoteCEV(resp.Header.Get(cos.HdrETag)) return @@ -87,7 +88,7 @@ func PutMptPart(lom *core.LOM, fh *os.File, oreq *http.Request, uploadID string, input = s3.UploadPartInput{ Bucket: aws.String(cloudBck.Name), Key: aws.String(lom.ObjName), - Body: fh, + Body: r, UploadId: aws.String(uploadID), PartNumber: &partNum, ContentLength: &size, @@ -108,19 +109,19 @@ func PutMptPart(lom *core.LOM, fh *os.File, oreq *http.Request, uploadID string, return etag, ecode, err } -func CompleteMpt(lom *core.LOM, oreq *http.Request, uploadID string, parts *aiss3.CompleteMptUpload) (etag string, ecode int, _ error) { +func CompleteMpt(lom *core.LOM, oreq *http.Request, oq url.Values, uploadID string, parts *aiss3.CompleteMptUpload) (etag string, + ecode int, _ error) { if lom.IsFeatureSet(feat.PresignedS3Req) && oreq != nil { - q := oreq.URL.Query() // TODO: optimize-out - body, err := xml.Marshal(parts) if err != nil { return "", http.StatusBadRequest, err } - pts := aiss3.NewPresignedReq(oreq, lom, io.NopCloser(bytes.NewReader(body)), q) + pts := aiss3.NewPresignedReq(oreq, lom, io.NopCloser(bytes.NewReader(body)), oq) resp, err := pts.Do(core.T.DataClient()) if err != nil { return "", resp.StatusCode, err - } else if resp != nil { + } + if resp != nil { result, err := decodeXML[aiss3.CompleteMptUploadResult](resp.Body) if err != nil { return "", http.StatusBadRequest, err @@ -164,14 +165,14 @@ func CompleteMpt(lom *core.LOM, oreq *http.Request, uploadID string, parts *aiss return etag, ecode, err } -func AbortMpt(lom *core.LOM, oreq *http.Request, uploadID string) (ecode int, err error) { +func AbortMpt(lom *core.LOM, oreq *http.Request, oq url.Values, uploadID string) (ecode int, err error) { if lom.IsFeatureSet(feat.PresignedS3Req) && oreq != nil { - q := oreq.URL.Query() // TODO: optimize-out - pts := aiss3.NewPresignedReq(oreq, lom, oreq.Body, q) + pts := aiss3.NewPresignedReq(oreq, lom, oreq.Body, oq) resp, err := pts.Do(core.T.DataClient()) if err != nil { return resp.StatusCode, err - } else if resp != nil { + } + if resp != nil { return resp.StatusCode, nil } } diff --git a/ais/backend/mock_aws.go b/ais/backend/mock_aws.go index 2630d57136f..ba5f35957cf 100644 --- a/ais/backend/mock_aws.go +++ b/ais/backend/mock_aws.go @@ -7,9 +7,9 @@ package backend import ( + "io" "net/http" "net/url" - "os" s3types "github.com/NVIDIA/aistore/ais/s3" "github.com/NVIDIA/aistore/api/apc" @@ -25,14 +25,14 @@ func StartMpt(*core.LOM, *http.Request, url.Values) (string, int, error) { return "", http.StatusBadRequest, cmn.NewErrUnsupp("start-mpt", mock) } -func PutMptPart(*core.LOM, *os.File, *http.Request, string, int32, int64) (string, int, error) { +func PutMptPart(*core.LOM, io.ReadCloser, *http.Request, url.Values, string, int64, int32) (string, int, error) { return "", http.StatusBadRequest, cmn.NewErrUnsupp("put-mpt-part", mock) } -func CompleteMpt(*core.LOM, *http.Request, string, *s3types.CompleteMptUpload) (string, int, error) { +func CompleteMpt(*core.LOM, *http.Request, url.Values, string, *s3types.CompleteMptUpload) (string, int, error) { return "", http.StatusBadRequest, cmn.NewErrUnsupp("complete-part", mock) } -func AbortMpt(*core.LOM, *http.Request, string) (int, error) { +func AbortMpt(*core.LOM, *http.Request, url.Values, string) (int, error) { return http.StatusBadRequest, cmn.NewErrUnsupp("abort-mpt", mock) } diff --git a/ais/tgts3mpt.go b/ais/tgts3mpt.go index 073d9a7b4e4..2d6953ccada 100644 --- a/ais/tgts3mpt.go +++ b/ais/tgts3mpt.go @@ -128,13 +128,12 @@ func (t *target) putMptPart(w http.ResponseWriter, r *http.Request, items []stri return } - // 3. write var ( etag string + size int64 ecode int partSHA = r.Header.Get(cos.S3HdrContentSHA256) checkPartSHA = partSHA != "" && partSHA != cos.S3UnsignedPayload - buf, slab = t.gmm.Alloc() cksumSHA = &cos.CksumHash{} cksumMD5 = &cos.CksumHash{} remote = bck.IsRemoteS3() @@ -145,15 +144,22 @@ func (t *target) putMptPart(w http.ResponseWriter, r *http.Request, items []stri if !remote { cksumMD5 = cos.NewCksumHash(cos.ChecksumMD5) } + + // 3. write mw := multiWriter(cksumMD5.H, cksumSHA.H, partFh) - size, err := io.CopyBuffer(mw, r.Body, buf) - slab.Free(buf) - // 4. rewind and call s3 API - if err == nil && remote { - if _, err = partFh.Seek(0, io.SeekStart); err == nil { - etag, ecode, err = backend.PutMptPart(lom, partFh, r, uploadID, partNum, size) - } + if !remote { + // write locally + buf, slab := t.gmm.Alloc() + size, err = io.CopyBuffer(mw, r.Body, buf) + slab.Free(buf) + } else { + // write locally and utilize TeeReader to simultaneously send data to S3 + tr := io.NopCloser(io.TeeReader(r.Body, mw)) + size = r.ContentLength + debug.Assert(size > 0, "mpt upload: expecting positive content-length") + + etag, ecode, err = backend.PutMptPart(lom, tr, r, q, uploadID, size, partNum) } cos.Close(partFh) @@ -165,8 +171,7 @@ func (t *target) putMptPart(w http.ResponseWriter, r *http.Request, items []stri return } - // 5. finalize part - // expecting the part's remote etag to be md5 checksum, not computing otherwise + // 4. finalize the part (expecting the part's remote etag to be md5 checksum) md5 := etag if cksumMD5.H != nil { debug.Assert(etag == "") @@ -243,7 +248,7 @@ func (t *target) completeMpt(w http.ResponseWriter, r *http.Request, items []str remote = bck.IsRemoteS3() ) if remote { - v, ecode, err := backend.CompleteMpt(lom, r, uploadID, partList) + v, ecode, err := backend.CompleteMpt(lom, r, q, uploadID, partList) if err != nil { s3.WriteMptErr(w, r, err, ecode, lom, uploadID) return @@ -398,7 +403,7 @@ func (t *target) abortMpt(w http.ResponseWriter, r *http.Request, items []string uploadID := q.Get(s3.QparamMptUploadID) if bck.IsRemoteS3() { - ecode, err := backend.AbortMpt(lom, r, uploadID) + ecode, err := backend.AbortMpt(lom, r, q, uploadID) if err != nil { s3.WriteErr(w, r, err, ecode) return