Hello all, HaskellDB is a nice, combinator library but there are two main disadvantages.
* HaskellDB uses 'Trex' module which is Hugs specific. Both GHC and NHC doesn't support 'Trex'. * HaskellDB cann't execute stored SQL procedures and doesn't allow to execute plain SQL statements. I use HSQL (see mail attachment) for data access. HSQL works with ODBC but its user interface isn't ODBC specific. The module can be rewriten to use native drivers for specific databases (MySQL,PostgresSQL,Oracle,Sybase,...). If there is somebody interested in using of HSQL I will place it on CVS. Krasimir Angelov __________________________________________________ Do You Yahoo!? Yahoo! - Official partner of 2002 FIFA World Cup http://fifaworldcup.yahoo.com
module HSQL ( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement , catchSql -- :: IO a -> (SqlError -> IO a) -> IO a , connect -- :: String -> String -> String -> IO Connection , disconnect -- :: Connection -> IO () , execute -- :: Connection -> String -> IO Statement , closeStatement -- :: Statement -> IO () , fetch -- :: Statement -> IO Bool , inTransaction -- :: Connection -> (Connection -> IO a) -> IO a , getFieldValue -- :: SqlBind a => Statement -> String -> IO a , getFieldValueType -- :: Statement -> String -> (SqlType, Bool) , getFieldsTypes -- :: Statement -> (String, SqlType, Bool) , forEachRow -- :: (Statement -> s -> IO s) -> Statement -> s -> IO s , forEachRow' -- :: (Statement -> IO ()) -> Statement -> IO () , collectRows -- :: (Statement -> IO s) -> Statement -> IO [s] ) where import Word(Word32, Word16) import Int(Int32, Int16) import Foreign import CString import IORef import Monad(when) import Exception (throwDyn, catchDyn, Exception(..)) import Dynamic #include <HSQLStructs.h> type SQLHANDLE = Ptr () type HENV = SQLHANDLE type HDBC = SQLHANDLE type HSTMT = SQLHANDLE type HENVRef = ForeignPtr () type SQLSMALLINT = Int16 type SQLUSMALLINT = Word16 type SQLINTEGER = Int32 type SQLUINTEGER = Word32 type SQLRETURN = SQLSMALLINT type SQLLEN = SQLINTEGER type SQLULEN = SQLINTEGER foreign import stdcall "sqlext.h SQLAllocEnv" sqlAllocEnv :: Ptr HENV -> IO SQLRETURN foreign import stdcall "sqlext.h SQLFreeEnv" sqlFreeEnv :: HENV -> IO SQLRETURN foreign import stdcall "sqlext.h SQLAllocConnect" sqlAllocConnect :: HENV -> Ptr HDBC -> IO SQLRETURN foreign import stdcall "sqlext.h SQLFreeConnect" sqlFreeConnect:: HDBC -> IO SQLRETURN foreign import stdcall "sqlext.h SQLConnect" sqlConnect :: HDBC -> CString -> Int -> CString -> Int -> CString -> Int -> IO SQLRETURN foreign import stdcall "sqlext.h SQLDisconnect" sqlDisconnect :: HDBC -> IO SQLRETURN foreign import stdcall "sqlext.h SQLAllocStmt" sqlAllocStmt :: HDBC -> Ptr HSTMT -> IO SQLRETURN foreign import stdcall "sqlext.h SQLFreeStmt" sqlFreeStmt :: HSTMT -> SQLUSMALLINT -> IO SQLRETURN foreign import stdcall "sqlext.h SQLNumResultCols" sqlNumResultCols :: HSTMT -> Ptr SQLUSMALLINT -> IO SQLRETURN foreign import stdcall "sqlext.h SQLDescribeCol" sqlDescribeCol :: HSTMT -> SQLUSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLULEN -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN foreign import stdcall "sqlext.h SQLBindCol" sqlBindCol :: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr a -> SQLLEN -> Ptr SQLINTEGER -> IO SQLRETURN foreign import stdcall "sqlext.h SQLFetch" sqlFetch :: HSTMT -> IO SQLRETURN foreign import stdcall "sqlext.h SQLGetDiagRec" sqlGetDiagRec :: SQLSMALLINT -> SQLHANDLE -> SQLSMALLINT -> CString -> Ptr SQLINTEGER -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN foreign import stdcall "sqlext.h SQLExecDirect" sqlExecDirect :: HSTMT -> CString -> Int -> IO SQLRETURN foreign import stdcall "sqlext.h SQLSetConnectOption" sqlSetConnectOption :: HDBC -> SQLUSMALLINT -> SQLULEN -> IO SQLRETURN foreign import stdcall "sqlext.h SQLTransact" sqlTransact :: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN data Connection = Connection { hDBC :: HDBC , environment :: HENVRef } type FieldDef = (String, SqlType, Bool, Int) data Statement = Statement { hSTMT :: HSTMT , connection :: Connection , fields :: [FieldDef] , fetchBuffer :: Ptr () } data SqlType = SqlChar Int | SqlVarChar Int | SqlLongVarChar Int | SqlDecimal Int Int | SqlNumeric Int Int | SqlSmallInt | SqlInteger | SqlReal | SqlDouble | SqlBit | SqlTinyInt | SqlBigInt | SqlBinary Int | SqlVarBinary Int | SqlLongVarBinary Int | SqlDate | SqlTime | SqlTimeStamp deriving (Eq, Show) data SqlError = SqlError { seState :: String , seNativeError :: Int , seErrorMsg :: String } | SqlNoData | SqlInvalidHandle | SqlStillExecuting | SqlNeedData deriving Show ----------------------------------------------------------------------------------------- -- routines for handling exceptions ----------------------------------------------------------------------------------------- {-# NOINLINE sqlErrorTy #-} sqlErrorTy = mkAppTy (mkTyCon "SqlError") [] instance Typeable SqlError where typeOf x = sqlErrorTy catchSql :: IO a -> (SqlError -> IO a) -> IO a catchSql = catchDyn sqlSuccess :: SQLRETURN -> Bool sqlSuccess res = (res == (#const SQL_SUCCESS)) || (res == (#const SQL_SUCCESS_WITH_INFO)) || (res == (#const SQL_NO_DATA)) handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO () handleSqlResult handleType handle res | sqlSuccess res = return () | res == (#const SQL_INVALID_HANDLE) = throwDyn SqlInvalidHandle | res == (#const SQL_STILL_EXECUTING) = throwDyn SqlStillExecuting | res == (#const SQL_NEED_DATA) = throwDyn SqlNeedData | res == (#const SQL_ERROR) = do pState <- mallocBytes 256 pNative <- malloc pMsg <- mallocBytes 256 pTextLen <- malloc sqlGetDiagRec handleType handle 1 pState pNative pMsg 256 pTextLen state <- peekCString pState free pState native <- peek pNative free pNative msg <- peekCString pMsg free pMsg free pTextLen throwDyn (SqlError {seState=state, seNativeError=fromIntegral native, seErrorMsg=msg}) | otherwise = error (show res) ----------------------------------------------------------------------------------------- -- keeper of HENV ----------------------------------------------------------------------------------------- {-# NOINLINE myEnvironment #-} myEnvironment :: HENVRef myEnvironment = unsafePerformIO $ do (phEnv :: Ptr HENV) <- malloc res <- sqlAllocEnv phEnv hEnv <- peek phEnv free phEnv handleSqlResult 0 nullPtr res newForeignPtr hEnv (closeEnvironment hEnv) where closeEnvironment :: HENV -> IO () closeEnvironment hEnv = sqlFreeEnv hEnv >>= handleSqlResult (#const SQL_HANDLE_ENV) hEnv ----------------------------------------------------------------------------------------- -- Connect/Disconnect ----------------------------------------------------------------------------------------- connect :: String -> String -> String -> IO Connection connect server user authentication = withForeignPtr myEnvironment $ \hEnv -> do (phDBC :: Ptr HDBC) <- malloc res <- sqlAllocConnect hEnv phDBC hDBC <- peek phDBC free phDBC handleSqlResult (#const SQL_HANDLE_ENV) hEnv res pServer <- newCString server pUser <- newCString user pAuthentication <- newCString authentication res <- sqlConnect hDBC pServer (length server) pUser (length user) pAuthentication (length authentication) free pServer free pUser free pAuthentication handleSqlResult (#const SQL_HANDLE_ENV) hEnv res return (Connection {hDBC=hDBC, environment=myEnvironment}) disconnect :: Connection -> IO () disconnect (Connection {hDBC=hDBC}) = do sqlDisconnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC return () ----------------------------------------------------------------------------------------- -- queries ----------------------------------------------------------------------------------------- execute :: Connection -> String -> IO Statement execute conn@(Connection {hDBC=hDBC}) query = do pFIELD <- mallocBytes (#const sizeof(FIELD)) res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD) when (not (sqlSuccess res)) (free pFIELD) handleSqlResult (#const SQL_HANDLE_DBC) hDBC res hSTMT <- (#peek FIELD, hSTMT) pFIELD let handleResult res = do when (not (sqlSuccess res)) (free pFIELD) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res pQuery <- newCString query res <- sqlExecDirect hSTMT pQuery (length query) free pQuery handleResult res sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) >>= handleResult count <- (#peek FIELD, fieldsCount) pFIELD (fields, offs) <- createBindState hSTMT pFIELD 0 1 count free pFIELD buffer <- mallocBytes offs let statement = Statement {hSTMT=hSTMT, connection=conn, fields=fields, fetchBuffer=buffer} catchSql (bindFields hSTMT buffer 1 fields) (errHandler statement) return statement where errHandler statement err = do closeStatement statement throwDyn err createBindState :: HSTMT -> Ptr a -> Int -> SQLUSMALLINT -> SQLUSMALLINT -> IO ([FieldDef], Int) createBindState hSTMT pFIELD offs n count | n > count = return ([], offs) | otherwise = do res <- sqlDescribeCol hSTMT n ((#ptr FIELD, fieldName) pFIELD) (#const FIELD_NAME_LENGTH) ((#ptr FIELD, NameLength) pFIELD) ((#ptr FIELD, DataType) pFIELD) ((#ptr FIELD, ColumnSize) pFIELD) ((#ptr FIELD, DecimalDigits) pFIELD) ((#ptr FIELD, Nullable) pFIELD) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res name <- peekCString ((#ptr FIELD, fieldName) pFIELD) dataType <- (#peek FIELD, DataType) pFIELD columnSize <- (#peek FIELD, ColumnSize) pFIELD decimalDigits <- (#peek FIELD, DecimalDigits) pFIELD (nullable :: SQLSMALLINT) <- (#peek FIELD, Nullable) pFIELD let (sqlType, offs') = mkSqlType dataType columnSize decimalDigits (offs+(#const sizeof(SQLINTEGER))) (fields, offs'') <- createBindState hSTMT pFIELD offs' (n+1) count return ((name,sqlType,toBool nullable,offs):fields, offs'') bindFields :: HSTMT -> Ptr () -> SQLUSMALLINT -> [FieldDef] -> IO () bindFields hSTMT fetchBuffer n [] = return () bindFields hSTMT fetchBuffer n ((name,SqlChar size, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlVarChar size, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlLongVarChar size, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlDecimal size prec,nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlNumeric size prec,nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlSmallInt, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlInteger, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlReal, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlDouble, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlBit, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlTinyInt, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlBigInt, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlBinary size, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlVarBinary size, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlLongVarBinary size,nullable,offs):fields)= do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlDate, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_DATE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_DATE_STRUCT)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlTime, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIME) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_TIME_STRUCT)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields bindFields hSTMT fetchBuffer n ((name,SqlTimeStamp, nullable,offs):fields) = do let buffer = fetchBuffer `plusPtr` offs res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIMESTAMP) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_TIMESTAMP_STRUCT)) (castPtr buffer) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res bindFields hSTMT fetchBuffer (n+1) fields mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> Int -> (SqlType, Int) mkSqlType (#const SQL_CHAR) size _ offs = (SqlChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1)) mkSqlType (#const SQL_VARCHAR) size _ offs = (SqlVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1)) mkSqlType (#const SQL_LONGVARCHAR) size _ offs = (SqlLongVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1)) mkSqlType (#const SQL_DECIMAL) size prec offs = (SqlDecimal (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE))) mkSqlType (#const SQL_NUMERIC) size prec offs = (SqlNumeric (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE))) mkSqlType (#const SQL_SMALLINT) _ _ offs = (SqlSmallInt, offs + (#const sizeof(SQLSMALLINT))) mkSqlType (#const SQL_INTEGER) _ _ offs = (SqlInteger, offs + (#const sizeof(SQLINTEGER))) mkSqlType (#const SQL_REAL) _ _ offs = (SqlReal, offs + (#const sizeof(SQLDOUBLE))) mkSqlType (#const SQL_DOUBLE) _ _ offs = (SqlDouble, offs + (#const sizeof(SQLDOUBLE))) mkSqlType (#const SQL_BIT) _ _ offs = (SqlBit, offs + (#const sizeof(SQLINTEGER))) mkSqlType (#const SQL_TINYINT) _ _ offs = (SqlTinyInt, offs + (#const sizeof(SQLSMALLINT))) mkSqlType (#const SQL_BIGINT) _ _ offs = (SqlBigInt, offs + (#const sizeof(SQLINTEGER))) mkSqlType (#const SQL_BINARY) size _ offs = (SqlBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1)) mkSqlType (#const SQL_VARBINARY) size _ offs = (SqlVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1)) mkSqlType (#const SQL_LONGVARBINARY)size _ offs = (SqlLongVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1)) mkSqlType (#const SQL_DATE) _ _ offs = (SqlDate, offs + (#const sizeof(SQL_DATE_STRUCT))) mkSqlType (#const SQL_TIME) _ _ offs = (SqlTime, offs + (#const sizeof(SQL_TIME_STRUCT))) mkSqlType (#const SQL_TIMESTAMP) _ _ offs = (SqlTimeStamp, offs + (#const sizeof(SQL_TIMESTAMP_STRUCT))) {-# NOINLINE fetch #-} fetch :: Statement -> IO Bool fetch stmt = do res <- sqlFetch (hSTMT stmt) handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt) res return (res /= (#const SQL_NO_DATA)) closeStatement :: Statement -> IO () closeStatement stmt = do sqlFreeStmt (hSTMT stmt) 0 >>= handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt) free (fetchBuffer stmt) ----------------------------------------------------------------------------------------- -- transactions ----------------------------------------------------------------------------------------- inTransaction :: Connection -> (Connection -> IO a) -> IO a inTransaction conn@(Connection {hDBC=hDBC, environment=envRef}) action = withForeignPtr envRef $ \hEnv -> do sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_OFF) r <- catchSql (action conn) (\err -> do sqlTransact hEnv hDBC (#const SQL_ROLLBACK) sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON) throwDyn err) sqlTransact hEnv hDBC (#const SQL_COMMIT) sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON) return r ----------------------------------------------------------------------------------------- -- binding ----------------------------------------------------------------------------------------- class SqlBind a where getValue :: SqlType -> Ptr () -> IO a instance SqlBind Int where getValue SqlInteger ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER)))) getValue SqlSmallInt ptr = do (n :: Int16) <- peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER)))) return (fromIntegral n) instance SqlBind String where getValue (SqlChar size) ptr = do len <- peek (castPtr ptr) peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len) getValue (SqlVarChar size) ptr = do len <- peek (castPtr ptr) peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len) getValue (SqlLongVarChar size) ptr = do len <- peek (castPtr ptr) peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len) instance SqlBind Double where getValue (SqlDecimal size prec) ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER)))) getValue (SqlNumeric size prec) ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER)))) getValue SqlDouble ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER)))) getValue SqlReal ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER)))) getFieldValue :: SqlBind a => Statement -> String -> IO a getFieldValue stmt name = getValue sqlType ((fetchBuffer stmt) `plusPtr` offs) where (_,sqlType,nullable,offs) = findField name (fields stmt) getFieldValueType :: Statement -> String -> (SqlType, Bool) getFieldValueType stmt name = (sqlType, nullable) where (_,sqlType,nullable,offs) = findField name (fields stmt) getFieldsTypes :: Statement -> [(String, SqlType, Bool)] getFieldsTypes stmt = map (\(name,sqlType,nullable,_) -> (name,sqlType,nullable)) (fields stmt) findField :: String -> [FieldDef] -> FieldDef findField name [] = error (name ++ "??") findField name (fieldDef@(name',_,_,_):fields) | name == name' = fieldDef | otherwise = findField name fields ----------------------------------------------------------------------------------------- -- helpers ----------------------------------------------------------------------------------------- forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s forEachRow f stmt s = do success <- fetch stmt if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> return s forEachRow' :: (Statement -> IO ()) -> Statement -> IO () forEachRow' f stmt = do success <- fetch stmt if success then f stmt >> forEachRow' f stmt else closeStatement stmt collectRows :: (Statement -> IO a) -> Statement -> IO [a] collectRows f stmt = loop where loop = do success <- fetch stmt if success then do x <- f stmt xs <- loop return (x:xs) else closeStatement stmt >> return []