Implement early version of CodeGen

This commit is contained in:
Ethan Girouard 2024-12-11 16:11:05 -05:00
parent 5a63229e74
commit 3d17813eb4
Signed by: eta357
GPG Key ID: 7BCDC36DFD11C146

View File

@ -1,8 +1,351 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Windows12.CodeGen where module Windows12.CodeGen where
import Data.Text (Text) import Windows12.Ast (BinOp(..), UnOp(..), AssignOp(..), Type(..),
Bind(..), TLStruct(..), TLEnum(..))
import Windows12.TAst import Windows12.TAst
import LLVM.AST (Module)
import LLVM.AST hiding (ArrayType, VoidType, Call, function)
import LLVM.AST.Type (i32, i1, i8, double, ptr, void)
import qualified LLVM.AST.Constant as C
import LLVM.IRBuilder hiding (double, IRBuilder, ModuleBuilder)
import LLVM.AST.Typed (typeOf)
import LLVM.Prelude (ShortByteString)
import qualified LLVM.AST.IntegerPredicate as IP
import qualified LLVM.AST.FloatingPointPredicate as FP
import Control.Monad.State hiding (void)
import Data.Text (Text, unpack)
import Data.String.Conversions
import Data.String
-- Global program context, used to keep track of operands
data Ctx = Ctx { operands :: [(Text, Operand)],
structs :: [TLStruct],
enums :: [TLEnum],
strings :: [(Text, Operand)] }
deriving (Eq, Show)
type ModuleBuilder = ModuleBuilderT (State Ctx)
type IRBuilder = IRBuilderT ModuleBuilder
-- Allow easy string conversion
instance ConvertibleStrings Text ShortByteString where
convertString = Data.String.fromString . Data.Text.unpack
-- Put an operand into the context with a name
createOperand :: MonadState Ctx m => Text -> Operand -> m ()
createOperand name op = do
ctx <- get
put $ ctx { operands = (name, op) : operands ctx }
-- Take in a source file name, the AST, and return the LLVM IR module
codegen :: Text -> TProgram -> Module codegen :: Text -> TProgram -> Module
codegen filename (TProgram structs enums funcs) = undefined codegen filename (TProgram structs enums funcs) =
flip evalState (Ctx [] [] [] [])
$ buildModuleT (cs filename)
$ do
printf <- externVarArgs (mkName "printf") [ptr i8] i32
createOperand "printf" printf
mapM_ emitTypeDef structs
mapM_ codegenFunc funcs
-- Given a struct name, search the context for the struct and return its fields
getStructFields :: MonadState Ctx m => Text -> m [Bind]
getStructFields name = do
ctx <- get
case filter (\(Struct n _) -> n == name) (structs ctx) of
[] -> error $ "Struct " ++ show name ++ " not found. Valid structs: " ++ show (map (\(Struct n _) -> n) (structs ctx))
[Struct _ fields] -> return fields
_ -> error $ "Multiple structs with name " ++ show name
-- Convert a Windows12 type to an LLVM type
convertType :: MonadState Ctx m => Windows12.Ast.Type -> m LLVM.AST.Type
convertType IntType = return i32
convertType UIntType = return i32
convertType FloatType = return double
convertType StrType = convertType (PtrType CharType)
convertType BoolType = return i1
convertType CharType = return i8
convertType (PtrType t) = ptr <$> convertType t
convertType (ArrayType t) = convertType (PtrType t)
convertType (StructType name) = do
fields <- getStructFields name
types <- mapM (convertType . bindType) fields
return $ StructureType True types -- True indicates packed
convertType (EnumType name) = return i32
convertType VoidType = return void
-- Get the size of a type in bytes
size :: MonadState Ctx m => Windows12.Ast.Type -> m Int
size IntType = return 4
size UIntType = return 4
size FloatType = return 8
size StrType = size (PtrType CharType)
size BoolType = return 1
size CharType = return 1
size (PtrType _) = return 4
size (ArrayType t) = size (PtrType t)
size (StructType name) = do
fields <- getStructFields name
sizes <- mapM (size . bindType) fields
return $ sum sizes
size (EnumType _) = return 8
size VoidType = return 0
-- CodeGen for LValues
codegenLVal :: TLVal -> IRBuilder Operand
codegenLVal (t, (TId name)) = do
ctx <- get
case lookup name (operands ctx) of
Just op -> return op
Nothing -> error $ "Variable " ++ show name ++ " not found"
-- TODO support members of members
codegenLVal ((StructType t), (LTMember ((_, TId sName)) field)) = do
ctx <- get
case lookup sName (operands ctx) of
Just struct -> do
fields <- getStructFields t
offset <- structFieldOffset (Struct sName fields) field
gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))]
Nothing -> error $ "Struct " ++ show sName ++ " not found"
codeGenLVal (t, (TDeref e)) = codegenExpr e
codeGenLVal (t, _) = error "Unimplemented or invalid LValue"
-- Given a struct and a field name, return the offset of the field in the struct.
-- In LLVM each field is actually size 1
structFieldOffset :: MonadState Ctx m => TLStruct -> Text -> m Int
structFieldOffset (Struct name fields) field = do
return $ length $ takeWhile (\(Bind n _) -> n /= field) fields
-- CodeGen for expressions
codegenExpr :: TExpr -> IRBuilder Operand
codegenExpr (t, (TVar name)) = flip load 0 =<< codegenLVal (t, (TId name))
codegenExpr (t, (TIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i))
codegenExpr (t, (TUIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i))
codegenExpr (t, (TFloatLit f)) = undefined -- TODO floats
codegenExpr (t, (TStrLit s)) = do
strs <- gets strings
case lookup s strs of
-- If the string is already in the context, return it
Just str -> return str
-- Otherwise, create a new global string and add it to the context
Nothing -> do
let str_name = mkName ("str." <> show (length strs))
op <- globalStringPtr (cs s) str_name
modify $ \ctx -> ctx { strings = (s, (ConstantOperand op)) : strs }
return (ConstantOperand op)
codegenExpr (t, (TBoolLit b)) = return $ ConstantOperand (C.Int 1 (if b then 1 else 0))
codegenExpr (t, (TCharLit c)) = return $ ConstantOperand (C.Int 8 (fromIntegral (fromEnum c)))
codegenExpr (t, (TBinOp op lhs rhs)) = do
lhs' <- codegenExpr lhs
rhs' <- codegenExpr rhs
-- TODO pointers, floating points
case op of
Windows12.Ast.Add -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> add lhs' rhs'
_ -> error "Invalid types for add"
Windows12.Ast.Sub -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> sub lhs' rhs'
_ -> error "Invalid types for sub"
Windows12.Ast.Mul -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> mul lhs' rhs'
_ -> error "Invalid types for mul"
Windows12.Ast.Div -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> sdiv lhs' rhs'
_ -> error "Invalid types for div"
Windows12.Ast.Mod -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> srem lhs' rhs'
_ -> error "Invalid types for mod"
Windows12.Ast.Eq -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> icmp IP.EQ lhs' rhs'
_ -> error "Invalid types for eq"
Windows12.Ast.Ne -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> icmp IP.NE lhs' rhs'
_ -> error "Invalid types for ne"
Windows12.Ast.Lt -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> icmp IP.SLT lhs' rhs'
_ -> error "Invalid types for lt"
Windows12.Ast.Gt -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> icmp IP.SGT lhs' rhs'
_ -> error "Invalid types for gt"
Windows12.Ast.Le -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> icmp IP.SLE lhs' rhs'
_ -> error "Invalid types for le"
Windows12.Ast.Ge -> case (typeOf lhs', typeOf rhs') of
(IntegerType 32, IntegerType 32) -> icmp IP.SGE lhs' rhs'
_ -> error "Invalid types for ge"
other -> error $ "Operator " ++ show other ++ " not implemented"
codegenExpr (t, (TUnOp op e)) = undefined -- TODO handle unary operators
-- Function calls: look up the function in operands, then call it with the args
codegenExpr (t, (TCall f args)) = do
ctx <- get
f <- case lookup f (operands ctx) of
Just f -> return f
Nothing -> error $ "Function " ++ show f ++ " not found"
args <- mapM (fmap (, []) . codegenExpr) args
call f args
codegenExpr (t, (TIndex arr idx)) = undefined -- TODO arrays
-- Get the address of the struct field and load it
codegenExpr (t, (TMember ((StructType sName), (TVar sVarName)) m)) = do
ctx <- get
case lookup sVarName (operands ctx) of
Just struct -> do
fields <- getStructFields sName
offset <- structFieldOffset (Struct sVarName fields) m
addr <- gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))]
load addr 0
Nothing -> error $ "Struct operand " ++ show sVarName ++ " not found"
codegenExpr (_, (TCast t e)) = undefined -- TODO casts
codegenExpr (_, (TSizeof t)) = ConstantOperand . C.Int 32 . fromIntegral <$> size t
mkTerminator :: IRBuilder () -> IRBuilder ()
mkTerminator instr = do
check <- hasTerminator
unless check instr
-- Codegen for statements
codegenStmt :: TStmt -> IRBuilder ()
-- For expression statements, just evaluate the expression and discard the result
codegenStmt (TExprStmt e) = do
_expr <- codegenExpr e
return ()
codegenStmt (TReturn e) = ret =<< codegenExpr e
-- Generate if statements, with a merge block at the end
codegenStmt (TIf cond t f) = mdo
cond' <- codegenExpr cond
condBr cond' then' else'
then' <- block `named` "then"
codegenStmt (TBlock t)
mkTerminator $ br merge
else' <- block `named` "else"
codegenStmt (case f of
Just f' -> TBlock f'
Nothing -> TBlock [])
mkTerminator $ br merge
merge <- block `named` "merge"
return ()
-- Generate while loops, with a merge block at the end
codegenStmt (TWhile cond body) = mdo
br condBlock
condBlock <- block `named` "cond"
cond' <- codegenExpr cond
condBr cond' loop end
loop <- block `named` "loop"
codegenStmt (TBlock body)
mkTerminator $ br condBlock
end <- block `named` "end"
return ()
codegenStmt (TAssign BaseAssign l@(t, (TId name)) e) = do
op <- codegenExpr e
var <- codegenLVal l
store var 0 op
codegenStmt (TAssign BaseAssign l@((StructType tName), (LTMember ((_, TId sName)) field)) e) = do
op <- codegenExpr e
struct <- codegenLVal l
store struct 0 op
codegenStmt (TAssign AddAssign l@(t, (TId name)) e) = do
op <- codegenExpr e
var <- codegenLVal l
val <- load var 0
store var 0 =<< add val op
codegenStmt (TAssign SubAssign l@(t, (TId name)) e) = do
op <- codegenExpr e
var <- codegenLVal l
val <- load var 0
store var 0 =<< sub val op
-- A block is just a list of statements
codegenStmt (TBlock stmts) = mapM_ codegenStmt stmts
-- Since the vars are already allocated by genBody, we just need to assign the value
codegenStmt (TDeclVar name t (Just e)) = codegenStmt (TAssign BaseAssign (t, (TId name)) e)
-- Do nothing with variable declaration if no expression is given
-- This is because allocation is done already
codegenStmt (TDeclVar name _ Nothing) = return ()
codegenStmt s = error $ "Unimplemented or invalid statement " ++ show s
-- Generate code for a function
-- First create the function, then allocate space for the arguments and locals
codegenFunc :: TTLFunc -> ModuleBuilder ()
codegenFunc func@(TTLFunc name args retType body) = mdo
createOperand name f
(f, strs) <- do
params' <- mapM mkParam args
retType' <- convertType retType
f <- function (mkName (cs name)) params' retType' genBody
strs <- gets strings
return (f, strs)
modify $ \ctx -> ctx { strings = strs }
where
mkParam (Bind name t) = (,) <$> convertType t <*> pure (ParameterName (cs name))
genBody :: [Operand] -> IRBuilder ()
genBody ops = do
forM_ (zip ops args) $ \(op, (Bind name t)) -> do
addr <- alloca (typeOf op) Nothing 0
store addr 0 op
createOperand name addr
forM_ (getLocals func) $ \(Bind name t) -> do
ltype <- convertType t
addr <- alloca ltype Nothing 0
createOperand name addr
codegenStmt (TBlock body)
-- Given a function, get all the local variables
-- Used so allocation can be done before the function body
getLocals :: TTLFunc -> [Bind]
getLocals (TTLFunc _ args _ body) = blockGetLocals body
blockGetLocals :: [TStmt] -> [Bind]
blockGetLocals = concatMap stmtGetLocals
stmtGetLocals :: TStmt -> [Bind]
stmtGetLocals (TDeclVar n t _) = [Bind n t]
stmtGetLocals (TBlock stmts) = blockGetLocals stmts
stmtGetLocals (TIf _ t f) = blockGetLocals t ++ maybe [] blockGetLocals f
stmtGetLocals (TWhile _ body) = blockGetLocals body
stmtGetLocals _ = []
-- Create structs
emitTypeDef :: TLStruct -> ModuleBuilder LLVM.AST.Type
emitTypeDef (Struct name fields) = do
modify $ \ctx -> ctx { structs = Struct name fields : structs ctx }
sType <- convertType (StructType name)
typedef (mkName (cs ("struct." <> name))) (Just sType)