Implement early version of CodeGen
This commit is contained in:
parent
5a63229e74
commit
3d17813eb4
@ -1,8 +1,351 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE RecursiveDo #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
|
||||
module Windows12.CodeGen where
|
||||
|
||||
import Data.Text (Text)
|
||||
import Windows12.Ast (BinOp(..), UnOp(..), AssignOp(..), Type(..),
|
||||
Bind(..), TLStruct(..), TLEnum(..))
|
||||
import Windows12.TAst
|
||||
import LLVM.AST (Module)
|
||||
|
||||
import LLVM.AST hiding (ArrayType, VoidType, Call, function)
|
||||
import LLVM.AST.Type (i32, i1, i8, double, ptr, void)
|
||||
import qualified LLVM.AST.Constant as C
|
||||
import LLVM.IRBuilder hiding (double, IRBuilder, ModuleBuilder)
|
||||
import LLVM.AST.Typed (typeOf)
|
||||
import LLVM.Prelude (ShortByteString)
|
||||
|
||||
import qualified LLVM.AST.IntegerPredicate as IP
|
||||
import qualified LLVM.AST.FloatingPointPredicate as FP
|
||||
|
||||
import Control.Monad.State hiding (void)
|
||||
|
||||
import Data.Text (Text, unpack)
|
||||
import Data.String.Conversions
|
||||
import Data.String
|
||||
|
||||
-- Global program context, used to keep track of operands
|
||||
data Ctx = Ctx { operands :: [(Text, Operand)],
|
||||
structs :: [TLStruct],
|
||||
enums :: [TLEnum],
|
||||
strings :: [(Text, Operand)] }
|
||||
deriving (Eq, Show)
|
||||
|
||||
type ModuleBuilder = ModuleBuilderT (State Ctx)
|
||||
type IRBuilder = IRBuilderT ModuleBuilder
|
||||
|
||||
-- Allow easy string conversion
|
||||
instance ConvertibleStrings Text ShortByteString where
|
||||
convertString = Data.String.fromString . Data.Text.unpack
|
||||
|
||||
-- Put an operand into the context with a name
|
||||
createOperand :: MonadState Ctx m => Text -> Operand -> m ()
|
||||
createOperand name op = do
|
||||
ctx <- get
|
||||
put $ ctx { operands = (name, op) : operands ctx }
|
||||
|
||||
-- Take in a source file name, the AST, and return the LLVM IR module
|
||||
codegen :: Text -> TProgram -> Module
|
||||
codegen filename (TProgram structs enums funcs) = undefined
|
||||
codegen filename (TProgram structs enums funcs) =
|
||||
flip evalState (Ctx [] [] [] [])
|
||||
$ buildModuleT (cs filename)
|
||||
$ do
|
||||
printf <- externVarArgs (mkName "printf") [ptr i8] i32
|
||||
createOperand "printf" printf
|
||||
mapM_ emitTypeDef structs
|
||||
mapM_ codegenFunc funcs
|
||||
|
||||
-- Given a struct name, search the context for the struct and return its fields
|
||||
getStructFields :: MonadState Ctx m => Text -> m [Bind]
|
||||
getStructFields name = do
|
||||
ctx <- get
|
||||
case filter (\(Struct n _) -> n == name) (structs ctx) of
|
||||
[] -> error $ "Struct " ++ show name ++ " not found. Valid structs: " ++ show (map (\(Struct n _) -> n) (structs ctx))
|
||||
[Struct _ fields] -> return fields
|
||||
_ -> error $ "Multiple structs with name " ++ show name
|
||||
|
||||
-- Convert a Windows12 type to an LLVM type
|
||||
convertType :: MonadState Ctx m => Windows12.Ast.Type -> m LLVM.AST.Type
|
||||
convertType IntType = return i32
|
||||
convertType UIntType = return i32
|
||||
convertType FloatType = return double
|
||||
convertType StrType = convertType (PtrType CharType)
|
||||
convertType BoolType = return i1
|
||||
convertType CharType = return i8
|
||||
convertType (PtrType t) = ptr <$> convertType t
|
||||
convertType (ArrayType t) = convertType (PtrType t)
|
||||
convertType (StructType name) = do
|
||||
fields <- getStructFields name
|
||||
types <- mapM (convertType . bindType) fields
|
||||
return $ StructureType True types -- True indicates packed
|
||||
convertType (EnumType name) = return i32
|
||||
convertType VoidType = return void
|
||||
|
||||
-- Get the size of a type in bytes
|
||||
size :: MonadState Ctx m => Windows12.Ast.Type -> m Int
|
||||
size IntType = return 4
|
||||
size UIntType = return 4
|
||||
size FloatType = return 8
|
||||
size StrType = size (PtrType CharType)
|
||||
size BoolType = return 1
|
||||
size CharType = return 1
|
||||
size (PtrType _) = return 4
|
||||
size (ArrayType t) = size (PtrType t)
|
||||
size (StructType name) = do
|
||||
fields <- getStructFields name
|
||||
sizes <- mapM (size . bindType) fields
|
||||
return $ sum sizes
|
||||
size (EnumType _) = return 8
|
||||
size VoidType = return 0
|
||||
|
||||
-- CodeGen for LValues
|
||||
codegenLVal :: TLVal -> IRBuilder Operand
|
||||
codegenLVal (t, (TId name)) = do
|
||||
ctx <- get
|
||||
case lookup name (operands ctx) of
|
||||
Just op -> return op
|
||||
Nothing -> error $ "Variable " ++ show name ++ " not found"
|
||||
|
||||
-- TODO support members of members
|
||||
codegenLVal ((StructType t), (LTMember ((_, TId sName)) field)) = do
|
||||
ctx <- get
|
||||
case lookup sName (operands ctx) of
|
||||
Just struct -> do
|
||||
fields <- getStructFields t
|
||||
offset <- structFieldOffset (Struct sName fields) field
|
||||
gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))]
|
||||
Nothing -> error $ "Struct " ++ show sName ++ " not found"
|
||||
|
||||
codeGenLVal (t, (TDeref e)) = codegenExpr e
|
||||
|
||||
codeGenLVal (t, _) = error "Unimplemented or invalid LValue"
|
||||
|
||||
-- Given a struct and a field name, return the offset of the field in the struct.
|
||||
-- In LLVM each field is actually size 1
|
||||
structFieldOffset :: MonadState Ctx m => TLStruct -> Text -> m Int
|
||||
structFieldOffset (Struct name fields) field = do
|
||||
return $ length $ takeWhile (\(Bind n _) -> n /= field) fields
|
||||
|
||||
-- CodeGen for expressions
|
||||
codegenExpr :: TExpr -> IRBuilder Operand
|
||||
codegenExpr (t, (TVar name)) = flip load 0 =<< codegenLVal (t, (TId name))
|
||||
codegenExpr (t, (TIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i))
|
||||
codegenExpr (t, (TUIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i))
|
||||
codegenExpr (t, (TFloatLit f)) = undefined -- TODO floats
|
||||
codegenExpr (t, (TStrLit s)) = do
|
||||
strs <- gets strings
|
||||
case lookup s strs of
|
||||
-- If the string is already in the context, return it
|
||||
Just str -> return str
|
||||
-- Otherwise, create a new global string and add it to the context
|
||||
Nothing -> do
|
||||
let str_name = mkName ("str." <> show (length strs))
|
||||
op <- globalStringPtr (cs s) str_name
|
||||
modify $ \ctx -> ctx { strings = (s, (ConstantOperand op)) : strs }
|
||||
return (ConstantOperand op)
|
||||
codegenExpr (t, (TBoolLit b)) = return $ ConstantOperand (C.Int 1 (if b then 1 else 0))
|
||||
codegenExpr (t, (TCharLit c)) = return $ ConstantOperand (C.Int 8 (fromIntegral (fromEnum c)))
|
||||
|
||||
codegenExpr (t, (TBinOp op lhs rhs)) = do
|
||||
lhs' <- codegenExpr lhs
|
||||
rhs' <- codegenExpr rhs
|
||||
|
||||
-- TODO pointers, floating points
|
||||
case op of
|
||||
Windows12.Ast.Add -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> add lhs' rhs'
|
||||
_ -> error "Invalid types for add"
|
||||
Windows12.Ast.Sub -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> sub lhs' rhs'
|
||||
_ -> error "Invalid types for sub"
|
||||
Windows12.Ast.Mul -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> mul lhs' rhs'
|
||||
_ -> error "Invalid types for mul"
|
||||
Windows12.Ast.Div -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> sdiv lhs' rhs'
|
||||
_ -> error "Invalid types for div"
|
||||
Windows12.Ast.Mod -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> srem lhs' rhs'
|
||||
_ -> error "Invalid types for mod"
|
||||
Windows12.Ast.Eq -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> icmp IP.EQ lhs' rhs'
|
||||
_ -> error "Invalid types for eq"
|
||||
Windows12.Ast.Ne -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> icmp IP.NE lhs' rhs'
|
||||
_ -> error "Invalid types for ne"
|
||||
Windows12.Ast.Lt -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> icmp IP.SLT lhs' rhs'
|
||||
_ -> error "Invalid types for lt"
|
||||
Windows12.Ast.Gt -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> icmp IP.SGT lhs' rhs'
|
||||
_ -> error "Invalid types for gt"
|
||||
Windows12.Ast.Le -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> icmp IP.SLE lhs' rhs'
|
||||
_ -> error "Invalid types for le"
|
||||
Windows12.Ast.Ge -> case (typeOf lhs', typeOf rhs') of
|
||||
(IntegerType 32, IntegerType 32) -> icmp IP.SGE lhs' rhs'
|
||||
_ -> error "Invalid types for ge"
|
||||
|
||||
other -> error $ "Operator " ++ show other ++ " not implemented"
|
||||
|
||||
codegenExpr (t, (TUnOp op e)) = undefined -- TODO handle unary operators
|
||||
|
||||
-- Function calls: look up the function in operands, then call it with the args
|
||||
codegenExpr (t, (TCall f args)) = do
|
||||
ctx <- get
|
||||
f <- case lookup f (operands ctx) of
|
||||
Just f -> return f
|
||||
Nothing -> error $ "Function " ++ show f ++ " not found"
|
||||
args <- mapM (fmap (, []) . codegenExpr) args
|
||||
call f args
|
||||
|
||||
codegenExpr (t, (TIndex arr idx)) = undefined -- TODO arrays
|
||||
|
||||
-- Get the address of the struct field and load it
|
||||
codegenExpr (t, (TMember ((StructType sName), (TVar sVarName)) m)) = do
|
||||
ctx <- get
|
||||
case lookup sVarName (operands ctx) of
|
||||
Just struct -> do
|
||||
fields <- getStructFields sName
|
||||
offset <- structFieldOffset (Struct sVarName fields) m
|
||||
addr <- gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))]
|
||||
load addr 0
|
||||
Nothing -> error $ "Struct operand " ++ show sVarName ++ " not found"
|
||||
|
||||
codegenExpr (_, (TCast t e)) = undefined -- TODO casts
|
||||
|
||||
codegenExpr (_, (TSizeof t)) = ConstantOperand . C.Int 32 . fromIntegral <$> size t
|
||||
|
||||
mkTerminator :: IRBuilder () -> IRBuilder ()
|
||||
mkTerminator instr = do
|
||||
check <- hasTerminator
|
||||
unless check instr
|
||||
|
||||
-- Codegen for statements
|
||||
codegenStmt :: TStmt -> IRBuilder ()
|
||||
|
||||
-- For expression statements, just evaluate the expression and discard the result
|
||||
codegenStmt (TExprStmt e) = do
|
||||
_expr <- codegenExpr e
|
||||
return ()
|
||||
|
||||
codegenStmt (TReturn e) = ret =<< codegenExpr e
|
||||
|
||||
-- Generate if statements, with a merge block at the end
|
||||
codegenStmt (TIf cond t f) = mdo
|
||||
cond' <- codegenExpr cond
|
||||
condBr cond' then' else'
|
||||
|
||||
then' <- block `named` "then"
|
||||
codegenStmt (TBlock t)
|
||||
mkTerminator $ br merge
|
||||
|
||||
else' <- block `named` "else"
|
||||
codegenStmt (case f of
|
||||
Just f' -> TBlock f'
|
||||
Nothing -> TBlock [])
|
||||
mkTerminator $ br merge
|
||||
|
||||
merge <- block `named` "merge"
|
||||
return ()
|
||||
|
||||
-- Generate while loops, with a merge block at the end
|
||||
codegenStmt (TWhile cond body) = mdo
|
||||
br condBlock
|
||||
condBlock <- block `named` "cond"
|
||||
cond' <- codegenExpr cond
|
||||
condBr cond' loop end
|
||||
|
||||
loop <- block `named` "loop"
|
||||
codegenStmt (TBlock body)
|
||||
mkTerminator $ br condBlock
|
||||
|
||||
end <- block `named` "end"
|
||||
return ()
|
||||
|
||||
codegenStmt (TAssign BaseAssign l@(t, (TId name)) e) = do
|
||||
op <- codegenExpr e
|
||||
var <- codegenLVal l
|
||||
store var 0 op
|
||||
|
||||
codegenStmt (TAssign BaseAssign l@((StructType tName), (LTMember ((_, TId sName)) field)) e) = do
|
||||
op <- codegenExpr e
|
||||
struct <- codegenLVal l
|
||||
store struct 0 op
|
||||
|
||||
codegenStmt (TAssign AddAssign l@(t, (TId name)) e) = do
|
||||
op <- codegenExpr e
|
||||
var <- codegenLVal l
|
||||
val <- load var 0
|
||||
store var 0 =<< add val op
|
||||
|
||||
codegenStmt (TAssign SubAssign l@(t, (TId name)) e) = do
|
||||
op <- codegenExpr e
|
||||
var <- codegenLVal l
|
||||
val <- load var 0
|
||||
store var 0 =<< sub val op
|
||||
|
||||
-- A block is just a list of statements
|
||||
codegenStmt (TBlock stmts) = mapM_ codegenStmt stmts
|
||||
|
||||
-- Since the vars are already allocated by genBody, we just need to assign the value
|
||||
codegenStmt (TDeclVar name t (Just e)) = codegenStmt (TAssign BaseAssign (t, (TId name)) e)
|
||||
|
||||
-- Do nothing with variable declaration if no expression is given
|
||||
-- This is because allocation is done already
|
||||
codegenStmt (TDeclVar name _ Nothing) = return ()
|
||||
|
||||
codegenStmt s = error $ "Unimplemented or invalid statement " ++ show s
|
||||
|
||||
-- Generate code for a function
|
||||
-- First create the function, then allocate space for the arguments and locals
|
||||
codegenFunc :: TTLFunc -> ModuleBuilder ()
|
||||
codegenFunc func@(TTLFunc name args retType body) = mdo
|
||||
createOperand name f
|
||||
(f, strs) <- do
|
||||
params' <- mapM mkParam args
|
||||
retType' <- convertType retType
|
||||
f <- function (mkName (cs name)) params' retType' genBody
|
||||
strs <- gets strings
|
||||
return (f, strs)
|
||||
modify $ \ctx -> ctx { strings = strs }
|
||||
where
|
||||
mkParam (Bind name t) = (,) <$> convertType t <*> pure (ParameterName (cs name))
|
||||
|
||||
genBody :: [Operand] -> IRBuilder ()
|
||||
genBody ops = do
|
||||
forM_ (zip ops args) $ \(op, (Bind name t)) -> do
|
||||
addr <- alloca (typeOf op) Nothing 0
|
||||
store addr 0 op
|
||||
createOperand name addr
|
||||
|
||||
forM_ (getLocals func) $ \(Bind name t) -> do
|
||||
ltype <- convertType t
|
||||
addr <- alloca ltype Nothing 0
|
||||
createOperand name addr
|
||||
|
||||
codegenStmt (TBlock body)
|
||||
|
||||
-- Given a function, get all the local variables
|
||||
-- Used so allocation can be done before the function body
|
||||
getLocals :: TTLFunc -> [Bind]
|
||||
getLocals (TTLFunc _ args _ body) = blockGetLocals body
|
||||
|
||||
blockGetLocals :: [TStmt] -> [Bind]
|
||||
blockGetLocals = concatMap stmtGetLocals
|
||||
|
||||
stmtGetLocals :: TStmt -> [Bind]
|
||||
stmtGetLocals (TDeclVar n t _) = [Bind n t]
|
||||
stmtGetLocals (TBlock stmts) = blockGetLocals stmts
|
||||
stmtGetLocals (TIf _ t f) = blockGetLocals t ++ maybe [] blockGetLocals f
|
||||
stmtGetLocals (TWhile _ body) = blockGetLocals body
|
||||
stmtGetLocals _ = []
|
||||
|
||||
-- Create structs
|
||||
emitTypeDef :: TLStruct -> ModuleBuilder LLVM.AST.Type
|
||||
emitTypeDef (Struct name fields) = do
|
||||
modify $ \ctx -> ctx { structs = Struct name fields : structs ctx }
|
||||
sType <- convertType (StructType name)
|
||||
typedef (mkName (cs ("struct." <> name))) (Just sType)
|
||||
|
Loading…
x
Reference in New Issue
Block a user