-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Lucho Farje
authored and
Lucho Farje
committed
Feb 28, 2022
1 parent
5d78119
commit 9dc69a9
Showing
8 changed files
with
736 additions
and
0 deletions.
There are no files selected for viewing
57 changes: 57 additions & 0 deletions
57
infrastructure/src/main/resources/bigquery_to_featurestore.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import NamedTuple | ||
|
||
|
||
def import_feature_values( | ||
project: str, | ||
featurestore_id: str, | ||
entity_type_id: str, | ||
bigquery_uri: str, | ||
entity_id_field: str, | ||
bigquery_table_id: str, | ||
worker_count: int = 1, | ||
location: str = "europe-west3", | ||
api_endpoint: str = "europe-west3-aiplatform.googleapis.com", | ||
timeout: int = 500)-> NamedTuple("Outputs", [ | ||
("featurestore_id", str), | ||
]): | ||
import collections | ||
import datetime | ||
from google.cloud import aiplatform | ||
from google.protobuf.timestamp_pb2 import Timestamp | ||
time_now = datetime.datetime.now().timestamp() | ||
seconds = int(time_now) | ||
proto_timestamp = Timestamp(seconds=seconds) | ||
client_options = {"api_endpoint": api_endpoint} | ||
|
||
client = aiplatform.gapic.FeaturestoreServiceClient(client_options=client_options) | ||
entity_type = f"projects/{project}/locations/{location}/featurestores/{featurestore_id}/entityTypes/{entity_type_id}" | ||
entity_id_field="user_id" | ||
|
||
bigquery_source = aiplatform.gapic.BigQuerySource(input_uri=bigquery_uri) | ||
|
||
|
||
feature_specs = [ | ||
aiplatform.gapic.ImportFeatureValuesRequest.FeatureSpec(id="user_id"), | ||
aiplatform.gapic.ImportFeatureValuesRequest.FeatureSpec(id="item_id"), | ||
aiplatform.gapic.ImportFeatureValuesRequest.FeatureSpec(id="rating"), | ||
aiplatform.gapic.ImportFeatureValuesRequest.FeatureSpec(id="timestamp"), | ||
] | ||
import_feature_values_request = aiplatform.gapic.ImportFeatureValuesRequest( | ||
entity_type=entity_type, | ||
bigquery_source=bigquery_source, | ||
feature_specs=feature_specs, | ||
entity_id_field=entity_id_field, | ||
feature_time=proto_timestamp, | ||
worker_count=worker_count, | ||
disable_online_serving=True | ||
) | ||
lro_response = client.import_feature_values(request=import_feature_values_request) | ||
print("Long running operation:", lro_response.operation.name) | ||
import_feature_values_response = lro_response.result(timeout=timeout) | ||
print("import_feature_values_response:", import_feature_values_response) | ||
|
||
outputs = collections.namedtuple( | ||
"Outputs", | ||
["featurestore_id"]) | ||
|
||
return outputs(featurestore_id) |
168 changes: 168 additions & 0 deletions
168
infrastructure/src/main/resources/generator/generator_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from typing import NamedTuple | ||
|
||
|
||
def generate_movielens_dataset_for_bigquery( | ||
project_id: str, | ||
raw_data_path: str, | ||
batch_size: int, | ||
rank_k: int, | ||
num_actions: int, | ||
driver_steps: int, | ||
bigquery_tmp_file: str, | ||
bigquery_dataset_id: str, | ||
bigquery_location: str, | ||
bigquery_table_id: str | ||
) -> NamedTuple("Outputs", [ | ||
("bigquery_dataset_id", str), | ||
("bigquery_location", str), | ||
("bigquery_table_id", str), | ||
]): | ||
|
||
# pylint: disable=g-import-not-at-top | ||
import collections | ||
import json | ||
from typing import Any, Dict | ||
|
||
from google.cloud import bigquery | ||
|
||
from tf_agents import replay_buffers | ||
from tf_agents import trajectories | ||
from tf_agents.bandits.agents.examples.v2 import trainer | ||
from tf_agents.bandits.environments import movielens_py_environment | ||
from tf_agents.drivers import dynamic_step_driver | ||
from tf_agents.environments import tf_py_environment | ||
from tf_agents.policies import random_tf_policy | ||
|
||
def generate_simulation_data( | ||
raw_data_path: str, | ||
batch_size: int, | ||
rank_k: int, | ||
num_actions: int, | ||
driver_steps: int) -> replay_buffers.TFUniformReplayBuffer: | ||
|
||
# Create movielens simulation environment. | ||
env = movielens_py_environment.MovieLensPyEnvironment( | ||
raw_data_path, | ||
rank_k, | ||
batch_size, | ||
num_movies=num_actions, | ||
csv_delimiter="\t") | ||
environment = tf_py_environment.TFPyEnvironment(env) | ||
|
||
# Define random policy for collecting data. | ||
random_policy = random_tf_policy.RandomTFPolicy( | ||
action_spec=environment.action_spec(), | ||
time_step_spec=environment.time_step_spec()) | ||
|
||
# Use replay buffer and observers to keep track of Trajectory data. | ||
data_spec = random_policy.trajectory_spec | ||
replay_buffer = trainer.get_replay_buffer(data_spec, environment.batch_size, | ||
driver_steps) | ||
observers = [replay_buffer.add_batch] | ||
|
||
# Run driver to apply the random policy in the simulation environment. | ||
driver = dynamic_step_driver.DynamicStepDriver( | ||
env=environment, | ||
policy=random_policy, | ||
num_steps=driver_steps * environment.batch_size, | ||
observers=observers) | ||
driver.run() | ||
|
||
return replay_buffer | ||
|
||
def build_dict_from_trajectory( | ||
trajectory: trajectories.Trajectory) -> Dict[str, Any]: | ||
|
||
trajectory_dict = { | ||
"step_type": trajectory.step_type.numpy().tolist(), | ||
"observation": [{ | ||
"observation_batch": batch | ||
} for batch in trajectory.observation.numpy().tolist()], | ||
"action": trajectory.action.numpy().tolist(), | ||
"policy_info": trajectory.policy_info, | ||
"next_step_type": trajectory.next_step_type.numpy().tolist(), | ||
"reward": trajectory.reward.numpy().tolist(), | ||
"discount": trajectory.discount.numpy().tolist(), | ||
} | ||
return trajectory_dict | ||
|
||
def write_replay_buffer_to_file( | ||
replay_buffer: replay_buffers.TFUniformReplayBuffer, | ||
batch_size: int, | ||
dataset_file: str) -> None: | ||
|
||
dataset = replay_buffer.as_dataset(sample_batch_size=batch_size) | ||
dataset_size = replay_buffer.num_frames().numpy() | ||
|
||
with open(dataset_file, "w") as f: | ||
for example in dataset.take(count=dataset_size): | ||
traj_dict = build_dict_from_trajectory(example[0]) | ||
f.write(json.dumps(traj_dict) + "\n") | ||
|
||
def load_dataset_into_bigquery( | ||
project_id: str, | ||
dataset_file: str, | ||
bigquery_dataset_id: str, | ||
bigquery_location: str, | ||
bigquery_table_id: str) -> None: | ||
|
||
# Construct a BigQuery client object. | ||
client = bigquery.Client(project=project_id) | ||
|
||
# Construct a full Dataset object to send to the API. | ||
dataset = bigquery.Dataset(bigquery_dataset_id) | ||
|
||
# Specify the geographic location where the dataset should reside. | ||
dataset.location = bigquery_location | ||
|
||
# Create the dataset, or get the dataset if it exists. | ||
dataset = client.create_dataset(dataset, exists_ok=True, timeout=30) | ||
|
||
job_config = bigquery.LoadJobConfig( | ||
schema=[ | ||
bigquery.SchemaField("step_type", "INT64", mode="REPEATED"), | ||
bigquery.SchemaField( | ||
"observation", | ||
"RECORD", | ||
mode="REPEATED", | ||
fields=[ | ||
bigquery.SchemaField("observation_batch", "FLOAT64", | ||
"REPEATED") | ||
]), | ||
bigquery.SchemaField("action", "INT64", mode="REPEATED"), | ||
bigquery.SchemaField("policy_info", "FLOAT64", mode="REPEATED"), | ||
bigquery.SchemaField("next_step_type", "INT64", mode="REPEATED"), | ||
bigquery.SchemaField("reward", "FLOAT64", mode="REPEATED"), | ||
bigquery.SchemaField("discount", "FLOAT64", mode="REPEATED"), | ||
], | ||
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON, | ||
create_disposition=bigquery.CreateDisposition.CREATE_IF_NEEDED, | ||
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, | ||
) | ||
|
||
with open(dataset_file, "rb") as source_file: | ||
load_job = client.load_table_from_file( | ||
source_file, bigquery_table_id, job_config=job_config) | ||
|
||
load_job.result() # Wait for the job to complete. | ||
|
||
replay_buffer = generate_simulation_data( | ||
raw_data_path=raw_data_path, | ||
batch_size=batch_size, | ||
rank_k=rank_k, | ||
num_actions=num_actions, | ||
driver_steps=driver_steps) | ||
|
||
write_replay_buffer_to_file( | ||
replay_buffer=replay_buffer, | ||
batch_size=batch_size, | ||
dataset_file=bigquery_tmp_file) | ||
|
||
load_dataset_into_bigquery(project_id, bigquery_tmp_file, bigquery_dataset_id, | ||
bigquery_location, bigquery_table_id) | ||
|
||
outputs = collections.namedtuple( | ||
"Outputs", | ||
["bigquery_dataset_id", "bigquery_location", "bigquery_table_id"]) | ||
|
||
return outputs(bigquery_dataset_id, bigquery_location, bigquery_table_id) |
98 changes: 98 additions & 0 deletions
98
infrastructure/src/main/resources/ingester/ingester_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import NamedTuple | ||
|
||
|
||
def ingest_bigquery_dataset_into_tfrecord( | ||
project_id: str, | ||
bigquery_table_id: str, | ||
tfrecord_file: str, | ||
bigquery_max_rows: int = None | ||
) -> NamedTuple("Outputs", [ | ||
("tfrecord_file", str), | ||
]): | ||
|
||
# pylint: disable=g-import-not-at-top | ||
import collections | ||
from typing import Optional | ||
|
||
from google.cloud import bigquery | ||
|
||
import tensorflow as tf | ||
import logging | ||
|
||
def read_data_from_bigquery( | ||
project_id: str, | ||
bigquery_table_id: str, | ||
bigquery_max_rows: Optional[int]) -> bigquery.table.RowIterator: | ||
|
||
# Construct a BigQuery client object. | ||
client = bigquery.Client(project=project_id) | ||
|
||
|
||
# Get dataset. | ||
query_job = client.query( | ||
f""" | ||
SELECT * FROM `{bigquery_table_id}` | ||
""" | ||
) | ||
table = query_job.result(max_results=bigquery_max_rows) | ||
|
||
|
||
return table | ||
|
||
def _bytes_feature(tensor: tf.Tensor) -> tf.train.Feature: | ||
|
||
value = tf.io.serialize_tensor(tensor) | ||
if isinstance(value, type(tf.constant(0))): | ||
value = value.numpy() | ||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | ||
|
||
def build_example(data_row: bigquery.table.Row) -> tf.train.Example: | ||
feature = { | ||
"step_type": | ||
_bytes_feature(data_row.get("step_type")), | ||
"observation": | ||
_bytes_feature([ | ||
observation["observation_batch"] | ||
for observation in data_row.get("observation") | ||
]), | ||
"action": | ||
_bytes_feature(data_row.get("action")), | ||
"policy_info": | ||
_bytes_feature(data_row.get("policy_info")), | ||
"next_step_type": | ||
_bytes_feature(data_row.get("next_step_type")), | ||
"reward": | ||
_bytes_feature(data_row.get("reward")), | ||
"discount": | ||
_bytes_feature(data_row.get("discount")), | ||
} | ||
|
||
example_proto = tf.train.Example( | ||
features=tf.train.Features(feature=feature)) | ||
return example_proto | ||
|
||
def write_tfrecords( | ||
tfrecord_file: str, | ||
table: bigquery.table.RowIterator) -> None: | ||
|
||
with tf.io.TFRecordWriter(tfrecord_file) as writer: | ||
for data_row in table: | ||
example = build_example(data_row) | ||
writer.write(example.SerializeToString()) | ||
|
||
table = read_data_from_bigquery( | ||
project_id=project_id, | ||
bigquery_table_id=bigquery_table_id, | ||
bigquery_max_rows=bigquery_max_rows) | ||
|
||
logging.info("writing records------------------") | ||
|
||
write_tfrecords(tfrecord_file, table) | ||
|
||
outputs = collections.namedtuple( | ||
"Outputs", | ||
["tfrecord_file"]) | ||
logging.info(outputs) | ||
|
||
return outputs(tfrecord_file) | ||
|
Oops, something went wrong.