Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[controller] Log remote address in controller audit log #1507

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public class AuditInfo {
private String url;
private Map<String, String> params;
private String method;
private String clientIp;

public AuditInfo(Request request) {
this.url = request.url();
Expand All @@ -18,49 +19,31 @@ public AuditInfo(Request request) {
this.params.put(param, request.queryParams(param));
}
this.method = request.requestMethod();
this.clientIp = request.ip() + ":" + request.raw().getRemotePort();
}

/**
* @return a string representation of {@link AuditInfo} object.
*/
@Override
public String toString() {
StringJoiner joiner = new StringJoiner(" ");
joiner.add("[AUDIT]");
joiner.add(method);
joiner.add(url);
joiner.add(params.toString());
return joiner.toString();
return formatAuditMessage("[AUDIT]", null);
}

/**
* @return a audit-successful string.
*/
public String successString() {
return toString(true, null);
return formatAuditMessage("[AUDIT]", "SUCCESS");
}

/**
* @return a audit-failure string.
*/
public String failureString(String errMsg) {
return toString(false, errMsg);
return formatAuditMessage("[AUDIT]", "FAILURE: " + (errMsg != null ? errMsg : ""));
}

private String toString(boolean success, String errMsg) {
StringJoiner joiner = new StringJoiner(" ");
joiner.add("[AUDIT]");
if (success) {
joiner.add("SUCCESS");
} else {
joiner.add("FAILURE: ");
if (errMsg != null) {
joiner.add(errMsg);
}
private String formatAuditMessage(String prefix, String status) {
StringJoiner joiner = new StringJoiner(" ").add(prefix);

if (status != null) {
joiner.add(status);
}
joiner.add(method);
joiner.add(url);
joiner.add(params.toString());

joiner.add(method).add(url).add(params.toString()).add("ClientIP: " + clientIp);

return joiner.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package com.linkedin.venice.controller;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;

import java.util.HashSet;
import java.util.Set;
import javax.servlet.http.HttpServletRequest;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import spark.Request;


public class AuditInfoTest {
private static final String TEST_URL = "http://localhost/test";
private static final String METHOD_GET = "GET";
private static final String CLIENT_IP = "127.0.0.1";
private static final int CLIENT_PORT = 8080;
private static final String PARAM_1 = "param1";
private static final String PARAM_2 = "param2";
private static final String VALUE_1 = "value1";
private static final String VALUE_2 = "value2";
private static final String AUDIT_PREFIX = "[AUDIT]";
private static final String SUCCESS = "SUCCESS";
private static final String FAILURE = "FAILURE";
private static final String ERROR_MESSAGE = "Some error";

private Request request;
private AuditInfo auditInfo;
private HttpServletRequest httpServletRequest;

@BeforeMethod
public void setUp() {
request = mock(Request.class);
when(request.url()).thenReturn(TEST_URL);
when(request.requestMethod()).thenReturn(METHOD_GET);
when(request.ip()).thenReturn(CLIENT_IP);

Set<String> queryParams = new HashSet<>();
queryParams.add(PARAM_1);
queryParams.add(PARAM_2);

when(request.queryParams()).thenReturn(queryParams);
when(request.queryParams(PARAM_1)).thenReturn(VALUE_1);
when(request.queryParams(PARAM_2)).thenReturn(VALUE_2);

httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getRemotePort()).thenReturn(CLIENT_PORT);
when(request.raw()).thenReturn(httpServletRequest);

auditInfo = new AuditInfo(request);
}

@Test
public void testToStringReturnsExpectedFormat() {
String result = auditInfo.toString();
assertTrue(result.contains(AUDIT_PREFIX));
assertTrue(result.contains(METHOD_GET));
assertTrue(result.contains(TEST_URL));
assertTrue(result.contains(PARAM_1 + "=" + VALUE_1));
assertTrue(result.contains(PARAM_2 + "=" + VALUE_2));
assertTrue(result.contains("ClientIP: " + CLIENT_IP + ":" + CLIENT_PORT));
}

@Test
public void testSuccessStringReturnsExpectedFormat() {
String result = auditInfo.successString();
assertTrue(result.contains(AUDIT_PREFIX));
assertTrue(result.contains(SUCCESS));
assertTrue(result.contains(METHOD_GET));
assertTrue(result.contains(TEST_URL));
assertTrue(result.contains("ClientIP: " + CLIENT_IP));
}

@Test
public void testFailureStringReturnsExpectedFormat() {
String result = auditInfo.failureString(ERROR_MESSAGE);
assertTrue(result.contains(AUDIT_PREFIX));
assertTrue(result.contains(FAILURE));
assertTrue(result.contains(ERROR_MESSAGE));
assertTrue(result.contains(METHOD_GET));
assertTrue(result.contains(TEST_URL));
assertTrue(result.contains("ClientIP: " + CLIENT_IP));
}

@Test
public void testFailureStringHandlesNullErrorMessage() {
String result = auditInfo.failureString(null);
assertTrue(result.contains(AUDIT_PREFIX));
assertFalse(result.contains("null"));
}
}
Loading