Skip to content

Commit

Permalink
use WeakHashMap to avoid memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
OrezzerO committed Jul 9, 2021
1 parent f429064 commit c7f62f8
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 45 deletions.
5 changes: 5 additions & 0 deletions src/main/java/com/alipay/remoting/BaseRemoting.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ protected RemotingCommand invokeSync(final Connection conn, final RemotingComman
final InvokeFuture future = createInvokeFuture(request, request.getInvokeContext());
conn.addInvokeFuture(future);
final int requestId = request.getId();
InvokeContext invokeContext = request.getInvokeContext();
try {
conn.getChannel().writeAndFlush(request).addListener(new ChannelFutureListener() {

Expand All @@ -75,6 +76,10 @@ public void operationComplete(ChannelFuture f) throws Exception {
}

});

if (null != invokeContext) {
invokeContext.put("REQUEST_SEND", System.nanoTime());
}
} catch (Exception e) {
conn.removeInvokeFuture(requestId);
future.putResponse(commandFactory.createSendFailedResponse(conn.getRemoteAddress(), e));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package com.alipay.remoting.rpc.protocol;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;

import com.alipay.remoting.util.ThreadLocalArriveTimeHolder;
import io.netty.channel.Channel;
import org.slf4j.Logger;

import com.alipay.remoting.CommandCode;
Expand Down Expand Up @@ -98,13 +98,8 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) thro
byte[] clazz = null;
byte[] header = null;
byte[] content = null;
SocketAddress socketAddress = ctx.channel().remoteAddress();
String uniqueKey = null;
if (socketAddress != null) {
String remoteAddress = socketAddress.toString();
uniqueKey = remoteAddress + requestId;
ThreadLocalArriveTimeHolder.arrive(uniqueKey);
}
Channel channel = ctx.channel();
ThreadLocalArriveTimeHolder.arrive(channel, requestId);

if (in.readableBytes() >= classLen + headerLen + contentLen) {
if (classLen > 0) {
Expand All @@ -124,10 +119,14 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) thro
return;
}
RequestCommand command;

long headerArriveTimeInNano = ThreadLocalArriveTimeHolder.getAndClear(
channel, requestId);

if (cmdCode == CommandCode.HEARTBEAT_VALUE) {
command = new HeartbeatCommand();
} else {
command = createRequestCommand(cmdCode, uniqueKey);
command = createRequestCommand(cmdCode, headerArriveTimeInNano);
}
command.setType(type);
command.setVersion(ver2);
Expand Down Expand Up @@ -217,11 +216,11 @@ private ResponseCommand createResponseCommand(short cmdCode) {
return command;
}

private RpcRequestCommand createRequestCommand(short cmdCode, String key) {
private RpcRequestCommand createRequestCommand(short cmdCode, long headerArriveTimeInNano) {
RpcRequestCommand command = new RpcRequestCommand();
command.setCmdCode(RpcCommandCode.valueOf(cmdCode));
command.setArriveTime(System.currentTimeMillis());
command.setArriveHeaderTimeInNano(ThreadLocalArriveTimeHolder.getAndClear(key));
command.setArriveHeaderTimeInNano(headerArriveTimeInNano);
command.setArriveBodyTimeInNano(System.nanoTime());
return command;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package com.alipay.remoting.rpc.protocol;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;

import com.alipay.remoting.log.BoltLoggerFactory;
import com.alipay.remoting.util.ThreadLocalArriveTimeHolder;
import io.netty.channel.Channel;
import org.slf4j.Logger;

import com.alipay.remoting.CommandCode;
Expand Down Expand Up @@ -105,13 +105,9 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) thro
byte[] header = null;
byte[] content = null;

SocketAddress socketAddress = ctx.channel().remoteAddress();
String uniqueKey = null;
if (socketAddress != null) {
String remoteAddress = socketAddress.toString();
uniqueKey = remoteAddress + requestId;
ThreadLocalArriveTimeHolder.arrive(uniqueKey);
}
Channel channel = ctx.channel();
ThreadLocalArriveTimeHolder.arrive(channel, requestId);

// decide the at-least bytes length for each version
int lengthAtLeastForV1 = classLen + headerLen + contentLen;
boolean crcSwitchOn = ProtocolSwitch.isOn(
Expand Down Expand Up @@ -144,11 +140,15 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) thro
in.resetReaderIndex();
return;
}

long headerArriveTimeInNano = ThreadLocalArriveTimeHolder.getAndClear(
channel, requestId);

RequestCommand command;
if (cmdCode == CommandCode.HEARTBEAT_VALUE) {
command = new HeartbeatCommand();
} else {
command = createRequestCommand(cmdCode, uniqueKey);
command = createRequestCommand(cmdCode, headerArriveTimeInNano);
}
command.setType(type);
command.setVersion(ver2);
Expand Down Expand Up @@ -270,11 +270,11 @@ private ResponseCommand createResponseCommand(short cmdCode) {
return command;
}

private RpcRequestCommand createRequestCommand(short cmdCode, String key) {
private RpcRequestCommand createRequestCommand(short cmdCode, long headerArriveTimeInNano) {
RpcRequestCommand command = new RpcRequestCommand();
command.setCmdCode(RpcCommandCode.valueOf(cmdCode));
command.setArriveTime(System.currentTimeMillis());
command.setArriveHeaderTimeInNano(ThreadLocalArriveTimeHolder.getAndClear(key));
command.setArriveHeaderTimeInNano(headerArriveTimeInNano);
command.setArriveBodyTimeInNano(System.nanoTime());

return command;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,47 @@
*/
package com.alipay.remoting.util;

import io.netty.channel.Channel;
import io.netty.util.concurrent.FastThreadLocal;

import java.util.HashMap;
import java.util.Map;
import java.util.WeakHashMap;

/**
* @author zhaowang
* @version : ThreadLocalTimeHolder.java, v 0.1 2021年07月01日 3:05 下午 zhaowang
*/
public class ThreadLocalArriveTimeHolder {
private static FastThreadLocal<Map<String, Long>> arriveTimeInNano = new FastThreadLocal<Map<String, Long>>();
private static FastThreadLocal<WeakHashMap<Channel, Map<Integer, Long>>> arriveTimeInNano = new FastThreadLocal<WeakHashMap<Channel, Map<Integer, Long>>>();

public static void arrive(String key) {
Map<String, Long> map = getArriveTimeMap();
public static void arrive(Channel channel, Integer key) {
Map<Integer, Long> map = getArriveTimeMap(channel);
if (map.get(key) == null) {
map.put(key, System.nanoTime());
}
}

public static long getAndClear(String key) {
Map<String, Long> map = getArriveTimeMap();
public static long getAndClear(Channel channel, Integer key) {
Map<Integer, Long> map = getArriveTimeMap(channel);
Long result = map.remove(key);
if (result == null) {
return -1;
}
return result;
}

private static Map<String, Long> getArriveTimeMap() {
Map<String, Long> map = arriveTimeInNano.get();
private static Map<Integer, Long> getArriveTimeMap(Channel channel) {
WeakHashMap<Channel, Map<Integer, Long>> map = arriveTimeInNano.get();
if (map == null) {
arriveTimeInNano.set(new HashMap<String, Long>(256));
return arriveTimeInNano.get();
} else {
return map;
arriveTimeInNano.set(new WeakHashMap<Channel, Map<Integer, Long>>(256));
map = arriveTimeInNano.get();
}
Map<Integer, Long> subMap = map.get(channel);
if (subMap == null) {
map.put(channel, new HashMap<Integer, Long>());
}
return map.get(channel);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package com.alipay.remoting.util;

import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Assert;
import org.junit.Test;

Expand All @@ -30,43 +31,46 @@ public class ThreadLocalArriveTimeHolderTest {

@Test
public void test() {
EmbeddedChannel channel = new EmbeddedChannel();
long start = System.nanoTime();
ThreadLocalArriveTimeHolder.arrive("a");
ThreadLocalArriveTimeHolder.arrive(channel, 1);
long end = System.nanoTime();
ThreadLocalArriveTimeHolder.arrive("a");
long time = ThreadLocalArriveTimeHolder.getAndClear("a");
ThreadLocalArriveTimeHolder.arrive(channel, 1);
long time = ThreadLocalArriveTimeHolder.getAndClear(channel, 1);
Assert.assertTrue(time >= start);
Assert.assertTrue(time <= end);
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear("a"));
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear(channel, 1));
}

@Test
public void testRemoveNull() {
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear(null));
EmbeddedChannel channel = new EmbeddedChannel();
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear(channel, 1));
}

@Test
public void testMultiThread() throws InterruptedException {
final EmbeddedChannel channel = new EmbeddedChannel();
final CountDownLatch countDownLatch = new CountDownLatch(1);
long start = System.nanoTime();
ThreadLocalArriveTimeHolder.arrive("a");
ThreadLocalArriveTimeHolder.arrive(channel, 1);
long end = System.nanoTime();
ThreadLocalArriveTimeHolder.arrive("a");
long time = ThreadLocalArriveTimeHolder.getAndClear("a");
ThreadLocalArriveTimeHolder.arrive(channel, 1);
long time = ThreadLocalArriveTimeHolder.getAndClear(channel, 1);
Assert.assertTrue(time >= start);
Assert.assertTrue(time <= end);
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear("a"));
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear(channel, 1));

Runnable runnable = new Runnable() {
@Override
public void run() {
long start = System.nanoTime();
ThreadLocalArriveTimeHolder.arrive("a");
ThreadLocalArriveTimeHolder.arrive(channel, 1);
long end = System.nanoTime();
long time = ThreadLocalArriveTimeHolder.getAndClear("a");
long time = ThreadLocalArriveTimeHolder.getAndClear(channel, 1);
Assert.assertTrue(time >= start);
Assert.assertTrue(time <= end);
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear("a"));
Assert.assertEquals(-1, ThreadLocalArriveTimeHolder.getAndClear(channel, 1));
countDownLatch.countDown();
}
};
Expand Down

0 comments on commit c7f62f8

Please sign in to comment.