{-# OPTIONS -fglasgow-exts #-}

-----------------------------------------------------------------------------------------
{-| Module      :  Database.HSQL.MSI
    Copyright   :  (c) Krasimir Angelov 2005
    License     :  BSD-style

    Maintainer  :  kr.angelov@gmail.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to Microsoft Installer Database
-}
-----------------------------------------------------------------------------------------

module Database.HSQL.MSI(connect, module Database.HSQL) where

import Control.Concurrent.MVar
import Control.Exception (throwDyn)
import Control.Monad
import Data.Char (isUpper, toLower)
import Data.IORef
import Data.Word
import Foreign
import Foreign.C
import Database.HSQL
import Database.HSQL.Types

#include <windows.h>

type MSIHANDLE = CInt

connect :: String -> String -> IO Connection
connect source dest =
  withCString source $ \csource ->
  withCString dest   $ \cdest   ->
  alloca             $ \phandle -> do
    msiOpenDatabase csource cdest phandle >>= checkResult
    hDatabase <- peek phandle
    refFalse <- newMVar False
    let connection = (Connection
    			{ connDisconnect = disconnect hDatabase
    			, connExecute    = execute hDatabase
    			, connQuery      = query connection hDatabase
    			, connTables     = tables hDatabase
    			, connDescribe   = describe hDatabase
    			, connBeginTransaction = beginTransaction hDatabase
    			, connCommitTransaction = commitTransaction hDatabase
    			, connRollbackTransaction = rollbackTransaction hDatabase
    			, connClosed     = refFalse
    			})
    return connection
  where
    disconnect :: MSIHANDLE -> IO ()
    disconnect hDatabase = do
      msiDatabaseCommit hDatabase >>= checkResult
      msiCloseHandle hDatabase >>= checkResult

    execute :: MSIHANDLE -> String -> IO ()
    execute hDatabase query =
      withCString query $ \cquery  ->
      alloca            $ \phandle -> do
        msiDatabaseOpenView hDatabase cquery phandle >>= checkResult
        hView <- peek phandle
        msiViewExecute hView 0 >>= checkResult
        msiCloseHandle hView >>= checkResult

    col_buffer_size = 1024

    query :: Connection -> MSIHANDLE -> String -> IO Statement
    query connection hDatabase query =
      withCString query $ \cquery  ->
      alloca            $ \phandle -> do
        msiDatabaseOpenView hDatabase cquery phandle >>= checkResult
        hView <- peek phandle
        msiViewExecute hView 0 >>= checkResult
        fields <- getFields hView
        refFalse <- newMVar False
        refRecord <- newIORef 0
        let statement = Statement
       			    { stmtConn   = connection
    			    , stmtClose  = closeStatement hView refRecord
    			    , stmtFetch  = fetch hView refRecord
    			    , stmtGetCol = getColValue refRecord
    			    , stmtFields = fields
    			    , stmtClosed = refFalse
    			    }
        return statement
      where
        getFields hView = 
          alloca $ \phNamesRecord ->
          alloca $ \phTypesRecord -> do
            msiViewGetColumnInfo hView 0 phNamesRecord >>= checkResult
            msiViewGetColumnInfo hView 1 phTypesRecord >>= checkResult
            hNamesRecord <- peek phNamesRecord
            hTypesRecord <- peek phTypesRecord
            count <- msiRecordGetFieldCount hNamesRecord
            loop 1 count hNamesRecord hTypesRecord

        loop n count hNamesRecord hTypesRecord
          | n > count = return []
          | otherwise = 
              allocaBytes col_buffer_size $ \buffer ->
              alloca                      $ \plen -> do
                poke plen (fromIntegral col_buffer_size)
                msiRecordGetString hNamesRecord n buffer plen >>= checkResult
                name <- peekCString buffer
                poke plen (fromIntegral col_buffer_size)
                msiRecordGetString hTypesRecord n buffer plen >>= checkResult
                typ <- peekCString buffer
                fieldDefs <- loop (n+1) count hNamesRecord hTypesRecord
                return (mkFieldDef name typ : fieldDefs)
        
        mkFieldDef name (c:cs) = (name, sqlType, isUpper c)
          where
            width   = read cs
            sqlType = case toLower c of
                        's' -> case width of
                                 0 -> SqlText
                                 n -> SqlVarChar n
                        'l' -> case width of
			         0 -> SqlText
                                 n -> SqlVarChar n
                        'i' -> case width of
                                 2 -> SqlInteger
                                 4 -> SqlBigInt
                        'v' -> case width of
			         0 -> SqlBLOB

    tables :: MSIHANDLE -> IO [String]
    tables hDatabase =
      withCString query $ \cquery  ->
      alloca            $ \phandle -> do
        msiDatabaseOpenView hDatabase cquery phandle >>= checkResult
        hView <- peek phandle
        msiViewExecute hView 0 >>= checkResult
        loop hView
      where
        query = "select Name from _Tables"
        
        loop :: MSIHANDLE -> IO [String]
	loop hView = do
	  alloca $ \phRecord -> do
	    res <- msiViewFetch hView phRecord
	    if res == 259
	      then do msiCloseHandle hView
	              return []
	      else do checkResult res
	              hRecord <- peek phRecord
	              name <- allocaBytes col_buffer_size $ \buffer ->
		              alloca                      $ \plen -> do
		                poke plen (fromIntegral col_buffer_size)
		                msiRecordGetString hRecord 1 buffer plen >>= checkResult
		                len <- peek plen
		                peekCStringLen (buffer,fromIntegral len)
	              msiCloseHandle hRecord >>= checkResult
	              names <- loop hView
                      return (name:names)

    describe :: MSIHANDLE -> String -> IO [FieldDef]
    describe hDatabase tableName =
      withCString query $ \cquery  ->
      alloca            $ \phandle -> do
        msiDatabaseOpenView hDatabase cquery phandle >>= checkResult
        hView <- peek phandle
        msiViewExecute hView 0 >>= checkResult
        loop hView
      where
        query = "select Name from _Columns where `Table`="++toSqlValue tableName
        
        loop :: MSIHANDLE -> IO [FieldDef]
	loop hView = do
	  alloca $ \phRecord -> do
	    res <- msiViewFetch hView phRecord
	    if res == 259
	      then do msiCloseHandle hView
	              return []
	      else do checkResult res
	              hRecord <- peek phRecord
	              name <- allocaBytes col_buffer_size $ \buffer ->
		              alloca                      $ \plen -> do
		                poke plen (fromIntegral col_buffer_size)
		                msiRecordGetString hRecord 1 buffer plen >>= checkResult
		                len <- peek plen
		                peekCStringLen (buffer,fromIntegral len)
	              msiCloseHandle hRecord >>= checkResult
	              columns <- loop hView
                      return ((name, SqlText, False):columns)

    beginTransaction hDatabase = throwDyn SqlUnsupportedOperation
    commitTransaction hDatabase = throwDyn SqlUnsupportedOperation
    rollbackTransaction hDatabase = throwDyn SqlUnsupportedOperation
    
    fetch :: MSIHANDLE -> IORef MSIHANDLE -> IO Bool
    fetch hView refRecord = do
      hRecord <- readIORef refRecord
      unless (hRecord == 0) $
        (msiCloseHandle hRecord >>= checkResult)
      alloca $ \phRecord -> do
        res <- msiViewFetch hView phRecord
        if res == 259
          then do writeIORef refRecord 0
                  return False
          else do checkResult res
                  hRecord <- peek phRecord
                  writeIORef refRecord hRecord
                  return True

    getColValue :: IORef MSIHANDLE -> Int -> FieldDef -> (SqlType -> CString -> Int -> IO (Maybe a)) -> IO (Maybe a)
    getColValue refRecord colNumber (name,sqlType,nullable) f =
      allocaBytes col_buffer_size $ \buffer ->
      alloca                      $ \plen -> do
        poke plen (fromIntegral col_buffer_size)
        hRecord <- readIORef refRecord
        msiRecordGetString hRecord (fromIntegral colNumber+1) buffer plen >>= checkResult
        len <- peek plen
        f sqlType buffer (fromIntegral len)

    closeStatement :: MSIHANDLE -> IORef MSIHANDLE -> IO ()
    closeStatement hView refRecord = do
      msiCloseHandle hView >>= checkResult
      hRecord <- readIORef refRecord
      unless (hRecord == 0) $
        (msiCloseHandle hRecord >>= checkResult)

foreign import stdcall "MsiOpenDatabaseA" msiOpenDatabase :: CString -> CString -> Ptr MSIHANDLE -> IO Word32
foreign import stdcall "MsiDatabaseCommit" msiDatabaseCommit :: MSIHANDLE -> IO Word32
foreign import stdcall "MsiCloseHandle"  msiCloseHandle :: MSIHANDLE -> IO Word32
foreign import stdcall "MsiDatabaseOpenViewA" msiDatabaseOpenView :: MSIHANDLE -> CString -> Ptr MSIHANDLE -> IO Word32
foreign import stdcall "MsiViewExecute" msiViewExecute :: MSIHANDLE -> MSIHANDLE -> IO Word32
foreign import stdcall "MsiGetLastErrorRecord" msiGetLastErrorRecord :: IO MSIHANDLE
foreign import stdcall "MsiFormatRecordA" msiFormatRecord :: MSIHANDLE -> MSIHANDLE -> CString -> Ptr CInt -> IO Word32
foreign import stdcall "MsiViewGetColumnInfo" msiViewGetColumnInfo :: MSIHANDLE -> CInt -> Ptr MSIHANDLE -> IO Word32
foreign import stdcall "MsiRecordGetFieldCount" msiRecordGetFieldCount :: MSIHANDLE -> IO Word32
foreign import stdcall "MsiRecordGetStringA" msiRecordGetString :: MSIHANDLE -> Word32 -> CString -> Ptr Word32 -> IO Word32
foreign import stdcall "MsiViewFetch" msiViewFetch :: MSIHANDLE -> Ptr MSIHANDLE -> IO Word32
foreign import stdcall "FormatMessageA" formatMessage :: Word32 -> Ptr () -> Word32 -> Word32 -> CString -> Word32 -> Ptr () -> IO Word32

checkResult :: Word32 -> IO ()
checkResult err
  | err == 0  = return ()
  | otherwise = do
      msg <- allocaBytes 1024 $ \cmsg -> do
                formatMessage (#const FORMAT_MESSAGE_FROM_SYSTEM) nullPtr err 0 cmsg 1024 nullPtr
                peekCString cmsg
      throwDyn (SqlError "" (fromIntegral err) msg)
