diff --git a/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/OpenEOProcesses.scala b/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/OpenEOProcesses.scala index f20550303..a28e92a9c 100644 --- a/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/OpenEOProcesses.scala +++ b/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/OpenEOProcesses.scala @@ -705,6 +705,16 @@ class OpenEOProcesses extends Serializable { return None } } + + def maybeBandLabels[K](cube: RDD[(K, MultibandTile)]): Option[Seq[String]] = { + if (cube.isInstanceOf[OpenEORasterCube[K]] && cube.asInstanceOf[OpenEORasterCube[K]].openEOMetadata.bandCount > 0) { + val labels = cube.asInstanceOf[OpenEORasterCube[K]].openEOMetadata.bands + return Some(labels) + }else{ + return None + } + } + /** * Get band count used in RDD (each tile in RDD should have same band count) */ diff --git a/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/aggregate_polygon/AggregatePolygonProcess.scala b/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/aggregate_polygon/AggregatePolygonProcess.scala index 8504f64a2..20662971f 100644 --- a/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/aggregate_polygon/AggregatePolygonProcess.scala +++ b/openeo-geotrellis/src/main/scala/org/openeo/geotrellis/aggregate_polygon/AggregatePolygonProcess.scala @@ -12,7 +12,7 @@ import org.apache.spark.rdd._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, Row, SaveMode, SparkSession} -import org.openeo.geotrellis.SpatialToSpacetimeJoinRdd +import org.openeo.geotrellis.{OpenEOProcesses, SpatialToSpacetimeJoinRdd} import org.openeo.geotrellis.aggregate_polygon.intern.PixelRateValidator.exceedsTreshold import org.openeo.geotrellis.aggregate_polygon.intern._ import org.openeo.geotrellis.layers.LayerProvider @@ -153,7 +153,8 @@ class AggregatePolygonProcess() { } } val cellType = datacube.metadata.cellType - aggregateByDateAndPolygon(pixelRDD, scriptBuilder, bandCount, cellType, outputPath) + val maybeLabels = new OpenEOProcesses().maybeBandLabels(datacube) + aggregateByDateAndPolygon(pixelRDD, scriptBuilder, bandCount, cellType, outputPath,maybeLabels) } def aggregateSpatialForGeometryWithSpatialCube(scriptBuilder: SparkAggregateScriptBuilder, @@ -305,14 +306,15 @@ class AggregatePolygonProcess() { } } val cellType = datacube.metadata.cellType - aggregateByDateAndPolygon(pixelRDD, scriptBuilder, bandCount, cellType, outputPath) + val maybeLabels = new OpenEOProcesses().maybeBandLabels(datacube) + aggregateByDateAndPolygon(pixelRDD, scriptBuilder, bandCount, cellType, outputPath,maybeLabels) }finally{ byIndexMask.unpersist() } } - private def aggregateByDateAndPolygon(pixelRDD: RDD[Row], scriptBuilder: SparkAggregateScriptBuilder, bandCount: Int, cellType: CellType, outputPath: String) = { + private def aggregateByDateAndPolygon(pixelRDD: RDD[Row], scriptBuilder: SparkAggregateScriptBuilder, bandCount: Int, cellType: CellType, outputPath: String, maybeBandLabels: Option[Seq[String]] = Option.empty[Seq[String]]) = { val session = SparkSession.builder().config(pixelRDD.sparkContext.getConf).getOrCreate() val dataType = if (cellType.isFloatingPoint) { @@ -320,7 +322,7 @@ class AggregatePolygonProcess() { } else { IntegerType } - val bandColumns = (0 until bandCount).map(b => f"band_$b") + val bandColumns = maybeBandLabels.getOrElse (0 until bandCount).map(b => f"band_$b") val bandStructs = bandColumns.map(StructField(_, dataType, true))