Skip to content

Commit

Permalink
Add additional generic types to DataFrame methods
Browse files Browse the repository at this point in the history
  • Loading branch information
controversial committed Dec 9, 2024
1 parent 4f9cafd commit cbeedf8
Showing 1 changed file with 45 additions and 40 deletions.
85 changes: 45 additions & 40 deletions polars/dataframe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* @param other DataFrame to vertically add.
*/
extend(other: DataFrame): DataFrame;
extend(other: DataFrame<T>): DataFrame<T>;
/**
* Fill null/missing values by a filling strategy
*
Expand All @@ -480,7 +480,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* - "one"
* @returns DataFrame with None replaced with the filling strategy.
*/
fillNull(strategy: FillNullStrategy): DataFrame;
fillNull(strategy: FillNullStrategy): DataFrame<T>;
/**
* Filter the rows in the DataFrame based on a predicate expression.
* ___
Expand Down Expand Up @@ -519,7 +519,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
filter(predicate: any): DataFrame;
filter(predicate: any): DataFrame<T>;
/**
* Find the index of a column by name.
* ___
Expand Down Expand Up @@ -764,7 +764,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
/**
* Interpolate intermediate values. The interpolation method is linear.
*/
interpolate(): DataFrame;
interpolate(): DataFrame<T>;
/**
* Get a mask of all duplicated rows in this DataFrame.
*/
Expand Down Expand Up @@ -937,8 +937,11 @@ export interface DataFrame<T extends Record<string, Series> = any>
* Get first N rows as DataFrame.
* @see {@link head}
*/
limit(length?: number): DataFrame;
map(func: (...args: any[]) => any): any[];
limit(length?: number): DataFrame<T>;
map<ReturnT>(
// TODO: strong types for the mapping function
func: (row: any[], i: number, arr: any[][]) => ReturnT
): ReturnT[];

/**
* Aggregate the columns of this DataFrame to their maximum value.
Expand All @@ -962,8 +965,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
max(): DataFrame;
max(axis: 0): DataFrame;
max(): DataFrame<T>;
max(axis: 0): DataFrame<T>;
max(axis: 1): Series;
/**
* Aggregate the columns of this DataFrame to their mean value.
Expand All @@ -972,8 +975,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* @param axis - either 0 or 1
* @param nullStrategy - this argument is only used if axis == 1
*/
mean(): DataFrame;
mean(axis: 0): DataFrame;
mean(): DataFrame<T>;
mean(axis: 0): DataFrame<T>;
mean(axis: 1): Series;
mean(axis: 1, nullStrategy?: "ignore" | "propagate"): Series;
/**
Expand All @@ -997,7 +1000,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
median(): DataFrame;
median(): DataFrame<T>;
/**
* Unpivot a DataFrame from wide to long format.
* @deprecated *since 0.13.0* use {@link unpivot}
Expand Down Expand Up @@ -1059,8 +1062,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
min(): DataFrame;
min(axis: 0): DataFrame;
min(): DataFrame<T>;
min(axis: 0): DataFrame<T>;
min(axis: 1): Series;
/**
* Get number of chunks used by the ChunkedArrays of this DataFrame.
Expand All @@ -1087,12 +1090,12 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
nullCount(): DataFrame;
nullCount(): DataFrame<{ [K in keyof T]: Series<JsToDtype<number>, K & string> }>;
partitionBy(
cols: string | string[],
stable?: boolean,
includeKey?: boolean,
): DataFrame[];
): DataFrame<T>[];
partitionBy<T>(
cols: string | string[],
stable: boolean,
Expand Down Expand Up @@ -1210,13 +1213,13 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
quantile(quantile: number): DataFrame;
quantile(quantile: number): DataFrame<T>;
/**
* __Rechunk the data in this DataFrame to a contiguous allocation.__
*
* This will make sure all subsequent operations have optimal and predictable performance.
*/
rechunk(): DataFrame;
rechunk(): DataFrame<T>;
/**
* __Rename column names.__
* ___
Expand Down Expand Up @@ -1443,12 +1446,12 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
shiftAndFill(n: number, fillValue: number): DataFrame;
shiftAndFill({ n, fillValue }: { n: number; fillValue: number }): DataFrame;
shiftAndFill(n: number, fillValue: number): DataFrame<T>;
shiftAndFill({ n, fillValue }: { n: number; fillValue: number }): DataFrame<T>;
/**
* Shrink memory usage of this DataFrame to fit the exact capacity needed to hold the data.
*/
shrinkToFit(): DataFrame;
shrinkToFit(): DataFrame<T>;
shrinkToFit(inPlace: true): void;
shrinkToFit({ inPlace }: { inPlace: true }): void;
/**
Expand Down Expand Up @@ -1477,8 +1480,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
slice({ offset, length }: { offset: number; length: number }): DataFrame;
slice(offset: number, length: number): DataFrame;
slice({ offset, length }: { offset: number; length: number }): DataFrame<T>;
slice(offset: number, length: number): DataFrame<T>;
/**
* Sort the DataFrame by column.
* ___
Expand All @@ -1494,7 +1497,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
descending?: boolean,
nullsLast?: boolean,
maintainOrder?: boolean,
): DataFrame;
): DataFrame<T>;
sort({
by,
reverse, // deprecated
Expand All @@ -1504,7 +1507,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
reverse?: boolean; // deprecated
nullsLast?: boolean;
maintainOrder?: boolean;
}): DataFrame;
}): DataFrame<T>;
sort({
by,
descending,
Expand All @@ -1514,7 +1517,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
descending?: boolean;
nullsLast?: boolean;
maintainOrder?: boolean;
}): DataFrame;
}): DataFrame<T>;
/**
* Aggregate the columns of this DataFrame to their standard deviation value.
* ___
Expand All @@ -1536,16 +1539,16 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
std(): DataFrame;
std(): DataFrame<T>;
/**
* Aggregate the columns of this DataFrame to their mean value.
* ___
*
* @param axis - either 0 or 1
* @param nullStrategy - this argument is only used if axis == 1
*/
sum(): DataFrame;
sum(axis: 0): DataFrame;
sum(): DataFrame<T>;
sum(axis: 0): DataFrame<T>;
sum(axis: 1): Series;
sum(axis: 1, nullStrategy?: "ignore" | "propagate"): Series;
/**
Expand Down Expand Up @@ -1595,7 +1598,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────────┴─────╯
* ```
*/
tail(length?: number): DataFrame;
tail(length?: number): DataFrame<T>;
/**
* @deprecated *since 0.4.0* use {@link writeCSV}
* @category Deprecated
Expand All @@ -1614,7 +1617,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ```
* @category IO
*/
toRecords(): Record<string, any>[];
toRecords(): { [K in keyof T]: DTypeToJs<T[K]["dtype"]> | null }[];

/**
* compat with `JSON.stringify`
Expand Down Expand Up @@ -1644,7 +1647,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ```
* @category IO
*/
toObject(): { [K in keyof T]: DTypeToJs<T[K]["dtype"]>[] };
toObject(): { [K in keyof T]: DTypeToJs<T[K]["dtype"] | null>[] };

/**
* @deprecated *since 0.4.0* use {@link writeIPC}
Expand All @@ -1656,7 +1659,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* @category IO Deprecated
*/
toParquet(destination?, options?);
toSeries(index?: number): Series;
toSeries(index?: number): T[keyof T];
toString(): string;
/**
* Convert a ``DataFrame`` to a ``Series`` of type ``Struct``
Expand Down Expand Up @@ -1768,12 +1771,12 @@ export interface DataFrame<T extends Record<string, Series> = any>
maintainOrder?: boolean,
subset?: ColumnSelection,
keep?: "first" | "last",
): DataFrame;
): DataFrame<T>;
unique(opts: {
maintainOrder?: boolean;
subset?: ColumnSelection;
keep?: "first" | "last";
}): DataFrame;
}): DataFrame<T>;
/**
Decompose a struct into its fields. The fields will be inserted in to the `DataFrame` on the
location of the `struct` type.
Expand Down Expand Up @@ -1833,7 +1836,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
var(): DataFrame;
var(): DataFrame<T>;
/**
* Grow this DataFrame vertically by stacking a DataFrame to it.
* @param df - DataFrame to stack.
Expand Down Expand Up @@ -1866,12 +1869,14 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴─────╯
* ```
*/
vstack(df: DataFrame): DataFrame;
vstack(df: DataFrame<T>): DataFrame<T>;
/**
* Return a new DataFrame with the column added or replaced.
* @param column - Series, where the name of the Series refers to the column in the DataFrame.
*/
withColumn(column: Series | Expr): DataFrame;
withColumn<SeriesTypeT extends DataType, SeriesNameT extends string>(
column: Series<SeriesTypeT, SeriesNameT>
): DataFrame<Simplify<T & { [K in SeriesNameT]: Series<SeriesTypeT, SeriesNameT> }>>;
withColumn(column: Series | Expr): DataFrame;
withColumns(...columns: (Expr | Series)[]): DataFrame;
/**
Expand All @@ -1896,7 +1901,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
*/
withRowCount(name?: string): DataFrame;
/** @see {@link filter} */
where(predicate: any): DataFrame;
where(predicate: any): DataFrame<T>;
/**
Upsample a DataFrame at a regular frequency.
Expand Down Expand Up @@ -1972,13 +1977,13 @@ shape: (7, 3)
every: string,
by?: string | string[],
maintainOrder?: boolean,
): DataFrame;
): DataFrame<T>;
upsample(opts: {
timeColumn: string;
every: string;
by?: string | string[];
maintainOrder?: boolean;
}): DataFrame;
}): DataFrame<T>;
}

function prepareOtherArg(anyValue: any): Series {
Expand Down

0 comments on commit cbeedf8

Please sign in to comment.