diff --git a/src/com/amazon/corretto/crypto/provider/AccessibleByteArrayOutputStream.java b/src/com/amazon/corretto/crypto/provider/AccessibleByteArrayOutputStream.java index 0a4124ba..61243af8 100644 --- a/src/com/amazon/corretto/crypto/provider/AccessibleByteArrayOutputStream.java +++ b/src/com/amazon/corretto/crypto/provider/AccessibleByteArrayOutputStream.java @@ -8,6 +8,7 @@ class AccessibleByteArrayOutputStream extends OutputStream implements Cloneable { private final int limit; + private final BufferShrinkStrategy shrinkStrategy; private byte[] buf; private int count; @@ -19,6 +20,10 @@ class AccessibleByteArrayOutputStream extends OutputStream implements Cloneable } AccessibleByteArrayOutputStream(final int capacity, final int limit) { + this(capacity, limit, new BufferShrinkStrategy.BasicThreshold()); + } + + AccessibleByteArrayOutputStream(final int capacity, final int limit, BufferShrinkStrategy shrinkStrategy) { if (limit < 0) { throw new IllegalArgumentException("Limit must be non-negative"); } @@ -27,6 +32,7 @@ class AccessibleByteArrayOutputStream extends OutputStream implements Cloneable } buf = capacity == 0 ? Utils.EMPTY_ARRAY : new byte[capacity]; this.limit = limit; + this.shrinkStrategy = shrinkStrategy; count = 0; } @@ -83,11 +89,14 @@ byte[] getDataBuffer() { } void reset() { - Arrays.fill(buf, 0, count, (byte) 0); + int sizeUsed = count; + Arrays.fill(buf, 0, sizeUsed, (byte) 0); count = 0; - // TODO: Consider keeping track of length at reset. - // If it is consistently below the maximum value we may want to trim - // down to save on memory. + + if (shrinkStrategy.shouldShrink(sizeUsed, buf.length)) { + // Shrink the buffer. + buf = new byte[buf.length / 2]; + } } void write(final ByteBuffer bbuff) { diff --git a/src/com/amazon/corretto/crypto/provider/BufferShrinkStrategy.java b/src/com/amazon/corretto/crypto/provider/BufferShrinkStrategy.java new file mode 100644 index 00000000..aae3059c --- /dev/null +++ b/src/com/amazon/corretto/crypto/provider/BufferShrinkStrategy.java @@ -0,0 +1,103 @@ +package com.amazon.corretto.crypto.provider; + +/** + * Represents a strategy for downsizing/shrinking a reusable buffer. + * For example, {@link AccessibleByteArrayOutputStream} holds a reusable buffer which grows as needed + * to accommodate different sized payloads. If this buffer never shrinks, it may waste space (for example + * if we see one rare very large payload but subsequent payloads are small, the buffer will remain unnecessarily + * large). + * Note: implementations are usually stateful and thus instances cannot be safely shared. + */ +public interface BufferShrinkStrategy { + // TODO: could return 'recommended buffer size' instead of boolean. Value/simplicity? + // TODO: should the strategy also handle growth? + + /** + * Buffer owners should call this after consumption of every payload. + * E.g. handlePayload(byte[] payload) { + * ... maybe grow buffer + * ... encrypt + * if (shouldShrink(payload.length, buffer.length)) shrinkBuf(); + * } + * @param payloadSize The size of the payload processed. + * @param bufferSize The size of the buffer. + * @return true if the strategy recommends shrinking the buffer. false otherwise. + */ + boolean shouldShrink(int payloadSize, int bufferSize); + + /** + * Shrink the buffer when it is too large for the payload ('over-sized') N times in a row. + * E.g. if an 800KB payload grows the buffer to 1MB, we shrink the buffer after seeing N consecutive + * payloads under 500KB. + */ + class BasicThreshold implements BufferShrinkStrategy { + private final int timesOversizedThreshold; + private int timesOversized = 0; + + public BasicThreshold() { + this.timesOversizedThreshold = 1024; + } + + public BasicThreshold(int timesOversizedThreshold) { + this.timesOversizedThreshold = timesOversizedThreshold; + } + + @Override + public boolean shouldShrink(int payloadSize, int bufferSize) { + if (bufferSize / 2 > payloadSize) { + // The buffer was over-sized for this usage. + if (timesOversized++ > timesOversizedThreshold) { + timesOversized = 0; + return true; + } + } else { + // Buffer was not over-sized, reset counter. + timesOversized = 0; + } + return false; + } + } + + /** + * Similar to {@link BasicThreshold}, but the threshold starts at 1 and increases + * to the chosen limit. This has the benefit of being somewhat adaptive; it starts + * out eager to shrink quickly after a large payload, but slows down every time + * re-growth is needed, up to the chosen limit. + */ + class IncreasingThreshold implements BufferShrinkStrategy { + private final int maxOversizedThreshold; + private int timesOversizedThreshold = 1; + private int timesOversized = 0; + private int previousBufferSize = Integer.MAX_VALUE; + + public IncreasingThreshold() { + this.maxOversizedThreshold = 1024; + } + + public IncreasingThreshold(int maxOversizedThreshold) { + this.maxOversizedThreshold = maxOversizedThreshold; + } + + @Override + public boolean shouldShrink(int payloadSize, int bufferSize) { + if (bufferSize > previousBufferSize) { + // Every time we need to grow, make it harder to shrink in the future (up to a limit). + timesOversizedThreshold = Math.min(maxOversizedThreshold, timesOversizedThreshold * 2); + } + previousBufferSize = bufferSize; + + if (bufferSize / 2 > payloadSize) { + // The buffer was over-sized for this usage. + if (timesOversized++ > timesOversizedThreshold) { + timesOversized = 0; + return true; + } + } else { + // Buffer was not over-sized, reset counter. + timesOversized = 0; + } + + return false; + } + } +}