Skip to content

Commit

Permalink
fix type of DifferentiableFn
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyuheng committed Jun 6, 2024
1 parent c93c35a commit 94df3bf
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions src/system-a/gradient/gradient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {
scalarLink,
scalarTruncate,
tensorMap,
type Scalar,
type Tensor,
} from "../tensor/index.js"
import {
Expand All @@ -14,9 +13,7 @@ import {

// The effect of `gradient` on a `DifferentiableFn`
// is `sum` of all elements of it's result tensor.
export type DifferentiableFn =
| ((...args: Array<Tensor>) => Tensor)
| ((...args: Array<Scalar>) => Tensor)
export type DifferentiableFn = (...args: Array<Tensor>) => Tensor

export function gradient(fn: DifferentiableFn, args: Array<Tensor>): Tensor {
const wrt = tensorMap(args, scalarTruncate)
Expand Down

0 comments on commit 94df3bf

Please sign in to comment.