Skip to content

Commit

Permalink
Merge pull request #3616 from aws/dongie/buffer-frominputstream
Browse files Browse the repository at this point in the history
Buffer if necessary in fromInputStream
  • Loading branch information
dagnir authored Jan 30, 2025
2 parents 6127626 + 7ebc71b commit 2966950
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,28 @@ public static RequestBody fromFile(File file) {
* To support resetting via {@link ContentStreamProvider}, this uses {@link InputStream#reset()} and uses a read limit of
* 128 KiB. If you need more control, use {@link #fromContentProvider(ContentStreamProvider, long, String)} or
* {@link #fromContentProvider(ContentStreamProvider, String)}.
* <p>
* <b>Important:</b> If {@code inputStream} does not support mark and reset, the stream will be buffered.
*
* @param inputStream Input stream to send to the service. The stream will not be closed by the SDK.
* @param contentLength Content length of data in input stream.
* @return RequestBody instance.
*/
public static RequestBody fromInputStream(InputStream inputStream, long contentLength) {
// NOTE: does not have an effect if mark not supported
IoUtils.markStreamWithMaxReadLimit(inputStream);
InputStream nonCloseable = nonCloseableInputStream(inputStream);
ContentStreamProvider provider = () -> {
if (nonCloseable.markSupported()) {
ContentStreamProvider provider;
if (nonCloseable.markSupported()) {
// stream supports mark + reset
provider = () -> {
invokeSafely(nonCloseable::reset);
}
return nonCloseable;
};
return nonCloseable;
};
} else {
// stream doesn't support mark + reset, make sure to buffer it
provider = new BufferingContentStreamProvider(() -> nonCloseable, contentLength);
}
return new RequestBody(provider, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,22 @@
import java.nio.file.FileSystem;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Random;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.mockito.Mockito;
import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm;
import software.amazon.awssdk.checksums.SdkChecksum;
import software.amazon.awssdk.core.internal.sync.BufferingContentStreamProvider;
import software.amazon.awssdk.core.internal.util.Mimetype;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.IoUtils;
import software.amazon.awssdk.utils.StringInputStream;


public class RequestBodyTest {
private static final SdkChecksum CRC32 = SdkChecksum.forAlgorithm(DefaultChecksumAlgorithm.CRC32);

@Rule
public TemporaryFolder folder = new TemporaryFolder();
Expand Down Expand Up @@ -140,4 +147,58 @@ public void remainingByteBufferConstructorOnlyRemainingBytesCopied() throws IOEx
byte[] requestBodyBytes = IoUtils.toByteArray(requestBody.contentStreamProvider().newStream());
assertThat(ByteBuffer.wrap(requestBodyBytes)).isEqualTo(bb);
}

@Test
public void fromInputStream_streamSupportMarkReset_doesNotBuffer() {
byte[] newData = new byte[16536];
new Random().nextBytes(newData);

ByteArrayInputStream stream = new ByteArrayInputStream(newData);

RequestBody requestBody = RequestBody.fromInputStream(stream, newData.length);
assertThat(requestBody.contentStreamProvider()).isNotInstanceOf(BufferingContentStreamProvider.class);
}

@Test
public void fromInputStream_streamDoesNotSupportMarkReset_buffers() {
byte[] newData = new byte[16536];
new Random().nextBytes(newData);

ByteArrayInputStream stream = Mockito.spy(new ByteArrayInputStream(newData));
Mockito.when(stream.markSupported()).thenReturn(false);

RequestBody requestBody = RequestBody.fromInputStream(stream, newData.length);
assertThat(requestBody.contentStreamProvider()).isInstanceOf(BufferingContentStreamProvider.class);
}

@Test
public void fromInputStream_streamSupportsReset_resetsTheStream() {
byte[] newData = new byte[16536];
new Random().nextBytes(newData);

String streamCrc32 = getCrc32(new ByteArrayInputStream(newData));

ByteArrayInputStream stream = new ByteArrayInputStream(newData);
assertThat(stream.markSupported()).isTrue();
RequestBody requestBody = RequestBody.fromInputStream(stream, newData.length);

assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(streamCrc32);
assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(streamCrc32);
}

private static String getCrc32(InputStream inputStream) {
byte[] buff = new byte[1024];
int read;

CRC32.reset();
try {
while ((read = inputStream.read(buff)) != -1) {
CRC32.update(buff, 0, read);
}
} catch (IOException e) {
throw new RuntimeException(e);
}

return BinaryUtils.toHex(CRC32.getChecksumBytes());
}
}

0 comments on commit 2966950

Please sign in to comment.