Skip to content

Commit

Permalink
Stubs for testing parquet dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed May 8, 2024
1 parent 0f3f68c commit f553650
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
17 changes: 7 additions & 10 deletions high_order_implicit_representation/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def image_to_dataset(
return torch_image_flat, torch_position, torch_image


def simple_image_to_dataset(
image: Tensor, device="cpu"
):
def simple_image_to_dataset(image: Tensor, device="cpu"):
"""
Read in an image file and return the flattened position input
flattened output and torch array of the original image.def image_to_dataset(filename: str, peano: str = False, rotations: int = 1):
Expand All @@ -71,7 +69,7 @@ def simple_image_to_dataset(
rotations=rotations,
normalize=True,
)

torch_position = torch.stack(line_list2, dim=2)
torch_position = torch_position.reshape(-1, 2 * rotations)

Expand All @@ -80,7 +78,6 @@ def simple_image_to_dataset(
return torch_image_flat, torch_position, image



def image_neighborhood_dataset(
image: Tensor,
width: int = 3,
Expand Down Expand Up @@ -338,16 +335,15 @@ def __init__(self, files: list[str]):
def __call__(self):
for file in self.files:
data = pd.read_parquet(file)

for index, row in data.iterrows():
caption = row['caption']
caption = row["caption"]

jpg_0 = row["jpg_0"]
img = Image.open(io.BytesIO(jpg_0))
arr = np.asarray(img)
yield caption, torch.from_numpy(arr)


jpg_1 = row["jpg_1"]
img = Image.open(io.BytesIO(jpg_1))
arr = np.asarray(img)
Expand All @@ -364,5 +360,6 @@ def __len__(self):
def __getitem__(self, idx):
# I'm totally ignoring the index
caption, image = self.dataset()


flattened_image, flattened_position = simple_image_to_dataset(image)
for index, rgb in enumerate(flattened_image):
yield caption, flattened_position[index], rgb
3 changes: 3 additions & 0 deletions tests/test_single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ def test_image_neighborhood_reader():
assert ind.lastx == 27
assert ind.lasty == 27
assert ind.image.shape == torch.Size([3, 32, 32])

def test_parquet_dataset():
pass

0 comments on commit f553650

Please sign in to comment.