Skip to content

Commit

Permalink
append
Browse files Browse the repository at this point in the history
  • Loading branch information
dpvanbalen committed Sep 11, 2023
1 parent ef3f9ab commit 733c3f9
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ module Data.Array.Accelerate.TensorFlow.Lite (
-- Accelerate, and hence work also on other backends (via the fallback
-- implementation).

argMin, argMax,
argMin, argMax, append,

-- * Model serialisation
--
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ test_backpermute tc =
testDIM2 =
testGroup "DIM2"
[ testProperty "transpose" $ prop_transpose tc f32
, testProperty "reverse" $ prop_reverse tc f32
]

prop_transpose
Expand All @@ -58,6 +59,18 @@ prop_transpose tc e =
xs <- forAll (array ForInput sh e)
tpuTestCase tc A.transpose dat xs

prop_reverse
:: (Elt e, Show e, Similar e)
=> TestContext
-> (WhichData -> Gen e)
-> Property
prop_reverse tc e =
property $ do
sh <- forAll dim1
dat <- forAllWith (const "sample-data") (generate_sample_data_reverse sh e)
xs <- forAll (array ForInput sh e)
tpuTestCase tc A.reverse dat xs

generate_sample_data_transpose
:: Elt e
=> DIM2
Expand All @@ -68,3 +81,13 @@ generate_sample_data_transpose sh@(Z :. h :. w) e = do
xs <- Gen.list (Range.singleton i) (array ForSample sh e)
return [ x :-> Result (Z :. w :. h) | x <- xs ]


generate_sample_data_reverse
:: Elt e
=> DIM1
-> (WhichData -> Gen e)
-> Gen (RepresentativeData (Array DIM1 e -> Array DIM1 e))
generate_sample_data_reverse sh@(Z :. sz) e = do
i <- Gen.int (Range.linear 10 16)
xs <- Gen.list (Range.singleton i) (array ForSample sh e)
return [ x :-> Result (Z :. sz) | x <- xs ]
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import qualified Hedgehog.Range as Range
import Test.Tasty
import Test.Tasty.Hedgehog

import Control.Monad.IO.Class
import Prelude as P


Expand All @@ -53,8 +54,34 @@ test_foreign tc =
testGroup ("DIM" P.++ show (rank @sh))
[ testProperty "argmin" $ prop_min tc dim f32
, testProperty "argmax" $ prop_max tc dim i16
, testProperty "append_i32" $ prop_app tc dim i32
, testProperty "append_f32" $ prop_app tc dim f32
]


prop_app
:: (P.Eq sh, Show sh, Shape sh, Elt e, Show e, Similar e)
=> TestContext
-> Gen (sh:.Int)
-> (WhichData -> Gen e)
-> Property
prop_app tc dim e =
property $ do
sh1 <- forAll (Gen.filter (\(_ :. n) -> n P.> 0) dim)
sh2 <- forAll (Gen.filter (\(_ :. n) -> n P.> 0) dim)
ndat <- forAll (Gen.int (Range.linear 10 16))
dat1 <- forAll (Gen.list (Range.singleton ndat) (array ForSample sh1 e))
dat2 <- forAll (Gen.list (Range.singleton ndat) (array ForSample sh2 e))
liftIO $ putStrLn $ show sh1 P.++ " | " P.++ show sh2
liftIO $ print dat1
liftIO $ print dat2
xs <- forAll (array ForInput sh1 e)
ys <- forAll (array ForInput sh2 e)
let sh = appendresultshape sh1 sh2
tpuTestCase tc append (P.zipWith (\a b -> a :-> b :-> Result sh) dat1 dat2) xs ys
where
appendresultshape (sh1:.sz1) (sh2:.sz2) = Data.Array.Accelerate.Sugar.Shape.intersect sh1 sh2 :. sz1+sz2

prop_min
:: (P.Eq sh, Show sh, Shape2 sh, Elt e, Show e, Similar e, A.Ord e, P.Ord e, P.Num e)
=> TestContext
Expand Down
33 changes: 32 additions & 1 deletion accelerate-tensorflow/src/Data/Array/Accelerate/TensorFlow.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ module Data.Array.Accelerate.TensorFlow (
run,
runN,

argMin, argMax
argMin, argMax, append

) where

Expand Down Expand Up @@ -57,6 +57,7 @@ import Data.Array.Accelerate.TensorFlow.CodeGen.Foreign
import qualified Proto.Tensorflow.Core.Framework.NodeDef_Fields as TF
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Core as TF
import qualified TensorFlow.Ops as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF
import qualified TensorFlow.Internal.FFI as Internal
Expand All @@ -73,6 +74,7 @@ import Text.Printf
import qualified Data.Text as T
import qualified Data.Vector.Storable as V

import qualified Data.Array.Accelerate.TensorFlow.CodeGen.Arithmetic as Ar

-- | Run a complete embedded program using the default TensorFlow backend
--
Expand Down Expand Up @@ -248,3 +250,32 @@ argMin, argMax :: (A.Ord a, Shape sh) => Acc (Array (sh :. Int) a) -> Acc (Array
argMin = argMinMax Min
argMax = argMinMax Max

-- Specialised instance, because naive translation through `generate` results in a 'select' over two generates that both include out-of-bounds indexing.
append :: (Shape sh, A.Elt e) => Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e)
append xs ys = A.foreignAcc
(ForeignAcc "append" tensorflowappend)
(\(T2 xs' ys') -> xs' A.++ ys')
(backpermuteToSameSize $ T2 xs ys)
where
tensorflowappend :: (((), Tensor (sh,Int) e), Tensor (sh,Int) e) -> Tensor (sh,Int) e
tensorflowappend (((), Tensor (R.ArrayR (ShapeRsnoc shR) eR) (sh,szl) l)
, Tensor (R.ArrayR _ _ ) (_ ,szr) r) =
Tensor (R.ArrayR (ShapeRsnoc shR) eR) (sh, szl+szr) (go eR l r)
where
go :: TypeR e -> TensorArrayData e -> TensorArrayData e -> TensorArrayData e
go TupRunit () () = ()
go (TupRpair l r) (l1,r1) (l2,r2) = (go l l1 l2, go r r1 r2)
go (TupRsingle t) x y = buildTypeDictsScalar t $ Sh.wrapConcat (fromIntegral $ rank shR) [x, y]
-- zipmin :: forall e. TypeR e -> TensorArrayData e -> TensorArrayData e -> TensorArrayData e
-- zipmin TupRunit () () = ()
-- zipmin (TupRpair l r) (l1,r1) (l2,r2) = (zipmin l l1 l2, zipmin r r1 r2)
-- zipmin (TupRsingle t) x y = buildTypeDictsScalar t $ Ar.min (singleType @e) (x, y)

backpermuteToSameSize :: (Shape sh, A.Elt e) => Acc (Array (sh:.Int) e, Array (sh:.Int) e) -> Acc (Array (sh:.Int) e, Array (sh:.Int) e)
backpermuteToSameSize (T2 xs ys) = T2 xs' ys'
where
xs' = A.backpermute (sh::.szx) id xs
ys' = A.backpermute (sh::.szy) id ys
shx ::. szx = A.shape xs
shy ::. szy = A.shape ys
sh = A.intersect shx shy
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
Expand Down Expand Up @@ -46,7 +48,7 @@ import qualified Data.Set as Set
import qualified Data.Vector.Storable as V


type family Tensors t where
type family Tensors ts = t | t -> ts where
Tensors () = ()
Tensors (Array sh e) = Tensor sh e
Tensors (a, b) = (Tensors a, Tensors b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type TypeDictsFor t s =
(Storable t
,Typeable s
,Show s
,IsScalar t
,IsSingle t
,s ~ ScalarTensorDataR t
,TF.TensorType s
,TArrayDataR Sh.Tensor t ~ Sh.Tensor s
Expand Down

0 comments on commit 733c3f9

Please sign in to comment.