Skip to content

Commit

Permalink
Merge pull request #8749 from dolthub/nicktobey/parquet
Browse files Browse the repository at this point in the history
Allow importing parquet fields containing repeated elements.
  • Loading branch information
nicktobey authored Jan 16, 2025
2 parents fcc8f3e + 09d42a4 commit c5a7a11
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 30 deletions.
176 changes: 146 additions & 30 deletions go/libraries/doltcore/table/typed/parquet/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/xitongsys/parquet-go-source/local"
"github.com/xitongsys/parquet-go/common"
"github.com/xitongsys/parquet-go/parquet"
"github.com/xitongsys/parquet-go/reader"
"github.com/xitongsys/parquet-go/source"

Expand All @@ -38,14 +39,22 @@ import (

// ParquetReader implements TableReader. It reads parquet files and returns rows.
type ParquetReader struct {
fileReader source.ParquetFile
pReader *reader.ParquetReader
sch schema.Schema
vrw types.ValueReadWriter
numRow int
rowReadCounter int
fileData map[string][]interface{}
columnName []string
fileReader source.ParquetFile
pReader *reader.ParquetReader
sch schema.Schema
vrw types.ValueReadWriter
numRow int
rowsRead int
// rowReadCounters tracks offsets into each column. Necessary because of repeated fields.
rowReadCounters map[string]int
fileData map[string][]interface{}
// rLevels indicate whether a value in a column is a repeat of a repeated type.
// We only include these for repeated fields.
rLevels map[string][]int32
// dLevels are used for interpreting null values by indicating the deepest level in
// a nested field that's defined.
dLevels map[string][]int32
columnName []string
}

var _ table.SqlTableReader = (*ParquetReader)(nil)
Expand Down Expand Up @@ -75,28 +84,95 @@ func NewParquetReader(vrw types.ValueReadWriter, fr source.ParquetFile, sche sch

// TODO : need to solve for getting single row data in readRow (storing all columns data in memory right now)
data := make(map[string][]interface{})
rLevels := make(map[string][]int32)
dLevels := make(map[string][]int32)
rowReadCounters := make(map[string]int)
var colName []string
for _, col := range columns {
colData, _, _, cErr := pr.ReadColumnByPath(common.ReformPathStr(fmt.Sprintf("%s.%s", rootName, col.Name)), num)
pathName := common.ReformPathStr(fmt.Sprintf("%s.%s", rootName, col.Name))
resolvedColumnName, found, isRepeated, err := resolveColumnPrefix(pr, pathName)
if err != nil {
return nil, fmt.Errorf("cannot read column: %s", err.Error())
}
if !found {
if resolvedColumnName != "" {
return nil, fmt.Errorf("cannot read column: %s is ambiguous", resolvedColumnName)
}
return nil, fmt.Errorf("cannot read column: %s Column not found", col.Name)
}
colData, rLevel, dLevel, cErr := pr.ReadColumnByPath(resolvedColumnName, num)
if cErr != nil {
return nil, fmt.Errorf("cannot read column: %s", cErr.Error())
}
data[col.Name] = colData
if isRepeated {
rLevels[col.Name] = rLevel
}
dLevels[col.Name] = dLevel
rowReadCounters[col.Name] = 0
colName = append(colName, col.Name)
}

return &ParquetReader{
fileReader: fr,
pReader: pr,
sch: sche,
vrw: vrw,
numRow: int(num),
rowReadCounter: 0,
fileData: data,
columnName: colName,
fileReader: fr,
pReader: pr,
sch: sche,
vrw: vrw,
numRow: int(num),
rowsRead: 0,
rowReadCounters: rowReadCounters,
fileData: data,
rLevels: rLevels,
dLevels: dLevels,
columnName: colName,
}, nil
}

// resolveColumnPrefix takes a path into a parquet schema and determines:
// - whether there is exactly one leaf column corresponding to that path
// - whether any of the types after the prefix are repeated.
func resolveColumnPrefix(pr *reader.ParquetReader, columnPrefix string) (columnName string, found bool, isRepeated bool, err error) {
inPath, err := pr.SchemaHandler.ConvertToInPathStr(columnPrefix)
if err != nil {
return "", false, false, err
}

segments := strings.Split(inPath, "\x01")
pathMapType := pr.SchemaHandler.PathMap
for _, segment := range segments[1:] {
pathMapType, found = pathMapType.Children[segment]
if !found {
return "", false, isRepeated, nil
}
}

for {
if len(pathMapType.Children) == 0 {
// type has no children, we've reached the leaf
return pathMapType.Path, true, isRepeated, nil
}
if len(pathMapType.Children) > 1 {
// type has many children, ambiguous
return pathMapType.Path, false, isRepeated, nil
}
// type has exactly one child; recurse
for _, child := range pathMapType.Children {
pathMapType = child
repetitionType, err := pr.SchemaHandler.GetRepetitionType([]string{pathMapType.Path})
if err != nil {
return "", false, false, err
}
if repetitionType == parquet.FieldRepetitionType_REPEATED {
if isRepeated {
// We can't currently parse fields with multiple repeated fields.
return "", false, false, fmt.Errorf("%s has multiple repeated fields", columnPrefix)
}
isRepeated = true
}
}
}
}

func (pr *ParquetReader) ReadRow(ctx context.Context) (row.Row, error) {
panic("deprecated")
}
Expand Down Expand Up @@ -131,34 +207,74 @@ func DecimalByteArrayToString(dec []byte, prec int, scale int) string {
}

func (pr *ParquetReader) ReadSqlRow(ctx context.Context) (sql.Row, error) {
if pr.rowReadCounter >= pr.numRow {
if pr.rowsRead >= pr.numRow {
return nil, io.EOF
}

allCols := pr.sch.GetAllCols()
row := make(sql.Row, allCols.Size())
allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
val := pr.fileData[col.Name][pr.rowReadCounter]
if val != nil {
switch col.TypeInfo.GetTypeIdentifier() {
case typeinfo.DatetimeTypeIdentifier:
val = time.UnixMicro(val.(int64))
case typeinfo.TimeTypeIdentifier:
val = gmstypes.Timespan(time.Duration(val.(int64)).Microseconds())
rowReadCounter := pr.rowReadCounters[col.Name]
readVal := func() interface{} {
val := pr.fileData[col.Name][rowReadCounter]
rowReadCounter++
if val != nil {
switch col.TypeInfo.GetTypeIdentifier() {
case typeinfo.DatetimeTypeIdentifier:
val = time.UnixMicro(val.(int64))
case typeinfo.TimeTypeIdentifier:
val = gmstypes.Timespan(time.Duration(val.(int64)).Microseconds())
}
}
}

if col.Kind == types.DecimalKind {
prec, scale := col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Precision(), col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Scale()
val = DecimalByteArrayToString([]byte(val.(string)), int(prec), int(scale))
if col.Kind == types.DecimalKind {
prec, scale := col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Precision(), col.TypeInfo.ToSqlType().(gmstypes.DecimalType_).Scale()
val = DecimalByteArrayToString([]byte(val.(string)), int(prec), int(scale))
}
return val
}
var val interface{}
rLevels, isRepeated := pr.rLevels[col.Name]
dLevels, _ := pr.dLevels[col.Name]
readVals := func() (val interface{}) {
var vals []interface{}
for {
dLevel := dLevels[rowReadCounter]
subVal := readVal()
if subVal == nil {
// dLevels tells us how to interpret this nil value:
// 0 -> the column value is NULL
// 1 -> the column exists but is empty
// 2 -> the column contains an empty value
// 3+ -> the column contains a non-empty value
switch dLevel {
case 0:
return nil
case 1:
return []interface{}{}
}
}
vals = append(vals, subVal)
// an rLevel of 0 marks the start of a new record.
if rowReadCounter >= len(rLevels) || rLevels[rowReadCounter] == 0 {
break
}
}
return vals
}
if !isRepeated {
val = readVal()
} else {
val = readVals()
}

pr.rowReadCounters[col.Name] = rowReadCounter
row[allCols.TagToIdx[tag]] = val

return false, nil
})

pr.rowReadCounter++
pr.rowsRead++

return row, nil
}
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions integration-tests/bats/helper/parquet/sequences.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
create table sequences(`pk` int primary key, `name` varchar(20), `embeddings` json);
19 changes: 19 additions & 0 deletions integration-tests/bats/import-create-tables.bats
Original file line number Diff line number Diff line change
Expand Up @@ -930,4 +930,23 @@ DELIM
[[ "$output" =~ "text" ]] || false
[[ "$output" =~ "hello foo" ]] || false
[[ "$output" =~ "hello world" ]] || false
}

@test "import-create-tables: import sequences as JSON arrays" {
# The file strings.parquet uses a different name for the root column than the one generated by `dolt table export`,
# but Dolt should still be able to import it.
run dolt table import -c -s `batshelper parquet/sequences.sql` sequences `batshelper parquet/sequences.parquet`
[ "$status" -eq 0 ]

dolt sql -r csv -q "select * from sequences;"
run dolt sql -r csv -q "select * from sequences;"
[ "$status" -eq 0 ]
[ "${#lines[@]}" -eq 7 ]
[[ "$output" =~ '1,empty,[]' ]] || false
[[ "$output" =~ "2,single,[1]" ]] || false
[[ "$output" =~ "3,null," ]] || false
[[ "$output" =~ '4,double,"[2,3]"' ]] || false
[[ "$output" =~ '5,contains null,"[4,null]"' ]] || false
[[ "$output" =~ '6,empty,[]' ]] || false

}

0 comments on commit c5a7a11

Please sign in to comment.