From f553650dd54bb232da65adf0da0f39170ed7eedc Mon Sep 17 00:00:00 2001 From: jloveric Date: Tue, 7 May 2024 20:51:05 -0700 Subject: [PATCH] Stubs for testing parquet dataset --- .../single_image_dataset.py | 17 +++++++---------- tests/test_single_image_dataset.py | 3 +++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/high_order_implicit_representation/single_image_dataset.py b/high_order_implicit_representation/single_image_dataset.py index fd6546b..625b38a 100644 --- a/high_order_implicit_representation/single_image_dataset.py +++ b/high_order_implicit_representation/single_image_dataset.py @@ -49,9 +49,7 @@ def image_to_dataset( return torch_image_flat, torch_position, torch_image -def simple_image_to_dataset( - image: Tensor, device="cpu" -): +def simple_image_to_dataset(image: Tensor, device="cpu"): """ Read in an image file and return the flattened position input flattened output and torch array of the original image.def image_to_dataset(filename: str, peano: str = False, rotations: int = 1): @@ -71,7 +69,7 @@ def simple_image_to_dataset( rotations=rotations, normalize=True, ) - + torch_position = torch.stack(line_list2, dim=2) torch_position = torch_position.reshape(-1, 2 * rotations) @@ -80,7 +78,6 @@ def simple_image_to_dataset( return torch_image_flat, torch_position, image - def image_neighborhood_dataset( image: Tensor, width: int = 3, @@ -338,16 +335,15 @@ def __init__(self, files: list[str]): def __call__(self): for file in self.files: data = pd.read_parquet(file) - + for index, row in data.iterrows(): - caption = row['caption'] + caption = row["caption"] jpg_0 = row["jpg_0"] img = Image.open(io.BytesIO(jpg_0)) arr = np.asarray(img) yield caption, torch.from_numpy(arr) - jpg_1 = row["jpg_1"] img = Image.open(io.BytesIO(jpg_1)) arr = np.asarray(img) @@ -364,5 +360,6 @@ def __len__(self): def __getitem__(self, idx): # I'm totally ignoring the index caption, image = self.dataset() - - + flattened_image, flattened_position = simple_image_to_dataset(image) + for index, rgb in enumerate(flattened_image): + yield caption, flattened_position[index], rgb diff --git a/tests/test_single_image_dataset.py b/tests/test_single_image_dataset.py index 863ac07..20f9d57 100644 --- a/tests/test_single_image_dataset.py +++ b/tests/test_single_image_dataset.py @@ -32,3 +32,6 @@ def test_image_neighborhood_reader(): assert ind.lastx == 27 assert ind.lasty == 27 assert ind.image.shape == torch.Size([3, 32, 32]) + +def test_parquet_dataset(): + pass \ No newline at end of file