Skip to content

Commit

Permalink
s3 multipart upload: read and send next part in parallel
Browse files Browse the repository at this point in the history
* use tee reader

Signed-off-by: Alex Aizman <[email protected]>
  • Loading branch information
VirrageS authored and alex-aizman committed Apr 23, 2024
1 parent ebc957e commit ac33920
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
33 changes: 17 additions & 16 deletions ais/backend/awsmpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"io"
"net/http"
"net/url"
"os"

aiss3 "github.com/NVIDIA/aistore/ais/s3"
"github.com/NVIDIA/aistore/cmn"
Expand All @@ -39,7 +38,8 @@ func StartMpt(lom *core.LOM, oreq *http.Request, oq url.Values) (id string, ecod
resp, err := pts.Do(core.T.DataClient())
if err != nil {
return "", resp.StatusCode, err
} else if resp != nil {
}
if resp != nil {
result, err := decodeXML[aiss3.InitiateMptUploadResult](resp.Body)
if err != nil {
return "", http.StatusBadRequest, err
Expand Down Expand Up @@ -68,14 +68,15 @@ func StartMpt(lom *core.LOM, oreq *http.Request, oq url.Values) (id string, ecod
return id, ecode, err
}

func PutMptPart(lom *core.LOM, fh *os.File, oreq *http.Request, uploadID string, partNum int32, size int64) (etag string, ecode int, _ error) {
func PutMptPart(lom *core.LOM, r io.ReadCloser, oreq *http.Request, oq url.Values, uploadID string, size int64, partNum int32) (etag string,
ecode int, _ error) {
if lom.IsFeatureSet(feat.PresignedS3Req) && oreq != nil {
q := oreq.URL.Query() // TODO: optimize-out
pts := aiss3.NewPresignedReq(oreq, lom, fh, q)
pts := aiss3.NewPresignedReq(oreq, lom, r, oq)
resp, err := pts.Do(core.T.DataClient())
if err != nil {
return "", resp.StatusCode, err
} else if resp != nil {
}
if resp != nil {
ecode = resp.StatusCode
etag = cmn.UnquoteCEV(resp.Header.Get(cos.HdrETag))
return
Expand All @@ -87,7 +88,7 @@ func PutMptPart(lom *core.LOM, fh *os.File, oreq *http.Request, uploadID string,
input = s3.UploadPartInput{
Bucket: aws.String(cloudBck.Name),
Key: aws.String(lom.ObjName),
Body: fh,
Body: r,
UploadId: aws.String(uploadID),
PartNumber: &partNum,
ContentLength: &size,
Expand All @@ -108,19 +109,19 @@ func PutMptPart(lom *core.LOM, fh *os.File, oreq *http.Request, uploadID string,
return etag, ecode, err
}

func CompleteMpt(lom *core.LOM, oreq *http.Request, uploadID string, parts *aiss3.CompleteMptUpload) (etag string, ecode int, _ error) {
func CompleteMpt(lom *core.LOM, oreq *http.Request, oq url.Values, uploadID string, parts *aiss3.CompleteMptUpload) (etag string,
ecode int, _ error) {
if lom.IsFeatureSet(feat.PresignedS3Req) && oreq != nil {
q := oreq.URL.Query() // TODO: optimize-out

body, err := xml.Marshal(parts)
if err != nil {
return "", http.StatusBadRequest, err
}
pts := aiss3.NewPresignedReq(oreq, lom, io.NopCloser(bytes.NewReader(body)), q)
pts := aiss3.NewPresignedReq(oreq, lom, io.NopCloser(bytes.NewReader(body)), oq)
resp, err := pts.Do(core.T.DataClient())
if err != nil {
return "", resp.StatusCode, err
} else if resp != nil {
}
if resp != nil {
result, err := decodeXML[aiss3.CompleteMptUploadResult](resp.Body)
if err != nil {
return "", http.StatusBadRequest, err
Expand Down Expand Up @@ -164,14 +165,14 @@ func CompleteMpt(lom *core.LOM, oreq *http.Request, uploadID string, parts *aiss
return etag, ecode, err
}

func AbortMpt(lom *core.LOM, oreq *http.Request, uploadID string) (ecode int, err error) {
func AbortMpt(lom *core.LOM, oreq *http.Request, oq url.Values, uploadID string) (ecode int, err error) {
if lom.IsFeatureSet(feat.PresignedS3Req) && oreq != nil {
q := oreq.URL.Query() // TODO: optimize-out
pts := aiss3.NewPresignedReq(oreq, lom, oreq.Body, q)
pts := aiss3.NewPresignedReq(oreq, lom, oreq.Body, oq)
resp, err := pts.Do(core.T.DataClient())
if err != nil {
return resp.StatusCode, err
} else if resp != nil {
}
if resp != nil {
return resp.StatusCode, nil
}
}
Expand Down
8 changes: 4 additions & 4 deletions ais/backend/mock_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
package backend

import (
"io"
"net/http"
"net/url"
"os"

s3types "github.com/NVIDIA/aistore/ais/s3"
"github.com/NVIDIA/aistore/api/apc"
Expand All @@ -25,14 +25,14 @@ func StartMpt(*core.LOM, *http.Request, url.Values) (string, int, error) {
return "", http.StatusBadRequest, cmn.NewErrUnsupp("start-mpt", mock)
}

func PutMptPart(*core.LOM, *os.File, *http.Request, string, int32, int64) (string, int, error) {
func PutMptPart(*core.LOM, io.ReadCloser, *http.Request, url.Values, string, int64, int32) (string, int, error) {
return "", http.StatusBadRequest, cmn.NewErrUnsupp("put-mpt-part", mock)
}

func CompleteMpt(*core.LOM, *http.Request, string, *s3types.CompleteMptUpload) (string, int, error) {
func CompleteMpt(*core.LOM, *http.Request, url.Values, string, *s3types.CompleteMptUpload) (string, int, error) {
return "", http.StatusBadRequest, cmn.NewErrUnsupp("complete-part", mock)
}

func AbortMpt(*core.LOM, *http.Request, string) (int, error) {
func AbortMpt(*core.LOM, *http.Request, url.Values, string) (int, error) {
return http.StatusBadRequest, cmn.NewErrUnsupp("abort-mpt", mock)
}
31 changes: 18 additions & 13 deletions ais/tgts3mpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,12 @@ func (t *target) putMptPart(w http.ResponseWriter, r *http.Request, items []stri
return
}

// 3. write
var (
etag string
size int64
ecode int
partSHA = r.Header.Get(cos.S3HdrContentSHA256)
checkPartSHA = partSHA != "" && partSHA != cos.S3UnsignedPayload
buf, slab = t.gmm.Alloc()
cksumSHA = &cos.CksumHash{}
cksumMD5 = &cos.CksumHash{}
remote = bck.IsRemoteS3()
Expand All @@ -145,15 +144,22 @@ func (t *target) putMptPart(w http.ResponseWriter, r *http.Request, items []stri
if !remote {
cksumMD5 = cos.NewCksumHash(cos.ChecksumMD5)
}

// 3. write
mw := multiWriter(cksumMD5.H, cksumSHA.H, partFh)
size, err := io.CopyBuffer(mw, r.Body, buf)
slab.Free(buf)

// 4. rewind and call s3 API
if err == nil && remote {
if _, err = partFh.Seek(0, io.SeekStart); err == nil {
etag, ecode, err = backend.PutMptPart(lom, partFh, r, uploadID, partNum, size)
}
if !remote {
// write locally
buf, slab := t.gmm.Alloc()
size, err = io.CopyBuffer(mw, r.Body, buf)
slab.Free(buf)
} else {
// write locally and utilize TeeReader to simultaneously send data to S3
tr := io.NopCloser(io.TeeReader(r.Body, mw))
size = r.ContentLength
debug.Assert(size > 0, "mpt upload: expecting positive content-length")

etag, ecode, err = backend.PutMptPart(lom, tr, r, q, uploadID, size, partNum)
}

cos.Close(partFh)
Expand All @@ -165,8 +171,7 @@ func (t *target) putMptPart(w http.ResponseWriter, r *http.Request, items []stri
return
}

// 5. finalize part
// expecting the part's remote etag to be md5 checksum, not computing otherwise
// 4. finalize the part (expecting the part's remote etag to be md5 checksum)
md5 := etag
if cksumMD5.H != nil {
debug.Assert(etag == "")
Expand Down Expand Up @@ -243,7 +248,7 @@ func (t *target) completeMpt(w http.ResponseWriter, r *http.Request, items []str
remote = bck.IsRemoteS3()
)
if remote {
v, ecode, err := backend.CompleteMpt(lom, r, uploadID, partList)
v, ecode, err := backend.CompleteMpt(lom, r, q, uploadID, partList)
if err != nil {
s3.WriteMptErr(w, r, err, ecode, lom, uploadID)
return
Expand Down Expand Up @@ -398,7 +403,7 @@ func (t *target) abortMpt(w http.ResponseWriter, r *http.Request, items []string
uploadID := q.Get(s3.QparamMptUploadID)

if bck.IsRemoteS3() {
ecode, err := backend.AbortMpt(lom, r, uploadID)
ecode, err := backend.AbortMpt(lom, r, q, uploadID)
if err != nil {
s3.WriteErr(w, r, err, ecode)
return
Expand Down

0 comments on commit ac33920

Please sign in to comment.