Skip to content

Commit

Permalink
[system-a] inline GradientDescentOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyuheng committed May 16, 2024
1 parent 82a745f commit 9967949
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
2 changes: 2 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# system-a

[system-a] `samplingObjective`

# the-book

6: An Apple a Day
Expand Down
15 changes: 8 additions & 7 deletions src/system-a/gradientDescent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ import {
} from "./index.js"
import { mul, sub } from "./toys/index.js"

export type GradientDescentOptions = {
revs: number
learningRate: number
}

export function gradientDescent(
objective: (...ps: Array<Tensor>) => Scalar,
ps: Array<Tensor>,
options: GradientDescentOptions,
options: {
revs: number
learningRate: number
},
): Array<Tensor> {
const step = gradientDescentStep(objective, options)
const rs = revise(step, options.revs, ps)
Expand All @@ -27,7 +25,10 @@ export function gradientDescent(

export function gradientDescentStep(
objective: (...ps: Array<Tensor>) => Scalar,
options: GradientDescentOptions,
options: {
revs: number
learningRate: number
},
): (ps: Array<Tensor>) => Array<Tensor> {
return function step(ps: Array<Tensor>): Array<Tensor> {
const gs = gradient(objective, ps)
Expand Down
17 changes: 17 additions & 0 deletions src/system-a/samplingObjective.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { assertTensorArray, type Tensor } from "./Tensor.js"
import type { Expectant, Objective } from "./loss.js"

export function samplingObjective(
expectant: Expectant,
xs: Tensor,
ys: Tensor,
options: {
batchSize: number
},
): Objective {
assertTensorArray(xs)
const size = xs.length
return (...ps) => {
throw new Error()
}
}

0 comments on commit 9967949

Please sign in to comment.