Boostで抽象構文木を作ってllvm言語に変換してみた

Boost.Variant、Boost.Spirit.Qi、Boost.Spirit.Phoenixを使って簡単な中置記法の計算するための抽象構文木(AST)を作って、そのASTをllvm言語に変換するというのを作ってみました。コード書き殴り。 BoostでASTが割とすっきりと書け、そこからllvm言語に変換することも割と簡単に出来ます。

コードと出力結果はGistのほうに置いておきます。


雑ですが説明。

ASTの表現

Boost.VariantでASTを表します。

namespace ast {

    struct add;
    struct sub;
    struct mul;
    struct div;

    template <class Op>
    struct binary_op;

    using expr = boost::variant<
        int,
        boost::recursive_wrapper< binary_op< add > >,
        boost::recursive_wrapper< binary_op< sub > >,
        boost::recursive_wrapper< binary_op< mul > >,
        boost::recursive_wrapper< binary_op< div > >
    >;

    template <class Op>
    struct binary_op
    {
        expr lhs;
        expr rhs;

        binary_op(expr const& lhs_, expr const& rhs_) :
            lhs( lhs_ ), rhs( rhs_ )
        { }
    };

} // namespace ast

値または式を格納するためにboost::variantを使ってast::exprを定義しています。 ここでは値はintまたは二項演算を表すast::binary_opで四則演算のいずれかを格納できるようにしています。 2 * 3 + 1のような式の場合、以下の図のような木になります。この図のbinary_opやintは全てast::exprで表すことができます。

f:id:lnseab:20130408191145p:plain

ast::binary_opについて、演算子の左、右でそれぞれast::exprのlhs、rhsで表現しています。 しかし、ast::exprを定義するためにast::binary_opが必要で、ast::binary_opを定義するためにast::exprが必要という循環になってしまってます。なので、ast::binary_opをast::exprの前方で宣言しておいてboost::recursive_wrapperを使って解決しています。

パーサ

Boost.Spirit.Qi、Boost.Spirit.Phoenixを使ってパーサを作ります。

namespace parser {

    namespace qi = boost::spirit::qi;

    template <class Iterator>
    struct arith_grammar :
        qi::grammar< Iterator, ast::expr(), qi::ascii::space_type >
    {
        template <class T>
        using rule_t = qi::rule< Iterator, T, qi::ascii::space_type >;

        rule_t< ast::expr() > expr;
        rule_t< ast::expr() > term;
        rule_t< ast::expr() > factor;

        arith_grammar() :
            arith_grammar::base_type( expr )
        { 
            namespace phx = boost::phoenix;

            expr %= term[qi::_val = qi::_1] 
                >> *( 
                    ( '+' >> term[qi::_val = phx::construct< ast::binary_op< ast::add > >( qi::_val, qi::_1 )] ) 
                    | ( '-' >> term[qi::_val = phx::construct< ast::binary_op< ast::sub > >( qi::_val, qi::_1 )] )
                );

            term %= factor[qi::_val = qi::_1]
                >> *(
                    ( '*' >> factor[qi::_val = phx::construct< ast::binary_op< ast::mul > >( qi::_val, qi::_1 )] ) 
                    | ( '/' >> factor[qi::_val = phx::construct< ast::binary_op< ast::div > >( qi::_val, qi::_1 )] )
                );

            factor %= qi::int_ | ( '(' >> expr >> ')' )[qi::_val = qi::_1];
        }
    };

} // namespace parser

セマンティックアクションの中でast::binary_opを構築するときに、直接ast::binary_opと書くとセマンティックアション前に構築されてしまってエラーになるのでboost::phoenix::constructを使って構築を遅延しています。

llvm言語に変換

ASTが出来たらllvm言語に変換します。 llvm言語周りはまだわかってないので過程だけ説明します。

218行目で、まずLLVM側の準備、main関数の作成、printfを使えるようにしています。

auto& context = llvm::getGlobalContext();
std::unique_ptr< llvm::Module > module( new llvm::Module( "arith", context ) );
llvm::IRBuilder<> builder( context );

make_main_func( context, module, builder );

auto* printf_func = make_printf( module, builder );
auto* format = builder.CreateGlobalStringPtr( "%d\n" );

ASTをasemmblyを使ってllvm言語に変換します。

class asemmbly :
    public boost::static_visitor< llvm::Value* >
{
    llvm::IRBuilder<>& builder_;

public:
    asemmbly(llvm::IRBuilder<>& builder) :
        boost::static_visitor< llvm::Value* >(),
        builder_( builder )
    { }

    llvm::Value* operator()(int value)
    { 
        return llvm::ConstantInt::get( builder_.getInt32Ty(), value );
    }

    template <class Op>
    llvm::Value* operator()(ast::binary_op< Op > const& op) 
    { 
        llvm::Value* lhs = boost::apply_visitor( *this, op.lhs );
        llvm::Value* rhs = boost::apply_visitor( *this, op.rhs );

        return apply_op( op, lhs, rhs );
    }

private:
    llvm::Value* apply_op(ast::binary_op< ast::add > const&, llvm::Value* lhs, llvm::Value* rhs)
    {
        return builder_.CreateAdd( lhs, rhs );
    }

    llvm::Value* apply_op(ast::binary_op< ast::sub > const&, llvm::Value* lhs, llvm::Value* rhs)
    {
        return builder_.CreateSub( lhs, rhs );
    }

    llvm::Value* apply_op(ast::binary_op< ast::mul > const&, llvm::Value* lhs, llvm::Value* rhs)
    {
        return builder_.CreateMul( lhs, rhs );
    }

    llvm::Value* apply_op(ast::binary_op< ast::div > const&, llvm::Value* lhs, llvm::Value* rhs)
    {
        return builder_.CreateSDiv( lhs, rhs );
    }
};

返り値をllvm::Value*に合わせて再帰的にllvm言語に変換していきます。

227行目、std::vector< llvm::Value* > argsはprintfに渡す引数をとなっており、前述のprintf_funcと合わせてbuilder.CreateCallするとllvm言語で表されたprintfの呼び出しコードが生成されます。

asemmbly asm_obj( builder );
for( auto const& i : asts ) {
    std::vector< llvm::Value* > args = {
        format, boost::apply_visitor( asm_obj, i )
    };
    builder.CreateCall( printf_func, llvm::ArrayRef< llvm::Value* >( args ) );
}

235行目でmake_ret_mainを呼び出してmain関数の戻り値を書き込みます。

237行目でstd::cout << print_llvm_lang( module );として生成したllvm言語のコードを標準出力しています。

出力

入力を次のようにすると、

2 * ( 1 + 2 )
2 * 1 + 2

下記のように出力されます。

; ModuleID = 'arith'

@0 = private unnamed_addr constant [4 x i8] c"%d\0A\00"

define i32 @main() {
entry:
  %0 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @0, i32 0, i32 0), i32 6)
  %1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @0, i32 0, i32 0), i32 4)
  ret i32 0
}

declare i32 @printf(i8*, ...)

これをとりあえず動かすにはlliを使うといいでしょう。

参考文献

Boost -- Spirit 2.5.2
LLVM Tutorial: Table of Contents
LLVM API Documentation
redboltzの日記 -- qiを使って構文ツリーを構築
stackoverflow -- Retrieving AST from boost::spirit parser
letsboost::spirit
letsboost::variant