From c012314ffd8bb9b54278f63d8eb3dd4c9f469dde Mon Sep 17 00:00:00 2001 From: Mitchell Kutchuk Date: Thu, 1 Sep 2022 16:35:36 -0400 Subject: [PATCH] Implement setting the response size limit on Android --- .../java/com/reactnativegrpc/GrpcModule.java | 362 +++++++++--------- 1 file changed, 186 insertions(+), 176 deletions(-) diff --git a/android/src/main/java/com/reactnativegrpc/GrpcModule.java b/android/src/main/java/com/reactnativegrpc/GrpcModule.java index fe80128..8f421fd 100644 --- a/android/src/main/java/com/reactnativegrpc/GrpcModule.java +++ b/android/src/main/java/com/reactnativegrpc/GrpcModule.java @@ -25,248 +25,258 @@ import io.grpc.Status; public class GrpcModule extends ReactContextBaseJavaModule { - private final ReactApplicationContext context; - private final HashMap callsMap = new HashMap<>(); + private final ReactApplicationContext context; + private final HashMap callsMap = new HashMap<>(); - private String host; - private boolean isInsecure = false; - private ManagedChannel managedChannel = null; + private String host; + private boolean isInsecure = false; + private Integer responseSizeLimit = null; + private ManagedChannel managedChannel = null; - public GrpcModule(ReactApplicationContext context) { - this.context = context; - } + public GrpcModule(ReactApplicationContext context) { + this.context = context; + } + + @NonNull + @Override + public String getName() { + return "Grpc"; + } - @NonNull - @Override - public String getName() { - return "Grpc"; - } + @ReactMethod() + public void getHost(final Promise promise) { + promise.resolve(this.host); + } - @ReactMethod() - public void getHost(final Promise promise) { - promise.resolve(this.host); - } + @ReactMethod() + public void getIsInsecure(final Promise promise) { + promise.resolve(this.isInsecure); + } - @ReactMethod() - public void getIsInsecure(final Promise promise) { - promise.resolve(this.isInsecure); - } + @ReactMethod + public void setHost(String host) { + this.host = host; + } - @ReactMethod - public void setHost(String host) { - this.host = host; - } + @ReactMethod + public void setInsecure(boolean insecure) { + this.isInsecure = insecure; + } - @ReactMethod - public void setInsecure(boolean insecure) { - this.isInsecure = insecure; - } + @ReactMethod + public void setResponseSizeLimit(int limit) { + this.responseSizeLimit = limit; + } - @ReactMethod - public void unaryCall(int id, String path, ReadableMap obj, ReadableMap headers, final Promise promise) { - ClientCall call = this.startGrpcCall(id, path, MethodDescriptor.MethodType.UNARY, headers); + @ReactMethod + public void unaryCall(int id, String path, ReadableMap obj, ReadableMap headers, final Promise promise) { + ClientCall call = this.startGrpcCall(id, path, MethodDescriptor.MethodType.UNARY, headers); - byte[] data = Base64.decode(obj.getString("data"), Base64.NO_WRAP); + byte[] data = Base64.decode(obj.getString("data"), Base64.NO_WRAP); - call.sendMessage(data); - call.request(1); - call.halfClose(); + call.sendMessage(data); + call.request(1); + call.halfClose(); - callsMap.put(id, call); + callsMap.put(id, call); - promise.resolve(null); - } + promise.resolve(null); + } - @ReactMethod - public void serverStreamingCall(int id, String path, ReadableMap obj, ReadableMap headers, final Promise promise) { - ClientCall call = this.startGrpcCall(id, path, MethodDescriptor.MethodType.SERVER_STREAMING, headers); + @ReactMethod + public void serverStreamingCall(int id, String path, ReadableMap obj, ReadableMap headers, final Promise promise) { + ClientCall call = this.startGrpcCall(id, path, MethodDescriptor.MethodType.SERVER_STREAMING, headers); - byte[] data = Base64.decode(obj.getString("data"), Base64.NO_WRAP); + byte[] data = Base64.decode(obj.getString("data"), Base64.NO_WRAP); - call.sendMessage(data); - call.request(1); - call.halfClose(); + call.sendMessage(data); + call.request(1); + call.halfClose(); - callsMap.put(id, call); + callsMap.put(id, call); - promise.resolve(null); - } + promise.resolve(null); + } - @ReactMethod - public void clientStreamingCall(int id, String path, ReadableMap obj, ReadableMap headers, final Promise promise) { - ClientCall call = callsMap.get(id); + @ReactMethod + public void clientStreamingCall(int id, String path, ReadableMap obj, ReadableMap headers, final Promise promise) { + ClientCall call = callsMap.get(id); - if (call == null) { - call = this.startGrpcCall(id, path, MethodDescriptor.MethodType.CLIENT_STREAMING, headers); + if (call == null) { + call = this.startGrpcCall(id, path, MethodDescriptor.MethodType.CLIENT_STREAMING, headers); - callsMap.put(id, call); - } + callsMap.put(id, call); + } - byte[] data = Base64.decode(obj.getString("data"), Base64.NO_WRAP); + byte[] data = Base64.decode(obj.getString("data"), Base64.NO_WRAP); - call.sendMessage(data); - call.request(1); + call.sendMessage(data); + call.request(1); - promise.resolve(null); - } + promise.resolve(null); + } - @ReactMethod - public void finishClientStreaming(int id, final Promise promise) { - if (callsMap.containsKey(id)) { - ClientCall call = callsMap.get(id); + @ReactMethod + public void finishClientStreaming(int id, final Promise promise) { + if (callsMap.containsKey(id)) { + ClientCall call = callsMap.get(id); - call.halfClose(); + call.halfClose(); - promise.resolve(true); - } else { - promise.resolve(false); + promise.resolve(true); + } else { + promise.resolve(false); + } } - } - @ReactMethod - public void cancelGrpcCall(int id, final Promise promise) { - if (callsMap.containsKey(id)) { - ClientCall call = callsMap.get(id); - call.cancel("Cancelled", new Exception("Cancelled by app")); + @ReactMethod + public void cancelGrpcCall(int id, final Promise promise) { + if (callsMap.containsKey(id)) { + ClientCall call = callsMap.get(id); + call.cancel("Cancelled", new Exception("Cancelled by app")); - promise.resolve(true); - } else { - promise.resolve(false); + promise.resolve(true); + } else { + promise.resolve(false); + } } - } - private ClientCall startGrpcCall(int id, String path, MethodDescriptor.MethodType methodType, ReadableMap headers) { - path = normalizePath(path); + private ClientCall startGrpcCall(int id, String path, MethodDescriptor.MethodType methodType, ReadableMap headers) { + path = normalizePath(path); - final Metadata headersMetadata = new Metadata(); + final Metadata headersMetadata = new Metadata(); - for (Map.Entry headerEntry : headers.toHashMap().entrySet()) { - headersMetadata.put(Metadata.Key.of(headerEntry.getKey(), Metadata.ASCII_STRING_MARSHALLER), headerEntry.getValue().toString()); - } + for (Map.Entry headerEntry : headers.toHashMap().entrySet()) { + headersMetadata.put(Metadata.Key.of(headerEntry.getKey(), Metadata.ASCII_STRING_MARSHALLER), headerEntry.getValue().toString()); + } - MethodDescriptor.Marshaller marshaller = new GrpcMarshaller(); + MethodDescriptor.Marshaller marshaller = new GrpcMarshaller(); - MethodDescriptor descriptor = MethodDescriptor.newBuilder() - .setFullMethodName(path) - .setType(methodType) - .setRequestMarshaller(marshaller) - .setResponseMarshaller(marshaller) - .build(); + MethodDescriptor descriptor = MethodDescriptor.newBuilder() + .setFullMethodName(path) + .setType(methodType) + .setRequestMarshaller(marshaller) + .setResponseMarshaller(marshaller) + .build(); - CallOptions callOptions = CallOptions.DEFAULT; + CallOptions callOptions = CallOptions.DEFAULT; - ClientCall call = this.getManagedChannel().newCall(descriptor, callOptions); + ClientCall call = this.getManagedChannel().newCall(descriptor, callOptions); - call.start(new ClientCall.Listener() { - @Override - public void onHeaders(Metadata headers) { - super.onHeaders(headers); + call.start(new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + super.onHeaders(headers); - WritableMap event = Arguments.createMap(); - WritableMap payload = Arguments.createMap(); + WritableMap event = Arguments.createMap(); + WritableMap payload = Arguments.createMap(); - for (String key : headers.keys()) { - if (key.endsWith("-bin")) { - byte[] data = headers.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + for (String key : headers.keys()) { + if (key.endsWith("-bin")) { + byte[] data = headers.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); - payload.putString(key, new String(Base64.encode(data, Base64.NO_WRAP))); - } else { - String data = headers.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + payload.putString(key, new String(Base64.encode(data, Base64.NO_WRAP))); + } else { + String data = headers.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); - payload.putString(key, data); - } - } + payload.putString(key, data); + } + } - event.putInt("id", id); - event.putString("type", "headers"); - event.putMap("payload", payload); + event.putInt("id", id); + event.putString("type", "headers"); + event.putMap("payload", payload); - emitEvent("grpc-call", event); - } + emitEvent("grpc-call", event); + } - @Override - public void onMessage(Object messageObj) { - super.onMessage(messageObj); + @Override + public void onMessage(Object messageObj) { + super.onMessage(messageObj); - byte[] data = (byte[]) messageObj; + byte[] data = (byte[]) messageObj; - WritableMap event = Arguments.createMap(); + WritableMap event = Arguments.createMap(); - event.putInt("id", id); - event.putString("type", "response"); - event.putString("payload", Base64.encodeToString(data, Base64.NO_WRAP)); + event.putInt("id", id); + event.putString("type", "response"); + event.putString("payload", Base64.encodeToString(data, Base64.NO_WRAP)); - emitEvent("grpc-call", event); + emitEvent("grpc-call", event); - if (methodType == MethodDescriptor.MethodType.SERVER_STREAMING) { - call.request(1); - } - } + if (methodType == MethodDescriptor.MethodType.SERVER_STREAMING) { + call.request(1); + } + } - @Override - public void onClose(Status status, Metadata trailers) { - super.onClose(status, trailers); + @Override + public void onClose(Status status, Metadata trailers) { + super.onClose(status, trailers); - callsMap.remove(id); + callsMap.remove(id); - WritableMap event = Arguments.createMap(); - event.putInt("id", id); + WritableMap event = Arguments.createMap(); + event.putInt("id", id); - WritableMap trailersMap = Arguments.createMap(); + WritableMap trailersMap = Arguments.createMap(); - for (String key : trailers.keys()) { - if (key.endsWith("-bin")) { - byte[] data = trailers.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + for (String key : trailers.keys()) { + if (key.endsWith("-bin")) { + byte[] data = trailers.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); - trailersMap.putString(key, new String(Base64.encode(data, Base64.NO_WRAP))); - } else { - String data = trailers.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + trailersMap.putString(key, new String(Base64.encode(data, Base64.NO_WRAP))); + } else { + String data = trailers.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); - trailersMap.putString(key, data); - } - } + trailersMap.putString(key, data); + } + } - if (!status.isOk()) { - event.putString("type", "error"); - event.putString("error", status.asException(trailers).getLocalizedMessage()); - event.putInt("code", status.getCode().value()); - event.putMap("trailers", trailersMap); - } else { - event.putString("type", "trailers"); - event.putMap("payload", trailersMap); - } + if (!status.isOk()) { + event.putString("type", "error"); + event.putString("error", status.asException(trailers).getLocalizedMessage()); + event.putInt("code", status.getCode().value()); + event.putMap("trailers", trailersMap); + } else { + event.putString("type", "trailers"); + event.putMap("payload", trailersMap); + } - emitEvent("grpc-call", event); - } - }, headersMetadata); + emitEvent("grpc-call", event); + } + }, headersMetadata); - return call; - } + return call; + } + + private void emitEvent(String eventName, Object params) { + context + .getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class) + .emit(eventName, params); + } - private void emitEvent(String eventName, Object params) { - context - .getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class) - .emit(eventName, params); - } + private static String normalizePath(String path) { + if (path.startsWith("/")) { + path = path.substring(1); + } - private static String normalizePath(String path) { - if (path.startsWith("/")) { - path = path.substring(1); + return path; } - return path; - } + private ManagedChannel getManagedChannel() { + if (managedChannel != null) return managedChannel; - private ManagedChannel getManagedChannel(){ - if (managedChannel != null) return managedChannel; + ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forTarget(this.host); - ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forTarget(this.host); + if (this.responseSizeLimit != null) { + channelBuilder = channelBuilder.maxInboundMessageSize(this.responseSizeLimit); + } - if (this.isInsecure) { - channelBuilder = channelBuilder.usePlaintext(); - } + if (this.isInsecure) { + channelBuilder = channelBuilder.usePlaintext(); + } - managedChannel = channelBuilder.build(); - return managedChannel; - } + managedChannel = channelBuilder.build(); + return managedChannel; + } }