Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
FANNG1 committed Mar 18, 2024
1 parent 2ab6749 commit 3086089
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,13 @@ private void checkTableColumns(
.check(tableInfo);
}

// Datasource partition column is in schema when creating table
@Test
public void testCreateDatasourceFormatPartitionTable() {
String tableName = "datasource_partition_table";

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + " USING PARQUET PARTITIONED BY (name, age)";
createTableSQL = createTableSQL + "USING PARQUET PARTITIONED BY (name, age)";
sql(createTableSQL);
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
Expand All @@ -449,7 +448,7 @@ public void testCreateHiveFormatPartitionTable() {

dropTableIfExists(tableName);
String createTableSQL = getCreateSimpleTableString(tableName);
createTableSQL = createTableSQL + "PARTITIONED BY (age_p1 int, age_p2 STRING)";
createTableSQL = createTableSQL + "PARTITIONED BY (age_p1 INT, age_p2 STRING)";
sql(createTableSQL);

List<SparkColumnInfo> columns = new ArrayList<>(getSimpleTableColumn());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void addPartition(Transform partition) {
this.partitions.add(partition);
if (partition instanceof IdentityTransform) {
partitionColumnNames.add(((IdentityTransform) partition).reference().fieldNames()[0]);
} else {
throw new NotSupportedException(partition.name() + " is not supported yet.");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
/**
* SparkTransformConverter translate between Spark transform and Gravitino partition, distribution,
* sort orders. There may be multi partition transforms, but should be only one bucket transform.
* Spark bucket transform corresponding Gravitino Hash distribution without sort orders. Spark
* sorted bucket transform corresponding Gravitino Hash distribution with sort orders.
* Spark bucket transform is corresponding to Gravitino Hash distribution without sort orders. Spark
* sorted bucket transform is corresponding to Gravitino Hash distribution with sort orders.
*/
public class SparkTransformConverter {

Expand Down Expand Up @@ -86,8 +86,7 @@ public static GravitinoTransformBundles toGravitinoTransform(
Distribution distribution = toGravitinoDistribution(bucketTransform);
bundles.setDistribution(distribution);
} else {
throw new NotSupportedException(
"Not support Spark transform, class name: " + transform.name());
throw new NotSupportedException("Not support Spark transform: " + transform.name());
}
});

Expand Down Expand Up @@ -138,14 +137,16 @@ private static Distribution toGravitinoDistribution(BucketTransform bucketTransf
return Distributions.hash(bucketNum, expressions);
}

// Spark datasourceV2 doesn't support specify sort order direction, use ASCENDING as default.
private static Pair<Distribution, SortOrder[]> toGravitinoDistributionAndSortOrders(
SortedBucketTransform sortedBucketTransform) {
int bucketNum = (Integer) sortedBucketTransform.numBuckets().value();
Expression[] bucketColumns =
transToGravitinoExpression(JavaConverters.seqAsJavaList(sortedBucketTransform.columns()));
transToGravitinoNamedReference(
JavaConverters.seqAsJavaList(sortedBucketTransform.columns()));

Expression[] sortColumns =
transToGravitinoExpression(
transToGravitinoNamedReference(
JavaConverters.seqAsJavaList(sortedBucketTransform.sortedColumns()));
SortOrder[] sortOrders =
Arrays.stream(sortColumns)
Expand All @@ -171,36 +172,31 @@ private static org.apache.spark.sql.connector.expressions.Transform toSparkBucke
String[] bucketFields =
Arrays.stream(distribution.expressions())
.map(
expression -> {
NamedReference namedReference = (NamedReference) expression;
String fieldName[] = namedReference.fieldName();
return String.join(ConnectorConstants.DOT, fieldName);
})
expression ->
getFieldNameFromGravitinoNamedReference((NamedReference) expression))
.toArray(String[]::new);
if (sortOrders == null || sortOrders.length == 0) {
return Expressions.bucket(bucketNum, bucketFields);
} else {
String[] sortOrderFields =
Arrays.stream(sortOrders)
.map(
sortOrder -> {
NamedReference gravitinoNameReference =
(NamedReference) sortOrder.expression();
return String.join(
ConnectorConstants.DOT, gravitinoNameReference.fieldName());
})
sortOrder ->
getFieldNameFromGravitinoNamedReference(
(NamedReference) sortOrder.expression()))
.toArray(String[]::new);
return createSortBucketTransform(bucketNum, bucketFields, sortOrderFields);
}
// Spark doesn't support EVEN or RANGE distribution
default:
throw new NotSupportedException("Not support " + distribution.strategy());
throw new NotSupportedException(
"Not support distribution strategy: " + distribution.strategy());
}
}

private static Expression[] transToGravitinoExpression(
List<org.apache.spark.sql.connector.expressions.NamedReference> references) {
return references.stream()
private static Expression[] transToGravitinoNamedReference(
List<org.apache.spark.sql.connector.expressions.NamedReference> sparkNamedReferences) {
return sparkNamedReferences.stream()
.map(sparkReference -> NamedReference.field(sparkReference.fieldNames()))
.toArray(Expression[]::new);
}
Expand All @@ -222,4 +218,10 @@ public static IdentityTransform createSparkIdentityTransform(String columnName)
.map(Expressions::column)
.toArray(org.apache.spark.sql.connector.expressions.NamedReference[]::new);
}

// Gravitino use ["a","b"] for nested fields while Spark use "a.b";
private static String getFieldNameFromGravitinoNamedReference(
NamedReference gravitinoNamedReference) {
return String.join(ConnectorConstants.DOT, gravitinoNamedReference.fieldName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import com.datastrato.gravitino.rel.expressions.NamedReference;
import com.datastrato.gravitino.rel.expressions.distributions.Distribution;
import com.datastrato.gravitino.rel.expressions.distributions.Distributions;
import com.datastrato.gravitino.rel.expressions.distributions.Strategy;
import com.datastrato.gravitino.rel.expressions.sorts.SortDirection;
import com.datastrato.gravitino.rel.expressions.sorts.SortOrder;
import com.datastrato.gravitino.rel.expressions.sorts.SortOrders;
Expand Down Expand Up @@ -73,7 +72,7 @@ void testPartition() {
@Test
void testGravitinoToSparkDistributionWithoutSortOrder() {
int bucketNum = 16;
String[][] columnNames = createFieldReference("a", "b.c");
String[][] columnNames = createGravitinoFieldReferenceNames("a", "b.c");
Distribution gravitinoDistribution = createHashDistribution(bucketNum, columnNames);

org.apache.spark.sql.connector.expressions.Transform[] sparkTransforms =
Expand Down Expand Up @@ -107,11 +106,10 @@ void testGravitinoToSparkDistributionWithoutSortOrder() {
@Test
void testSparkToGravitinoDistributionWithoutSortOrder() {
int bucketNum = 16;
String[] columnNames = new String[] {"a", "b.c"};
String[][] columnNames2 = createFieldReference(columnNames);
String[] sparkFieldReferences = new String[] {"a", "b.c"};

org.apache.spark.sql.connector.expressions.Transform sparkBucket =
Expressions.bucket(bucketNum, columnNames);
Expressions.bucket(bucketNum, sparkFieldReferences);
GravitinoTransformBundles bundles =
SparkTransformConverter.toGravitinoTransform(
new org.apache.spark.sql.connector.expressions.Transform[] {sparkBucket});
Expand All @@ -120,25 +118,21 @@ void testSparkToGravitinoDistributionWithoutSortOrder() {
Assertions.assertNull(bundles.getPartitions());

Distribution distribution = bundles.getDistribution();
Assertions.assertEquals(Strategy.HASH, distribution.strategy());
Assertions.assertEquals(bucketNum, distribution.number());
String[][] fieldNames =
Arrays.stream(distribution.expressions())
.map(expression -> ((NamedReference) expression).fieldName())
.toArray(String[][]::new);
Assertions.assertArrayEquals(columnNames2, fieldNames);
String[][] gravitinoFieldReferences = createGravitinoFieldReferenceNames(sparkFieldReferences);
Assertions.assertTrue(
distribution.equals(createHashDistribution(bucketNum, gravitinoFieldReferences)));
}

@Test
void testSparkToGravitinoSortOrder() {
int bucketNum = 16;
String[][] bucketColumnNames = createFieldReference("a", "b.c");
String[][] sortColumnNames = createFieldReference("f", "m.n");
String[][] bucketColumnNames = createGravitinoFieldReferenceNames("a", "b.c");
String[][] sortColumnNames = createGravitinoFieldReferenceNames("f", "m.n");
SortedBucketTransform sortedBucketTransform =
LogicalExpressions.bucket(
bucketNum,
getSparkFieldReference(bucketColumnNames),
getSparkFieldReference(sortColumnNames));
createSparkFieldReference(bucketColumnNames),
createSparkFieldReference(sortColumnNames));

GravitinoTransformBundles bundles =
SparkTransformConverter.toGravitinoTransform(
Expand All @@ -162,8 +156,8 @@ void testSparkToGravitinoSortOrder() {
@Test
void testGravitinoToSparkSortOrder() {
int bucketNum = 16;
String[][] bucketColumnNames = createFieldReference("a", "b.c");
String[][] sortColumnNames = createFieldReference("f", "m.n");
String[][] bucketColumnNames = createGravitinoFieldReferenceNames("a", "b.c");
String[][] sortColumnNames = createGravitinoFieldReferenceNames("f", "m.n");
Distribution distribution = createHashDistribution(bucketNum, bucketColumnNames);
SortOrder[] sortOrders =
createSortOrders(sortColumnNames, ConnectorConstants.SPARK_DEFAULT_SORT_DIRECTION);
Expand All @@ -189,15 +183,15 @@ void testGravitinoToSparkSortOrder() {
Assertions.assertArrayEquals(sortColumnNames, sparkSortColumns);
}

private org.apache.spark.sql.connector.expressions.NamedReference[] getSparkFieldReference(
private org.apache.spark.sql.connector.expressions.NamedReference[] createSparkFieldReference(
String[][] fields) {
return Arrays.stream(fields)
.map(field -> FieldReference.apply(String.join(ConnectorConstants.DOT, field)))
.toArray(org.apache.spark.sql.connector.expressions.NamedReference[]::new);
}

// split column name
private String[][] createFieldReference(String... columnNames) {
// split column name for Gravitino
private String[][] createGravitinoFieldReferenceNames(String... columnNames) {
return Arrays.stream(columnNames)
.map(columnName -> columnName.split("\\."))
.toArray(String[][]::new);
Expand Down

0 comments on commit 3086089

Please sign in to comment.