Skip to content
This repository was archived by the owner on Apr 10, 2020. It is now read-only.

Commit

Permalink
fromExample should return Option (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewsmartin authored Dec 12, 2018
1 parent 5e1dc23 commit 7087d42
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.tensorflow.example.Example

trait ExampleConverter[T] {
def toExample(record: T): Example
def fromExample(example: Example): T
def fromExample(example: Example): Option[T]
}

object ExampleConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ trait Implicits {
.build()
}

override def fromExample(example: Example): T = fb.fromFeatures(example.getFeatures, None)
override def fromExample(example: Example): Option[T] =
Try(fb.fromFeatures(example.getFeatures, None)).toOption
}

private def featuresOf(name: Option[String], feature: Feature): Features.Builder =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ExampleConverterTest extends FlatSpec with Matchers {
.putFeature("string", stringFeat("world"))
.build)
featuresOf(actual) shouldEqual featuresOf(expected)
converter.fromExample(actual) shouldEqual record
converter.fromExample(actual) shouldEqual Some(record)
}

it should "support nested case class" in {
Expand All @@ -60,7 +60,7 @@ class ExampleConverterTest extends FlatSpec with Matchers {
features.get("f1").getInt64List shouldEqual Int64List.newBuilder().addValue(1).build
features.get("f2").getInt64List shouldEqual Int64List.newBuilder().addValue(2).build
features.get("inner.f3").getInt64List shouldEqual Int64List.newBuilder().addValue(3).build
converter.fromExample(example) shouldEqual record
converter.fromExample(example) shouldEqual Some(record)
}

it should "handle duplicate feature names" in {
Expand All @@ -77,7 +77,7 @@ class ExampleConverterTest extends FlatSpec with Matchers {
"middle.inner.f" -> longFeat(3L)
).asJava
featuresOf(example) shouldEqual expectedFeatures
converter.fromExample(example) shouldEqual record
converter.fromExample(example) shouldEqual Some(record)
}

it should "support collection types" in {
Expand All @@ -98,7 +98,7 @@ class ExampleConverterTest extends FlatSpec with Matchers {
.build)
featuresOf(example) shouldEqual featuresOf(expected)
// Test round trip
val newRecord = converter.fromExample(example)
val newRecord = converter.fromExample(example).get
newRecord.int shouldEqual 1
newRecord.ints shouldEqual List(1, 2, 3)
newRecord.inner.bools.toList shouldEqual List(true, false)
Expand All @@ -123,7 +123,7 @@ class ExampleConverterTest extends FlatSpec with Matchers {
.build())
featuresOf(example) shouldEqual featuresOf(expected)
val newRecord = converter.fromExample(example)
newRecord shouldEqual record
newRecord shouldEqual Some(record)
}

it should "support option types" in {
Expand All @@ -147,7 +147,7 @@ class ExampleConverterTest extends FlatSpec with Matchers {
val example1 = converter.toExample(record1)
featuresOf(example1) shouldEqual featuresOf(expected1)
val newRecord = converter.fromExample(example1)
converter.fromExample(example1) shouldEqual record1
newRecord shouldEqual Some(record1)

// All None
val record2 = OptionRecord(None, Middle(None, Some(Inner(None))))
Expand All @@ -156,7 +156,7 @@ class ExampleConverterTest extends FlatSpec with Matchers {
.build()
val example2 = converter.toExample(record2)
featuresOf(example2) shouldEqual featuresOf(expected2)
converter.fromExample(example2) shouldEqual record2
converter.fromExample(example2) shouldEqual Some(record2)
}

it should "support custom TensorflowMapping on case class" in {
Expand All @@ -181,7 +181,14 @@ class ExampleConverterTest extends FlatSpec with Matchers {
val example = converter.toExample(record)
featuresOf(example) shouldEqual featuresOf(expected)
val record2 = converter.fromExample(example)
record2 shouldEqual record
record2 shouldEqual Some(record)
}

it should "safely return None for bad example" in {
case class Record(xs: List[Int])
val converter = ExampleConverter[Record]
val badExample = Example.newBuilder().build()
converter.fromExample(badExample) shouldBe None
}

private def featureOfKeyPrefix(fMap: Map[String, Feature], prefix: String): Option[Feature] =
Expand Down

0 comments on commit 7087d42

Please sign in to comment.