Skip to content

Commit

Permalink
Add POC to reattach MultipartFiles to objects
Browse files Browse the repository at this point in the history
  • Loading branch information
cromoteca committed Jan 17, 2025
1 parent 1d37be8 commit 26922f5
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
*/
package com.vaadin.hilla;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.TreeMap;
import java.util.stream.Collectors;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpStatus;
Expand All @@ -33,7 +31,13 @@
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartHttpServletRequest;

import com.fasterxml.jackson.core.JsonPointer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.POJONode;
import com.vaadin.flow.internal.CurrentInstance;
import com.vaadin.flow.server.VaadinRequest;
import com.vaadin.flow.server.VaadinService;
Expand All @@ -47,6 +51,10 @@
import com.vaadin.hilla.auth.EndpointAccessChecker;
import com.vaadin.hilla.exception.EndpointException;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

/**
* The controller that is responsible for processing Vaadin endpoint requests.
* Each class that is annotated with {@link Endpoint} or {@link BrowserCallable}
Expand Down Expand Up @@ -87,6 +95,8 @@ public class EndpointController {

private final EndpointInvoker endpointInvoker;

private final ObjectMapper objectMapper;

VaadinService vaadinService;

/**
Expand All @@ -101,13 +111,16 @@ public class EndpointController {
* @param csrfChecker
* the csrf checker to use
*/

public EndpointController(ApplicationContext context,
EndpointRegistry endpointRegistry, EndpointInvoker endpointInvoker,
CsrfChecker csrfChecker) {
CsrfChecker csrfChecker,
@Qualifier("hillaEndpointObjectMapper") ObjectMapper objectMapper) {
this.context = context;
this.endpointInvoker = endpointInvoker;
this.csrfChecker = csrfChecker;
this.endpointRegistry = endpointRegistry;
this.objectMapper = objectMapper;
}

/**
Expand Down Expand Up @@ -169,7 +182,7 @@ public void registerEndpoints() {
* the current response
* @return execution result as a JSON string or an error message string
*/
@PostMapping(path = ENDPOINT_METHODS, produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
@PostMapping(path = ENDPOINT_METHODS, consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
public ResponseEntity<String> serveEndpoint(
@PathVariable("endpoint") String endpointName,
@PathVariable("method") String methodName,
Expand All @@ -179,6 +192,16 @@ public ResponseEntity<String> serveEndpoint(
response);
}

@PostMapping(path = "/{endpoint}/{method}", consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
public ResponseEntity<String> serveMultipartEndpoint(
@PathVariable("endpoint") String endpointName,
@PathVariable("method") String methodName,
HttpServletRequest request, HttpServletResponse response)
throws IOException {
return doServeEndpoint(endpointName, methodName, null, request,
response);
}

/**
* Captures and processes the Vaadin endpoint requests.
* <p>
Expand Down Expand Up @@ -227,6 +250,38 @@ private ResponseEntity<String> doServeEndpoint(String endpointName,
if (enforcementResult.isEnforcementNeeded()) {
return buildEnforcementResponseEntity(enforcementResult);
}

if (isMultipartRequest(request)) {
try {
var multipartRequest = (MultipartHttpServletRequest) request;
var fileMap = multipartRequest.getFileMap();
var bodyPart = multipartRequest.getParts().stream()
.filter(part -> part.getSubmittedFileName() == null)
.filter(part -> MediaType.APPLICATION_JSON_VALUE
.equals(part.getContentType()))
.findAny().orElseThrow();

body = objectMapper.readValue(bodyPart.getInputStream(),
ObjectNode.class);

for (var entry : fileMap.entrySet()) {
var partName = entry.getKey();
var file = entry.getValue();
var pointer = JsonPointer.valueOf(partName);
var parent = pointer.head();
var property = pointer.last().getMatchingProperty();
var parentObject = body.withObject(parent);
parentObject.putPOJO(property, file);
}
} catch (IOException | ServletException e) {
LOGGER.error("Error processing multipart request parts", e);
return ResponseEntity
.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(endpointInvoker.createResponseErrorObject(
"Error processing multipart request parts"));
}
}

Object returnValue = endpointInvoker.invoke(endpointName,
methodName, body, request.getUserPrincipal(),
request::isUserInRole);
Expand Down Expand Up @@ -273,6 +328,12 @@ private ResponseEntity<String> doServeEndpoint(String endpointName,
}
}

private boolean isMultipartRequest(HttpServletRequest request) {
String contentType = request.getContentType();
return contentType != null
&& contentType.startsWith(MediaType.MULTIPART_FORM_DATA_VALUE);
}

private ResponseEntity<String> buildEnforcementResponseEntity(
DAUUtils.EnforcementResult enforcementResult) {
EnforcementNotificationMessages messages = enforcementResult.messages();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,27 @@
*/
package com.vaadin.hilla.endpointransfermapper;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.multipart.MultipartFile;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.deser.std.StdDelegatingDeserializer;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.databind.node.POJONode;
import com.fasterxml.jackson.databind.ser.std.StdDelegatingSerializer;
import com.fasterxml.jackson.databind.type.TypeFactory;
import com.fasterxml.jackson.databind.util.StdConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Defines mappings for certain endpoint types to corresponding transfer types.
Expand Down Expand Up @@ -174,6 +182,8 @@ public JavaType getOutputType(TypeFactory typeFactory) {
});

jacksonModule.addDeserializer(endpointType, deserializer);
jacksonModule.addDeserializer(MultipartFile.class,
new MultipartFileDeserializer());
}

/**
Expand Down Expand Up @@ -318,4 +328,24 @@ private Logger getLogger() {
return LoggerFactory.getLogger(getClass());
}

public static class MultipartFileDeserializer
extends JsonDeserializer<MultipartFile> {

@Override
public MultipartFile deserialize(JsonParser p,
DeserializationContext ctxt) throws IOException {
JsonNode node = p.getCodec().readTree(p);

if (node instanceof POJONode) {
Object pojo = ((POJONode) node).getPojo();

if (pojo instanceof MultipartFile) {
return (MultipartFile) pojo;
}
}

throw new IOException(
"Expected a POJONode wrapping a MultipartFile");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ public void setUp() {
EndpointRegistry endpointRegistry = new EndpointRegistry(
new EndpointNameChecker());
ApplicationContext appCtx = Mockito.mock(ApplicationContext.class);
ObjectMapper objectMapper = new JacksonObjectMapperFactory.Json()
.build();
EndpointInvoker endpointInvoker = new EndpointInvoker(appCtx,
new JacksonObjectMapperFactory.Json().build(),
new ExplicitNullableTypeChecker(), servletContext,
objectMapper, new ExplicitNullableTypeChecker(), servletContext,
endpointRegistry);
controller = new EndpointController(appCtx, endpointRegistry,
endpointInvoker, csrfChecker);
endpointInvoker, csrfChecker, objectMapper);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ public EndpointController build() {
EndpointInvoker invoker = Mockito.spy(
new EndpointInvoker(applicationContext, endpointObjectMapper,
explicitNullableTypeChecker, servletContext, registry));
EndpointController controller = Mockito.spy(new EndpointController(
applicationContext, registry, invoker, csrfChecker));
EndpointController controller = Mockito
.spy(new EndpointController(applicationContext, registry,
invoker, csrfChecker, endpointObjectMapper));
Mockito.doReturn(mock(EndpointAccessChecker.class)).when(invoker)
.getAccessChecker();
return controller;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,8 @@ public void should_Never_UseSpringObjectMapper() {
endpointObjectMapper, mock(ExplicitNullableTypeChecker.class),
mock(ServletContext.class), registry);

new EndpointController(contextMock, registry, invoker, null)
.registerEndpoints();
new EndpointController(contextMock, registry, invoker, null,
mockOwnObjectMapper).registerEndpoints();

verify(contextMock, never()).getBean(ObjectMapper.class);
verify(contextMock, times(1))
Expand Down Expand Up @@ -1312,7 +1312,7 @@ private <T> EndpointController createVaadinController(T endpoint,

EndpointController connectController = Mockito
.spy(new EndpointController(mockApplicationContext, registry,
invoker, csrfChecker));
invoker, csrfChecker, endpointObjectMapper));
connectController.registerEndpoints();
return connectController;
}
Expand Down

0 comments on commit 26922f5

Please sign in to comment.