Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Regression module to address first point in issue #67 #113

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
dist
**/.DS_Store
TAGS
cabal.sandbox.config
6 changes: 6 additions & 0 deletions Statistics/Function.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ module Statistics.Function
-- * Combinators
, for
, rfor
, for_
) where

#include "MachDeps.h"
Expand Down Expand Up @@ -141,6 +142,11 @@ rfor n0 !n f = loop n0
| otherwise = let i' = i-1 in f i' >> loop i'
{-# INLINE rfor #-}

for_ :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for_ n0 !n f | n0 > n = rfor n0 n f
| otherwise = for n0 n f
{-# INLINE for_ #-}

unsafeModify :: M.MVector s Double -> Int -> (Double -> Double) -> ST s ()
unsafeModify v i f = do
k <- M.unsafeRead v i
Expand Down
13 changes: 13 additions & 0 deletions Statistics/Matrix.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE PatternGuards #-}

-- |
-- Module : Statistics.Matrix
-- Copyright : 2011 Aleksey Khudyakov, 2014 Bryan O'Sullivan
Expand All @@ -12,6 +13,8 @@
module Statistics.Matrix
( -- * Data types
Matrix(..)
, TMatrix(..)
, UpperLower(..)
, Vector
-- * Conversion from/to lists/vectors
, fromVector
Expand All @@ -29,6 +32,7 @@ module Statistics.Matrix
, generateSym
, ident
, diag
, diagOf
, dimension
, center
, multiply
Expand Down Expand Up @@ -169,6 +173,7 @@ generateSym n f = runST $ do
ident :: Int -> Matrix
ident n = diag $ U.replicate n 1.0


-- | Create a square matrix with given diagonal, other entries default to 0
diag :: Vector -> Matrix
diag v
Expand All @@ -180,6 +185,14 @@ diag v
where
n = U.length v

-- | Return diagonal of a square matrix
diagOf :: Matrix -> Vector
diagOf m
| rs /= cs = error $ "matrix is not square, dimension = " ++ show d
| otherwise = U.generate rs (\i -> unsafeIndex m i i)
where
d@(rs,cs) = dimension m

-- | Return the dimensions of this matrix, as a (row,column) pair.
dimension :: Matrix -> (Int, Int)
dimension (Matrix r c _ _) = (r, c)
Expand Down
25 changes: 25 additions & 0 deletions Statistics/Matrix/Algorithms.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ module Statistics.Matrix.Algorithms

import Control.Applicative ((<$>), (<*>))
import Control.Monad.ST (ST, runST)
import Control.Monad (when)
import Prelude hiding (sum, replicate)
import Statistics.Matrix (Matrix, column, dimension, for, norm)
import qualified Statistics.Matrix.Mutable as M
Expand All @@ -37,6 +38,30 @@ qr mat = runST $ do
M.unsafeModify a i jj $ subtract (p * aij)
(,) <$> M.unsafeFreeze a <*> M.unsafeFreeze r

-- | /O(n^3)/ Compute the Cholesky factorization and return
-- the lower triangular Cholesky factor (/L/). Note: does
-- not check whether matrix is positive definite or not.
-- Will fail if diagonal elements are non-positive.
chol :: Matrix -> Matrix
chol mat
| m /= n = error "Matrix must be square."
| otherwise = runST $ do
l <- M.thaw mat
for 0 n $ \j -> do
M.unsafeModify l j j sqrt
for (j+1) n $ \i -> do
ljj <- M.unsafeRead l j j
M.unsafeModify l i j (/ ljj)
M.unsafeWrite l j i 0
for (j+1) (i+1) $ \jj -> do
ljjj <- M.unsafeRead l jj j
lij <- M.unsafeRead l i j
M.unsafeModify l i jj (subtract $ lij*ljjj)
when (i /= jj) $
M.unsafeModify l jj i (subtract $ lij*ljjj)
M.unsafeFreeze l
where (m,n) = dimension mat

innerProduct :: M.MMatrix s -> Int -> Int -> ST s Double
innerProduct mmat j k = M.immutably mmat $ \mat ->
sum $ U.zipWith (*) (column mat j) (column mat k)
11 changes: 11 additions & 0 deletions Statistics/Matrix/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ module Statistics.Matrix.Types
, MVector
, Matrix(..)
, MMatrix(..)
, TMatrix(..)
, UpperLower(..)
, debug
) where

Expand Down Expand Up @@ -42,6 +44,15 @@ data MMatrix s = MMatrix
{-# UNPACK #-} !Int
!(MVector s)

-- | Data type records whether a triangular matrix is upper or lower triangular.
data UpperLower = Upper | Lower
deriving (Eq, Show)

-- | Triangular matrix, stored as a Matrix with indication of whether it is
-- upper or lower triangular.
data TMatrix = TMatrix Matrix UpperLower
deriving (Eq, Show)

-- The Show instance is useful only for debugging.
instance Show Matrix where
show = debug
Expand Down
Loading