-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
450e82a
commit 5c96d00
Showing
5 changed files
with
89 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
module Main where | ||
|
||
import Control.Monad (replicateM) | ||
import System.Random (randomRIO) | ||
|
||
import qualified Data.Array.Accelerate as A | ||
import qualified Data.Array.Accelerate.Interpreter as I | ||
import qualified Data.Array.Accelerate.TensorFlow as TFCPU | ||
import qualified Data.Array.Accelerate.TensorFlow.Lite as TPU | ||
|
||
|
||
-- matrix is a vector of rows | ||
matmat :: A.Acc (A.Matrix Float) -> A.Acc (A.Matrix Float) -> A.Acc (A.Matrix Float) | ||
matmat a b = | ||
let A.I2 k m = A.shape a | ||
A.I2 _ n = A.shape b | ||
in A.sum $ | ||
A.generate (A.I3 k n m) $ \(A.I3 i j p) -> | ||
a A.! A.I2 i p * b A.! A.I2 p j | ||
|
||
genmatrix :: A.DIM2 -> IO (A.Matrix Float) | ||
genmatrix dim@(A.Z A.:. n A.:. m) = A.fromList dim <$> replicateM (n * m) (randomRIO (0, 10)) | ||
|
||
main :: IO () | ||
main = do | ||
-- Inputs | ||
let dimA = A.Z A.:. 3 A.:. 2 | ||
dimB = A.Z A.:. 2 A.:. 4 | ||
dimC = A.Z A.:. 3 A.:. 4 -- result dimension | ||
let a1 = A.fromList dimA | ||
[1, 2 | ||
,3, 4 | ||
,5, 6] | ||
b1 = A.fromList dimB | ||
[1, 0, 1, 0 | ||
,0, 1, 0, 1] | ||
|
||
-- representative sample input | ||
samples <- replicateM 10 $ do | ||
a <- genmatrix dimA | ||
b <- genmatrix dimB | ||
return (a TPU.:-> b TPU.:-> TPU.Result dimC) | ||
|
||
-- First let's try it in the accelerate interpreter | ||
putStrLn "## Running in the interpreter" | ||
print $ I.runN matmat a1 b1 | ||
|
||
-- First let's run it on the CPU using TensorFlow | ||
putStrLn "## Running on TensorFlow native CPU" | ||
print $ TFCPU.runN matmat a1 b1 | ||
|
||
-- Then let's run it on a TPU, the easy way | ||
putStrLn "## Running on TPU, easy" | ||
do let model = TPU.compile matmat samples | ||
print $ TPU.execute model a1 b1 | ||
|
||
-- And then the hard way, which scales better to multiple model executions | ||
putStrLn "## Running on TPU, better" | ||
do TPU.withConverterPy $ \converter -> do | ||
TPU.withDeviceContext $ do | ||
let model = TPU.compileWith converter matmat samples | ||
print $ TPU.execute model a1 b1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters