Skip to content

Commit

Permalink
Fix BigQuery query template to retrieve training data (#182)
Browse files Browse the repository at this point in the history
* Fix BigQuery query template to retrieve training data

* Update expected value BigQuery template test

* Use FeatureInfo to create Features in BigQueryDatasetTemplater so it's neater
  • Loading branch information
davidheryanto authored and feast-ci-bot committed Apr 18, 2019
1 parent d5c3809 commit aeb12cd
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,15 @@
import feast.core.dao.FeatureInfoRepository;
import feast.core.model.FeatureInfo;
import feast.core.model.StorageInfo;
import feast.specs.FeatureSpecProto.FeatureSpec;
import feast.specs.StorageSpecProto.StorageSpec;
import lombok.Getter;

import java.time.Instant;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
import lombok.Getter;

public class BigQueryDatasetTemplater {
private final FeatureInfoRepository featureInfoRepository;
Expand All @@ -59,20 +55,18 @@ public BigQueryDatasetTemplater(
* @param limit limit
* @return SQL query for creating training table.
*/
public String createQuery(
FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) {
String createQuery(FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) {
List<String> featureIds = featureSet.getFeatureIdsList();
List<FeatureInfo> featureInfos = featureInfoRepository.findAllById(featureIds);
Features features = new Features(featureInfos);

if (featureInfos.size() < featureIds.size()) {
Set<String> foundFeatureIds =
featureInfos.stream().map(FeatureInfo::getId).collect(Collectors.toSet());
featureIds.removeAll(foundFeatureIds);
throw new NoSuchElementException("features not found: " + featureIds);
}

String tableId = getBqTableId(featureInfos.get(0));
Features features = new Features(featureIds, tableId);

String startDateStr = formatDateString(startDate);
String endDateStr = formatDateString(endDate);
String limitStr = (limit != 0) ? String.valueOf(limit) : null;
Expand All @@ -90,7 +84,7 @@ private String renderTemplate(
return jinjava.render(template, context);
}

private String getBqTableId(FeatureInfo featureInfo) {
private static String getBqTableId(FeatureInfo featureInfo) {
StorageInfo whStorage = featureInfo.getWarehouseStore();

String type = whStorage.getType();
Expand All @@ -117,12 +111,9 @@ static final class Features {
final List<String> columns;
final String tableId;

public Features(List<String> featureIds, String tableId) {
this.columns = featureIds.stream()
.map(f -> f.replace(".", "_"))
.collect(Collectors.toList());
this.tableId = tableId;
Features(List<FeatureInfo> featureInfos) {
columns = featureInfos.stream().map(FeatureInfo::getName).collect(Collectors.toList());
tableId = featureInfos.size() > 0 ? getBqTableId(featureInfos.get(0)) : "";
}
}

}
10 changes: 4 additions & 6 deletions core/src/main/resources/templates/bq_training.tmpl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
SELECT
{{ feature_set.tableId }}.id,
{{ feature_set.tableId }}.event_timestamp
{% for feature in feature_set.columns -%}
,{{ feature }}
{%- endfor %}
id,
event_timestamp{%- if feature_set.columns | length > 0 %},{%- endif %}
{{ feature_set.columns | join(',') }}
FROM
{{ feature_set.tableId }}
`{{ feature_set.tableId }}`
WHERE event_timestamp >= TIMESTAMP("{{ start_date }}") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("{{ end_date }}", INTERVAL 1 DAY))
{% if limit is not none -%}
LIMIT {{ limit }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ public void shouldPassCorrectArgumentToTemplateEngine() {
Timestamps.fromSeconds(Instant.parse("2019-01-01T00:00:00.00Z").getEpochSecond());
int limit = 100;
String featureId = "myentity.feature1";
String featureName = "feature1";
String tableId = "project.dataset.myentity";

when(featureInfoRespository.findAllById(any(List.class)))
.thenReturn(Collections.singletonList(createFeatureInfo(featureId, tableId)));
.thenReturn(Collections.singletonList(createFeatureInfo(featureId, featureName, tableId)));

FeatureSet fs =
FeatureSet.newBuilder()
Expand All @@ -123,22 +124,25 @@ public void shouldPassCorrectArgumentToTemplateEngine() {

Features features = (Features) actualContext.get("feature_set");
assertThat(features.getColumns().size(), equalTo(1));
assertThat(features.getColumns().get(0), equalTo(featureId.replace(".", "_")));
assertThat(features.getColumns().get(0), equalTo(featureName));
assertThat(features.getTableId(), equalTo(tableId));
}

@Test
public void shouldRenderCorrectQuery1() throws Exception {
String tableId1 = "project.dataset.myentity";
String featureId1 = "myentity.feature1";
String featureName1 = "feature1";
String featureId2 = "myentity.feature2";
String featureName2 = "feature2";

FeatureInfo featureInfo1 = createFeatureInfo(featureId1, tableId1);
FeatureInfo featureInfo2 = createFeatureInfo(featureId2, tableId1);
FeatureInfo featureInfo1 = createFeatureInfo(featureId1, featureName1, tableId1);
FeatureInfo featureInfo2 = createFeatureInfo(featureId2, featureName2, tableId1);

String tableId2 = "project.dataset.myentity";
String featureId3 = "myentity.feature3";
FeatureInfo featureInfo3 = createFeatureInfo(featureId3, tableId2);
String featureName3 = "feature3";
FeatureInfo featureInfo3 = createFeatureInfo(featureId3, featureName3, tableId2);

when(featureInfoRespository.findAllById(any(List.class)))
.thenReturn(Arrays.asList(featureInfo1, featureInfo2, featureInfo3));
Expand Down Expand Up @@ -166,8 +170,9 @@ public void shouldRenderCorrectQuery2() throws Exception {

String tableId = "project.dataset.myentity";
String featureId = "myentity.feature1";
String featureName = "feature1";

featureInfos.add(createFeatureInfo(featureId, tableId));
featureInfos.add(createFeatureInfo(featureId, featureName, tableId));
featureIds.add(featureId);

when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos);
Expand Down Expand Up @@ -197,7 +202,7 @@ private void checkExpectedQuery(String query, String pathToExpQuery) throws Exce
assertThat(query, equalTo(expQuery));
}

private FeatureInfo createFeatureInfo(String id, String tableId) {
private FeatureInfo createFeatureInfo(String featureId, String featureName, String tableId) {
StorageSpec storageSpec =
StorageSpec.newBuilder()
.setId("BQ")
Expand All @@ -209,11 +214,12 @@ private FeatureInfo createFeatureInfo(String id, String tableId) {

FeatureSpec fs =
FeatureSpec.newBuilder()
.setId(id)
.setId(featureId)
.setName(featureName)
.setDataStores(DataStores.newBuilder().setWarehouse(DataStore.newBuilder().setId("BQ")))
.build();

EntitySpec entitySpec = EntitySpec.newBuilder().setName(id.split("\\.")[0]).build();
EntitySpec entitySpec = EntitySpec.newBuilder().setName(featureId.split("\\.")[0]).build();
EntityInfo entityInfo = new EntityInfo(entitySpec);
return new FeatureInfo(fs, entityInfo, null, storageInfo, null);
}
Expand Down
12 changes: 6 additions & 6 deletions core/src/test/resources/sql/expQuery1.sql
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
SELECT
project.dataset.myentity.id,
project.dataset.myentity.event_timestamp ,
myentity_feature1,
myentity_feature2,
myentity_feature3
id,
event_timestamp,
feature1,
feature2,
feature3
FROM
project.dataset.myentity
`project.dataset.myentity`
WHERE
event_timestamp >= TIMESTAMP("2018-01-02")
AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) LIMIT 100
8 changes: 4 additions & 4 deletions core/src/test/resources/sql/expQuery2.sql
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
SELECT
project.dataset.myentity.id,
project.dataset.myentity.event_timestamp ,
myentity_feature1
id,
event_timestamp,
feature1
FROM
project.dataset.myentity
`project.dataset.myentity`
WHERE
event_timestamp >= TIMESTAMP("2018-01-02")
AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) LIMIT 1000

0 comments on commit aeb12cd

Please sign in to comment.