Skip to content

Commit

Permalink
Convert client forwarder logic to expect async callable
Browse files Browse the repository at this point in the history
Rather than async generator and using .asend()
  • Loading branch information
milesgranger committed Apr 16, 2019
1 parent 809a33b commit 2cb4ef7
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 88 deletions.
1 change: 1 addition & 0 deletions gordo_components/cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd # noqa

from gordo_components.client import Client
from gordo_components.client.client import EndpointMetadata # noqa
from gordo_components import serializer
from gordo_components.data_provider import providers
from gordo_components.cli.custom_types import IsoFormatDateTime, DataProviderParam
Expand Down
62 changes: 18 additions & 44 deletions gordo_components/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(
metadata: typing.Optional[dict] = None,
data_provider: typing.Optional[GordoBaseDataProvider] = None,
prediction_forwarder: typing.Optional[
typing.Callable[[EndpointMetadata, dict], typing.AsyncGenerator]
typing.Callable[
[pd.DataFrame, EndpointMetadata, dict], typing.Awaitable[None]
]
] = None,
batch_size: int = 1000,
parallelism: int = 10,
Expand Down Expand Up @@ -70,11 +72,9 @@ def __init__(
data_provider: Optional[GordoBaseDataProvider]
The data provider to use for the dataset. If not set, the client
will fall back to using the GET /prediction endpoint
prediction_forwarder: Optional[Callable[[EndpointMetadata, dict], typing.AsyncGenerator]]
Async generator to initialize with two parameters: EndpointMetadata
and a dict of metadata key-value pairs. It should then be able to
process a dataframe of predicitions which has a datetime index.
(via .asend()).
prediction_forwarder: Optional[Callable[[pd.DataFrame, EndpointMetadata, dict], typing.Awaitable[None]]]
Async callable which will take a dataframe of predictions, ``EndpointMetadata`` and the metadata
passed into this constructor
batch_size: int
How many samples to send to the server, only applicable when data
provider is supplied.
Expand Down Expand Up @@ -197,14 +197,7 @@ def predict(
# For every endpoint, start making predictions for the time range
jobs = asyncio.gather(
*[
predict_method(
endpoint=endpoint,
start=start,
end=end,
forwarder=self.prediction_forwarder(endpoint, self.metadata)
if self.prediction_forwarder is not None
else None,
)
predict_method(endpoint=endpoint, start=start, end=end)
for endpoint in self.endpoints
]
)
Expand All @@ -221,11 +214,7 @@ def predict(
] # type: ignore

async def _predict_via_post(
self,
endpoint: EndpointMetadata,
start: datetime,
end: datetime,
forwarder: typing.Optional[typing.AsyncGenerator],
self, endpoint: EndpointMetadata, start: datetime, end: datetime
) -> PredictionResult:
"""
Get predictions based on the /prediction POST endpoint of Gordo ML Servers
Expand All @@ -236,15 +225,13 @@ async def _predict_via_post(
Named tuple which has 'endpoint' specifying the full url to the base ml server
start: datetime
end: datetime
Returns
-------
dict
Prediction response from /prediction GET
"""

if forwarder is not None and not forwarder.ag_running:
await forwarder.asend(None) # start async coroutine

# Fetch all of the raw data
X, y = await self._raw_data(endpoint, start, end)

Expand All @@ -258,21 +245,19 @@ async def _predict_via_post(
endpoint=endpoint,
start=start,
end=end,
forwarder=forwarder,
session=session,
)
for i in range(0, X.shape[0], self.batch_size)
]
return await self._accumulate_coroutine_predictions(endpoint, jobs)

@staticmethod
async def _process_post_prediction_task(
self,
X: pd.DataFrame,
chunk: slice,
endpoint: EndpointMetadata,
start: datetime,
end: datetime,
forwarder: typing.Optional[typing.AsyncGenerator] = None,
session: typing.Optional[aiohttp.ClientSession] = None,
):
"""
Expand All @@ -287,7 +272,6 @@ async def _process_post_prediction_task(
endpoint: EndpointMetadata
start: datetime
end: datetime
forwarder: Optional[AsyncGenerator]
Notes
-----
Expand Down Expand Up @@ -327,19 +311,15 @@ async def _process_post_prediction_task(
)

# Forward predictions to any other consumer if registered.
if forwarder is not None:
await forwarder.asend(predictions)
if self.prediction_forwarder is not None:
await self.prediction_forwarder(predictions, endpoint, self.metadata)

return PredictionResult(
name=endpoint.target_name, predictions=predictions, error_messages=[]
)

async def _predict_via_get(
self,
endpoint: EndpointMetadata,
start: datetime,
end: datetime,
forwarder: typing.Optional[typing.AsyncGenerator],
self, endpoint: EndpointMetadata, start: datetime, end: datetime
) -> PredictionResult:
"""
Get predictions based on the /prediction GET endpoint of Gordo ML Servers
Expand All @@ -358,27 +338,21 @@ async def _predict_via_get(
"""
start_end_dates = make_date_ranges(start, end, max_interval_days=1, freq="23H")

if forwarder is not None and not forwarder.ag_running:
await forwarder.asend(None) # start async coroutine

async with aiohttp.ClientSession() as session:

# Create all the jobs which will be done, but don't await them
jobs = [
self._process_get_prediction_task(
endpoint, start, end, forwarder, session
)
self._process_get_prediction_task(endpoint, start, end, session)
for start, end in start_end_dates
]

return await self._accumulate_coroutine_predictions(endpoint, jobs)

@staticmethod
async def _process_get_prediction_task(
self,
endpoint: EndpointMetadata,
start: datetime,
end: datetime,
forwarder: typing.Optional[typing.AsyncGenerator] = None,
session: typing.Optional[aiohttp.ClientSession] = None,
):
"""
Expand All @@ -391,7 +365,6 @@ async def _process_get_prediction_task(
endpoint: EndpointMetadata
start: datetime
end: datetime
forwarder: Optional[AsyncGenerator]
Notes
-----
Expand Down Expand Up @@ -438,8 +411,8 @@ async def _process_get_prediction_task(
# Convert to dataframe
predictions = pd.DataFrame.from_records(records, index="time")

if forwarder is not None:
await forwarder.asend(predictions)
if self.prediction_forwarder is not None:
await self.prediction_forwarder(predictions, endpoint, self.metadata)
return PredictionResult(
name=endpoint.target_name, predictions=predictions, error_messages=[]
)
Expand All @@ -457,6 +430,7 @@ async def _accumulate_coroutine_predictions(
jobs: List[Coroutine]
An awaitable coroutine which will return a single PredictionResult
from a single prediction task
Returns
-------
PredictionResult
Expand Down
81 changes: 37 additions & 44 deletions gordo_components/client/forwarders.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,54 +194,47 @@ def create_anomaly_point(

return anomaly_point

async def __call__(self, endpoint: EndpointMetadata, metadata: dict = dict()):
"""
Insert predictions into the destination influx
"""
while True:

# Predictions is a dataframe consisting of a DatetimeIndex and columns
# either as EndpointMetadata.tag_list values, or double that list with
# 'output_<tag>', 'input_<tag>' names reflecting what the model got
# as input per tag and what it put out, per tag.
predictions = yield
async def __call__(
self,
predictions: pd.DataFrame,
endpoint: EndpointMetadata,
metadata: dict = dict(),
):
# First, let's post all anomalies per sensor
logger.info(f"Calculating points per sensor per record")
data = [
point
for package in predictions.apply(
lambda rec: self.create_anomaly_point_per_sensor(
rec, metadata, endpoint
),
axis=1,
)
for point in package
]

# Async write predictions to influx
with ThreadPoolExecutor(max_workers=1) as executor:
# Write the per-sensor points to influx
logger.info(f"Writing {len(data)} sensor points to Influx")
future = executor.submit(
self.destionation_client.write_points, data, batch_size=10000
)
await asyncio.wrap_future(future)

# First, let's post all anomalies per sensor
logger.info(f"Calculating points per sensor per record")
# Now calculate the error per line from model input vs output
logger.debug(f"Calculating points per record")
data = [
point
for package in predictions.apply(
lambda rec: self.create_anomaly_point_per_sensor(
rec, metadata, endpoint
),
for point in predictions.apply(
lambda rec: self.create_anomaly_point(rec, metadata, endpoint),
axis=1,
)
for point in package
]
logger.info(f"Writing {len(data)} points to Influx")

# Async write predictions to influx
with ThreadPoolExecutor(max_workers=1) as executor:

# Write the per-sensor points to influx
logger.info(f"Writing {len(data)} sensor points to Influx")
future = executor.submit(
self.destionation_client.write_points, data, batch_size=10000
)
await asyncio.wrap_future(future)

# Now calculate the error per line from model input vs output
logger.debug(f"Calculating points per record")
data = [
point
for point in predictions.apply(
lambda rec: self.create_anomaly_point(rec, metadata, endpoint),
axis=1,
)
]
logger.info(f"Writing {len(data)} points to Influx")

# Write the per-sample errors to influx
future = executor.submit(
self.destionation_client.write_points, data, batch_size=10000
)
await asyncio.wrap_future(future)
# Write the per-sample errors to influx
future = executor.submit(
self.destionation_client.write_points, data, batch_size=10000
)
await asyncio.wrap_future(future)

0 comments on commit 2cb4ef7

Please sign in to comment.