From a15178c5a31f604e0a7c680b4f78f3228b8bacc3 Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Tue, 14 Jan 2025 14:30:56 -0800 Subject: [PATCH] Allow importing parquet fields containing lists into JSON columns. --- .../doltcore/table/typed/parquet/reader.go | 152 ++++++++++++++---- 1 file changed, 122 insertions(+), 30 deletions(-) diff --git a/go/libraries/doltcore/table/typed/parquet/reader.go b/go/libraries/doltcore/table/typed/parquet/reader.go index d883fe65707..6617b7d173c 100644 --- a/go/libraries/doltcore/table/typed/parquet/reader.go +++ b/go/libraries/doltcore/table/typed/parquet/reader.go @@ -17,6 +17,7 @@ package parquet import ( "context" "fmt" + "github.com/xitongsys/parquet-go/parquet" "io" "math/big" "strings" @@ -38,14 +39,19 @@ import ( // ParquetReader implements TableReader. It reads parquet files and returns rows. type ParquetReader struct { - fileReader source.ParquetFile - pReader *reader.ParquetReader - sch schema.Schema - vrw types.ValueReadWriter - numRow int - rowReadCounter int - fileData map[string][]interface{} - columnName []string + fileReader source.ParquetFile + pReader *reader.ParquetReader + sch schema.Schema + vrw types.ValueReadWriter + numRow int + rowsRead int + // rowReadCounters tracks offsets into each column. Necessary because of repeated fields. + rowReadCounters map[string]int + fileData map[string][]interface{} + // rLevels indicate whether a value in a column is a repeat of a repeated type. + // We only include these for repeated fields. + rLevels map[string][]int32 + columnName []string } var _ table.SqlTableReader = (*ParquetReader)(nil) @@ -75,28 +81,92 @@ func NewParquetReader(vrw types.ValueReadWriter, fr source.ParquetFile, sche sch // TODO : need to solve for getting single row data in readRow (storing all columns data in memory right now) data := make(map[string][]interface{}) + rLevels := make(map[string][]int32) + rowReadCounters := make(map[string]int) var colName []string for _, col := range columns { - colData, _, _, cErr := pr.ReadColumnByPath(common.ReformPathStr(fmt.Sprintf("%s.%s", rootName, col.Name)), num) + pathName := common.ReformPathStr(fmt.Sprintf("%s.%s", rootName, col.Name)) + resolvedColumnName, found, isRepeated, err := resolveColumnPrefix(pr, pathName) + if err != nil { + return nil, fmt.Errorf("cannot read column: %s", err.Error()) + } + if !found { + if resolvedColumnName != "" { + return nil, fmt.Errorf("cannot read column: %s is ambiguous", resolvedColumnName) + } + return nil, fmt.Errorf("cannot read column: %s Column not found", col.Name) + } + colData, rLevel, _, cErr := pr.ReadColumnByPath(resolvedColumnName, num) if cErr != nil { return nil, fmt.Errorf("cannot read column: %s", cErr.Error()) } data[col.Name] = colData + if isRepeated { + rLevels[col.Name] = rLevel + } + rowReadCounters[col.Name] = 0 colName = append(colName, col.Name) } return &ParquetReader{ - fileReader: fr, - pReader: pr, - sch: sche, - vrw: vrw, - numRow: int(num), - rowReadCounter: 0, - fileData: data, - columnName: colName, + fileReader: fr, + pReader: pr, + sch: sche, + vrw: vrw, + numRow: int(num), + rowsRead: 0, + rowReadCounters: rowReadCounters, + fileData: data, + rLevels: rLevels, + columnName: colName, }, nil } +// resolveColumnPrefix takes a path into a parquet schema and determines: +// - whether there is exactly one leaf column corresponding to that path +// - whether any of the types after the prefix are repeated. +func resolveColumnPrefix(pr *reader.ParquetReader, columnPrefix string) (columnName string, found bool, isRepeated bool, err error) { + inPath, err := pr.SchemaHandler.ConvertToInPathStr(columnPrefix) + if err != nil { + return "", false, false, err + } + + segments := strings.Split(inPath, "\x01") + pathMapType := pr.SchemaHandler.PathMap + for _, segment := range segments[1:] { + pathMapType, found = pathMapType.Children[segment] + if !found { + return "", false, isRepeated, nil + } + } + + for { + if len(pathMapType.Children) == 0 { + // type has no children, we've reached the leaf + return pathMapType.Path, true, isRepeated, nil + } + if len(pathMapType.Children) > 1 { + // type has many children, ambiguous + return pathMapType.Path, false, isRepeated, nil + } + // type has exactly one child; recurse + for _, child := range pathMapType.Children { + pathMapType = child + repetitionType, err := pr.SchemaHandler.GetRepetitionType([]string{pathMapType.Path}) + if err != nil { + return "", false, false, err + } + if repetitionType == parquet.FieldRepetitionType_REPEATED { + if isRepeated { + // We can't currently parse fields with multiple repeated fields. + return "", false, false, fmt.Errorf("%s has multiple repeated fields", columnPrefix) + } + isRepeated = true + } + } + } +} + func (pr *ParquetReader) ReadRow(ctx context.Context) (row.Row, error) { panic("deprecated") } @@ -131,34 +201,56 @@ func DecimalByteArrayToString(dec []byte, prec int, scale int) string { } func (pr *ParquetReader) ReadSqlRow(ctx context.Context) (sql.Row, error) { - if pr.rowReadCounter >= pr.numRow { + if pr.rowsRead >= pr.numRow { return nil, io.EOF } allCols := pr.sch.GetAllCols() row := make(sql.Row, allCols.Size()) allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) { - val := pr.fileData[col.Name][pr.rowReadCounter] - if val != nil { - switch col.TypeInfo.GetTypeIdentifier() { - case typeinfo.DatetimeTypeIdentifier: - val = time.UnixMicro(val.(int64)) - case typeinfo.TimeTypeIdentifier: - val = gmstypes.Timespan(time.Duration(val.(int64)).Microseconds()) + rowReadCounter := pr.rowReadCounters[col.Name] + readVal := func() interface{} { + val := pr.fileData[col.Name][rowReadCounter] + rowReadCounter++ + if val != nil { + switch col.TypeInfo.GetTypeIdentifier() { + case typeinfo.DatetimeTypeIdentifier: + val = time.UnixMicro(val.(int64)) + case typeinfo.TimeTypeIdentifier: + val = gmstypes.Timespan(time.Duration(val.(int64)).Microseconds()) + } } - } - if col.Kind == types.DecimalKind { - prec, scale := col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Precision(), col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Scale() - val = DecimalByteArrayToString([]byte(val.(string)), int(prec), int(scale)) + if col.Kind == types.DecimalKind { + prec, scale := col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Precision(), col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Scale() + val = DecimalByteArrayToString([]byte(val.(string)), int(prec), int(scale)) + } + return val + } + var val interface{} + rLevels, isRepeated := pr.rLevels[col.Name] + if !isRepeated { + val = readVal() + } else { + var vals []interface{} + for { + subVal := readVal() + vals = append(vals, subVal) + // an rLevel of 0 marks the start of a new record. + if rLevels[rowReadCounter] == 0 { + break + } + } + val = vals } + pr.rowReadCounters[col.Name] = rowReadCounter row[allCols.TagToIdx[tag]] = val return false, nil }) - pr.rowReadCounter++ + pr.rowsRead++ return row, nil }