Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Better types for HfInference #1121

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 73 additions & 12 deletions packages/inference/src/HfInference.ts
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ import type { DistributiveOmit } from "./utils/distributive-omit";
/* eslint-disable @typescript-eslint/no-empty-interface */
/* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */

type Task = typeof tasks;
type Task = Omit<typeof tasks, "request" | "streamingRequest">;

type TaskWithNoAccessToken = {
[key in keyof Task]: (
@@ -21,22 +21,83 @@ type TaskWithNoAccessTokenNoEndpointUrl = {
) => ReturnType<Task[key]>;
};

export class HfInference {
private readonly accessToken: string;
private readonly defaultOptions: Options;
export class HfInference implements Task {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

export class HfInference implements Task {

That bit makes sure HfInference defines a method for every task (textToImage, chatCompletion, etc)

protected readonly accessToken: string;
protected readonly defaultOptions: Options;

speechToText: typeof tasks.speechToText;
audioClassification: typeof tasks.audioClassification;
automaticSpeechRecognition: typeof tasks.automaticSpeechRecognition;
textToSpeech: typeof tasks.textToSpeech;
audioToAudio: typeof tasks.audioToAudio;
imageClassification: typeof tasks.imageClassification;
imageSegmentation: typeof tasks.imageSegmentation;
imageToText: typeof tasks.imageToText;
objectDetection: typeof tasks.objectDetection;
textToImage: typeof tasks.textToImage;
imageToImage: typeof tasks.imageToImage;
zeroShotImageClassification: typeof tasks.zeroShotImageClassification;
featureExtraction: typeof tasks.featureExtraction;
fillMask: typeof tasks.fillMask;
questionAnswering: typeof tasks.questionAnswering;
sentenceSimilarity: typeof tasks.sentenceSimilarity;
summarization: typeof tasks.summarization;
tableQuestionAnswering: typeof tasks.tableQuestionAnswering;
textClassification: typeof tasks.textClassification;
textGeneration: typeof tasks.textGeneration;
tokenClassification: typeof tasks.tokenClassification;
translation: typeof tasks.translation;
zeroShotClassification: typeof tasks.zeroShotClassification;
chatCompletion: typeof tasks.chatCompletion;
documentQuestionAnswering: typeof tasks.documentQuestionAnswering;
visualQuestionAnswering: typeof tasks.visualQuestionAnswering;
tabularRegression: typeof tasks.tabularRegression;
tabularClassification: typeof tasks.tabularClassification;
textGenerationStream: typeof tasks.textGenerationStream;
chatCompletionStream: typeof tasks.chatCompletionStream;

static mapInferenceFn<TOut, TArgs>(instance: HfInference, func: (...args: [TArgs, Options?]) => TOut) {
return function (...[args, options]: Parameters<(...args: [TArgs, Options?]) => TOut>): TOut {
return func({ ...args, accessToken: instance.accessToken }, { ...instance.defaultOptions, ...(options ?? {}) });
};
}
Comment on lines +59 to +63
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper functions adds some default arguments when calling HfInference.<task>() while keeping the exact type for inputs and outputs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So jealous about how Typescript's seems so much powerful than Python's typing system... 😒


constructor(accessToken = "", defaultOptions: Options = {}) {
this.accessToken = accessToken;
this.defaultOptions = defaultOptions;

for (const [name, fn] of Object.entries(tasks)) {
Object.defineProperty(this, name, {
enumerable: false,
value: (params: RequestArgs, options: Options) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fn({ ...params, accessToken } as any, { ...defaultOptions, ...options }),
});
}
this.speechToText = HfInference.mapInferenceFn(this, tasks.speechToText);
this.audioClassification = HfInference.mapInferenceFn(this, tasks.audioClassification);
this.automaticSpeechRecognition = HfInference.mapInferenceFn(this, tasks.automaticSpeechRecognition);
this.textToSpeech = HfInference.mapInferenceFn(this, tasks.textToSpeech);
this.audioToAudio = HfInference.mapInferenceFn(this, tasks.audioToAudio);
this.imageClassification = HfInference.mapInferenceFn(this, tasks.imageClassification);
this.imageSegmentation = HfInference.mapInferenceFn(this, tasks.imageSegmentation);
this.imageToText = HfInference.mapInferenceFn(this, tasks.imageToText);
this.objectDetection = HfInference.mapInferenceFn(this, tasks.objectDetection);
this.textToImage = HfInference.mapInferenceFn(this, tasks.textToImage);
this.imageToImage = HfInference.mapInferenceFn(this, tasks.imageToImage);
this.zeroShotImageClassification = HfInference.mapInferenceFn(this, tasks.zeroShotImageClassification);
this.featureExtraction = HfInference.mapInferenceFn(this, tasks.featureExtraction);
this.fillMask = HfInference.mapInferenceFn(this, tasks.fillMask);
this.questionAnswering = HfInference.mapInferenceFn(this, tasks.questionAnswering);
this.sentenceSimilarity = HfInference.mapInferenceFn(this, tasks.sentenceSimilarity);
this.summarization = HfInference.mapInferenceFn(this, tasks.summarization);
this.tableQuestionAnswering = HfInference.mapInferenceFn(this, tasks.tableQuestionAnswering);
this.textClassification = HfInference.mapInferenceFn(this, tasks.textClassification);
this.textGeneration = HfInference.mapInferenceFn(this, tasks.textGeneration);
this.tokenClassification = HfInference.mapInferenceFn(this, tasks.tokenClassification);
this.translation = HfInference.mapInferenceFn(this, tasks.translation);
this.zeroShotClassification = HfInference.mapInferenceFn(this, tasks.zeroShotClassification);
this.chatCompletion = HfInference.mapInferenceFn(this, tasks.chatCompletion);
this.documentQuestionAnswering = HfInference.mapInferenceFn(this, tasks.documentQuestionAnswering);
this.visualQuestionAnswering = HfInference.mapInferenceFn(this, tasks.visualQuestionAnswering);
this.tabularRegression = HfInference.mapInferenceFn(this, tasks.tabularRegression);
this.tabularClassification = HfInference.mapInferenceFn(this, tasks.tabularClassification);

/// Streaming methods
this.textGenerationStream = HfInference.mapInferenceFn(this, tasks.textGenerationStream);
this.chatCompletionStream = HfInference.mapInferenceFn(this, tasks.chatCompletionStream);
}

/**
4 changes: 4 additions & 0 deletions packages/inference/src/tasks/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { automaticSpeechRecognition } from "./audio/automaticSpeechRecognition";

// Custom tasks with arbitrary inputs and outputs
export * from "./custom/request";
export * from "./custom/streamingRequest";
@@ -40,3 +42,5 @@ export * from "./multimodal/visualQuestionAnswering";
// Tabular tasks
export * from "./tabular/tabularRegression";
export * from "./tabular/tabularClassification";

export const speechToText = automaticSpeechRecognition;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unrelated) aliasing automaticSpeechRecognition as speechToText

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did it exist before or that's new?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(why not i guess, though we should try to push the "official" task names otherwise it's hard to standardize one way of doing things)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(why not i guess, though we should try to push the "official" task names otherwise it's hard to standardize one way of doing things)

yeah, good point

24 changes: 12 additions & 12 deletions packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
@@ -589,18 +589,18 @@ describe.concurrent("HfInference", () => {
generated_text: "a large brown and white giraffe standing in a field ",
});
});
it("request - openai-community/gpt2", async () => {
expect(
await hf.request({
model: "openai-community/gpt2",
inputs: "one plus two equals",
})
).toMatchObject([
{
generated_text: expect.any(String),
},
]);
});
// it("request - openai-community/gpt2", async () => {
// expect(
// await hf.request({
// model: "openai-community/gpt2",
// inputs: "one plus two equals",
// })
// ).toMatchObject([
// {
// generated_text: expect.any(String),
// },
// ]);
// });
Comment on lines +592 to +603
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to find a way around request and streamingRequest... They use a generic type T which messes thing up I think


// Skipped at the moment because takes forever
it.skip("tabularRegression", async () => {