diff --git a/openwpm_utils/s3.py b/openwpm_utils/s3.py index 3e8581c..2242a9e 100644 --- a/openwpm_utils/s3.py +++ b/openwpm_utils/s3.py @@ -15,8 +15,8 @@ from openwpm_utils.dataquality import TableFilter -class S3Dataset(object): - def __init__(self, s3_directory, s3_bucket="openwpm-crawls"): +class S3Dataset: + def __init__(self, s3_directory: str, s3_bucket: str = "openwpm-crawls"): """Helper class to load OpenWPM datasets from S3 using pandas This dataset wrapper is safe to use by spark worker processes, as it @@ -87,8 +87,11 @@ def collect_content(self, content_hash, beautify=False): class PySparkS3Dataset(S3Dataset): def __init__( - self, spark_context, s3_directory: str, s3_bucket: str = "openwpm-crawls" - ): + self, + spark_context: SparkContext, + s3_directory: str, + s3_bucket: str = "openwpm-crawls", + ) -> None: """Helper class to load OpenWPM datasets from S3 using PySpark Parameters @@ -111,7 +114,7 @@ def __init__( def read_table( self, table_name: str, columns: List[str] = None, mode: str = "successful" - ): + ) -> DataFrame: """Read `table_name` from OpenWPM dataset into a pyspark dataframe. Parameters @@ -126,7 +129,6 @@ def read_table( if one of it's commands failed or if it's in the interrupted table """ table = self._sql_context.read.parquet(self._s3_table_loc % table_name) - if mode == "all": table = table elif mode == "failed":