Skip to content

Commit

Permalink
[Feature][Transform-V2] Spark support transform with multi-table (#8340)
Browse files Browse the repository at this point in the history
  • Loading branch information
Carl-Zhou-CN authored Dec 23, 2024
1 parent a53d809 commit e128ccc
Show file tree
Hide file tree
Showing 17 changed files with 84 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import org.apache.seatunnel.api.common.JobContext;
import org.apache.seatunnel.api.configuration.ReadonlyConfig;
import org.apache.seatunnel.api.configuration.util.ConfigValidator;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.factory.TableTransformFactory;
import org.apache.seatunnel.api.table.factory.TableTransformFactoryContext;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.transform.SeaTunnelFlatMapTransform;
import org.apache.seatunnel.api.transform.SeaTunnelMapTransform;
Expand All @@ -34,23 +34,20 @@
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelFactoryDiscovery;
import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelTransformPluginDiscovery;
import org.apache.seatunnel.translation.spark.execution.DatasetTableInfo;
import org.apache.seatunnel.translation.spark.serialization.SeaTunnelRowConverter;
import org.apache.seatunnel.translation.spark.utils.TypeConverterUtils;
import org.apache.seatunnel.translation.spark.execution.MultiTableManager;

import org.apache.commons.collections.CollectionUtils;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.catalyst.expressions.GenericRow;

import lombok.extern.slf4j.Slf4j;

import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -141,7 +138,7 @@ public List<DatasetTableInfo> execute(List<DatasetTableInfo> upstreamDataStreams
pluginOutputIdentifier,
new DatasetTableInfo(
inputDataset,
Collections.singletonList(transform.getProducedCatalogTable()),
transform.getProducedCatalogTables(),
pluginOutputIdentifier));
} catch (Exception e) {
throw new TaskExecuteException(
Expand All @@ -155,55 +152,52 @@ public List<DatasetTableInfo> execute(List<DatasetTableInfo> upstreamDataStreams
}

private Dataset<Row> sparkTransform(SeaTunnelTransform transform, DatasetTableInfo tableInfo) {
MultiTableManager inputManager =
new MultiTableManager(tableInfo.getCatalogTables().toArray(new CatalogTable[0]));
MultiTableManager outputManager =
new MultiTableManager(
(CatalogTable[])
transform.getProducedCatalogTables().toArray(new CatalogTable[0]));
Dataset<Row> stream = tableInfo.getDataset();
SeaTunnelDataType<?> inputDataType =
tableInfo.getCatalogTables().get(0).getSeaTunnelRowType();
SeaTunnelDataType<?> outputDataTYpe =
transform.getProducedCatalogTable().getSeaTunnelRowType();
StructType outputSchema = (StructType) TypeConverterUtils.parcel(outputDataTYpe);
SeaTunnelRowConverter inputRowConverter = new SeaTunnelRowConverter(inputDataType);
SeaTunnelRowConverter outputRowConverter = new SeaTunnelRowConverter(outputDataTYpe);
ExpressionEncoder<Row> encoder = RowEncoder.apply(outputSchema);

ExpressionEncoder<Row> encoder = RowEncoder.apply(outputManager.getTableSchema());
return stream.flatMap(
new TransformMapPartitionsFunction(
transform, inputRowConverter, outputRowConverter),
new TransformMapPartitionsFunction(transform, inputManager, outputManager),
encoder)
.filter(Objects::nonNull);
}

private static class TransformMapPartitionsFunction implements FlatMapFunction<Row, Row> {
private SeaTunnelTransform<SeaTunnelRow> transform;
private SeaTunnelRowConverter inputRowConverter;
private SeaTunnelRowConverter outputRowConverter;
private MultiTableManager inputManager;
private MultiTableManager outputManager;

public TransformMapPartitionsFunction(
SeaTunnelTransform<SeaTunnelRow> transform,
SeaTunnelRowConverter inputRowConverter,
SeaTunnelRowConverter outputRowConverter) {
MultiTableManager inputManager,
MultiTableManager outputManager) {
this.transform = transform;
this.inputRowConverter = inputRowConverter;
this.outputRowConverter = outputRowConverter;
this.inputManager = inputManager;
this.outputManager = outputManager;
}

@Override
public Iterator<Row> call(Row row) throws Exception {
List<Row> rows = new ArrayList<>();

SeaTunnelRow seaTunnelRow = inputRowConverter.unpack((GenericRowWithSchema) row);
SeaTunnelRow seaTunnelRow = inputManager.reconvert((GenericRow) row);
if (transform instanceof SeaTunnelFlatMapTransform) {
List<SeaTunnelRow> seaTunnelRows =
((SeaTunnelFlatMapTransform<SeaTunnelRow>) transform).flatMap(seaTunnelRow);
if (CollectionUtils.isNotEmpty(seaTunnelRows)) {
for (SeaTunnelRow seaTunnelRowTransform : seaTunnelRows) {
rows.add(outputRowConverter.parcel(seaTunnelRowTransform));
rows.add(outputManager.convert(seaTunnelRowTransform));
}
}
} else if (transform instanceof SeaTunnelMapTransform) {
SeaTunnelRow seaTunnelRowTransform =
((SeaTunnelMapTransform<SeaTunnelRow>) transform).map(seaTunnelRow);
if (seaTunnelRowTransform != null) {
rows.add(outputRowConverter.parcel(seaTunnelRowTransform));
rows.add(outputManager.convert(seaTunnelRowTransform));
}
}
return rows.iterator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -35,10 +33,6 @@ public void testCopy(TestContainer container) throws IOException, InterruptedExc
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testCopyMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ public void testEmbedding(TestContainer container) throws IOException, Interrupt
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testEmbeddingMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -35,10 +33,6 @@ public void testFilter(TestContainer container) throws IOException, InterruptedE
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testFilterMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -43,10 +41,6 @@ public void testFilterRowKind(TestContainer container)
Assertions.assertEquals(0, execResult3.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testFilterRowKindMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.TestResource;
import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
Expand Down Expand Up @@ -90,10 +88,6 @@ public void testLLMWithOpenAI(TestContainer container)
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testLLMWithOpenAIMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -40,10 +38,6 @@ public void testRowKindExtractorTransform(TestContainer container)
Assertions.assertEquals(0, execResult2.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testRowKindExtractorMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -35,10 +33,6 @@ public void testSplit(TestContainer container) throws IOException, InterruptedEx
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testSplitMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

import org.apache.seatunnel.e2e.common.TestResource;
import org.apache.seatunnel.e2e.common.container.ContainerExtendedFactory;
import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;
import org.apache.seatunnel.e2e.common.junit.TestContainerExtension;

import org.junit.jupiter.api.AfterAll;
Expand Down Expand Up @@ -118,10 +116,6 @@ public void testDynamicSingleCompileJava(TestContainer container)
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testDynamicSingleCompileJavaMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -39,10 +37,6 @@ public void testFieldMapper(TestContainer container) throws IOException, Interru
Assertions.assertEquals(0, execResult1.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testFieldMapperMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
*/
package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -33,10 +31,6 @@ public void testBasicType(TestContainer container) throws Exception {
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testBasicTypeMultiTable(TestContainer container) throws Exception {
Container.ExecResult execResult =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -29,10 +27,6 @@

public class TestMetadataIT extends TestSuiteBase {

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testMetadataMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -29,10 +27,6 @@

public class TestRenameIT extends TestSuiteBase {

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testRenameMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.seatunnel.e2e.transform;

import org.apache.seatunnel.e2e.common.container.EngineType;
import org.apache.seatunnel.e2e.common.container.TestContainer;
import org.apache.seatunnel.e2e.common.junit.DisabledOnContainer;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
Expand All @@ -35,10 +33,6 @@ public void testReplace(TestContainer container) throws IOException, Interrupted
Assertions.assertEquals(0, execResult.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testReplaceMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ public void testSQLTransform(TestContainer container) throws IOException, Interr
Assertions.assertEquals(0, splitSql.getExitCode());
}

@DisabledOnContainer(
value = {},
type = {EngineType.SPARK},
disabledReason = "Currently SPARK do not multi table transform")
@TestTemplate
public void testSQLTransformMultiTable(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Loading

0 comments on commit e128ccc

Please sign in to comment.