Skip to content

Commit

Permalink
feat: create new connection on connection validation error in postgres (
Browse files Browse the repository at this point in the history
  • Loading branch information
saxenakshitiz authored Aug 25, 2022
1 parent e80f843 commit a42561f
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 26 deletions.
1 change: 1 addition & 0 deletions document-store/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies {
implementation("org.apache.commons:commons-lang3:3.10")
implementation("net.jodah:failsafe:2.4.0")
implementation("com.google.guava:guava:31.1-jre")

testImplementation("org.junit.jupiter:junit-jupiter:5.8.2")
testImplementation("org.mockito:mockito-core:4.4.0")
testImplementation("org.mockito:mockito-junit-jupiter:4.4.0")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package org.hypertrace.core.documentstore.postgres;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class PostgresClient {

private static final Logger log = LoggerFactory.getLogger(PostgresClient.class);
private static final int VALIDATION_QUERY_TIMEOUT_SECONDS = 5;

private final String url;
private final String user;
private final String password;
private final int maxConnectionAttempts;
private final Duration connectionRetryBackoff;

private int count = 0;
private Connection connection;

public PostgresClient(
String url,
String user,
String password,
int maxConnectionAttempts,
Duration connectionRetryBackoff) {
this.url = url;
this.user = user;
this.password = password;
this.maxConnectionAttempts = maxConnectionAttempts;
this.connectionRetryBackoff = connectionRetryBackoff;
}

public synchronized Connection getConnection() {
try {
if (connection == null) {
newConnection();
} else if (!isConnectionValid(connection)) {
log.info("The database connection is invalid. Reconnecting...");
close();
newConnection();
}
} catch (SQLException sqle) {
throw new RuntimeException(sqle);
}
return connection;
}

private boolean isConnectionValid(Connection connection) {
try {
if (connection.getMetaData().getJDBCMajorVersion() >= 4) {
return connection.isValid(VALIDATION_QUERY_TIMEOUT_SECONDS);
} else {
try (PreparedStatement preparedStatement = connection.prepareStatement("SELECT 1");
ResultSet resultSet = preparedStatement.executeQuery()) {
return true;
}
}
} catch (SQLException sqle) {
log.debug("Unable to check if the underlying connection is valid", sqle);
return false;
}
}

private void newConnection() throws SQLException {
++count;
int attempts = 0;
while (attempts < maxConnectionAttempts) {
try {
++attempts;
log.info("Attempting(attempt #{}) to open connection #{} to {}", attempts, count, url);
connection = DriverManager.getConnection(url, user, password);
return;
} catch (SQLException sqle) {
attempts++;
if (attempts < maxConnectionAttempts) {
log.info(
"Unable to connect(#{}) to database on attempt {}/{}. Will retry in {} ms.",
count,
attempts,
maxConnectionAttempts,
connectionRetryBackoff,
sqle);
try {
TimeUnit.MILLISECONDS.sleep(connectionRetryBackoff.toMillis());
} catch (InterruptedException e) {
// this is ok because just woke up early
}
} else {
throw sqle;
}
}
}
}

private void close() {
if (connection != null) {
try {
log.info("Closing connection #{} to {}", count, url);
connection.close();
} catch (SQLException sqle) {
log.warn("Ignoring error closing connection", sqle);
} finally {
connection = null;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.sql.BatchUpdateException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
Expand Down Expand Up @@ -63,18 +62,18 @@ public class PostgresCollection implements Collection {
private static final ObjectMapper MAPPER = new ObjectMapper();
private static final CloseableIterator<Document> EMPTY_ITERATOR = createEmptyIterator();

private final Connection client;
private final PostgresClient client;
private final String collectionName;

public PostgresCollection(Connection client, String collectionName) {
public PostgresCollection(PostgresClient client, String collectionName) {
this.client = client;
this.collectionName = collectionName;
}

@Override
public boolean upsert(Key key, Document document) throws IOException {
try (PreparedStatement preparedStatement =
client.prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) {
client.getConnection().prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) {
String jsonString = prepareDocument(key, document);
preparedStatement.setString(1, key.toString());
preparedStatement.setString(2, jsonString);
Expand Down Expand Up @@ -127,7 +126,7 @@ public UpdateResult update(Key key, Document document, Filter condition) throws
@Override
public CreateResult create(Key key, Document document) throws IOException {
try (PreparedStatement preparedStatement =
client.prepareStatement(getInsertSQL(), Statement.RETURN_GENERATED_KEYS)) {
client.getConnection().prepareStatement(getInsertSQL(), Statement.RETURN_GENERATED_KEYS)) {
String jsonString = prepareDocument(key, document);
preparedStatement.setString(1, key.toString());
preparedStatement.setString(2, jsonString);
Expand Down Expand Up @@ -226,7 +225,7 @@ private boolean updateSubDocInternal(Key key, String subDocPath, Document subDoc
String jsonString = subDocument.toJson();

try (PreparedStatement preparedStatement =
client.prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS)) {
client.getConnection().prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS)) {
preparedStatement.setString(1, jsonSubDocPath);
preparedStatement.setString(2, jsonString);
preparedStatement.setString(3, key.toString());
Expand Down Expand Up @@ -272,7 +271,8 @@ private BulkUpdateSubDocsInternalResult bulkUpdateSubDocsInternal(
"UPDATE %s SET %s=jsonb_set(%s, ?::text[], ?::jsonb) WHERE %s = ?",
collectionName, DOCUMENT, DOCUMENT, ID);
try {
PreparedStatement preparedStatement = client.prepareStatement(updateSubDocSQL);
PreparedStatement preparedStatement =
client.getConnection().prepareStatement(updateSubDocSQL);
for (Key key : documents.keySet()) {
orderList.add(key);
Map<String, Document> subDocuments = documents.get(key);
Expand Down Expand Up @@ -410,7 +410,7 @@ public long count(org.hypertrace.core.documentstore.query.Query query) {
@Override
public boolean delete(Key key) {
String deleteSQL = String.format("DELETE FROM %s WHERE %s = ?", collectionName, ID);
try (PreparedStatement preparedStatement = client.prepareStatement(deleteSQL)) {
try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(deleteSQL)) {
preparedStatement.setString(1, key.toString());
preparedStatement.executeUpdate();
return true;
Expand Down Expand Up @@ -460,7 +460,7 @@ public BulkDeleteResult delete(Set<Key> keys) {
.append(ids)
.append(")")
.toString();
try (PreparedStatement preparedStatement = client.prepareStatement(deleteSQL)) {
try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(deleteSQL)) {
int deletedCount = preparedStatement.executeUpdate();
return new BulkDeleteResult(deletedCount);
} catch (SQLException e) {
Expand All @@ -477,7 +477,7 @@ public boolean deleteSubDoc(Key key, String subDocPath) {
String jsonSubDocPath = getJsonSubDocPath(subDocPath);

try (PreparedStatement preparedStatement =
client.prepareStatement(deleteSubDocSQL, Statement.RETURN_GENERATED_KEYS)) {
client.getConnection().prepareStatement(deleteSubDocSQL, Statement.RETURN_GENERATED_KEYS)) {
preparedStatement.setString(1, jsonSubDocPath);
preparedStatement.setString(2, key.toString());
int resultSet = preparedStatement.executeUpdate();
Expand All @@ -497,7 +497,7 @@ public boolean deleteSubDoc(Key key, String subDocPath) {
@Override
public boolean deleteAll() {
String deleteSQL = String.format("DELETE FROM %s", collectionName);
try (PreparedStatement preparedStatement = client.prepareStatement(deleteSQL)) {
try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(deleteSQL)) {
preparedStatement.executeUpdate();
return true;
} catch (SQLException e) {
Expand All @@ -510,7 +510,7 @@ public boolean deleteAll() {
public long count() {
String countSQL = String.format("SELECT COUNT(*) FROM %s", collectionName);
long count = -1;
try (PreparedStatement preparedStatement = client.prepareStatement(countSQL)) {
try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(countSQL)) {
ResultSet resultSet = preparedStatement.executeQuery();
while (resultSet.next()) {
count = resultSet.getLong(1);
Expand Down Expand Up @@ -596,7 +596,7 @@ public CloseableIterator<Document> bulkUpsertAndReturnOlderDocuments(Map<Key, Do
.append(")")
.toString();

PreparedStatement preparedStatement = client.prepareStatement(query);
PreparedStatement preparedStatement = client.getConnection().prepareStatement(query);
ResultSet resultSet = preparedStatement.executeQuery();

// Now go ahead and bulk upsert the documents.
Expand All @@ -618,7 +618,7 @@ public CloseableIterator<Document> bulkUpsertAndReturnOlderDocuments(Map<Key, Do
@VisibleForTesting
protected PreparedStatement buildPreparedStatement(String sqlQuery, Params params)
throws SQLException, RuntimeException {
PreparedStatement preparedStatement = client.prepareStatement(sqlQuery);
PreparedStatement preparedStatement = client.getConnection().prepareStatement(sqlQuery);
enrichPreparedStatementWithParams(preparedStatement, params);
return preparedStatement;
}
Expand All @@ -645,7 +645,8 @@ protected void enrichPreparedStatementWithParams(
@Override
public void drop() {
String dropTableSQL = String.format("DROP TABLE IF EXISTS %s", collectionName);
try (PreparedStatement preparedStatement = client.prepareStatement(dropTableSQL)) {
try (PreparedStatement preparedStatement =
client.getConnection().prepareStatement(dropTableSQL)) {
preparedStatement.executeUpdate();
} catch (SQLException e) {
LOGGER.error("Exception deleting table name: {}", collectionName);
Expand Down Expand Up @@ -823,7 +824,7 @@ private String getJsonSubDocPath(String subDocPath) {

private int[] bulkUpsertImpl(Map<Key, Document> documents) throws SQLException, IOException {
try (PreparedStatement preparedStatement =
client.prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) {
client.getConnection().prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) {
for (Map.Entry<Key, Document> entry : documents.entrySet()) {

Key key = entry.getKey();
Expand Down Expand Up @@ -902,7 +903,7 @@ private long bulkUpdateRequestsWithoutFilter(List<BulkUpdateRequest> requestsWit
long totalRowsUpdated = 0;
try {

PreparedStatement ps = client.prepareStatement(getUpdateSQL());
PreparedStatement ps = client.getConnection().prepareStatement(getUpdateSQL());

for (BulkUpdateRequest req : requestsWithoutFilter) {
Key key = req.getKey();
Expand Down Expand Up @@ -974,7 +975,7 @@ private long updateLastModifiedTime(Set<Key> keys) {
long now = System.currentTimeMillis();
try {
PreparedStatement preparedStatement =
client.prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS);
client.getConnection().prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS);
for (Key key : keys) {
preparedStatement.setString(1, String.valueOf(now));
preparedStatement.setString(2, key.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Duration;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
Expand All @@ -28,8 +29,10 @@ public class PostgresDatastore implements Datastore {
private static final String DEFAULT_USER = "postgres";
private static final String DEFAULT_PASSWORD = "postgres";
private static final String DEFAULT_DB_NAME = "postgres";
private static final int DEFAULT_MAX_CONNECTION_ATTEMPTS = 200;
private static final Duration DEFAULT_CONNECTION_RETRY_BACKOFF = Duration.ofSeconds(5);

private Connection client;
private PostgresClient client;
private String database;

@Override
Expand All @@ -47,13 +50,22 @@ public boolean init(Config config) {
url = String.format("jdbc:postgresql://%s:%s/", hostName, port);
}

String DEFAULT_USER = "postgres";
String user = config.hasPath("user") ? config.getString("user") : DEFAULT_USER;
String password =
config.hasPath("password") ? config.getString("password") : DEFAULT_PASSWORD;
int maxConnectionAttempts =
config.hasPath("maxConnectionAttempts")
? config.getInt("maxConnectionAttempts")
: DEFAULT_MAX_CONNECTION_ATTEMPTS;
Duration connectionRetryBackoff =
config.hasPath("connectionRetryBackoff")
? config.getDuration("connectionRetryBackoff")
: DEFAULT_CONNECTION_RETRY_BACKOFF;

String finalUrl = url + this.database;
client = DriverManager.getConnection(finalUrl, user, password);
client =
new PostgresClient(
finalUrl, user, password, maxConnectionAttempts, connectionRetryBackoff);

} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
Expand All @@ -69,7 +81,7 @@ public boolean init(Config config) {
public Set<String> listCollections() {
Set<String> collections = new HashSet<>();
try {
DatabaseMetaData metaData = client.getMetaData();
DatabaseMetaData metaData = client.getConnection().getMetaData();
ResultSet tables = metaData.getTables(null, null, "%", new String[] {"TABLE"});
while (tables.next()) {
collections.add(database + "." + tables.getString("TABLE_NAME"));
Expand All @@ -91,7 +103,8 @@ public boolean createCollection(String collectionName, Map<String, String> optio
+ "%s TIMESTAMPTZ NOT NULL DEFAULT NOW()"
+ ");",
collectionName, ID, DOCUMENT, CREATED_AT, UPDATED_AT);
try (PreparedStatement preparedStatement = client.prepareStatement(createTableSQL)) {
try (PreparedStatement preparedStatement =
client.getConnection().prepareStatement(createTableSQL)) {
preparedStatement.executeUpdate();
} catch (SQLException e) {
LOGGER.error("Exception creating table name: {}", collectionName);
Expand All @@ -103,7 +116,8 @@ public boolean createCollection(String collectionName, Map<String, String> optio
@Override
public boolean deleteCollection(String collectionName) {
String dropTableSQL = String.format("DROP TABLE IF EXISTS %s", collectionName);
try (PreparedStatement preparedStatement = client.prepareStatement(dropTableSQL)) {
try (PreparedStatement preparedStatement =
client.getConnection().prepareStatement(dropTableSQL)) {
int result = preparedStatement.executeUpdate();
return result >= 0;
} catch (SQLException e) {
Expand All @@ -124,7 +138,8 @@ public Collection getCollection(String collectionName) {
@Override
public boolean healthCheck() {
String healtchCheckSQL = "SELECT 1;";
try (PreparedStatement preparedStatement = client.prepareStatement(healtchCheckSQL)) {
try (PreparedStatement preparedStatement =
client.getConnection().prepareStatement(healtchCheckSQL)) {
return preparedStatement.execute();
} catch (SQLException e) {
LOGGER.error("Exception executing health check");
Expand All @@ -133,6 +148,6 @@ public boolean healthCheck() {
}

public Connection getPostgresClient() {
return client;
return client.getConnection();
}
}

0 comments on commit a42561f

Please sign in to comment.