Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getColumnDataTypes method to SchemaManager to get datatype for table columns #5135

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class SchemaManager {
Expand All @@ -24,6 +27,7 @@ public class SchemaManager {
static final String BINLOG_POSITION = "Position";
static final int NUM_OF_RETRIES = 3;
static final int BACKOFF_IN_MILLIS = 500;
static final String TYPE_NAME = "TYPE_NAME";
private final ConnectionManager connectionManager;

public SchemaManager(ConnectionManager connectionManager) {
Expand All @@ -35,11 +39,12 @@ public List<String> getPrimaryKeys(final String database, final String table) {
while (retry <= NUM_OF_RETRIES) {
final List<String> primaryKeys = new ArrayList<>();
try (final Connection connection = connectionManager.getConnection()) {
final ResultSet rs = connection.getMetaData().getPrimaryKeys(database, null, table);
while (rs.next()) {
primaryKeys.add(rs.getString(COLUMN_NAME));
try (final ResultSet rs = connection.getMetaData().getPrimaryKeys(database, null, table)) {
while (rs.next()) {
primaryKeys.add(rs.getString(COLUMN_NAME));
}
return primaryKeys;
}
return primaryKeys;
} catch (Exception e) {
LOG.error("Failed to get primary keys for table {}, retrying", table, e);
}
Expand All @@ -50,6 +55,33 @@ public List<String> getPrimaryKeys(final String database, final String table) {
return List.of();
}

public Map<String, String> getColumnDataTypes(final String database, final String tableName) {
final Map<String, String> columnsToDataType = new HashMap<>();
for (int retry = 0; retry <= NUM_OF_RETRIES; retry++) {
try (Connection connection = connectionManager.getConnection()) {
final DatabaseMetaData metaData = connection.getMetaData();

// Retrieve column metadata
try (ResultSet columns = metaData.getColumns(database, null, tableName, null)) {
while (columns.next()) {
columnsToDataType.put(
columns.getString(COLUMN_NAME),
columns.getString(TYPE_NAME)
);
}
}
} catch (final Exception e) {
LOG.error("Failed to get dataTypes for database {} table {}, retrying", database, tableName, e);
if (retry == NUM_OF_RETRIES) {
throw new RuntimeException(String.format("Failed to get dataTypes for database %s table %s after " +
"%d retries", database, tableName, retry), e);
}
}
applyBackoff();
}
return columnsToDataType;
}

public Optional<BinlogCoordinate> getCurrentBinaryLogPosition() {
int retry = 0;
while (retry <= NUM_OF_RETRIES) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,29 @@
import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_FILE;
import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_POSITION;
import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.BINLOG_STATUS_QUERY;
import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.COLUMN_NAME;
import static org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager.TYPE_NAME;

@ExtendWith(MockitoExtension.class)
class SchemaManagerTest {
Expand All @@ -41,6 +47,9 @@ class SchemaManagerTest {
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private Connection connection;

@Mock
private DatabaseMetaData databaseMetaData;

@Mock
private ResultSet resultSet;

Expand Down Expand Up @@ -105,6 +114,59 @@ void test_getCurrentBinaryLogPosition_throws_exception_then_returns_empty() thro
assertThat(binlogCoordinate.isPresent(), is(false));
}

@Test
public void getColumnDataTypes_whenFailedToRetrieveColumns_shouldThrowException() throws SQLException {
final String database = UUID.randomUUID().toString();
final String tableName = UUID.randomUUID().toString();

when(connectionManager.getConnection()).thenReturn(connection);
when(connection.getMetaData()).thenReturn(databaseMetaData);
when(databaseMetaData.getColumns(database, null, tableName, null)).thenThrow(new SQLException("Test exception"));

assertThrows(RuntimeException.class, () -> schemaManager.getColumnDataTypes(database, tableName));
}

@Test
public void getColumnDataTypes_whenFailedToGetConnection_shouldThrowException() throws SQLException {
final String database = UUID.randomUUID().toString();
final String tableName = UUID.randomUUID().toString();

when(connectionManager.getConnection()).thenThrow(new SQLException("Connection failed"));

assertThrows(RuntimeException.class, () -> schemaManager.getColumnDataTypes(database, tableName));
}

@Test
void getColumnDataTypes_whenColumnsExist_shouldReturnValidMapping() throws SQLException {
final String database = UUID.randomUUID().toString();
final String tableName = UUID.randomUUID().toString();
final Map<String, String> expectedColumnTypes = Map.of(
"id", "INTEGER",
"name", "VARCHAR",
"created_at", "TIMESTAMP"
);

// Setup the mocks
when(connectionManager.getConnection()).thenReturn(connection);
when(connection.getMetaData()).thenReturn(databaseMetaData);
when(databaseMetaData.getColumns(database, null, tableName, null))
.thenReturn(resultSet);

// Setup ResultSet to return our expected columns
when(resultSet.next())
.thenReturn(true, true, true, false); // Three columns, then done
when(resultSet.getString(COLUMN_NAME))
.thenReturn("id", "name", "created_at");
when(resultSet.getString(TYPE_NAME))
.thenReturn("INTEGER", "VARCHAR", "TIMESTAMP");

Map<String, String> result = schemaManager.getColumnDataTypes(database, tableName);

assertThat(result, notNullValue());
assertThat(result.size(), is(expectedColumnTypes.size()));
assertThat(result, equalTo(expectedColumnTypes));
}

private SchemaManager createObjectUnderTest() {
return new SchemaManager(connectionManager);
}
Expand Down
Loading