diff --git a/pyproject.toml b/pyproject.toml index 090155e..0c5d478 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zeno-client" -version = "0.1.10" +version = "0.1.11" description = "Python client for creating new Zeno projects and uploading data." authors = ["Zeno Team "] license = "MIT" diff --git a/zeno_client/client.py b/zeno_client/client.py index cb7261f..8a0f524 100644 --- a/zeno_client/client.py +++ b/zeno_client/client.py @@ -1,10 +1,11 @@ """Functions to upload data to Zeno's backend.""" import io +import json import re import urllib from importlib.metadata import version as package_version from json import JSONDecodeError -from typing import List, Optional +from typing import Dict, List, Optional, Union from urllib.parse import quote import pandas as pd @@ -72,7 +73,7 @@ def upload_dataset( df: pd.DataFrame, *, id_column: str, - data_column: str, + data_column: Optional[str] = None, label_column: Optional[str] = None, ): """Upload a dataset to a Zeno project. @@ -80,7 +81,7 @@ def upload_dataset( Args: df (pd.DataFrame): The dataset to upload as a Pandas DataFrame. id_column (str): Column name containing unique instance IDs. - data_column (str): Column containing the + data_column (str | None, optional): Column containing the instance data. This can be raw data for data types such as text, or URLs for large media data such as images and videos. label_column (str | None, optional): Column containing the @@ -89,7 +90,11 @@ def upload_dataset( if ( id_column == label_column or id_column == data_column - or label_column == data_column + or ( + label_column == data_column + and label_column is not None + and data_column is not None + ) ): raise ValueError( "ERROR: ID, data, and label column names must be unique." @@ -100,7 +105,7 @@ def upload_dataset( if id_column not in df.columns: raise ValueError("ERROR: id_column not found in dataframe") - if data_column not in df.columns: + if data_column and data_column not in df.columns: raise ValueError("ERROR: data_column not found in dataframe") if label_column and label_column not in df.columns: @@ -306,7 +311,7 @@ def create_project( self, *, name: str, - view: str, + view: Union[str, Dict] = "", description: str = "", metrics: List[ZenoMetric] = [], samples_per_page: int = 10, @@ -317,8 +322,8 @@ def create_project( Args: name (str): The name of the project to be created. The project will be created under the current user, e.g. username/name. - project: str, - view (str): The view to use for the project. + view (Union[str, Dict], optional): The view to use for the project. + Defaults to "". description (str, optional): The description of the project. Defaults to "". metrics (list[ZenoMetric], optional): The metrics to calculate for the project. Defaults to []. @@ -336,8 +341,14 @@ def create_project( """ if name == "": raise ValueError("Project name cannot be empty") + if re.findall("[/]", name): raise ValueError("Project name cannot contain a '/'.") + + # if view is dict, dump to json + if isinstance(view, dict): + view = json.dumps(view) + response = requests.post( f"{self.endpoint}/api/project", json={