diff --git a/document-store/build.gradle.kts b/document-store/build.gradle.kts index cc0028a4..6b0be8e3 100644 --- a/document-store/build.gradle.kts +++ b/document-store/build.gradle.kts @@ -19,6 +19,7 @@ dependencies { implementation("org.apache.commons:commons-lang3:3.10") implementation("net.jodah:failsafe:2.4.0") implementation("com.google.guava:guava:31.1-jre") + testImplementation("org.junit.jupiter:junit-jupiter:5.8.2") testImplementation("org.mockito:mockito-core:4.4.0") testImplementation("org.mockito:mockito-junit-jupiter:4.4.0") diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresClient.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresClient.java new file mode 100644 index 00000000..d5c805a7 --- /dev/null +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresClient.java @@ -0,0 +1,114 @@ +package org.hypertrace.core.documentstore.postgres; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class PostgresClient { + + private static final Logger log = LoggerFactory.getLogger(PostgresClient.class); + private static final int VALIDATION_QUERY_TIMEOUT_SECONDS = 5; + + private final String url; + private final String user; + private final String password; + private final int maxConnectionAttempts; + private final Duration connectionRetryBackoff; + + private int count = 0; + private Connection connection; + + public PostgresClient( + String url, + String user, + String password, + int maxConnectionAttempts, + Duration connectionRetryBackoff) { + this.url = url; + this.user = user; + this.password = password; + this.maxConnectionAttempts = maxConnectionAttempts; + this.connectionRetryBackoff = connectionRetryBackoff; + } + + public synchronized Connection getConnection() { + try { + if (connection == null) { + newConnection(); + } else if (!isConnectionValid(connection)) { + log.info("The database connection is invalid. Reconnecting..."); + close(); + newConnection(); + } + } catch (SQLException sqle) { + throw new RuntimeException(sqle); + } + return connection; + } + + private boolean isConnectionValid(Connection connection) { + try { + if (connection.getMetaData().getJDBCMajorVersion() >= 4) { + return connection.isValid(VALIDATION_QUERY_TIMEOUT_SECONDS); + } else { + try (PreparedStatement preparedStatement = connection.prepareStatement("SELECT 1"); + ResultSet resultSet = preparedStatement.executeQuery()) { + return true; + } + } + } catch (SQLException sqle) { + log.debug("Unable to check if the underlying connection is valid", sqle); + return false; + } + } + + private void newConnection() throws SQLException { + ++count; + int attempts = 0; + while (attempts < maxConnectionAttempts) { + try { + ++attempts; + log.info("Attempting(attempt #{}) to open connection #{} to {}", attempts, count, url); + connection = DriverManager.getConnection(url, user, password); + return; + } catch (SQLException sqle) { + attempts++; + if (attempts < maxConnectionAttempts) { + log.info( + "Unable to connect(#{}) to database on attempt {}/{}. Will retry in {} ms.", + count, + attempts, + maxConnectionAttempts, + connectionRetryBackoff, + sqle); + try { + TimeUnit.MILLISECONDS.sleep(connectionRetryBackoff.toMillis()); + } catch (InterruptedException e) { + // this is ok because just woke up early + } + } else { + throw sqle; + } + } + } + } + + private void close() { + if (connection != null) { + try { + log.info("Closing connection #{} to {}", count, url); + connection.close(); + } catch (SQLException sqle) { + log.warn("Ignoring error closing connection", sqle); + } finally { + connection = null; + } + } + } +} diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresCollection.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresCollection.java index be86284c..29970a25 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresCollection.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresCollection.java @@ -9,7 +9,6 @@ import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.sql.BatchUpdateException; -import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; @@ -63,10 +62,10 @@ public class PostgresCollection implements Collection { private static final ObjectMapper MAPPER = new ObjectMapper(); private static final CloseableIterator EMPTY_ITERATOR = createEmptyIterator(); - private final Connection client; + private final PostgresClient client; private final String collectionName; - public PostgresCollection(Connection client, String collectionName) { + public PostgresCollection(PostgresClient client, String collectionName) { this.client = client; this.collectionName = collectionName; } @@ -74,7 +73,7 @@ public PostgresCollection(Connection client, String collectionName) { @Override public boolean upsert(Key key, Document document) throws IOException { try (PreparedStatement preparedStatement = - client.prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) { + client.getConnection().prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) { String jsonString = prepareDocument(key, document); preparedStatement.setString(1, key.toString()); preparedStatement.setString(2, jsonString); @@ -127,7 +126,7 @@ public UpdateResult update(Key key, Document document, Filter condition) throws @Override public CreateResult create(Key key, Document document) throws IOException { try (PreparedStatement preparedStatement = - client.prepareStatement(getInsertSQL(), Statement.RETURN_GENERATED_KEYS)) { + client.getConnection().prepareStatement(getInsertSQL(), Statement.RETURN_GENERATED_KEYS)) { String jsonString = prepareDocument(key, document); preparedStatement.setString(1, key.toString()); preparedStatement.setString(2, jsonString); @@ -226,7 +225,7 @@ private boolean updateSubDocInternal(Key key, String subDocPath, Document subDoc String jsonString = subDocument.toJson(); try (PreparedStatement preparedStatement = - client.prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS)) { + client.getConnection().prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS)) { preparedStatement.setString(1, jsonSubDocPath); preparedStatement.setString(2, jsonString); preparedStatement.setString(3, key.toString()); @@ -272,7 +271,8 @@ private BulkUpdateSubDocsInternalResult bulkUpdateSubDocsInternal( "UPDATE %s SET %s=jsonb_set(%s, ?::text[], ?::jsonb) WHERE %s = ?", collectionName, DOCUMENT, DOCUMENT, ID); try { - PreparedStatement preparedStatement = client.prepareStatement(updateSubDocSQL); + PreparedStatement preparedStatement = + client.getConnection().prepareStatement(updateSubDocSQL); for (Key key : documents.keySet()) { orderList.add(key); Map subDocuments = documents.get(key); @@ -410,7 +410,7 @@ public long count(org.hypertrace.core.documentstore.query.Query query) { @Override public boolean delete(Key key) { String deleteSQL = String.format("DELETE FROM %s WHERE %s = ?", collectionName, ID); - try (PreparedStatement preparedStatement = client.prepareStatement(deleteSQL)) { + try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(deleteSQL)) { preparedStatement.setString(1, key.toString()); preparedStatement.executeUpdate(); return true; @@ -460,7 +460,7 @@ public BulkDeleteResult delete(Set keys) { .append(ids) .append(")") .toString(); - try (PreparedStatement preparedStatement = client.prepareStatement(deleteSQL)) { + try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(deleteSQL)) { int deletedCount = preparedStatement.executeUpdate(); return new BulkDeleteResult(deletedCount); } catch (SQLException e) { @@ -477,7 +477,7 @@ public boolean deleteSubDoc(Key key, String subDocPath) { String jsonSubDocPath = getJsonSubDocPath(subDocPath); try (PreparedStatement preparedStatement = - client.prepareStatement(deleteSubDocSQL, Statement.RETURN_GENERATED_KEYS)) { + client.getConnection().prepareStatement(deleteSubDocSQL, Statement.RETURN_GENERATED_KEYS)) { preparedStatement.setString(1, jsonSubDocPath); preparedStatement.setString(2, key.toString()); int resultSet = preparedStatement.executeUpdate(); @@ -497,7 +497,7 @@ public boolean deleteSubDoc(Key key, String subDocPath) { @Override public boolean deleteAll() { String deleteSQL = String.format("DELETE FROM %s", collectionName); - try (PreparedStatement preparedStatement = client.prepareStatement(deleteSQL)) { + try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(deleteSQL)) { preparedStatement.executeUpdate(); return true; } catch (SQLException e) { @@ -510,7 +510,7 @@ public boolean deleteAll() { public long count() { String countSQL = String.format("SELECT COUNT(*) FROM %s", collectionName); long count = -1; - try (PreparedStatement preparedStatement = client.prepareStatement(countSQL)) { + try (PreparedStatement preparedStatement = client.getConnection().prepareStatement(countSQL)) { ResultSet resultSet = preparedStatement.executeQuery(); while (resultSet.next()) { count = resultSet.getLong(1); @@ -596,7 +596,7 @@ public CloseableIterator bulkUpsertAndReturnOlderDocuments(Map bulkUpsertAndReturnOlderDocuments(Map documents) throws SQLException, IOException { try (PreparedStatement preparedStatement = - client.prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) { + client.getConnection().prepareStatement(getUpsertSQL(), Statement.RETURN_GENERATED_KEYS)) { for (Map.Entry entry : documents.entrySet()) { Key key = entry.getKey(); @@ -902,7 +903,7 @@ private long bulkUpdateRequestsWithoutFilter(List requestsWit long totalRowsUpdated = 0; try { - PreparedStatement ps = client.prepareStatement(getUpdateSQL()); + PreparedStatement ps = client.getConnection().prepareStatement(getUpdateSQL()); for (BulkUpdateRequest req : requestsWithoutFilter) { Key key = req.getKey(); @@ -974,7 +975,7 @@ private long updateLastModifiedTime(Set keys) { long now = System.currentTimeMillis(); try { PreparedStatement preparedStatement = - client.prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS); + client.getConnection().prepareStatement(updateSubDocSQL, Statement.RETURN_GENERATED_KEYS); for (Key key : keys) { preparedStatement.setString(1, String.valueOf(now)); preparedStatement.setString(2, key.toString()); diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresDatastore.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresDatastore.java index 79d54a83..94c078dd 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresDatastore.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/PostgresDatastore.java @@ -12,6 +12,7 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.time.Duration; import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -28,8 +29,10 @@ public class PostgresDatastore implements Datastore { private static final String DEFAULT_USER = "postgres"; private static final String DEFAULT_PASSWORD = "postgres"; private static final String DEFAULT_DB_NAME = "postgres"; + private static final int DEFAULT_MAX_CONNECTION_ATTEMPTS = 200; + private static final Duration DEFAULT_CONNECTION_RETRY_BACKOFF = Duration.ofSeconds(5); - private Connection client; + private PostgresClient client; private String database; @Override @@ -47,13 +50,22 @@ public boolean init(Config config) { url = String.format("jdbc:postgresql://%s:%s/", hostName, port); } - String DEFAULT_USER = "postgres"; String user = config.hasPath("user") ? config.getString("user") : DEFAULT_USER; String password = config.hasPath("password") ? config.getString("password") : DEFAULT_PASSWORD; + int maxConnectionAttempts = + config.hasPath("maxConnectionAttempts") + ? config.getInt("maxConnectionAttempts") + : DEFAULT_MAX_CONNECTION_ATTEMPTS; + Duration connectionRetryBackoff = + config.hasPath("connectionRetryBackoff") + ? config.getDuration("connectionRetryBackoff") + : DEFAULT_CONNECTION_RETRY_BACKOFF; String finalUrl = url + this.database; - client = DriverManager.getConnection(finalUrl, user, password); + client = + new PostgresClient( + finalUrl, user, password, maxConnectionAttempts, connectionRetryBackoff); } catch (IllegalArgumentException e) { throw new IllegalArgumentException( @@ -69,7 +81,7 @@ public boolean init(Config config) { public Set listCollections() { Set collections = new HashSet<>(); try { - DatabaseMetaData metaData = client.getMetaData(); + DatabaseMetaData metaData = client.getConnection().getMetaData(); ResultSet tables = metaData.getTables(null, null, "%", new String[] {"TABLE"}); while (tables.next()) { collections.add(database + "." + tables.getString("TABLE_NAME")); @@ -91,7 +103,8 @@ public boolean createCollection(String collectionName, Map optio + "%s TIMESTAMPTZ NOT NULL DEFAULT NOW()" + ");", collectionName, ID, DOCUMENT, CREATED_AT, UPDATED_AT); - try (PreparedStatement preparedStatement = client.prepareStatement(createTableSQL)) { + try (PreparedStatement preparedStatement = + client.getConnection().prepareStatement(createTableSQL)) { preparedStatement.executeUpdate(); } catch (SQLException e) { LOGGER.error("Exception creating table name: {}", collectionName); @@ -103,7 +116,8 @@ public boolean createCollection(String collectionName, Map optio @Override public boolean deleteCollection(String collectionName) { String dropTableSQL = String.format("DROP TABLE IF EXISTS %s", collectionName); - try (PreparedStatement preparedStatement = client.prepareStatement(dropTableSQL)) { + try (PreparedStatement preparedStatement = + client.getConnection().prepareStatement(dropTableSQL)) { int result = preparedStatement.executeUpdate(); return result >= 0; } catch (SQLException e) { @@ -124,7 +138,8 @@ public Collection getCollection(String collectionName) { @Override public boolean healthCheck() { String healtchCheckSQL = "SELECT 1;"; - try (PreparedStatement preparedStatement = client.prepareStatement(healtchCheckSQL)) { + try (PreparedStatement preparedStatement = + client.getConnection().prepareStatement(healtchCheckSQL)) { return preparedStatement.execute(); } catch (SQLException e) { LOGGER.error("Exception executing health check"); @@ -133,6 +148,6 @@ public boolean healthCheck() { } public Connection getPostgresClient() { - return client; + return client.getConnection(); } }