Skip to content

Commit

Permalink
[system-a] line -- test samplingObjective
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyuheng committed May 16, 2024
1 parent 70dffb6 commit 66d8b39
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
3 changes: 2 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# system-a

[system-a] using `samplingObjective` in target tests
[system-a] `quad` -- test `samplingObjective`
[system-a] `plane` -- test `samplingObjective`

# the-book

Expand Down
17 changes: 17 additions & 0 deletions src/system-a/targets/line.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { test } from "node:test"
import { tensorAlmostEqual, tensorReal } from "../Tensor.js"
import { gradientDescent } from "../gradientDescent.js"
import { l2Loss } from "../loss.js"
import { samplingObjective } from "../samplingObjective.js"
import { line } from "./line.js"

test("line", () => {
Expand All @@ -26,3 +27,19 @@ test("line -- gradientDescent", () => {
assert(tensorAlmostEqual(rs, [1, 0], 10e-1))
assert(tensorAlmostEqual(rs, [1.05, 0], 10e-6))
})

test("line -- gradientDescent + samplingObjective ", () => {
const xs = [2, 1, 4, 3]
const ys = [1.8, 1.2, 4.2, 3.3]

const objective = samplingObjective(l2Loss(line), xs, ys, {
batchSize: 4,
})

const rs = gradientDescent(objective, [0, 0], {
revs: 1000,
learningRate: 0.01,
})

assert(tensorAlmostEqual(rs, [1, 0], 10e-1))
})

0 comments on commit 66d8b39

Please sign in to comment.