diff --git a/packages/inference/src/providers/replicate.ts b/packages/inference/src/providers/replicate.ts index 0bfeddfa8..cd55c2c3e 100644 --- a/packages/inference/src/providers/replicate.ts +++ b/packages/inference/src/providers/replicate.ts @@ -21,7 +21,8 @@ export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping = { "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", }, "text-to-speech": { - "OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26", + "OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:3c645149db020c85d080e2f8cfe482a0e68189a922cde964fa9e80fb179191f3", + "hexgrad/Kokoro-82M": "jaaari/kokoro-82m:dfdf537ba482b029e0a761699e6f55e9162cfd159270bfe0e44857caa5f275a6", }, "text-to-video": { "genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460", diff --git a/packages/inference/src/tasks/audio/textToSpeech.ts b/packages/inference/src/tasks/audio/textToSpeech.ts index d981ae7c8..153ed648f 100644 --- a/packages/inference/src/tasks/audio/textToSpeech.ts +++ b/packages/inference/src/tasks/audio/textToSpeech.ts @@ -1,8 +1,8 @@ import type { TextToSpeechInput } from "@huggingface/tasks"; import { InferenceOutputError } from "../../lib/InferenceOutputError"; import type { BaseArgs, Options } from "../../types"; +import { omit } from "../../utils/omit"; import { request } from "../custom/request"; - type TextToSpeechArgs = BaseArgs & TextToSpeechInput; interface OutputUrlTextToSpeechGeneration { @@ -13,7 +13,16 @@ interface OutputUrlTextToSpeechGeneration { * Recommended model: espnet/kan-bayashi_ljspeech_vits */ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise { - const res = await request(args, { + // Replicate models expects "text" instead of "inputs" + const payload = + args.provider === "replicate" + ? { + ...omit(args, ["inputs", "parameters"]), + ...args.parameters, + text: args.inputs, + } + : args; + const res = await request(payload, { ...options, taskHint: "text-to-speech", }); diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index 1f5e6b863..613786028 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -1,5 +1,4 @@ -import type { PipelineType } from "@huggingface/tasks"; -import type { ChatCompletionInput } from "@huggingface/tasks"; +import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks"; /** * HF model id, like "meta-llama/Llama-3.3-70B-Instruct" @@ -88,6 +87,7 @@ export type RequestArgs = BaseArgs & | { data: Blob | ArrayBuffer } | { inputs: unknown } | { prompt: string } + | { text: string } | { audio_url: string } | ChatCompletionInput ) & { diff --git a/packages/inference/test/HfInference.spec.ts b/packages/inference/test/HfInference.spec.ts index 04deb0be4..df2538286 100644 --- a/packages/inference/test/HfInference.spec.ts +++ b/packages/inference/test/HfInference.spec.ts @@ -1,11 +1,11 @@ -import { expect, it, describe, assert } from "vitest"; +import { assert, describe, expect, it } from "vitest"; import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; import { chatCompletion, FAL_AI_SUPPORTED_MODEL_IDS, HfInference } from "../src"; -import "./vcr"; -import { readTestFile } from "./test-files"; import { textToVideo } from "../src/tasks/cv/textToVideo"; +import { readTestFile } from "./test-files"; +import "./vcr"; const TIMEOUT = 60000 * 3; const env = import.meta.env; @@ -939,11 +939,21 @@ describe.concurrent("HfInference", () => { expect(res).toBeInstanceOf(Blob); }); - it("textToSpeech OuteTTS", async () => { + it.skip("textToSpeech OuteTTS - usually Cold", async () => { const res = await client.textToSpeech({ model: "OuteAI/OuteTTS-0.3-500M", provider: "replicate", - inputs: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters", + text: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters", + }); + + expect(res).toBeInstanceOf(Blob); + }); + + it("textToSpeech Kokoro", async () => { + const res = await client.textToSpeech({ + model: "hexgrad/Kokoro-82M", + provider: "replicate", + text: "Kokoro is a frontier TTS model for its size of 1 Billion parameters", }); expect(res).toBeInstanceOf(Blob);