diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index e52e43af..98b03e63 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -500,19 +500,17 @@ def f(self, other): deps = [self] plns = [self.name] if is_dask_collection(other): - task = (op, self.key, *other.__dask_keys__()) deps.append(other) + plns.append(other.name) if inv: - plns.insert(0, other.name) + task = (op, *other.__dask_keys__(), self.key) else: - plns.append(other.name) + task = (op, self.key, *other.__dask_keys__()) else: if inv: task = (op, other, self.key) else: task = (op, self.key, other) - if inv: - plns.reverse() graph = HighLevelGraph.from_collections( name, layer=AwkwardMaterializedLayer( @@ -532,6 +530,11 @@ def f(self, other): meta = op(other, self._meta) else: meta = op(self._meta, other) + if meta.ndim: + divisions = other.divisions if is_dask_collection(other) else [0, 1] + return new_array_object( + graph, name, meta=ak.Array(meta), divisions=divisions + ) return new_scalar_object(graph, name, meta=meta) return f @@ -570,6 +573,15 @@ def f(*args): args = tuple( ak.Array(arg.content) if isinstance(arg, MaybeNone) else arg for arg in args ) + args = tuple( + ( + ak.Array(arg) + if isinstance(arg, ak._nplikes.typetracer.TypeTracerArray) + else arg + ) + for arg in args + ) + result = op(*args) return result @@ -2598,6 +2610,8 @@ def typetracer_array(a: ak.Array | Array) -> ak.Array: behavior=a._behavior, attrs=a._attrs, ) + elif isinstance(a, numbers.Number): + return ak.Array([a]).layout.to_typetracer() else: msg = ( "`a` should be an awkward array or a Dask awkward collection.\n" diff --git a/tests/test_core.py b/tests/test_core.py index 8dd82b9f..e679c051 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -440,8 +440,6 @@ def test_typetracer_function(daa: Array) -> None: tta = typetracer_array(aa) assert tta is not None assert tta.layout.form == aa.layout.form - with pytest.raises(TypeError, match="Got type "): - typetracer_array(3) def test_single_partition(ndjson_points_file: str) -> None: