Skip to content

Commit

Permalink
✨ Add support for custom fetch function in @huggingface/hub (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
coyotte508 authored May 30, 2023
1 parent 25ec1b8 commit 203fd4c
Show file tree
Hide file tree
Showing 18 changed files with 101 additions and 28 deletions.
18 changes: 11 additions & 7 deletions packages/hub/src/lib/commit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ export interface CommitParams {
parentCommit?: string;
isPullRequest?: boolean;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}

export interface CommitOutput {
Expand Down Expand Up @@ -113,7 +117,7 @@ async function* commitIter(params: CommitParams): AsyncGenerator<unknown, Commit
return { ...operation, content: operation.content };
}

const lazyBlob = await createBlob(operation.content);
const lazyBlob = await createBlob(operation.content, { fetch: params.fetch });

return {
...operation,
Expand All @@ -136,7 +140,7 @@ async function* commitIter(params: CommitParams): AsyncGenerator<unknown, Commit
),
};

const res = await fetch(
const res = await (params.fetch ?? fetch)(
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/preupload/${encodeURIComponent(
params.branch ?? "main"
)}` + (params.isPullRequest ? "?create_pr=1" : ""),
Expand Down Expand Up @@ -194,7 +198,7 @@ async function* commitIter(params: CommitParams): AsyncGenerator<unknown, Commit
})),
};

const res = await fetch(
const res = await (params.fetch ?? fetch)(
`${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : repoId.type + "s/"}${
repoId.name
}.git/info/lfs/objects/batch`,
Expand Down Expand Up @@ -263,7 +267,7 @@ async function* commitIter(params: CommitParams): AsyncGenerator<unknown, Commit
const index = parseInt(part) - 1;
const slice = content.slice(index * chunkSize, (index + 1) * chunkSize);

const res = await fetch(header[part], {
const res = await (params.fetch ?? fetch)(header[part], {
method: "PUT",
/** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */
body: slice instanceof WebBlob && isFrontend ? await slice.arrayBuffer() : slice,
Expand All @@ -289,7 +293,7 @@ async function* commitIter(params: CommitParams): AsyncGenerator<unknown, Commit
MULTIPART_PARALLEL_UPLOAD
);

const res = await fetch(completionUrl, {
const res = await (params.fetch ?? fetch)(completionUrl, {
method: "POST",
body: JSON.stringify(completeReq),
headers: {
Expand All @@ -305,7 +309,7 @@ async function* commitIter(params: CommitParams): AsyncGenerator<unknown, Commit
});
}
} else {
const res = await fetch(obj.actions.upload.href, {
const res = await (params.fetch ?? fetch)(obj.actions.upload.href, {
method: "PUT",
headers: {
...(batchRequestId ? { "X-Request-Id": batchRequestId } : undefined),
Expand All @@ -328,7 +332,7 @@ async function* commitIter(params: CommitParams): AsyncGenerator<unknown, Commit

yield "committing";

const res = await fetch(
const res = await (params.fetch ?? fetch)(
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commit/${encodeURIComponent(
params.branch ?? "main"
)}` + (params.isPullRequest ? "?create_pr=1" : ""),
Expand Down
6 changes: 5 additions & 1 deletion packages/hub/src/lib/create-repo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export async function createRepo(params: {
/** @required for when {@link repo.type} === "space" */
sdk?: SpaceSdk;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<{ repoUrl: string }> {
checkCredentials(params.credentials);
const repoId = toRepoId(params.repo);
Expand All @@ -29,7 +33,7 @@ export async function createRepo(params: {
);
}

const res = await fetch(`${params.hubUrl ?? HUB_URL}/api/repos/create`, {
const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/create`, {
method: "POST",
body: JSON.stringify({
name: repoName,
Expand Down
2 changes: 2 additions & 0 deletions packages/hub/src/lib/delete-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export function deleteFile(params: {
commitTitle?: CommitParams["title"];
commitDescription?: CommitParams["description"];
hubUrl?: CommitParams["hubUrl"];
fetch?: CommitParams["fetch"];
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
Expand All @@ -28,5 +29,6 @@ export function deleteFile(params: {
branch: params.branch,
isPullRequest: params.isPullRequest,
parentCommit: params.parentCommit,
fetch: params.fetch,
});
}
2 changes: 2 additions & 0 deletions packages/hub/src/lib/delete-files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export function deleteFiles(params: {
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
fetch?: CommitParams["fetch"];
}): Promise<CommitOutput> {
return commit({
credentials: params.credentials,
Expand All @@ -26,5 +27,6 @@ export function deleteFiles(params: {
branch: params.branch,
isPullRequest: params.isPullRequest,
parentCommit: params.parentCommit,
fetch: params.fetch,
});
}
6 changes: 5 additions & 1 deletion packages/hub/src/lib/delete-repo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ export async function deleteRepo(params: {
repo: RepoDesignation;
credentials: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<void> {
checkCredentials(params.credentials);
const repoId = toRepoId(params.repo);
const [namespace, repoName] = repoId.name.split("/");

const res = await fetch(`${params.hubUrl ?? HUB_URL}/api/repos/delete`, {
const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/delete`, {
method: "DELETE",
body: JSON.stringify({
name: repoName,
Expand Down
6 changes: 5 additions & 1 deletion packages/hub/src/lib/download-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ export async function downloadFile(params: {
range?: [number, number];
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<Response | null> {
checkCredentials(params.credentials);
const repoId = toRepoId(params.repo);
const url = `${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/${
params.raw ? "raw" : "resolve"
}/${encodeURIComponent(params.revision ?? "main")}/${params.path}`;

const resp = await fetch(url, {
const resp = await (params.fetch ?? fetch)(url, {
headers: {
...(params.credentials
? {
Expand Down
6 changes: 5 additions & 1 deletion packages/hub/src/lib/file-download-info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ export async function fileDownloadInfo(params: {
revision?: string;
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
/**
* To get the raw pointer file behind a LFS file
*/
Expand All @@ -42,7 +46,7 @@ export async function fileDownloadInfo(params: {
}/${encodeURIComponent(params.revision ?? "main")}/${params.path}` +
(params.noContentDisposition ? "?noContentDisposition=1" : "");

const resp = await fetch(url, {
const resp = await (params.fetch ?? fetch)(url, {
method: "HEAD",
headers: params.credentials
? {
Expand Down
6 changes: 5 additions & 1 deletion packages/hub/src/lib/file-exists.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ export async function fileExists(params: {
revision?: string;
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<boolean> {
const info = await fileDownloadInfo({ ...params, raw: true });
const info = await fileDownloadInfo({ ...params, raw: true, fetch: params.fetch });
// ^use raw to not redirect and save some time for LFS files
return !!info;
}
6 changes: 5 additions & 1 deletion packages/hub/src/lib/list-datasets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ export async function* listDatasets(params?: {
};
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): AsyncGenerator<DatasetEntry> {
checkCredentials(params?.credentials);
const search = new URLSearchParams([
Expand All @@ -35,7 +39,7 @@ export async function* listDatasets(params?: {
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : "");

while (url) {
const res: Response = await fetch(url, {
const res: Response = await (params?.fetch ?? fetch)(url, {
headers: {
accept: "application/json",
...(params?.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : undefined),
Expand Down
6 changes: 5 additions & 1 deletion packages/hub/src/lib/list-files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ export async function* listFiles(params: {
revision?: string;
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): AsyncGenerator<ListFileEntry> {
checkCredentials(params.credentials);
const repoId = toRepoId(params.repo);
Expand All @@ -61,7 +65,7 @@ export async function* listFiles(params: {
}${params.path ? "/" + params.path : ""}?recursive=${!!params.recursive}&expand=${!!params.expand}`;

while (url) {
const res: Response = await fetch(url, {
const res: Response = await (params.fetch ?? fetch)(url, {
headers: {
accept: "application/json",
...(params.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : undefined),
Expand Down
6 changes: 5 additions & 1 deletion packages/hub/src/lib/list-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ export async function* listModels(params?: {
};
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): AsyncGenerator<ModelEntry> {
checkCredentials(params?.credentials);
const search = new URLSearchParams([
Expand All @@ -38,7 +42,7 @@ export async function* listModels(params?: {
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;

while (url) {
const res: Response = await fetch(url, {
const res: Response = await (params?.fetch ?? fetch)(url, {
headers: {
accept: "application/json",
...(params?.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : undefined),
Expand Down
6 changes: 5 additions & 1 deletion packages/hub/src/lib/list-spaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ export async function* listSpaces(params?: {
};
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): AsyncGenerator<SpaceEntry> {
checkCredentials(params?.credentials);
const search = new URLSearchParams([
Expand All @@ -31,7 +35,7 @@ export async function* listSpaces(params?: {
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`;

while (url) {
const res: Response = await fetch(url, {
const res: Response = await (params?.fetch ?? fetch)(url, {
headers: {
accept: "application/json",
...(params?.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : undefined),
Expand Down
12 changes: 12 additions & 0 deletions packages/hub/src/lib/parse-safetensors-metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ async function parseSingleFile(
revision?: string;
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}
): Promise<SafetensorsFileHeader> {
const firstResp = await downloadFile({
Expand Down Expand Up @@ -90,6 +94,10 @@ async function parseShardedIndex(
revision?: string;
credentials?: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}
): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> {
const indexResp = await downloadFile({
Expand Down Expand Up @@ -152,6 +160,10 @@ export async function parseSafetensorsMetadata(params: {
hubUrl?: string;
credentials?: Credentials;
revision?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<SafetensorsParseFromRepo> {
checkCredentials(params.credentials);
const repoId = toRepoId(params.repo);
Expand Down
2 changes: 2 additions & 0 deletions packages/hub/src/lib/upload-file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export function uploadFile(params: {
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
fetch?: CommitParams["fetch"];
}): Promise<CommitOutput> {
const path =
params.file instanceof URL
Expand All @@ -36,5 +37,6 @@ export function uploadFile(params: {
branch: params.branch,
isPullRequest: params.isPullRequest,
parentCommit: params.parentCommit,
fetch: params.fetch,
});
}
2 changes: 2 additions & 0 deletions packages/hub/src/lib/upload-files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export function uploadFiles(params: {
branch?: CommitParams["branch"];
isPullRequest?: CommitParams["isPullRequest"];
parentCommit?: CommitParams["parentCommit"];
fetch?: CommitParams["fetch"];
}): Promise<CommitOutput> {
return commit({
credentials: params.credentials,
Expand All @@ -27,5 +28,6 @@ export function uploadFiles(params: {
branch: params.branch,
isPullRequest: params.isPullRequest,
parentCommit: params.parentCommit,
fetch: params.fetch,
});
}
6 changes: 5 additions & 1 deletion packages/hub/src/lib/who-am-i.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,14 @@ export interface AuthInfo {
export async function whoAmI(params: {
credentials: Credentials;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<WhoAmI & { auth: AuthInfo }> {
checkCredentials(params.credentials);

const res = await fetch(`${params.hubUrl ?? HUB_URL}/api/whoami-v2`, {
const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/whoami-v2`, {
headers: {
Authorization: `Bearer ${params.credentials.accessToken}`,
},
Expand Down
Loading

0 comments on commit 203fd4c

Please sign in to comment.