Skip to content

Commit

Permalink
work on MNIST classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 6, 2023
1 parent 3010c78 commit a76f302
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 50 deletions.
24 changes: 15 additions & 9 deletions SciLean/Data/DataArray/DataArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,36 @@ def DataArray.reverse (arr : DataArray α) : DataArray α := Id.run do
arr := arr.swap i' j'
arr


@[irreducible]
def DataArray.intro (f : ι → α) : DataArray α := Id.run do
let bytes := (pd.bytes (Index.size ι))
let mut d : ByteArray := ByteArray.mkArray bytes.toNat 0
let d : ByteArray := ByteArray.mkArray bytes.toNat 0
let mut d' : DataArray α := ⟨d, (Index.size ι), sorry_proof⟩
let mut li : USize := 0
for i in fullRange ι do
d' := d'.set ⟨li, sorry_proof⟩ (f i)
li := li + 1
d' := d'.set ⟨(toIdx i).1,sorry_proof⟩ (f i)
d'

-- let d' : DataArray α := ⟨d, (Index.size ι), sorry_proof⟩
-- let rec @[specialize] go : Nat → DataArray α → DataArray α
-- | 0, d => d
-- | n+1, d =>
-- go n (d.set ⟨n.toUSize, sorry_proof⟩ (f (fromIdx ⟨n.toUSize, sorry_proof⟩)))
-- go (Index.size ι).toNat d'

structure DataArrayN (α : Type) [pd : PlainDataType α] (ι : Type) [Index ι] where
data : DataArray α
h_size : Index.size ι = data.size

@[irreducible]
@[inline]
instance : GetElem (DataArrayN α ι) ι α (λ _ _ => True) where
getElem xs i _ := xs.1.get (xs.2toIdx i)
getElem xs i _ := xs.1.get ((toIdx i).cast xs.2)

@[irreducible]
@[inline]
instance : SetElem (DataArrayN α ι) ι α where
setElem xs i xi := ⟨xs.1.set (xs.2toIdx i) xi, sorry_proof⟩
setElem xs i xi := ⟨xs.1.set ((toIdx i).cast xs.2) xi, sorry_proof⟩

@[irreducible]
@[inline]
instance : IntroElem (DataArrayN α ι) ι α where
introElem f := ⟨DataArray.intro f, sorry_proof⟩

Expand Down
6 changes: 6 additions & 0 deletions SciLean/Data/Idx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def toFin' {n : Nat} (i : Idx n.toUSize) : Fin n := ⟨i.1.toNat, sorry_proof⟩
def _root_.USize.toFloat (n : USize) : Float := n.toNat.toFloat
def toFloat {n} (i : Idx n) : Float := i.1.toFloat

@[macro_inline]
def cast (i : Idx n) (h : n = m) : Idx m := ⟨i.1, by rw[← h]; apply i.2

@[macro_inline]
def cast' (i : Idx n) (h : m = n) : Idx m := ⟨i.1, by rw[h]; apply i.2

def shiftPos (x : Idx n) (s : USize) := x + s
def shiftNeg (x : Idx n) (s : USize) := x - s
def shift (x : Idx n) (s : Int) :=
Expand Down
26 changes: 22 additions & 4 deletions SciLean/Modules/ML/MNIST.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,42 @@ import SciLean.Modules.ML.Convolution
import SciLean.Modules.ML.Pool
import SciLean.Modules.ML.Activation
import SciLean.Modules.ML.SoftMax
import SciLean.Data.Random

set_option synthInstance.maxSize 2000

namespace SciLean.ML

open ArrayType

instance : Inhabited (Idx 10) := ⟨0

def mnist (w x) :=
(fun ((w₁,b₁),(w₂,b₂),(w₃,b₃)) (x : Float^[1,28,28]) =>
x |> conv2d 32 1 w₁ b₁
x |> conv2d 8 1 w₁ b₁
|> map gelu
|> avgPool
|> dense 100 w₂ b₂
|> dense 30 w₂ b₂
|> map gelu
|> dense 10 w₃ b₃
|> softMax 1) w x
|> softMax 0.1) w x

#generate_revDeriv mnist w x
#generate_revDeriv mnist w
prop_by unfold mnist; simp[mnist.match_1]; fprop
trans_by unfold mnist; simp[mnist.match_1]; ftrans


abbrev weightsType (_f : α → β → γ) := α
abbrev inputType (_f : α → β → γ) := β
abbrev outputType (_f : α → β → γ) := γ

def mnist.initWeights := Random.rand (weightsType ML.mnist) |> IO.runRand


#eval 0


set_option trace.Meta.Tactic.simp.rewrite true in
#check (fun x => (revDerivUpdate Float fun w => mnist w x))
rewrite_by
unfold mnist; ftrans
36 changes: 21 additions & 15 deletions SciLean/Modules/ML/Pool.lean
Original file line number Diff line number Diff line change
@@ -1,38 +1,44 @@
import SciLean.Core
import SciLean.Core.Functions.Exp
import SciLean.Data.DataArray
import SciLean.Data.ArrayType
import SciLean.Data.Prod
import Mathlib

namespace SciLean.ML

set_option synthInstance.maxSize 200000

variable
{R : Type} [RealScalar R] [PlainDataType R]

set_default_scalar R


def _root_.SciLean.Idx.prodMerge (i : Idx (n/m)) (j : Idx m) : Idx n := ⟨i.1 * m + j.1, sorry_proof⟩
def _root_.SciLean.Idx.prodMerge' (i : Idx (n/m)) (j : Idx m) : Idx n := ⟨i.1 + j.1 * (n/m), sorry_proof⟩


def avgPool {n m : USize} {ι} [Index ι] (x : R^[ι,n,m]) : R^[ι,n/2,m/2] :=
⊞ (k,i,j) => (1/4 : R) * ∑ i' j' : Idx 2, x[(k,2*i.1+i'.1, sorry_proof⟩, ⟨2*j.1+j'.1,sorry_proof⟩)]
⊞ (k,i,j) => (1/4 : R) * ∑ i' j' : Idx 2, x[(k,i.prodMerge i', j.prodMerge j')]

#generate_revDeriv avgPool x
prop_by unfold avgPool; fprop
trans_by unfold avgPool; ftrans


-- TODO: needs exp and division working
-- def softMaxPool {n m : USize} {ι} [Index ι] (scale : R) (x : R^[ι,n,m]) : R^[ι,n/2,m/2] :=
-- introElem fun (k,i,j) =>
-- let a := x[(k,⟨2*i.1 ,sorry_proof⟩, ⟨2*j.1 ,sorry_proof⟩)]
-- let b := x[(k,⟨2*i.1 ,sorry_proof⟩, ⟨2*j.1+1,sorry_proof⟩)]
-- let c := x[(k,⟨2*i.1+1,sorry_proof⟩, ⟨2*j.1+1,sorry_proof⟩)]
-- let d := x[(k,⟨2*i.1+1,sorry_proof⟩, ⟨2*j.1+1,sorry_proof⟩)]
-- let ea := Scalar.exp (scale*a)
-- let eb := Scalar.exp (scale*b)
-- let ec := Scalar.exp (scale*c)
-- let ed := Scalar.exp (scale*d)
-- have : (ea + eb + ec + ed) ≠ 0 := sorry_proof
-- let w := 1 / (ea + eb + ec + ed)
-- (a*ea+b*eb+c*ec+d*ed) * w
def softMaxPool {n m : USize} {ι} [Index ι] (scale : R) (x : R^[ι,n,m]) : R^[ι,n/2,m/2] :=
introElem fun (k,i,j) =>
let ex := ⊞ (ij : Idx 2 × Idx 2) => Scalar.exp (scale*x[(k, i.prodMerge ij.1, j.prodMerge ij.2)])
let w := ∑ i', ex[i']
let w' := 1/w
∑ i' j' : Idx 2,
let xi := x[(k, i.prodMerge i', j.prodMerge j')]
let exi := ex[(i',j')]
xi * (w'*exi)


#generate_revDeriv softMaxPool x
prop_by unfold softMaxPool; fprop
trans_by unfold softMaxPool; ftrans

12 changes: 7 additions & 5 deletions SciLean/Modules/ML/SoftMax.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ import Mathlib
namespace SciLean.ML

variable
{R : Type} [RealScalar R] [PlainDataType R]
{R : Type} [RealScalar R] [PlainDataType R] [LT R] [∀ x y : R, Decidable (x < y)]

set_default_scalar R

def softMax
{ι} [Index ι] (r : R) (x : R^ι) : R^ι :=
let x := ArrayType.map (fun xi => Scalar.exp (r*xi)) x
def softMax [RealScalar R]
{ι} [Index ι] [Inhabited ι] (r : R) (x : R^ι) : R^ι :=
let m := ArrayType.max x
let x := ArrayType.map (fun xi => Scalar.exp (r*(xi-m))) x
let w := ∑ i, x[i]
(1/w) • x

#generate_revDeriv softMax x
prop_by unfold softMax; fprop
trans_by unfold softMax; ftrans
trans_by
unfold softMax; ftrans

111 changes: 94 additions & 17 deletions examples/MNISTClassifier.lean
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
import Lean
-- import SciLean
import SciLean.Data.ArrayType.Algebra
import SciLean.Data.DataArray
import SciLean.Core.FloatAsReal
import SciLean.Data.Random

import SciLean.Modules.ML.Dense
import SciLean.Modules.ML.Convolution
import SciLean.Modules.ML.Pool
import SciLean.Modules.ML.MNIST

import SciLean.Util.Profile


open SciLean
open IO FS System

open NotationOverField
set_default_scalar Float
set_option synthInstance.maxSize 2000

@[noinline] def blackbox : α → IO α := pure

def checkFileExists (path : FilePath) : IO Unit := do

if ¬(← path.pathExists) then
throw (IO.userError s!"MNIST data file '{path}' not found. Please download binary version from https://git-disl.github.io/GTDLBench/datasets/mnist_datasets/ and extract it in 'data' directory")

def toFloatRepr (b : ByteArray) (ι : Type _) [Index ι] (_ : b.size = sizeOf ι) : Float^ι := Id.run do
let mut idx : USize := 0
let mut x : Float^ι := 0
for i in fullRange ι do
let val := b.uget idx sorry_proof
x[i] := val.toNat.toFloat / 256.0
idx := idx + 1
x
def toFloatRepr (b : ByteArray) (ι : Type _) [Index ι] (_ : b.size = sizeOf ι) : Float^ι :=
⊞ i =>
let val := b.uget (toIdx i).1 sorry_proof
val.toUSize.toFloat / 256.0

open IO FS System in
def loadImages (path : FilePath) (maxImages : Nat) : IO (Array (Float^[28,28])) := do
Expand Down Expand Up @@ -109,9 +109,27 @@ def printDigit (digit : Float^[28,28]) : IO Unit := do
IO.println "|"


def prependDim (x : Float^[28,28]) : Float^[1,28,28] := ⟨x.1,sorry_proof⟩

def getBatch (i : Nat) (batchSize : USize) (images : Array (Float^[28,28])) (labels : Array (Float^[10]))
: Float^[1,28,28]^[batchSize] × Float^[10]^[batchSize] := Id.run do

let mut x : Float^[1,28,28]^[batchSize] := 0
let mut y : Float^[10]^[batchSize] := 0
for j in fullRange (Idx batchSize) do
let idx := i*batchSize.toNat + j.1.toNat
x[j] := prependDim images[idx]!
y[j] := labels[idx]!

return (x,y)


def _root_.IO.getRand (α : Type) [Random α] : BaseIO α := Random.rand α |> IO.runRand


def main : IO Unit := do

let trainNum := 1000
let trainNum := 100
let trainImages ← loadImages "data/train-images.idx3-ubyte" trainNum
let trainLabels ← loadLabels "data/train-labels.idx1-ubyte" trainNum

Expand All @@ -125,8 +143,67 @@ def main : IO Unit := do
IO.println "+----------------------------+"
printDigit trainImages[2]!
IO.println "+----------------------------+"


-- let start ← IO.monoMsNow
-- IO.print "generating initial random weights ... "
-- let w ← ML.mnist.initWeights
-- IO.println s!"took {(← IO.monoMsNow) - start}ms"

let (images,labels) := getBatch 0 5 trainImages trainLabels

for img in trainImages do
printDigit img
IO.println "+----------------------------+"
IO.sleep 100
-- let start ← IO.monoMsNow
-- IO.print "evaluating network ... "
-- let l := ML.mnist w (← blackbox images[1])
-- IO.println s!"took {(← IO.monoMsNow) - start}ms"
-- IO.println l


-- let l := images[1]
-- |> ML.conv2d 8 1 w.1.1 w.1.2
-- |> ArrayType.map ML.gelu
-- |> ML.avgPool
-- |> ML.dense 30 w.2.1.1 w.2.1.2
-- |> ArrayType.map ML.gelu
-- |> ML.dense 10 w.2.2.1 w.2.2.2

-- IO.println l


-- let start ← IO.monoMsNow
-- IO.print "evaluating network gradient... "
-- let dw := (ML.mnist.arg_wx.revDeriv w (← blackbox images[0])).2 (← IO.getRand _)
-- IO.println s!"took {(← IO.monoMsNow) - start}ms"
-- IO.println dw.1.2.2.2



let start ← IO.monoMsNow
IO.print "evaluating conv2d gradient... "
let dweightsbiasx := (ML.conv2d.arg_weightsbiasx.revDeriv 2 1 (← IO.getRand _) (← IO.getRand _) images[0]).2 (← IO.getRand _)
IO.println s!"took {(← IO.monoMsNow) - start}ms"
IO.println dweightsbiasx.1


let start ← IO.monoMsNow
IO.print "evaluating avgPool gradient... "
let dimg := (ML.avgPool.arg_x.revDeriv images[0]).2 (← IO.getRand _)
IO.println s!"took {(← IO.monoMsNow) - start}ms"
IO.println dimg[0]


let start ← IO.monoMsNow
IO.print "evaluating map gelu gradient... "
let dimg := ((revDeriv Float (fun x => ArrayType.map ML.gelu x) images[0]).2 (← IO.getRand _)) rewrite_by ftrans
IO.println s!"took {(← IO.monoMsNow) - start}ms"
IO.println dimg[0]


let start ← IO.monoMsNow
IO.print "evaluating dense gradient... "
let dimg := (ML.dense.arg_weightsbiasx.revDeriv 10 (← IO.getRand _) (← IO.getRand _) labels[0]).2 (← IO.getRand _)
IO.println s!"took {(← IO.monoMsNow) - start}ms"
IO.println dimg.1



0 comments on commit a76f302

Please sign in to comment.