From 26922f508e4053e742ec3f74136ce43c15f104ec Mon Sep 17 00:00:00 2001 From: Luciano Vernaschi Date: Fri, 17 Jan 2025 15:51:09 +0100 Subject: [PATCH] Add POC to reattach `MultipartFile`s to objects --- .../com/vaadin/hilla/EndpointController.java | 73 +++++++++++++++++-- .../EndpointTransferMapper.java | 34 ++++++++- .../hilla/EndpointControllerDauTest.java | 7 +- .../hilla/EndpointControllerMockBuilder.java | 5 +- .../vaadin/hilla/EndpointControllerTest.java | 6 +- 5 files changed, 109 insertions(+), 16 deletions(-) diff --git a/packages/java/endpoint/src/main/java/com/vaadin/hilla/EndpointController.java b/packages/java/endpoint/src/main/java/com/vaadin/hilla/EndpointController.java index 517b833a3a..e04f7126e0 100644 --- a/packages/java/endpoint/src/main/java/com/vaadin/hilla/EndpointController.java +++ b/packages/java/endpoint/src/main/java/com/vaadin/hilla/EndpointController.java @@ -15,15 +15,13 @@ */ package com.vaadin.hilla; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; import java.util.TreeMap; import java.util.stream.Collectors; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.node.ObjectNode; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Import; import org.springframework.http.HttpStatus; @@ -33,7 +31,13 @@ import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.multipart.MultipartHttpServletRequest; +import com.fasterxml.jackson.core.JsonPointer; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.POJONode; import com.vaadin.flow.internal.CurrentInstance; import com.vaadin.flow.server.VaadinRequest; import com.vaadin.flow.server.VaadinService; @@ -47,6 +51,10 @@ import com.vaadin.hilla.auth.EndpointAccessChecker; import com.vaadin.hilla.exception.EndpointException; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + /** * The controller that is responsible for processing Vaadin endpoint requests. * Each class that is annotated with {@link Endpoint} or {@link BrowserCallable} @@ -87,6 +95,8 @@ public class EndpointController { private final EndpointInvoker endpointInvoker; + private final ObjectMapper objectMapper; + VaadinService vaadinService; /** @@ -101,13 +111,16 @@ public class EndpointController { * @param csrfChecker * the csrf checker to use */ + public EndpointController(ApplicationContext context, EndpointRegistry endpointRegistry, EndpointInvoker endpointInvoker, - CsrfChecker csrfChecker) { + CsrfChecker csrfChecker, + @Qualifier("hillaEndpointObjectMapper") ObjectMapper objectMapper) { this.context = context; this.endpointInvoker = endpointInvoker; this.csrfChecker = csrfChecker; this.endpointRegistry = endpointRegistry; + this.objectMapper = objectMapper; } /** @@ -169,7 +182,7 @@ public void registerEndpoints() { * the current response * @return execution result as a JSON string or an error message string */ - @PostMapping(path = ENDPOINT_METHODS, produces = MediaType.APPLICATION_JSON_UTF8_VALUE) + @PostMapping(path = ENDPOINT_METHODS, consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_UTF8_VALUE) public ResponseEntity serveEndpoint( @PathVariable("endpoint") String endpointName, @PathVariable("method") String methodName, @@ -179,6 +192,16 @@ public ResponseEntity serveEndpoint( response); } + @PostMapping(path = "/{endpoint}/{method}", consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.APPLICATION_JSON_VALUE) + public ResponseEntity serveMultipartEndpoint( + @PathVariable("endpoint") String endpointName, + @PathVariable("method") String methodName, + HttpServletRequest request, HttpServletResponse response) + throws IOException { + return doServeEndpoint(endpointName, methodName, null, request, + response); + } + /** * Captures and processes the Vaadin endpoint requests. *

@@ -227,6 +250,38 @@ private ResponseEntity doServeEndpoint(String endpointName, if (enforcementResult.isEnforcementNeeded()) { return buildEnforcementResponseEntity(enforcementResult); } + + if (isMultipartRequest(request)) { + try { + var multipartRequest = (MultipartHttpServletRequest) request; + var fileMap = multipartRequest.getFileMap(); + var bodyPart = multipartRequest.getParts().stream() + .filter(part -> part.getSubmittedFileName() == null) + .filter(part -> MediaType.APPLICATION_JSON_VALUE + .equals(part.getContentType())) + .findAny().orElseThrow(); + + body = objectMapper.readValue(bodyPart.getInputStream(), + ObjectNode.class); + + for (var entry : fileMap.entrySet()) { + var partName = entry.getKey(); + var file = entry.getValue(); + var pointer = JsonPointer.valueOf(partName); + var parent = pointer.head(); + var property = pointer.last().getMatchingProperty(); + var parentObject = body.withObject(parent); + parentObject.putPOJO(property, file); + } + } catch (IOException | ServletException e) { + LOGGER.error("Error processing multipart request parts", e); + return ResponseEntity + .status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(endpointInvoker.createResponseErrorObject( + "Error processing multipart request parts")); + } + } + Object returnValue = endpointInvoker.invoke(endpointName, methodName, body, request.getUserPrincipal(), request::isUserInRole); @@ -273,6 +328,12 @@ private ResponseEntity doServeEndpoint(String endpointName, } } + private boolean isMultipartRequest(HttpServletRequest request) { + String contentType = request.getContentType(); + return contentType != null + && contentType.startsWith(MediaType.MULTIPART_FORM_DATA_VALUE); + } + private ResponseEntity buildEnforcementResponseEntity( DAUUtils.EnforcementResult enforcementResult) { EnforcementNotificationMessages messages = enforcementResult.messages(); diff --git a/packages/java/endpoint/src/main/java/com/vaadin/hilla/endpointransfermapper/EndpointTransferMapper.java b/packages/java/endpoint/src/main/java/com/vaadin/hilla/endpointransfermapper/EndpointTransferMapper.java index 631b131632..963e25f265 100644 --- a/packages/java/endpoint/src/main/java/com/vaadin/hilla/endpointransfermapper/EndpointTransferMapper.java +++ b/packages/java/endpoint/src/main/java/com/vaadin/hilla/endpointransfermapper/EndpointTransferMapper.java @@ -15,19 +15,27 @@ */ package com.vaadin.hilla.endpointransfermapper; +import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.web.multipart.MultipartFile; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.deser.std.StdDelegatingDeserializer; import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.node.POJONode; import com.fasterxml.jackson.databind.ser.std.StdDelegatingSerializer; import com.fasterxml.jackson.databind.type.TypeFactory; import com.fasterxml.jackson.databind.util.StdConverter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Defines mappings for certain endpoint types to corresponding transfer types. @@ -174,6 +182,8 @@ public JavaType getOutputType(TypeFactory typeFactory) { }); jacksonModule.addDeserializer(endpointType, deserializer); + jacksonModule.addDeserializer(MultipartFile.class, + new MultipartFileDeserializer()); } /** @@ -318,4 +328,24 @@ private Logger getLogger() { return LoggerFactory.getLogger(getClass()); } + public static class MultipartFileDeserializer + extends JsonDeserializer { + + @Override + public MultipartFile deserialize(JsonParser p, + DeserializationContext ctxt) throws IOException { + JsonNode node = p.getCodec().readTree(p); + + if (node instanceof POJONode) { + Object pojo = ((POJONode) node).getPojo(); + + if (pojo instanceof MultipartFile) { + return (MultipartFile) pojo; + } + } + + throw new IOException( + "Expected a POJONode wrapping a MultipartFile"); + } + } } diff --git a/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerDauTest.java b/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerDauTest.java index df92783dd5..05f88e80c4 100644 --- a/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerDauTest.java +++ b/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerDauTest.java @@ -57,12 +57,13 @@ public void setUp() { EndpointRegistry endpointRegistry = new EndpointRegistry( new EndpointNameChecker()); ApplicationContext appCtx = Mockito.mock(ApplicationContext.class); + ObjectMapper objectMapper = new JacksonObjectMapperFactory.Json() + .build(); EndpointInvoker endpointInvoker = new EndpointInvoker(appCtx, - new JacksonObjectMapperFactory.Json().build(), - new ExplicitNullableTypeChecker(), servletContext, + objectMapper, new ExplicitNullableTypeChecker(), servletContext, endpointRegistry); controller = new EndpointController(appCtx, endpointRegistry, - endpointInvoker, csrfChecker); + endpointInvoker, csrfChecker, objectMapper); } @Test diff --git a/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerMockBuilder.java b/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerMockBuilder.java index 12e0040cb1..bb26f5c940 100644 --- a/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerMockBuilder.java +++ b/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerMockBuilder.java @@ -37,8 +37,9 @@ public EndpointController build() { EndpointInvoker invoker = Mockito.spy( new EndpointInvoker(applicationContext, endpointObjectMapper, explicitNullableTypeChecker, servletContext, registry)); - EndpointController controller = Mockito.spy(new EndpointController( - applicationContext, registry, invoker, csrfChecker)); + EndpointController controller = Mockito + .spy(new EndpointController(applicationContext, registry, + invoker, csrfChecker, endpointObjectMapper)); Mockito.doReturn(mock(EndpointAccessChecker.class)).when(invoker) .getAccessChecker(); return controller; diff --git a/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerTest.java b/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerTest.java index 3af6311432..ae02cf7cb6 100644 --- a/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerTest.java +++ b/packages/java/endpoint/src/test/java/com/vaadin/hilla/EndpointControllerTest.java @@ -826,8 +826,8 @@ public void should_Never_UseSpringObjectMapper() { endpointObjectMapper, mock(ExplicitNullableTypeChecker.class), mock(ServletContext.class), registry); - new EndpointController(contextMock, registry, invoker, null) - .registerEndpoints(); + new EndpointController(contextMock, registry, invoker, null, + mockOwnObjectMapper).registerEndpoints(); verify(contextMock, never()).getBean(ObjectMapper.class); verify(contextMock, times(1)) @@ -1312,7 +1312,7 @@ private EndpointController createVaadinController(T endpoint, EndpointController connectController = Mockito .spy(new EndpointController(mockApplicationContext, registry, - invoker, csrfChecker)); + invoker, csrfChecker, endpointObjectMapper)); connectController.registerEndpoints(); return connectController; }