Skip to content

Commit

Permalink
[system-a] plane -- test samplingObjective
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyuheng committed May 16, 2024
1 parent dc56566 commit 5bcef6c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# system-a

[system-a] `quad` -- test `samplingObjective`
[system-a] `plane` -- test `samplingObjective`
[system-a] `assertTensorAlmostEqual`

# the-book

Expand Down
8 changes: 8 additions & 0 deletions src/system-a/Tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,11 @@ export function tensorAlmostEqual(
): boolean {
return tensorEvery(sub(x, y), (x) => Math.abs(scalarReal(x)) <= epsilon)
}

// export function assertTensorAlmostEqual(
// x: Tensor,
// y: Tensor,
// epsilon: number,
// ): boolean {
// return tensorEvery(sub(x, y), (x) => Math.abs(scalarReal(x)) <= epsilon)
// }
24 changes: 24 additions & 0 deletions src/system-a/targets/plane.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { test } from "node:test"
import { tensorAlmostEqual, tensorReal } from "../Tensor.js"
import { gradientDescent } from "../gradientDescent.js"
import { l2Loss } from "../loss.js"
import { samplingObjective } from "../samplingObjective.js"
import { plane } from "./plane.js"

test("plane", () => {
Expand Down Expand Up @@ -43,3 +44,26 @@ test("plane -- gradientDescent", () => {

assert(tensorAlmostEqual(plane([2, 3.91])([3.98, 2.04], 5.78), 22.4, 1))
})

test("plane -- gradientDescent & samplingObjective", () => {
const xs = [
[1, 2.05],
[1, 3],
[2, 2],
[2, 3.91],
[3, 6.13],
[4, 8.09],
]
const ys = [13.99, 15.99, 18, 22.4, 30.2, 37.94]

const objective = samplingObjective(l2Loss(plane), xs, ys, {
batchSize: 4,
})

const rs = gradientDescent(objective, [[0, 0], 0], {
revs: 15000,
learningRate: 0.001,
})

assert(tensorAlmostEqual(rs, [[3.98, 2.04], 5.78], 0.5))
})

0 comments on commit 5bcef6c

Please sign in to comment.