diff --git a/go/internal/feast/featurestore.go b/go/internal/feast/featurestore.go index d2127c4d4f..7250bbd669 100644 --- a/go/internal/feast/featurestore.go +++ b/go/internal/feast/featurestore.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/rs/zerolog/log" "os" "strings" @@ -85,7 +86,7 @@ func (fs *FeatureStore) GetOnlineFeatures( joinKeyToEntityValues map[string]*prototypes.RepeatedValue, requestData map[string]*prototypes.RepeatedValue, fullFeatureNames bool) ([]*onlineserving.FeatureVector, error) { - fvs, odFvs, err := fs.listAllViews() + fvs, odFvs, err := fs.ListAllViews() if err != nil { return nil, err } @@ -104,6 +105,7 @@ func (fs *FeatureStore) GetOnlineFeatures( requestedFeatureViews, requestedOnDemandFeatureViews, err = onlineserving.GetFeatureViewsToUseByFeatureRefs(featureRefs, fvs, odFvs) } + log.Info().Msgf("requestedOnDemandFeatureViews %v", requestedOnDemandFeatureViews) if err != nil { return nil, err } @@ -230,7 +232,7 @@ func (fs *FeatureStore) GetFeatureService(name string) (*model.FeatureService, e return fs.registry.GetFeatureService(fs.config.Project, name) } -func (fs *FeatureStore) listAllViews() (map[string]*model.FeatureView, map[string]*model.OnDemandFeatureView, error) { +func (fs *FeatureStore) ListAllViews() (map[string]*model.FeatureView, map[string]*model.OnDemandFeatureView, error) { fvs := make(map[string]*model.FeatureView) odFvs := make(map[string]*model.OnDemandFeatureView) @@ -291,6 +293,38 @@ func (fs *FeatureStore) ListEntities(hideDummyEntity bool) ([]*model.Entity, err return entities, nil } +func (fs *FeatureStore) GetEntityKeyTypeMaps() (map[string]prototypes.ValueType_Enum, error) { + + entityKeyTypeMap := make(map[string]prototypes.ValueType_Enum, 0) + allEntities, _ := fs.registry.ListEntities(fs.config.Project) + if allEntities == nil || len(allEntities) <= 0 { + return nil, fmt.Errorf("No entities found for project %s", fs.config.Project) + } + for _, entity := range allEntities { + entityKeyTypeMap[entity.JoinKey] = entity.ValueType + } + return entityKeyTypeMap, nil +} +func (fs *FeatureStore) GetRequestSources(fVList []string) (map[string]prototypes.ValueType_Enum, error) { + + requestSources := make(map[string]prototypes.ValueType_Enum, 0) + if fVList != nil && len(fVList) > 0 { + for _, fvName := range fVList { + odfv, err := fs.GetOnDemandFeatureView(fvName) + if err == nil { + schema := odfv.GetRequestDataSchema() + for name, dtype := range schema { + requestSources[name] = dtype + } + } + } + } + if len(requestSources) > 0 { + return requestSources, nil + } + return nil, fmt.Errorf("Request sources for feature views %v not found", fVList) +} + func (fs *FeatureStore) ListOnDemandFeatureViews() ([]*model.OnDemandFeatureView, error) { return fs.registry.ListOnDemandFeatureViews(fs.config.Project) } @@ -311,6 +345,14 @@ func (fs *FeatureStore) GetFeatureView(featureViewName string, hideDummyEntity b return fv, nil } +func (fs *FeatureStore) GetOnDemandFeatureView(featureViewName string) (*model.OnDemandFeatureView, error) { + fv, err := fs.registry.GetOnDemandFeatureView(fs.config.Project, featureViewName) + if err != nil { + return nil, err + } + return fv, nil +} + func (fs *FeatureStore) readFromOnlineStore(ctx context.Context, entityRows []*prototypes.EntityKey, requestedFeatureViewNames []string, requestedFeatureNames []string, diff --git a/go/internal/feast/featurestore_test.go b/go/internal/feast/featurestore_test.go index dd08bc287e..11dcf6696b 100644 --- a/go/internal/feast/featurestore_test.go +++ b/go/internal/feast/featurestore_test.go @@ -2,6 +2,7 @@ package feast import ( "context" + "github.com/feast-dev/feast/go/protos/feast/core" "path/filepath" "runtime" "testing" @@ -10,7 +11,7 @@ import ( "github.com/feast-dev/feast/go/internal/feast/onlinestore" "github.com/feast-dev/feast/go/internal/feast/registry" - "github.com/feast-dev/feast/go/protos/feast/types" + types "github.com/feast-dev/feast/go/protos/feast/types" ) // Return absolute path to the test_repo registry regardless of the working directory @@ -70,3 +71,165 @@ func TestGetOnlineFeaturesRedis(t *testing.T) { assert.Nil(t, err) assert.Len(t, response, 4) // 3 Features + 1 entity = 4 columns (feature vectors) in response } + +func getRepoConfig() (config registry.RepoConfig) { + return registry.RepoConfig{ + Project: "feature_repo", + Registry: getRegistryPath(), + Provider: "local", + OnlineStore: map[string]interface{}{ + "type": "redis", + "connection_string": "localhost:6379", + }, + } +} +func TestGetEntityKeyTypeMapsReturnsExpectedResult(t *testing.T) { + + config := getRepoConfig() + fs, _ := NewFeatureStore(&config, nil) + entity1 := &core.Entity{ + Spec: &core.EntitySpecV2{ + Name: "entity1", + JoinKey: "joinKey1", + ValueType: types.ValueType_INT64, + }, + } + entity2 := &core.Entity{ + Spec: &core.EntitySpecV2{ + Name: "entity2", + JoinKey: "joinKey2", + ValueType: types.ValueType_INT32, + }, + } + cachedEntities := make(map[string]map[string]*core.Entity) + cachedEntities["feature_repo"] = make(map[string]*core.Entity) + cachedEntities["feature_repo"]["entity1"] = entity1 + cachedEntities["feature_repo"]["entity2"] = entity2 + + fs.registry.CachedEntities = cachedEntities + + entityKeyTypeMap, err := fs.GetEntityKeyTypeMaps() + + assert.Nil(t, err) + assert.Equal(t, 2, len(entityKeyTypeMap)) + assert.Equal(t, types.ValueType_INT64, entityKeyTypeMap["joinKey1"]) + assert.Equal(t, types.ValueType_INT32, entityKeyTypeMap["joinKey2"]) +} + +func TestGetEntityKeyTypeMapsReturnsErrorWhenNoEntities(t *testing.T) { + + config := getRepoConfig() + fs, _ := NewFeatureStore(&config, nil) + + cachedEntities := make(map[string]map[string]*core.Entity) + fs.registry.CachedEntities = cachedEntities + + entityKeyTypeMap, err := fs.GetEntityKeyTypeMaps() + + assert.NotNil(t, err) + assert.Equal(t, 0, len(entityKeyTypeMap)) +} +func TestGetRequestSourcesWithValidFeatures(t *testing.T) { + config := getRepoConfig() + fs, _ := NewFeatureStore(&config, nil) + fVList := []string{"odfv1", "fv1"} + + odfv := &core.OnDemandFeatureView{ + Spec: &core.OnDemandFeatureViewSpec{ + Name: "odfv1", + Project: "feature_repo", + Sources: map[string]*core.OnDemandSource{ + "odfv1": { + Source: &core.OnDemandSource_RequestDataSource{ + RequestDataSource: &core.DataSource{ + Name: "request_source_1", + Type: core.DataSource_REQUEST_SOURCE, + Options: &core.DataSource_RequestDataOptions_{ + RequestDataOptions: &core.DataSource_RequestDataOptions{ + DeprecatedSchema: map[string]types.ValueType_Enum{ + "feature1": types.ValueType_INT64, + }, + Schema: []*core.FeatureSpecV2{ + { + Name: "feat1", + ValueType: types.ValueType_INT64, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + cachedOnDemandFVs := make(map[string]map[string]*core.OnDemandFeatureView) + cachedOnDemandFVs["feature_repo"] = make(map[string]*core.OnDemandFeatureView) + cachedOnDemandFVs["feature_repo"]["odfv1"] = odfv + fs.registry.CachedOnDemandFeatureViews = cachedOnDemandFVs + requestSources, err := fs.GetRequestSources(fVList) + + assert.Nil(t, err) + assert.Equal(t, 1, len(requestSources)) + assert.Equal(t, types.ValueType_INT64.Enum(), requestSources["feat1"].Enum()) +} + +func TestGetRequestSourcesWithInvalidFeatures(t *testing.T) { + + config := getRepoConfig() + fs, _ := NewFeatureStore(&config, nil) + fVList := []string{"invalidFV", "fv1"} + + odfv := &core.OnDemandFeatureView{ + Spec: &core.OnDemandFeatureViewSpec{ + Name: "odfv1", + Project: "feature_repo", + Sources: map[string]*core.OnDemandSource{ + "odfv1": { + Source: &core.OnDemandSource_RequestDataSource{ + RequestDataSource: &core.DataSource{ + Name: "request_source_1", + Type: core.DataSource_REQUEST_SOURCE, + Options: &core.DataSource_RequestDataOptions_{ + RequestDataOptions: &core.DataSource_RequestDataOptions{ + DeprecatedSchema: map[string]types.ValueType_Enum{ + "feature1": types.ValueType_INT64, + }, + Schema: []*core.FeatureSpecV2{ + { + Name: "feature1", + ValueType: types.ValueType_INT64, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + cachedOnDemandFVs := make(map[string]map[string]*core.OnDemandFeatureView) + cachedOnDemandFVs["feature_repo"] = make(map[string]*core.OnDemandFeatureView) + cachedOnDemandFVs["feature_repo"]["odfv1"] = odfv + fs.registry.CachedOnDemandFeatureViews = cachedOnDemandFVs + + requestSources, err := fs.GetRequestSources(fVList) + + assert.NotNil(t, err) + assert.Equal(t, 0, len(requestSources)) +} + +func TestGetRequestSourcesWithNoFeatures(t *testing.T) { + + config := getRepoConfig() + fs, _ := NewFeatureStore(&config, nil) + var fvList []string + + requestSources, err := fs.GetRequestSources(fvList) + + assert.NotNil(t, err) + assert.Equal(t, 0, len(requestSources)) +} diff --git a/go/internal/feast/model/entity.go b/go/internal/feast/model/entity.go index 5a09edb655..736920346d 100644 --- a/go/internal/feast/model/entity.go +++ b/go/internal/feast/model/entity.go @@ -2,16 +2,19 @@ package model import ( "github.com/feast-dev/feast/go/protos/feast/core" + "github.com/feast-dev/feast/go/protos/feast/types" ) type Entity struct { - Name string - JoinKey string + Name string + JoinKey string + ValueType types.ValueType_Enum } func NewEntityFromProto(proto *core.Entity) *Entity { return &Entity{ - Name: proto.Spec.Name, - JoinKey: proto.Spec.JoinKey, + Name: proto.Spec.Name, + JoinKey: proto.Spec.JoinKey, + ValueType: proto.Spec.ValueType, } } diff --git a/go/internal/feast/registry/registry.go b/go/internal/feast/registry/registry.go index 830e528f9d..ace99de6c3 100644 --- a/go/internal/feast/registry/registry.go +++ b/go/internal/feast/registry/registry.go @@ -32,10 +32,10 @@ type Registry struct { project string registryStore RegistryStore cachedFeatureServices map[string]map[string]*core.FeatureService - cachedEntities map[string]map[string]*core.Entity + CachedEntities map[string]map[string]*core.Entity cachedFeatureViews map[string]map[string]*core.FeatureView cachedStreamFeatureViews map[string]map[string]*core.StreamFeatureView - cachedOnDemandFeatureViews map[string]map[string]*core.OnDemandFeatureView + CachedOnDemandFeatureViews map[string]map[string]*core.OnDemandFeatureView cachedRegistry *core.Registry cachedRegistryProtoLastUpdated time.Time cachedRegistryProtoTtl time.Duration @@ -114,10 +114,10 @@ func (r *Registry) load(registry *core.Registry) { defer r.mu.Unlock() r.cachedRegistry = registry r.cachedFeatureServices = make(map[string]map[string]*core.FeatureService) - r.cachedEntities = make(map[string]map[string]*core.Entity) + r.CachedEntities = make(map[string]map[string]*core.Entity) r.cachedFeatureViews = make(map[string]map[string]*core.FeatureView) r.cachedStreamFeatureViews = make(map[string]map[string]*core.StreamFeatureView) - r.cachedOnDemandFeatureViews = make(map[string]map[string]*core.OnDemandFeatureView) + r.CachedOnDemandFeatureViews = make(map[string]map[string]*core.OnDemandFeatureView) r.loadEntities(registry) r.loadFeatureServices(registry) r.loadFeatureViews(registry) @@ -130,10 +130,10 @@ func (r *Registry) loadEntities(registry *core.Registry) { entities := registry.Entities for _, entity := range entities { // fmt.Println("Entity load: ", entity.Spec.Name) - if _, ok := r.cachedEntities[r.project]; !ok { - r.cachedEntities[r.project] = make(map[string]*core.Entity) + if _, ok := r.CachedEntities[r.project]; !ok { + r.CachedEntities[r.project] = make(map[string]*core.Entity) } - r.cachedEntities[r.project][entity.Spec.Name] = entity + r.CachedEntities[r.project][entity.Spec.Name] = entity } } @@ -174,10 +174,10 @@ func (r *Registry) loadOnDemandFeatureViews(registry *core.Registry) { onDemandFeatureViews := registry.OnDemandFeatureViews for _, onDemandFeatureView := range onDemandFeatureViews { // fmt.Println("onDemandFeatureView load: ", onDemandFeatureView.Spec.Name) - if _, ok := r.cachedOnDemandFeatureViews[r.project]; !ok { - r.cachedOnDemandFeatureViews[r.project] = make(map[string]*core.OnDemandFeatureView) + if _, ok := r.CachedOnDemandFeatureViews[r.project]; !ok { + r.CachedOnDemandFeatureViews[r.project] = make(map[string]*core.OnDemandFeatureView) } - r.cachedOnDemandFeatureViews[r.project][onDemandFeatureView.Spec.Name] = onDemandFeatureView + r.CachedOnDemandFeatureViews[r.project][onDemandFeatureView.Spec.Name] = onDemandFeatureView } } @@ -189,7 +189,7 @@ func (r *Registry) loadOnDemandFeatureViews(registry *core.Registry) { func (r *Registry) ListEntities(project string) ([]*model.Entity, error) { r.mu.RLock() defer r.mu.RUnlock() - if cachedEntities, ok := r.cachedEntities[project]; !ok { + if cachedEntities, ok := r.CachedEntities[project]; !ok { return []*model.Entity{}, nil } else { entities := make([]*model.Entity, len(cachedEntities)) @@ -273,7 +273,7 @@ func (r *Registry) ListFeatureServices(project string) ([]*model.FeatureService, func (r *Registry) ListOnDemandFeatureViews(project string) ([]*model.OnDemandFeatureView, error) { r.mu.RLock() defer r.mu.RUnlock() - if cachedOnDemandFeatureViews, ok := r.cachedOnDemandFeatureViews[project]; !ok { + if cachedOnDemandFeatureViews, ok := r.CachedOnDemandFeatureViews[project]; !ok { return []*model.OnDemandFeatureView{}, nil } else { onDemandFeatureViews := make([]*model.OnDemandFeatureView, len(cachedOnDemandFeatureViews)) @@ -289,7 +289,7 @@ func (r *Registry) ListOnDemandFeatureViews(project string) ([]*model.OnDemandFe func (r *Registry) GetEntity(project, entityName string) (*model.Entity, error) { r.mu.RLock() defer r.mu.RUnlock() - if cachedEntities, ok := r.cachedEntities[project]; !ok { + if cachedEntities, ok := r.CachedEntities[project]; !ok { return nil, fmt.Errorf("no cached entities found for project %s", project) } else { if entity, ok := cachedEntities[entityName]; !ok { @@ -345,7 +345,7 @@ func (r *Registry) GetFeatureService(project, featureServiceName string) (*model func (r *Registry) GetOnDemandFeatureView(project, onDemandFeatureViewName string) (*model.OnDemandFeatureView, error) { r.mu.RLock() defer r.mu.RUnlock() - if cachedOnDemandFeatureViews, ok := r.cachedOnDemandFeatureViews[project]; !ok { + if cachedOnDemandFeatureViews, ok := r.CachedOnDemandFeatureViews[project]; !ok { return nil, fmt.Errorf("no cached on demand feature views found for project %s", project) } else { if onDemandFeatureViewProto, ok := cachedOnDemandFeatureViews[onDemandFeatureViewName]; !ok { diff --git a/go/internal/feast/server/http_server.go b/go/internal/feast/server/http_server.go index 87e1522842..bd1ee83fff 100644 --- a/go/internal/feast/server/http_server.go +++ b/go/internal/feast/server/http_server.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "errors" "fmt" "net/http" "os" @@ -33,10 +34,12 @@ type httpServer struct { // Some Feast types aren't supported during JSON conversion type repeatedValue struct { stringVal []string + int32Val []int32 int64Val []int64 doubleVal []float64 boolVal []bool stringListVal [][]string + int32ListVal [][]int32 int64ListVal [][]int64 doubleListVal [][]float64 boolListVal [][]bool @@ -102,6 +105,11 @@ func (u *repeatedValue) ToProto() *prototypes.RepeatedValue { proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_Int64Val{Int64Val: val}}) } } + if u.int32Val != nil { + for _, val := range u.int32Val { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_Int32Val{Int32Val: val}}) + } + } if u.doubleVal != nil { for _, val := range u.doubleVal { proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_DoubleVal{DoubleVal: val}}) @@ -117,6 +125,11 @@ func (u *repeatedValue) ToProto() *prototypes.RepeatedValue { proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_StringListVal{StringListVal: &prototypes.StringList{Val: val}}}) } } + if u.int32ListVal != nil { + for _, val := range u.int32ListVal { + proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_Int32ListVal{Int32ListVal: &prototypes.Int32List{Val: val}}}) + } + } if u.int64ListVal != nil { for _, val := range u.int64ListVal { proto.Val = append(proto.Val, &prototypes.Value{Val: &prototypes.Value_Int64ListVal{Int64ListVal: &prototypes.Int64List{Val: val}}}) @@ -158,6 +171,22 @@ func logWithSpanContext(span tracer.Span) zerolog.Logger { return logger } + +/* +* +This function ensures that the entity value type aligns with the entity schema type specified in the feature definitions. +*/ +func typecastToEntitySchemaType(val *repeatedValue, entityType prototypes.ValueType_Enum) { + if val.int64Val != nil { + if entityType == prototypes.ValueType_INT32 { + for _, v := range val.int64Val { + val.int32Val = append(val.int32Val, int32(v)) + } + val.int64Val = nil + } + } +} + func (s *httpServer) getOnlineFeatures(w http.ResponseWriter, r *http.Request) { var err error @@ -192,7 +221,12 @@ func (s *httpServer) getOnlineFeatures(w http.ResponseWriter, r *http.Request) { writeJSONError(w, fmt.Errorf("Error decoding JSON request data: %+v", err), http.StatusInternalServerError) return } + var featureService *model.FeatureService + var entitiesProto = make(map[string]*prototypes.RepeatedValue) + var requestContextProto = make(map[string]*prototypes.RepeatedValue) + var fVList = make([]string, 0) + if request.FeatureService != nil { featureService, err = s.fs.GetFeatureService(*request.FeatureService) if err != nil { @@ -200,14 +234,73 @@ func (s *httpServer) getOnlineFeatures(w http.ResponseWriter, r *http.Request) { writeJSONError(w, fmt.Errorf("Error getting feature service from registry: %+v", err), http.StatusInternalServerError) return } + for _, fv := range featureService.Projections { + fVList = append(fVList, fv.Name) + } + } else if request.Features != nil && len(request.Features) > 0 { + log.Info().Msgf("request.Features %v", request.Features) + for _, featureName := range request.Features { + fvName, _, err := onlineserving.ParseFeatureReference(featureName) + if err != nil { + logSpanContext.Error().Err(err) + writeJSONError(w, fmt.Errorf("Error parsing feature reference %s", featureName), http.StatusBadRequest) + return + } + fVList = append(fVList, fvName) + } + } else { + logSpanContext.Error().Msg("No feature views or feature services specified in your request") + writeJSONError(w, errors.New("No feature views or feature services specified in your request"), http.StatusBadRequest) + return + } + requestSources, err2 := s.fs.GetRequestSources(fVList) + entityKeyTypeMap, err1 := s.fs.GetEntityKeyTypeMaps() + unifiedMap := make(map[string]prototypes.ValueType_Enum, 0) + if entityKeyTypeMap != nil && len(entityKeyTypeMap) > 0 { + for key, value := range entityKeyTypeMap { + unifiedMap[key] = value + } } - entitiesProto := make(map[string]*prototypes.RepeatedValue) - for key, value := range request.Entities { - entitiesProto[key] = value.ToProto() + if requestSources != nil && len(requestSources) > 0 { + for key, value := range requestSources { + unifiedMap[key] = value + } } - requestContextProto := make(map[string]*prototypes.RepeatedValue) - for key, value := range request.RequestContext { - requestContextProto[key] = value.ToProto() + if len(unifiedMap) > 0 { + if request.Entities != nil && len(request.Entities) > 0 { + for key, value := range request.Entities { + fieldType, ok1 := unifiedMap[key] + if ok1 { + typecastToEntitySchemaType(&value, fieldType) + } else { + logSpanContext.Error().Msgf("Entity type/request source type for key %s not found. Check if your join key names or request sources are correct", key) + writeJSONError(w, fmt.Errorf("Entity type/request source type for key %s not found. Check if your join key names or request sources are correct", key), http.StatusNotFound) + return + } + entitiesProto[key] = value.ToProto() + } + } else { + logSpanContext.Error().Msg("No entities specified in your request.") + writeJSONError(w, errors.New("No entities specified in your request"), http.StatusBadRequest) + return + } + if request.RequestContext != nil && len(request.RequestContext) > 0 { + for key, value := range request.RequestContext { + if requestSourceType, ok := unifiedMap[key]; !ok { + logSpanContext.Error().Msgf("No request source type found for key %s", key) + writeJSONError(w, fmt.Errorf("No request source type found for key %s", key), http.StatusNotFound) + return + } else { + typecastToEntitySchemaType(&value, requestSourceType) + requestContextProto[key] = value.ToProto() + } + } + } + } else { + logSpanContext.Error().Err(err1).Msg("Error when getting entities") + logSpanContext.Error().Err(err2).Msg("Error when getting request sources") + writeJSONError(w, errors.New("Error getting entities/request sources"), http.StatusNotFound) + return } featureVectors, err := s.fs.GetOnlineFeatures( diff --git a/go/internal/feast/server/http_server_test.go b/go/internal/feast/server/http_server_test.go index 67ba1c60f9..995b0b089c 100644 --- a/go/internal/feast/server/http_server_test.go +++ b/go/internal/feast/server/http_server_test.go @@ -1,6 +1,7 @@ package server import ( + prototypes "github.com/feast-dev/feast/go/protos/feast/types" "github.com/stretchr/testify/assert" "testing" ) @@ -38,3 +39,11 @@ func TestUnmarshalJSON(t *testing.T) { assert.Nil(t, u.UnmarshalJSON([]byte("[[true, false, true], [false, true, false]]"))) assert.Equal(t, [][]bool{{true, false, true}, {false, true, false}}, u.boolListVal) } + +func testTypecastToCorrectTypeWithInt32Val(t *testing.T) { + val := repeatedValue{} + val.int64Val = append(val.int64Val, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + typecastToEntitySchemaType(&val, prototypes.ValueType_INT32) + assert.Equal(t, nil, val.int64Val) + assert.NotEqual(t, nil, val.int32Val) +}