Skip to content

Commit

Permalink
simplify dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Mar 6, 2025
1 parent 7d33019 commit b571817
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions bfabric/src/bfabric/entities/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import tempfile
from pathlib import Path
from typing import Any, TYPE_CHECKING
import io
from typing import TYPE_CHECKING, Any

from polars import DataFrame

from bfabric.entities.core.entity import Entity

if TYPE_CHECKING:
from pathlib import Path
from bfabric import Bfabric


Expand Down Expand Up @@ -47,22 +47,16 @@ def write_csv(self, path: Path, separator: str = ",") -> None:
"""Writes the dataset to a csv file at `path`, using the specified column `separator`."""
self.to_polars().write_csv(path, separator=separator)

def get_csv(self, separator: str = ",") -> str:
"""Returns the dataset as a csv string, using the specified column `separator`."""
return self.to_polars().write_csv(separator=separator)

def write_parquet(self, path: Path) -> None:
"""Writes the dataset to a parquet file at `path`."""
self.to_polars().write_parquet(path)

def get_csv(self, separator: str = ",") -> str:
"""Returns the dataset as a csv string, using the specified column `separator`."""
with tempfile.NamedTemporaryFile() as tmp_file:
self.write_csv(Path(tmp_file.name), separator=separator)
tmp_file.flush()
tmp_file.seek(0)
return tmp_file.read().decode()

def get_parquet(self) -> bytes:
"""Returns the dataset as a parquet bytes object."""
with tempfile.NamedTemporaryFile() as tmp_file:
self.write_parquet(Path(tmp_file.name))
tmp_file.flush()
tmp_file.seek(0)
return tmp_file.read()
bytes_io = io.BytesIO()
self.to_polars().write_parquet(bytes_io)
return bytes_io.getvalue()

0 comments on commit b571817

Please sign in to comment.