From b559ee2b0df76b489816d43be68b820622d4d46c Mon Sep 17 00:00:00 2001 From: cheesecrust Date: Fri, 6 Dec 2024 16:43:16 +0900 Subject: [PATCH] INTERNAL: Limit bulk get keys size --- .../ascii/AsciiMemcachedNodeImpl.java | 51 +++++++++++-------- .../java/net/spy/memcached/OptimizeTest.java | 38 ++++++++++++++ 2 files changed, 67 insertions(+), 22 deletions(-) create mode 100644 src/test/java/net/spy/memcached/OptimizeTest.java diff --git a/src/main/java/net/spy/memcached/protocol/ascii/AsciiMemcachedNodeImpl.java b/src/main/java/net/spy/memcached/protocol/ascii/AsciiMemcachedNodeImpl.java index c85e240c4..76caf82b8 100644 --- a/src/main/java/net/spy/memcached/protocol/ascii/AsciiMemcachedNodeImpl.java +++ b/src/main/java/net/spy/memcached/protocol/ascii/AsciiMemcachedNodeImpl.java @@ -31,6 +31,9 @@ * Memcached node for the ASCII protocol. */ public final class AsciiMemcachedNodeImpl extends TCPMemcachedNodeImpl { + + private static final int GET_BULK_CHUNK_SIZE = 200; + public AsciiMemcachedNodeImpl(String name, SocketAddress sa, int bufSize, BlockingQueue rq, @@ -45,31 +48,35 @@ protected void optimize() { // make sure there are at least two get operations in a row before // attempting to optimize them. Operation nxtOp = writeQ.peek(); - if (nxtOp instanceof GetOperation && nxtOp.getAPIType() != APIType.MGET) { - optimizedOp = writeQ.remove(); - nxtOp = writeQ.peek(); - if (nxtOp instanceof GetOperation && nxtOp.getAPIType() != APIType.MGET) { - OptimizedGetImpl og = new OptimizedGetImpl( - (GetOperation) optimizedOp); - optimizedOp = og; + if (!(nxtOp instanceof GetOperation) || nxtOp.getAPIType() == APIType.MGET || + ((GetOperation) nxtOp).getKeys().size() > GET_BULK_CHUNK_SIZE) { + return; + } - do { - GetOperationImpl o = (GetOperationImpl) writeQ.remove(); - if (!o.isCancelled()) { - og.addOperation(o); - } - nxtOp = writeQ.peek(); - } while (nxtOp instanceof GetOperation && - nxtOp.getAPIType() != APIType.MGET); + int cnt = ((GetOperation) nxtOp).getKeys().size(); + optimizedOp = new OptimizedGetImpl((GetOperation) writeQ.remove()); + nxtOp = writeQ.peek(); + OptimizedGetImpl og = null; - // Initialize the new mega get - optimizedOp.initialize(); - assert optimizedOp.getState() == OperationState.WRITE_QUEUED; - ProxyCallback pcb = (ProxyCallback) og.getCallback(); - getLogger().debug("Set up %s with %s keys and %s callbacks", - this, pcb.numKeys(), pcb.numCallbacks()); + while (nxtOp instanceof GetOperation && nxtOp.getAPIType() != APIType.MGET) { + if (og == null) { + og = (OptimizedGetImpl) optimizedOp; } + cnt += ((GetOperation) nxtOp).getKeys().size(); + if (cnt > GET_BULK_CHUNK_SIZE) { + break; + } + GetOperationImpl currentOp = (GetOperationImpl) writeQ.remove(); + if (!currentOp.isCancelled()) { + og.addOperation(currentOp); + } + nxtOp = writeQ.peek(); } + // Initialize the new mega get + optimizedOp.initialize(); + assert optimizedOp.getState() == OperationState.WRITE_QUEUED; + ProxyCallback pcb = (ProxyCallback) optimizedOp.getCallback(); + getLogger().debug("Set up %s with %s keys and %s callbacks", + this, pcb.numKeys(), pcb.numCallbacks()); } - } diff --git a/src/test/java/net/spy/memcached/OptimizeTest.java b/src/test/java/net/spy/memcached/OptimizeTest.java new file mode 100644 index 000000000..bf72cd26b --- /dev/null +++ b/src/test/java/net/spy/memcached/OptimizeTest.java @@ -0,0 +1,38 @@ +package net.spy.memcached; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Future; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + + +class OptimizeTest { + + @Test + void testParallelGet() throws Throwable { + ConnectionFactoryBuilder builder = new ConnectionFactoryBuilder(); + builder.setShouldOptimize(true); + // Get a connection with the get optimization. + ArcusClientPool client = + ArcusClient.createArcusClientPool("127.0.0.1:2181", "test", builder, 1); + + final List keys = new ArrayList<>(10000); + for (int i = 0; i < 100; i++) { + keys.add("k" + i); + Boolean b = client.set(keys.get(i), 0, "value" + i).get(); + Assertions.assertEquals(true, b); + } + + List> results = new ArrayList<>(10000); + for (int i = 0; i < 100; i++) { + results.add(client.asyncGet(keys.get(i))); + } + + for (int i = 0; i < 100; i++) { + Object o = results.get(i).get(); + Assertions.assertEquals("value" + i, o); + } + } +}