Skip to content

Commit

Permalink
Merge branch 'mainline' into dev/murilommen/zip-segmented-reference-w…
Browse files Browse the repository at this point in the history
…ith-songbird
  • Loading branch information
richard-rogers authored and murilommen committed Feb 21, 2024
2 parents 794d79c + 3fbb983 commit 4acbefc
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 17 deletions.
2 changes: 1 addition & 1 deletion python/.bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.3.23
current_version = 1.3.24
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
serialize =
Expand Down
2 changes: 1 addition & 1 deletion python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ src.proto.dir := ../proto/src
src.proto := $(shell find $(src.proto.dir) -type f -name "*.proto")
src.proto.v0.dir := ../proto/v0
src.proto.v0 := $(shell find $(src.proto.v0.dir) -type f -name "*.proto")
version := 1.3.23
version := 1.3.24

dist.dir := dist
egg.dir := .eggs
Expand Down
2 changes: 1 addition & 1 deletion python/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
print("Pandoc is required to build our documentation.")
sys.exit(1)

version = "1.3.23"
version = "1.3.24"

project = "whylogs"
author = "whylogs developers"
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "whylogs"
version = "1.3.23"
version = "1.3.24"
description = "Profile and monitor your ML data pipeline end-to-end"
authors = ["WhyLabs.ai <[email protected]>"]
license = "Apache-2.0"
Expand Down
137 changes: 136 additions & 1 deletion python/tests/api/writer/test_whylabs_integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import os
import time
from uuid import uuid4

import numpy as np
import pandas as pd
import pytest
from whylabs_client.api.dataset_profile_api import DatasetProfileApi
Expand All @@ -13,7 +15,7 @@
)

import whylogs as why
from whylogs.api.writer.whylabs import WhyLabsWriter
from whylogs.api.writer.whylabs import WhyLabsTransaction, WhyLabsWriter
from whylogs.core import DatasetProfileView
from whylogs.core.feature_weights import FeatureWeights
from whylogs.core.schema import DatasetSchema
Expand All @@ -28,6 +30,8 @@

SLEEP_TIME = 30

logger = logging.getLogger(__name__)


@pytest.mark.load
def test_whylabs_writer():
Expand Down Expand Up @@ -217,3 +221,134 @@ def test_transactions():
downloaded_profile = writer._s3_pool.request("GET", download_url, headers=headers, timeout=writer._timeout_seconds)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view.get_columns().keys() == data.keys()


@pytest.mark.load
def test_transaction_context():
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
pdfs = np.array_split(df, 7)
writer = WhyLabsWriter(dataset_id=MODEL_ID)
tids = list()
try:
with WhyLabsTransaction(writer):
for data in pdfs:
trace_id = str(uuid4())
tids.append(trace_id)
result = why.log(data, schema=schema, trace_id=trace_id)
status, id = writer.write(result.profile())
if not status:
raise Exception() # or retry the profile...

except Exception:
# The start_transaction() or commit_transaction() in the
# WhyLabsTransaction context manager will throw on failure.
# Or retry the commit
logger.exception("Logging transaction failed")

time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
for trace_id in tids:
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
download_url = response.get("traces")[0]["download_url"]
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view is not None


@pytest.mark.load
def test_transaction_segmented():
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
schema = DatasetSchema(segments=segment_on_column("Gender"))
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
data = pd.read_csv(csv_url)
writer = WhyLabsWriter(dataset_id=MODEL_ID)
trace_id = str(uuid4())
try:
writer.start_transaction()
result = why.log(data, schema=schema, trace_id=trace_id)
status, id = writer.write(result)
if not status:
raise Exception() # or retry the profile...

except Exception:
# The start_transaction() or commit_transaction() in the
# WhyLabsTransaction context manager will throw on failure.
# Or retry the commit
logger.exception("Logging transaction failed")

writer.commit_transaction()
time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
assert len(response.get("traces")) == 2
for trace in response.get("traces"):
download_url = trace.get("download_url")
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
assert downloaded_profile is not None


@pytest.mark.load
def test_transaction_distributed():
ORG_ID = os.environ.get("WHYLABS_DEFAULT_ORG_ID")
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(force_local=True)
schema = DatasetSchema()
csv_url = "https://whylabs-public.s3.us-west-2.amazonaws.com/datasets/tour/current.csv"
df = pd.read_csv(csv_url)
pdfs = np.array_split(df, 7)
writer = WhyLabsWriter(dataset_id=MODEL_ID)
tids = list()
try:
transaction_id = writer.start_transaction()
for data in pdfs: # pretend each iteration is run on a different machine
dist_writer = WhyLabsWriter(dataset_id=MODEL_ID)
dist_writer.start_transaction(transaction_id)
trace_id = str(uuid4())
tids.append(trace_id)
result = why.log(data, schema=schema, trace_id=trace_id)
status, id = dist_writer.write(result.profile())
if not status:
raise Exception() # or retry the profile...
writer.commit_transaction()
except Exception:
# The start_transaction() or commit_transaction() in the
# WhyLabsTransaction context manager will throw on failure.
# Or retry the commit
logger.exception("Logging transaction failed")

time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
for trace_id in tids:
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
download_url = response.get("traces")[0]["download_url"]
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view is not None
2 changes: 1 addition & 1 deletion python/tests/smoketest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
not a development environment.
"""

current_version = "1.3.23"
current_version = "1.3.24"


def test_package_version() -> None:
Expand Down
103 changes: 92 additions & 11 deletions python/whylogs/api/writer/whylabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def _check_whylabs_condition_count_uncompound() -> bool:
else:
logger.info(f"Got response code {response.status_code} but expected 200, so running uncompound")
except Exception:
logger.warning("Error trying to read whylabs config, falling back to defaults for uncompounding")
pass

_WHYLABS_SKIP_CONFIG_READ = "True"
return True

Expand Down Expand Up @@ -581,6 +582,65 @@ def _write_segmented_reference_result_set(
else:
return False, "Failed to upload all segments"

def _flatten_tags(self, tags: Union[List, Dict]) -> List[SegmentTag]:
if type(tags[0]) == list:
result: List[SegmentTag] = []
for t in tags:
result.append(self._flatten_tags(t))
return result

return [SegmentTag(t["key"], t["value"]) for t in tags]

def _write_segmented_result_set_transaction(self, file: SegmentedResultSet, **kwargs: Any) -> Tuple[bool, str]:
utc_now = datetime.datetime.now(datetime.timezone.utc)

files = file.get_writables()
partitions = file.partitions
if len(partitions) > 1:
logger.warning(
"SegmentedResultSet contains more than one partition. Only the first partition will be uploaded. "
)
partition = partitions[0]
whylabs_tags = list()
for view in files:
view_tags = list()
dataset_timestamp = view.dataset_timestamp or utc_now
if view.partition.id != partition.id:
continue
_, segment_tags, _ = _generate_segment_tags_metadata(view.segment, view.partition)
for segment_tag in segment_tags:
tag_key = segment_tag.key.replace("whylogs.tag.", "")
tag_value = segment_tag.value
view_tags.append({"key": tag_key, "value": tag_value})
whylabs_tags.append(view_tags)
stamp = dataset_timestamp.timestamp()
dataset_timestamp_epoch = int(stamp * 1000)

region = os.getenv("WHYLABS_UPLOAD_REGION", None)
client: TransactionsApi = self._get_or_create_transaction_client()
messages: List[str] = list()
and_status: bool = True
for view, tags in zip(files, self._flatten_tags(whylabs_tags)):
with tempfile.NamedTemporaryFile() as tmp_file:
view.write(file=tmp_file)
tmp_file.flush()
tmp_file.seek(0)
request = TransactionLogRequest(
dataset_timestamp=dataset_timestamp_epoch, segment_tags=tags, region=region
)
result: AsyncLogResponse = client.log_transaction(self._transaction_id, request, **kwargs)
logger.info(f"Added profile {result.id} to transaction {self._transaction_id}")
bool_status, message = self._do_upload(
dataset_timestamp=dataset_timestamp_epoch,
upload_url=result.upload_url,
profile_id=result.id,
profile_file=tmp_file,
)
and_status = and_status and bool_status
messages.append(message)

return and_status, "; ".join(messages)

def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -> Tuple[bool, str]:
"""Put segmented result set for the specified dataset.
Expand All @@ -593,6 +653,9 @@ def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -
-------
Tuple[bool, str]
"""
if self._transaction_id is not None:
return self._write_segmented_result_set_transaction(file, **kwargs)

# multi-profile writer
files = file.get_writables()
messages: List[str] = list()
Expand All @@ -615,39 +678,46 @@ def _get_or_create_transactions_api(self) -> TransactionsApi:
self._refresh_client()
return TransactionsApi(self._api_client)

def start_transaction(self, **kwargs) -> None:
def start_transaction(self, transaction_id: Optional[str] = None, **kwargs) -> str:
"""
Initiates a transaction -- any profiles subsequently written by calling write()
will be uploaded to WhyLabs atomically when commit_transaction() is called. Throws
will be uploaded to WhyLabs, but not ingested until commit_transaction() is called. Throws
on failure.
"""
if self._transaction_id is not None:
logger.error("Must end current transaction with commit_transaction() before starting another")
return
return self._transaction_id

if kwargs.get("dataset_id") is not None:
self._dataset_id = kwargs.get("dataset_id")

transactions_api = self._get_or_create_transactions_api()
if transaction_id is not None:
self._transaction_id = transaction_id # type: ignore
return transaction_id

client: TransactionsApi = self._get_or_create_transaction_client()
request = TransactionStartRequest(dataset_id=self._dataset_id)
result: LogTransactionMetadata = transactions_api.start_transaction(request, **kwargs)
result: LogTransactionMetadata = client.start_transaction(request, **kwargs)
self._transaction_id = result["transaction_id"]
logger.info(f"Starting transaction {self._transaction_id}, expires {result['expiration_time']}")
return self._transaction_id # type: ignore

def commit_transaction(self, **kwargs) -> None:
"""
Atomically upload any profiles written since the previous start_transaction().
Ingest any profiles written since the previous start_transaction().
Throws on failure.
"""
if self._transaction_id is None:
logger.error("Must call start_transaction() before commit_transaction()")
return

logger.info(f"Committing transaction {self._transaction_id}")
client = self._get_or_create_transactions_api()
request = TransactionCommitRequest(verbose=True)
client.commit_transaction(self._transaction_id, request, **kwargs)
id = self._transaction_id
self._transaction_id = None
logger.info(f"Committing transaction {id}")
client = self._get_or_create_transaction_client()
request = TransactionCommitRequest(verbose=True)
# We abandon the transaction if this throws
client.commit_transaction(id, request, **kwargs)

def _get_uncompounded_view(self, view: DatasetProfileView) -> DatasetProfileView:
self._tag_custom_perf_metrics(view)
Expand Down Expand Up @@ -1114,3 +1184,14 @@ def _get_upload_url(self, dataset_timestamp: int, zip_file: bool = False) -> Tup
logger.debug(f"Replaced URL with our private domain. New URL: {upload_url}")

return upload_url, profile_id


class WhyLabsTransaction:
def __init__(self, writer: WhyLabsWriter):
self._writer = writer

def __enter__(self) -> None:
self._writer.start_transaction()

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
self._writer.commit_transaction()

0 comments on commit 4acbefc

Please sign in to comment.