Skip to content

Commit

Permalink
more work on MNIST classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 4, 2023
1 parent 438a8f1 commit dcb67de
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 71 deletions.
166 changes: 95 additions & 71 deletions examples/MNISTClassifier.lean
Original file line number Diff line number Diff line change
@@ -1,53 +1,93 @@
import Lean
import SciLean
import Lean
-- import SciLean
import SciLean.Data.DataArray
import SciLean.Core.FloatAsReal
import SciLean.Modules.ML.Dense
import SciLean.Modules.ML.Convolution
import SciLean.Modules.ML.Pool

open SciLean
open IO FS System

open NotationOverField
set_default_scalar Float
set_option synthInstance.maxSize 2000

open IO FS System

def checkFileExists (path : FilePath) : IO Unit := do

if ¬(← path.pathExists) then
throw (IO.userError s!"MNIST data file '{path}' not found. Please download binary version from https://git-disl.github.io/GTDLBench/datasets/mnist_datasets/ and extract it in 'data' directory")

open IO FS System in
def loadData : IO (Array ByteArray) := do

let trainImages : FilePath := "data/train-images.idx3-ubyte"
let trainLabels : FilePath := "data/train-labels.idx1-ubyte"
let testImages : FilePath := "data/t10k-images.idx3-ubyte"
let testLabels : FilePath := "data/t10k-labels.idx1-ubyte"

checkFileExists trainImages
checkFileExists trainLabels
checkFileExists testImages
checkFileExists testLabels

IO.FS.withFile trainImages .read fun m => do

let mut data : Array ByteArray := #[]

-- there seems to be extra 14 bytes at the begginning
-- there are four uint64, magic number, number of images, x dimension, y dimension
let _header ← m.read 16

for _ in [0:1000000] do
let n : Nat := 28
let nums ← m.read (n*n).toUSize
if nums.size = 0 then
break

-- byte data to floats
-- let mut d : Float^[28,28] := 0
-- for i in fullRange (Idx 28 × Idx 28) do
-- let li := (toIdx i).1
-- d[i] := nums[li]!.toNat.toFloat / 256.0

data := data.push nums

return data

def toFloatRepr (b : ByteArray) (ι : Type _) [Index ι] (_ : b.size = sizeOf ι) : Float^ι := Id.run do
let mut idx : USize := 0
let mut x : Float^ι := 0
for i in fullRange ι do
let val := b.uget idx sorry_proof
x[i] := val.toNat.toFloat / 256.0
idx := idx + 1
x

open IO FS System in
def loadImages (path : FilePath) (maxImages : Nat) : IO (Array (Float^[28,28])) := do

checkFileExists path

if maxImages = 0 then
return #[]

let start ← IO.monoMsNow
IO.print s!"loading images from {path} ... "
let data ←
IO.FS.withFile path .read fun m => do
let _header ← m.read 16 -- discart 16 byte header
let mut data : Array ByteArray := #[]
for _ in [0:maxImages] do
let n : Nat := 28
let nums ← m.read (n*n).toUSize
if nums.size = 0 then
break
data := data.push nums
pure data

if data.size ≠ maxImages then
throw <| IO.userError s!"file {path} contains only {data.size} images"

IO.println s!"loaded in {(← IO.monoMsNow) - start}ms"

let start ← IO.monoMsNow
IO.print "converting to float format ... "
let data := data.map (toFloatRepr · (Idx 28 × Idx 28) sorry_proof)
IO.println s!"converted in {(← IO.monoMsNow) - start}ms"

return data


def loadLabels (path : FilePath) (maxLabels : Nat) : IO (Array (Float^[10])) := do
checkFileExists path

if maxLabels = 0 then
return #[]

let start ← IO.monoMsNow
IO.print s!"loading labels from {path} ... "
let data ← IO.FS.withFile path .read fun m => do
let _header ← m.read 8 -- discart 8 byte header
m.read maxLabels.toUSize
if data.size ≠ maxLabels then
throw <| IO.userError s!"file {path} contains only {data.size} labels"
IO.println s!"loaded in {(← IO.monoMsNow) - start}ms"

let start ← IO.monoMsNow
IO.print "converting to float format ... "
let mut labels : Array (Float^[10]) := .mkEmpty data.size
for b in data do
let i : Idx 10 := ⟨b.toUSize, sorry_proof⟩
labels := labels.push (oneHot i 1)
IO.println s!"converted in {(← IO.monoMsNow) - start}ms"

return labels


def printDigit (digit : Float^[28,28]) : IO Unit := do

Expand All @@ -68,41 +108,25 @@ def printDigit (digit : Float^[28,28]) : IO Unit := do

IO.println "|"

def printDigit' (digit : ByteArray) : IO Unit := do
let mut idx := 0
for i in fullRange (Idx 28) do
IO.print "|"
for j in fullRange (Idx 28) do
let val := digit[idx]!
if (val > 200) then
IO.print "#"
else if (val > 150) then
IO.print "$"
else if (val > 50) then
IO.print "o"
else if (val > 1) then
IO.print "."
else
IO.print " "

idx := idx + 1

IO.println "|"



def main : IO Unit := do

IO.print "loading data ... "
let dataloadData
IO.println "data loaded"
let trainNum := 1000
let trainImagesloadImages "data/train-images.idx3-ubyte" trainNum
let trainLabels ← loadLabels "data/train-labels.idx1-ubyte" trainNum

IO.println ""
IO.println s!"number of images: {data.size}"
let testNum := 0
let testImages ← loadImages "data/t10k-images.idx3-ubyte" testNum
let testLabels ← loadLabels "data/t10k-labels.idx1-ubyte" testNum

IO.println ""

IO.println s!"label: {trainLabels[2]!}"
IO.println "+----------------------------+"
printDigit' data[400]!
IO.println "+----------------------------+"
printDigit' data[600]!
printDigit trainImages[2]!
IO.println "+----------------------------+"


for img in trainImages do
printDigit img
IO.println "+----------------------------+"
IO.sleep 100
103 changes: 103 additions & 0 deletions examples/MNISTClassifier/DataUtil.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import Lean
-- import SciLean
import SciLean.Data.DataArray
import SciLean.Core.FloatAsReal

open SciLean
open IO FS System

def checkFileExists (path : FilePath) : IO Unit := do

if ¬(← path.pathExists) then
throw (IO.userError s!"MNIST data file '{path}' not found. Please download binary version from https://git-disl.github.io/GTDLBench/datasets/mnist_datasets/ and extract it in 'data' directory")

def toFloatRepr (b : ByteArray) (ι : Type _) [Index ι] (_ : b.size = sizeOf ι) : Float^ι := Id.run do
let mut idx : USize := 0
let mut x : Float^ι := 0
for i in fullRange ι do
let val := b.uget idx sorry_proof
x[i] := val.toNat.toFloat / 256.0
idx := idx + 1
x

open IO FS System in
def loadImages (path : FilePath) (maxImages : Nat) : IO (Array (Float^[28,28])) := do

checkFileExists path

if maxImages = 0 then
return #[]

let start ← IO.monoMsNow
IO.print s!"loading images from {path} ... "
let data ←
IO.FS.withFile path .read fun m => do
let _header ← m.read 16 -- discart 16 byte header
let mut data : Array ByteArray := #[]
for _ in [0:maxImages] do
let n : Nat := 28
let nums ← m.read (n*n).toUSize
if nums.size = 0 then
break
data := data.push nums
pure data

if data.size ≠ maxImages then
throw <| IO.userError s!"file {path} contains only {data.size} images"

IO.println s!"loaded in {(← IO.monoMsNow) - start}ms"

let start ← IO.monoMsNow
IO.print "converting to float format ... "
let data := data.map (toFloatRepr · (Idx 28 × Idx 28) sorry_proof)
IO.println s!"converted in {(← IO.monoMsNow) - start}ms"

return data


def loadLabels (path : FilePath) (maxLabels : Nat) : IO (Array (Float^[10])) := do
checkFileExists path

if maxLabels = 0 then
return #[]

let start ← IO.monoMsNow
IO.print s!"loading labels from {path} ... "
let data ← IO.FS.withFile path .read fun m => do
let _header ← m.read 8 -- discart 8 byte header
m.read maxLabels.toUSize
if data.size ≠ maxLabels then
throw <| IO.userError s!"file {path} contains only {data.size} labels"
IO.println s!"loaded in {(← IO.monoMsNow) - start}ms"

let start ← IO.monoMsNow
IO.print "converting to float format ... "
let mut labels : Array (Float^[10]) := .mkEmpty data.size
for b in data do
let i : Idx 10 := ⟨b.toUSize, sorry_proof⟩
labels := labels.push (oneHot i 1)
IO.println s!"converted in {(← IO.monoMsNow) - start}ms"

return labels


def printDigit (digit : Float^[28,28]) : IO Unit := do

for i in fullRange (Idx 28) do
IO.print "|"
for j in fullRange (Idx 28) do
let val := digit[(i,j)]
if (val > 0.8) then
IO.print "#"
else if (val > 0.6) then
IO.print "$"
else if (val > 0.4) then
IO.print "o"
else if (val > 0.1) then
IO.print "."
else
IO.print " "

IO.println "|"


23 changes: 23 additions & 0 deletions examples/MNISTClassifier/Main.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import examples.MNISTClassifier.DataUtil
import examples.MNISTClassifier.Model

open SciLean


def main : IO Unit := do

let trainNum := 1000
let trainImages ← loadImages "data/train-images.idx3-ubyte" trainNum
let trainLabels ← loadLabels "data/train-labels.idx1-ubyte" trainNum

let testNum := 0
let testImages ← loadImages "data/t10k-images.idx3-ubyte" testNum
let testLabels ← loadLabels "data/t10k-labels.idx1-ubyte" testNum

IO.println ""

IO.println s!"label: {trainLabels[2]!}"
IO.println "+----------------------------+"
printDigit trainImages[2]!
IO.println "+----------------------------+"

33 changes: 33 additions & 0 deletions examples/MNISTClassifier/Model.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import Lean
-- import SciLean
import SciLean.Core.FloatAsReal
import SciLean.Modules.ML.Dense
import SciLean.Modules.ML.Convolution
import SciLean.Modules.ML.Pool

open SciLean
open IO FS System

open NotationOverField
set_default_scalar Float
set_option synthInstance.maxSize 2000

open ML ArrayType in
def model (w x) :=
(fun ((w₁,b₁),(w₂,b₂),(w₃,b₃)) (x : Float^[1,28,28]) =>
x |> conv2d 32 1 w₁ b₁
|> map (fun x => x^2)
|> avgPool
|> dense 100 w₂ b₂
|> map (fun x => x^2)
|> dense 10 w₃ b₃) w x
-- |> softMax


#generate_revDeriv model w x
prop_by unfold model; simp[model.match_1]; fprop
trans_by unfold model; simp[model.match_1]; ftrans


def batchLoss (w) (images : Float^[1,28,28]^[batchSize]) (labels : Float^[10]^[batchSize]) :=
∑ i, ‖(model w images[i] - labels[i])‖₂

0 comments on commit dcb67de

Please sign in to comment.