Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
cromoteca committed Jan 27, 2025
1 parent 891d775 commit 7221848
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.POJONode;
import com.vaadin.flow.internal.CurrentInstance;
import com.vaadin.flow.server.VaadinRequest;
import com.vaadin.flow.server.VaadinService;
Expand All @@ -51,7 +50,6 @@
import com.vaadin.hilla.auth.EndpointAccessChecker;
import com.vaadin.hilla.exception.EndpointException;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

Expand All @@ -78,6 +76,8 @@ public class EndpointController {
private static final Logger LOGGER = LoggerFactory
.getLogger(EndpointController.class);

public static final String BODY_PART_NAME = "hilla_body_part";

static final String ENDPOINT_METHODS = "/{endpoint}/{method}";

/**
Expand Down Expand Up @@ -192,7 +192,7 @@ public ResponseEntity<String> serveEndpoint(
response);
}

@PostMapping(path = "/{endpoint}/{method}", consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
@PostMapping(path = ENDPOINT_METHODS, consumes = MediaType.MULTIPART_FORM_DATA_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
public ResponseEntity<String> serveMultipartEndpoint(
@PathVariable("endpoint") String endpointName,
@PathVariable("method") String methodName,
Expand Down Expand Up @@ -252,33 +252,34 @@ private ResponseEntity<String> doServeEndpoint(String endpointName,
}

if (isMultipartRequest(request)) {
var multipartRequest = (MultipartHttpServletRequest) request;
var bodyPart = multipartRequest.getParameter(BODY_PART_NAME);

if (bodyPart == null) {
return ResponseEntity.badRequest()
.body(endpointInvoker.createResponseErrorObject(
"Missing body part in multipart request"));
}

try {
var multipartRequest = (MultipartHttpServletRequest) request;
var fileMap = multipartRequest.getFileMap();
var bodyPart = multipartRequest.getParts().stream()
.filter(part -> part.getSubmittedFileName() == null)
.filter(part -> MediaType.APPLICATION_JSON_VALUE
.equals(part.getContentType()))
.findAny().orElseThrow();

body = objectMapper.readValue(bodyPart.getInputStream(),
ObjectNode.class);

for (var entry : fileMap.entrySet()) {
var partName = entry.getKey();
var file = entry.getValue();
var pointer = JsonPointer.valueOf(partName);
var parent = pointer.head();
var property = pointer.last().getMatchingProperty();
var parentObject = body.withObject(parent);
parentObject.putPOJO(property, file);
}
} catch (IOException | ServletException e) {
LOGGER.error("Error processing multipart request parts", e);
body = objectMapper.readValue(bodyPart, ObjectNode.class);
} catch (IOException e) {
LOGGER.error("Request body does not contain valid JSON", e);
return ResponseEntity
.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(endpointInvoker.createResponseErrorObject(
"Error processing multipart request parts"));
"Request body does not contain valid JSON"));
}

var fileMap = multipartRequest.getFileMap();
for (var entry : fileMap.entrySet()) {
var partName = entry.getKey();
var file = entry.getValue();
var pointer = JsonPointer.valueOf(partName);
var parent = pointer.head();
var property = pointer.last().getMatchingProperty();
var parentObject = body.withObject(parent);
parentObject.putPOJO(property, file);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
import com.vaadin.hilla.exception.EndpointValidationException;
import com.vaadin.hilla.exception.EndpointValidationException.ValidationErrorData;
import com.vaadin.hilla.parser.jackson.JacksonObjectMapperFactory;

import jakarta.servlet.ServletContext;
import jakarta.validation.ConstraintViolation;
import jakarta.validation.Validation;
import jakarta.validation.Validator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
Expand All @@ -46,10 +48,12 @@
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Type;
import java.security.Principal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand All @@ -61,6 +65,8 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import io.swagger.v3.oas.annotations.links.Link;

/**
* Handles invocation of endpoint methods after checking the user has proper
* access.
Expand Down Expand Up @@ -303,13 +309,30 @@ private Method getMethod(String endpointName, String methodName) {
return endpointData.getMethod(methodName).orElse(null);
}

private Map<String, JsonNode> getRequestParameters(ObjectNode body) {
private Map<String, JsonNode> getRequestParameters(ObjectNode body,
List<String> parameterNames) {
Map<String, JsonNode> parametersData = new LinkedHashMap<>();
if (body != null) {
body.fields().forEachRemaining(entry -> parametersData
.put(entry.getKey(), entry.getValue()));
}
return parametersData;

// restore the order of parameters
var orderedData = new LinkedHashMap<String, JsonNode>();

for (String parameterName : parameterNames) {
JsonNode parameterData = parametersData.get(parameterName);
if (parameterData != null) {
parametersData.remove(parameterName);
orderedData.put(parameterName, parameterData);
} else {
getLogger().debug("Parameter '{}' not found in request body",
parameterName);
}
}

orderedData.putAll(parametersData);
return orderedData;
}

private Object[] getVaadinEndpointParameters(
Expand Down Expand Up @@ -404,7 +427,10 @@ private Object invokeVaadinEndpointMethod(String endpointName,
endpointName, methodName, checkError));
}

Map<String, JsonNode> requestParameters = getRequestParameters(body);
var parameterNames = Arrays.stream(methodToInvoke.getParameters())
.map(Parameter::getName).toList();
Map<String, JsonNode> requestParameters = getRequestParameters(body,
parameterNames);
Type[] javaParameters = getJavaParameters(methodToInvoke, ClassUtils
.getUserClass(vaadinEndpointData.getEndpointObject()));
if (javaParameters.length != requestParameters.size()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.vaadin.hilla;

import java.io.ByteArrayInputStream;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
Expand Down Expand Up @@ -28,6 +30,7 @@
import java.util.stream.Stream;

import com.vaadin.hilla.engine.EngineConfiguration;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
Expand All @@ -44,6 +47,8 @@
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartHttpServletRequest;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -75,6 +80,7 @@
import jakarta.annotation.security.PermitAll;
import jakarta.annotation.security.RolesAllowed;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.validation.constraints.Min;
Expand Down Expand Up @@ -158,6 +164,14 @@ public String getUserName() {
return VaadinService.getCurrentRequest().getUserPrincipal()
.getName();
}

@AnonymousAllowed
public String checkFileLength(MultipartFile fileToCheck,
long expectedLength) {
return fileToCheck.getSize() == expectedLength
? "Check file length OK"
: "Check file length FAILED";
}
}

@Endpoint("CustomEndpoint")
Expand Down Expand Up @@ -445,6 +459,40 @@ public void should_NotCallMethod_When_DenyAll() {
.contains(EndpointAccessChecker.ACCESS_DENIED_MSG));
}

@Test
public void should_AcceptMultipartFile()
throws IOException, ServletException {
var request = mock(MultipartHttpServletRequest.class);
when(request.getUserPrincipal()).thenReturn(mock(Principal.class));
when(request.getHeader("X-CSRF-Token")).thenReturn("Vaadin Fusion");
when(request.getContentType()).thenReturn("multipart/form-data");

// hilla request body
when(request.getParameter(EndpointController.BODY_PART_NAME))
.thenReturn("{\"expectedLength\":5}");

// uploaded file
var file = mock(MultipartFile.class);
when(request.getFileMap())
.thenReturn(Collections.singletonMap("/fileToCheck", file));
when(file.getOriginalFilename()).thenReturn("hello.txt");
when(file.getSize()).thenReturn(5L);
when(file.getInputStream())
.thenReturn(new ByteArrayInputStream("Hello".getBytes()));

ServletContext servletContext = mockServletContext();
when(request.getServletContext()).thenReturn(servletContext);
when(request.getCookies()).thenReturn(new Cookie[] {
new Cookie(ApplicationConstants.CSRF_TOKEN, "Vaadin Fusion") });

var vaadinController = createVaadinController(TEST_ENDPOINT);
var response = vaadinController.serveMultipartEndpoint(
TEST_ENDPOINT_NAME, "checkFileLength", request, null);

assertEquals(HttpStatus.OK, response.getStatusCode());
assertTrue(response.getBody().contains("Check file length OK"));
}

@Test
@Ignore("FIXME: this test is flaky, it fails when executed fast enough")
public void should_bePossibeToGetPrincipalInEndpoint() {
Expand Down

0 comments on commit 7221848

Please sign in to comment.