{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module Internal.Sparse(
GMatrix(..), CSR(..), mkCSR, fromCSR,
mkSparse, mkDiagR, mkDense,
AssocMatrix,
toDense,
gmXv, (!#>)
)where
import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import qualified Data.Vector.Storable as V
import Data.Function(on)
import Control.Arrow((***))
import Control.Monad(when)
import Data.List(groupBy, sort)
import Foreign.C.Types(CInt(..))
import Internal.Devel
import System.IO.Unsafe(unsafePerformIO)
import Foreign(Ptr)
import Text.Printf(printf)
infixl 0 ~!~
c :: Bool
c ~!~ :: Bool -> [Char] -> f ()
~!~ msg :: [Char]
msg = Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
c ([Char] -> f ()
forall a. HasCallStack => [Char] -> a
error [Char]
msg)
type AssocMatrix = [((Int,Int),Double)]
data CSR = CSR
{ CSR -> Vector Double
csrVals :: Vector Double
, CSR -> Vector CInt
csrCols :: Vector CInt
, CSR -> Vector CInt
csrRows :: Vector CInt
, CSR -> Int
csrNRows :: Int
, CSR -> Int
csrNCols :: Int
} deriving Int -> CSR -> ShowS
[CSR] -> ShowS
CSR -> [Char]
(Int -> CSR -> ShowS)
-> (CSR -> [Char]) -> ([CSR] -> ShowS) -> Show CSR
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CSR] -> ShowS
$cshowList :: [CSR] -> ShowS
show :: CSR -> [Char]
$cshow :: CSR -> [Char]
showsPrec :: Int -> CSR -> ShowS
$cshowsPrec :: Int -> CSR -> ShowS
Show
data CSC = CSC
{ CSC -> Vector Double
cscVals :: Vector Double
, CSC -> Vector CInt
cscRows :: Vector CInt
, CSC -> Vector CInt
cscCols :: Vector CInt
, CSC -> Int
cscNRows :: Int
, CSC -> Int
cscNCols :: Int
} deriving Int -> CSC -> ShowS
[CSC] -> ShowS
CSC -> [Char]
(Int -> CSC -> ShowS)
-> (CSC -> [Char]) -> ([CSC] -> ShowS) -> Show CSC
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CSC] -> ShowS
$cshowList :: [CSC] -> ShowS
show :: CSC -> [Char]
$cshow :: CSC -> [Char]
showsPrec :: Int -> CSC -> ShowS
$cshowsPrec :: Int -> CSC -> ShowS
Show
mkCSR :: AssocMatrix -> CSR
mkCSR :: AssocMatrix -> CSR
mkCSR sm' :: AssocMatrix
sm' = CSR :: Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR{..}
where
sm :: AssocMatrix
sm = AssocMatrix -> AssocMatrix
forall a. Ord a => [a] -> [a]
sort AssocMatrix
sm'
rws :: [(Vector CInt, Vector Double)]
rws = (AssocMatrix -> (Vector CInt, Vector Double))
-> [AssocMatrix] -> [(Vector CInt, Vector Double)]
forall a b. (a -> b) -> [a] -> [b]
map (([CInt] -> Vector CInt
forall a. Storable a => [a] -> Vector a
fromList ([CInt] -> Vector CInt)
-> ([Double] -> Vector Double)
-> ([CInt], [Double])
-> (Vector CInt, Vector Double)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [Double] -> Vector Double
forall a. Storable a => [a] -> Vector a
fromList)
(([CInt], [Double]) -> (Vector CInt, Vector Double))
-> (AssocMatrix -> ([CInt], [Double]))
-> AssocMatrix
-> (Vector CInt, Vector Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(CInt, Double)] -> ([CInt], [Double])
forall a b. [(a, b)] -> ([a], [b])
unzip
([(CInt, Double)] -> ([CInt], [Double]))
-> (AssocMatrix -> [(CInt, Double)])
-> AssocMatrix
-> ([CInt], [Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> (CInt, Double))
-> AssocMatrix -> [(CInt, Double)]
forall a b. (a -> b) -> [a] -> [b]
map ((CInt -> CInt
forall a. Enum a => a -> a
succ(CInt -> CInt) -> ((Int, Int) -> CInt) -> (Int, Int) -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Int -> CInt
fi(Int -> CInt) -> ((Int, Int) -> Int) -> (Int, Int) -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Int, Int) -> Int
forall a b. (a, b) -> b
snd) ((Int, Int) -> CInt)
-> (Double -> Double) -> ((Int, Int), Double) -> (CInt, Double)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Double -> Double
forall a. a -> a
id)
)
([AssocMatrix] -> [(Vector CInt, Vector Double)])
-> (AssocMatrix -> [AssocMatrix])
-> AssocMatrix
-> [(Vector CInt, Vector Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> ((Int, Int), Double) -> Bool)
-> AssocMatrix -> [AssocMatrix]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Int -> Int -> Bool)
-> (((Int, Int), Double) -> Int)
-> ((Int, Int), Double)
-> ((Int, Int), Double)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ((Int, Int) -> Int
forall a b. (a, b) -> a
fst((Int, Int) -> Int)
-> (((Int, Int), Double) -> (Int, Int))
-> ((Int, Int), Double)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.((Int, Int), Double) -> (Int, Int)
forall a b. (a, b) -> a
fst))
(AssocMatrix -> [(Vector CInt, Vector Double)])
-> AssocMatrix -> [(Vector CInt, Vector Double)]
forall a b. (a -> b) -> a -> b
$ AssocMatrix
sm
rszs :: [CInt]
rszs = ((Vector CInt, Vector Double) -> CInt)
-> [(Vector CInt, Vector Double)] -> [CInt]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> CInt
fi (Int -> CInt)
-> ((Vector CInt, Vector Double) -> Int)
-> (Vector CInt, Vector Double)
-> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector CInt -> Int
forall t. Storable t => Vector t -> Int
dim (Vector CInt -> Int)
-> ((Vector CInt, Vector Double) -> Vector CInt)
-> (Vector CInt, Vector Double)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector CInt, Vector Double) -> Vector CInt
forall a b. (a, b) -> a
fst) [(Vector CInt, Vector Double)]
rws
csrRows :: Vector CInt
csrRows = [CInt] -> Vector CInt
forall a. Storable a => [a] -> Vector a
fromList ((CInt -> CInt -> CInt) -> CInt -> [CInt] -> [CInt]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
(+) 1 [CInt]
rszs)
csrVals :: Vector Double
csrVals = [Vector Double] -> Vector Double
forall t. Storable t => [Vector t] -> Vector t
vjoin (((Vector CInt, Vector Double) -> Vector Double)
-> [(Vector CInt, Vector Double)] -> [Vector Double]
forall a b. (a -> b) -> [a] -> [b]
map (Vector CInt, Vector Double) -> Vector Double
forall a b. (a, b) -> b
snd [(Vector CInt, Vector Double)]
rws)
csrCols :: Vector CInt
csrCols = [Vector CInt] -> Vector CInt
forall t. Storable t => [Vector t] -> Vector t
vjoin (((Vector CInt, Vector Double) -> Vector CInt)
-> [(Vector CInt, Vector Double)] -> [Vector CInt]
forall a b. (a -> b) -> [a] -> [b]
map (Vector CInt, Vector Double) -> Vector CInt
forall a b. (a, b) -> a
fst [(Vector CInt, Vector Double)]
rws)
csrNRows :: Int
csrNRows = Vector CInt -> Int
forall t. Storable t => Vector t -> Int
dim Vector CInt
csrRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1
csrNCols :: Int
csrNCols = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector CInt -> CInt
forall a. (Storable a, Ord a) => Vector a -> a
V.maximum Vector CInt
csrCols)
data GMatrix
= SparseR
{ GMatrix -> CSR
gmCSR :: CSR
, GMatrix -> Int
nRows :: Int
, GMatrix -> Int
nCols :: Int
}
| SparseC
{ GMatrix -> CSC
gmCSC :: CSC
, nRows :: Int
, nCols :: Int
}
| Diag
{ GMatrix -> Vector Double
diagVals :: Vector Double
, nRows :: Int
, nCols :: Int
}
| Dense
{ GMatrix -> Matrix Double
gmDense :: Matrix Double
, nRows :: Int
, nCols :: Int
}
deriving Int -> GMatrix -> ShowS
[GMatrix] -> ShowS
GMatrix -> [Char]
(Int -> GMatrix -> ShowS)
-> (GMatrix -> [Char]) -> ([GMatrix] -> ShowS) -> Show GMatrix
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [GMatrix] -> ShowS
$cshowList :: [GMatrix] -> ShowS
show :: GMatrix -> [Char]
$cshow :: GMatrix -> [Char]
showsPrec :: Int -> GMatrix -> ShowS
$cshowsPrec :: Int -> GMatrix -> ShowS
Show
mkDense :: Matrix Double -> GMatrix
mkDense :: Matrix Double -> GMatrix
mkDense m :: Matrix Double
m = Dense :: Matrix Double -> Int -> Int -> GMatrix
Dense{..}
where
gmDense :: Matrix Double
gmDense = Matrix Double
m
nRows :: Int
nRows = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
m
nCols :: Int
nCols = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
m
mkSparse :: AssocMatrix -> GMatrix
mkSparse :: AssocMatrix -> GMatrix
mkSparse = CSR -> GMatrix
fromCSR (CSR -> GMatrix) -> (AssocMatrix -> CSR) -> AssocMatrix -> GMatrix
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AssocMatrix -> CSR
mkCSR
fromCSR :: CSR -> GMatrix
fromCSR :: CSR -> GMatrix
fromCSR csr :: CSR
csr = SparseR :: CSR -> Int -> Int -> GMatrix
SparseR {..}
where
gmCSR :: CSR
gmCSR @ CSR {..} = CSR
csr
nRows :: Int
nRows = Int
csrNRows
nCols :: Int
nCols = Int
csrNCols
mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR r :: Int
r c :: Int
c v :: Vector Double
v
| Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
r Int
c = Diag :: Vector Double -> Int -> Int -> GMatrix
Diag{..}
| Bool
otherwise = [Char] -> GMatrix
forall a. HasCallStack => [Char] -> a
error ([Char] -> GMatrix) -> [Char] -> GMatrix
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "mkDiagR: incorrect sizes (%d,%d) [%d]" Int
r Int
c (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
where
nRows :: Int
nRows = Int
r
nCols :: Int
nCols = Int
c
diagVals :: Vector Double
diagVals = Vector Double
v
type IV t = CInt -> Ptr CInt -> t
type V t = CInt -> Ptr Double -> t
type SMxV = V (IV (IV (V (V (IO CInt)))))
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv SparseR { gmCSR :: GMatrix -> CSR
gmCSR = CSR{..}, .. } v :: Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols Bool -> [Char] -> IO ()
forall (f :: * -> *). Applicative f => Bool -> [Char] -> f ()
~!~ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (CSR): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
(Vector Double
csrVals Vector Double
-> ((CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt)
-> Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrCols Vector CInt
-> ((CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt)
(CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
csrRows Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
SMxV
c_smXv IO CInt -> [Char] -> IO ()
#|"CSRXv"
Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r
gmXv SparseC { gmCSC :: GMatrix -> CSC
gmCSC = CSC{..}, .. } v :: Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols Bool -> [Char] -> IO ()
forall (f :: * -> *). Applicative f => Bool -> [Char] -> f ()
~!~ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (CSC): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
Vector Double
r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
(Vector Double
cscVals Vector Double
-> ((CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt)
-> Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscRows Vector CInt
-> ((CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt)
(CInt
-> Ptr CInt -> CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector CInt
cscCols Vector CInt
-> ((CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt)
-> Trans
(Vector CInt) (CInt -> Ptr Double -> CInt -> Ptr Double -> IO CInt)
-> IO CInt
forall c b r. TransArray c => c -> (b -> IO r) -> Trans c b -> IO r
# Vector Double
v Vector Double
-> Vector Double
-> Trans (Vector Double) (Trans (Vector Double) (IO CInt))
-> IO CInt
forall c c1 r.
(TransArray c, TransArray c1) =>
c1 -> c -> Trans c1 (Trans c (IO r)) -> IO r
#! Vector Double
r) Trans
(Vector Double)
(CInt
-> Ptr CInt
-> CInt
-> Ptr CInt
-> CInt
-> Ptr Double
-> CInt
-> Ptr Double
-> IO CInt)
SMxV
c_smTXv IO CInt -> [Char] -> IO ()
#|"CSCXv"
Vector Double -> IO (Vector Double)
forall (m :: * -> *) a. Monad m => a -> m a
return Vector Double
r
gmXv Diag{..} v :: Vector Double
v
| Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
= [Vector Double] -> Vector Double
forall t. Storable t => [Vector t] -> Vector t
vjoin [ Int -> Int -> Vector Double -> Vector Double
forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector 0 (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) Vector Double
v Vector Double -> Vector Double -> Vector Double
forall (c :: * -> *) e. Container c e => c e -> c e -> c e
`mul` Vector Double
diagVals
, Double -> Int -> Vector Double
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst 0 (Int
nRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) ]
| Bool
otherwise = [Char] -> Vector Double
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector Double) -> [Char] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
gmXv Dense{..} v :: Vector Double
v
| Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
= Matrix Double -> Vector Double -> Vector Double
forall t. Product t => Matrix t -> Vector t -> Vector t
mXv Matrix Double
gmDense Vector Double
v
| Bool
otherwise = [Char] -> Vector Double
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector Double) -> [Char] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf "gmXv (Dense): incorrect sizes: (%d,%d) x %d"
Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
infixr 8 !#>
(!#>) :: GMatrix -> Vector Double -> Vector Double
!#> :: GMatrix -> Vector Double -> Vector Double
(!#>) = GMatrix -> Vector Double -> Vector Double
gmXv
foreign import ccall unsafe "smXv"
c_smXv :: SMxV
foreign import ccall unsafe "smTXv"
c_smTXv :: SMxV
toDense :: AssocMatrix -> Matrix Double
toDense :: AssocMatrix -> Matrix Double
toDense asm :: AssocMatrix
asm = IndexOf Matrix
-> Double -> [(IndexOf Matrix, Double)] -> Matrix Double
forall (c :: * -> *) e.
Container c e =>
IndexOf c -> e -> [(IndexOf c, e)] -> c e
assoc (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+1,Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
+1) 0 AssocMatrix
[(IndexOf Matrix, Double)]
asm
where
(r :: Int
r,c :: Int
c) = ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> ([Int] -> Int) -> ([Int], [Int]) -> (Int, Int)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum) (([Int], [Int]) -> (Int, Int))
-> (AssocMatrix -> ([Int], [Int])) -> AssocMatrix -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, Int)] -> ([Int], [Int]))
-> (AssocMatrix -> [(Int, Int)]) -> AssocMatrix -> ([Int], [Int])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> (Int, Int)) -> AssocMatrix -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((Int, Int), Double) -> (Int, Int)
forall a b. (a, b) -> a
fst (AssocMatrix -> (Int, Int)) -> AssocMatrix -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ AssocMatrix
asm
instance Transposable CSR CSC
where
tr :: CSR -> CSC
tr (CSR vs :: Vector Double
vs cs :: Vector CInt
cs rs :: Vector CInt
rs n :: Int
n m :: Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSC
CSC Vector Double
vs Vector CInt
cs Vector CInt
rs Int
m Int
n
tr' :: CSR -> CSC
tr' = CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr
instance Transposable CSC CSR
where
tr :: CSC -> CSR
tr (CSC vs :: Vector Double
vs rs :: Vector CInt
rs cs :: Vector CInt
cs n :: Int
n m :: Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vs Vector CInt
rs Vector CInt
cs Int
m Int
n
tr' :: CSC -> CSR
tr' = CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr
instance Transposable GMatrix GMatrix
where
tr :: GMatrix -> GMatrix
tr (SparseR s :: CSR
s n :: Int
n m :: Int
m) = CSC -> Int -> Int -> GMatrix
SparseC (CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr CSR
s) Int
m Int
n
tr (SparseC s :: CSC
s n :: Int
n m :: Int
m) = CSR -> Int -> Int -> GMatrix
SparseR (CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr CSC
s) Int
m Int
n
tr (Diag v :: Vector Double
v n :: Int
n m :: Int
m) = Vector Double -> Int -> Int -> GMatrix
Diag Vector Double
v Int
m Int
n
tr (Dense a :: Matrix Double
a n :: Int
n m :: Int
m) = Matrix Double -> Int -> Int -> GMatrix
Dense (Matrix Double -> Matrix Double
forall m mt. Transposable m mt => m -> mt
tr Matrix Double
a) Int
m Int
n
tr' :: GMatrix -> GMatrix
tr' = GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr