diff --git a/octarine/visuals.py b/octarine/visuals.py index 0d79b88..8b27828 100644 --- a/octarine/visuals.py +++ b/octarine/visuals.py @@ -317,6 +317,9 @@ def points2gfx(points, color, size=2, marker=None, size_space="screen"): geometry_kwargs = {} material_kwargs = {} + # material_kwargs["pick_write"] = True # for picking + + # Parse sizes if utils.is_iterable(size): if len(size) != len(points): raise ValueError( @@ -324,17 +327,30 @@ def points2gfx(points, color, size=2, marker=None, size_space="screen"): "an array of the same length as `points`." ) geometry_kwargs["sizes"] = np.asarray(size).astype(np.float32, copy=False) - material_kwargs["size_mode"] = 'vertex' + material_kwargs["size_mode"] = "vertex" else: material_kwargs["size"] = size + # Parse color(s) + if isinstance(color, np.ndarray) and color.ndim == 2: + # If colors are provided for each node we have to make sure + # that we also include `None` for the breaks in the segments + n_points = len(points) + if len(color) != n_points: + raise ValueError(f"Got {len(color)} colors for {n_points} points.") + color = color.astype(np.float32, copy=False) + geometry_kwargs["colors"] = color + material_kwargs["color_mode"] = "vertex" + else: + if isinstance(color, np.ndarray): + color = color.astype(np.float32, copy=False) + material_kwargs["color"] = color + if marker is None: - material = gfx.PointsMaterial( - color=color, size_space=size_space, **material_kwargs - ) + material = gfx.PointsMaterial(size_space=size_space, **material_kwargs) else: material = gfx.PointsMarkerMaterial( - color=color, marker=marker, size_space=size_space, **material_kwargs + marker=marker, size_space=size_space, **material_kwargs ) vis = gfx.Points(gfx.Geometry(positions=points, **geometry_kwargs), material)