From 82b75e14486f4cc35528d1e70b1a597b77c84f09 Mon Sep 17 00:00:00 2001 From: Konstantin Nazarov Date: Thu, 29 Aug 2024 23:14:33 +0100 Subject: [PATCH] Simplify compiler a bit by removing unneeded casts --- CMakeLists.txt | 1 + src/common.hpp | 6 +- src/compiler.cpp | 170 ++++++++++++++++++---------------------------- src/compiler.hpp | 36 ++++++---- test/function.vli | 13 ++++ 5 files changed, 104 insertions(+), 122 deletions(-) create mode 100644 test/function.vli diff --git a/CMakeLists.txt b/CMakeLists.txt index 5cb717f..bb2408e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,6 +64,7 @@ set(LISP_TESTS numeric logic dict + function ) diff --git a/src/common.hpp b/src/common.hpp index e021eb9..b956ab6 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -1142,9 +1142,9 @@ class Value { Value k = TRY(Int64::create(key)); return ((Object*)buf)->get(k); } - Result first() { return ((Object*)buf)->first(); } - Result second() { return ((Object*)buf)->second(); } - Result rest() { return ((Object*)buf)->rest(); } + Result first() const { return ((Object*)buf)->first(); } + Result second() const { return ((Object*)buf)->second(); } + Result rest() const { return ((Object*)buf)->rest(); } // TODO: cmp() probably doesn't need arena parameter // Result operator==(Value& rhs) { return TRY(cmp(rhs)) == 0; } diff --git a/src/compiler.cpp b/src/compiler.cpp index 95255d0..4f04abd 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -190,9 +190,7 @@ Result Compiler::compile(Value& expr) { return ERROR(CompilationError); } - Pair& expr_pair = *expr.to(); - - auto ex = TRY(compile_body(context, expr_pair)); + auto ex = TRY(compile_body(context, expr)); TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)ex.reg})); Value name = TRY(Nil::create()); @@ -207,16 +205,16 @@ Result Compiler::compile(Value& expr) { return Value(std::move(mod)); } -Result Compiler::compile_expr(Context& context, Value& expr) { +Result Compiler::compile_expr(Context& context, const Value& expr) { switch (expr.tag()) { case Tag::Pair: - return compile_list(context, *expr.to()); + return compile_list(context, expr); case Tag::Int64: return compile_constant(context, expr); case Tag::Bool: - return compile_bool(context, *expr.to()); + return compile_constant(context, expr); case Tag::Symbol: - return compile_symbol(context, *expr.to()); + return compile_symbol(context, expr); case Tag::String: return compile_constant(context, expr); case Tag::Nil: @@ -238,7 +236,7 @@ Result Compiler::compile_expr(Context& context, Value& expr) { } Result Compiler::compile_primop(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { Value cur = TRY(expr.rest()); Expression ex = TRY(Expression::create()); @@ -264,8 +262,7 @@ Result Compiler::compile_primop(Context& context, Symbol& op, return ex; } - Pair& pair = *cur.to(); - auto subexpr = TRY(pair.first()); + auto subexpr = TRY(cur.first()); auto comp = TRY(compile_expr(context, subexpr)); @@ -273,17 +270,16 @@ Result Compiler::compile_primop(Context& context, Symbol& op, uint64_t firstreg = comp.reg; uint64_t reg = firstreg; - cur = TRY(pair.rest()); + cur = TRY(cur.rest()); while (!cur.is()) { - Pair& pair = *cur.to(); - auto subexpr = TRY(pair.first()); + auto subexpr = TRY(cur.first()); auto comp = TRY(compile_expr(context, subexpr)); ex.add_code(comp.code); - auto rest = TRY(pair.rest()); + auto rest = TRY(cur.rest()); uint64_t res = 0; if (rest.is()) @@ -304,7 +300,7 @@ Result Compiler::compile_primop(Context& context, Symbol& op, } Result Compiler::compile_comparison(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { Value cur = TRY(expr.rest()); Expression ex = TRY(Expression::create()); @@ -342,8 +338,7 @@ Result Compiler::compile_comparison(Context& context, Symbol& op, int64_t false_c = TRY(context.add_const(TRY(Bool::create(false)))); TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)result}, {1, true_c})); - Pair& pair = *cur.to(); - auto subexpr = TRY(pair.first()); + auto subexpr = TRY(cur.first()); auto comp = TRY(compile_expr(context, subexpr)); @@ -351,17 +346,16 @@ Result Compiler::compile_comparison(Context& context, Symbol& op, uint64_t firstreg = comp.reg; uint64_t reg = firstreg; - cur = TRY(pair.rest()); + cur = TRY(cur.rest()); while (!cur.is()) { - Pair& pair = *cur.to(); - auto subexpr = TRY(pair.first()); + auto subexpr = TRY(cur.first()); auto comp = TRY(compile_expr(context, subexpr)); ex.add_code(comp.code); - auto rest = TRY(pair.rest()); + auto rest = TRY(cur.rest()); TRY(ex.add_opcode(opcode, {0, (int64_t)reg}, {0, (int64_t)comp.reg}, {0, (int64_t)cmp_expected})); @@ -378,32 +372,26 @@ Result Compiler::compile_comparison(Context& context, Symbol& op, } Result Compiler::compile_if(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { Value first = TRY(expr.rest()); Expression ex = TRY(Expression::create()); - if (first.is() || !first.is()) { + if (!first.is()) { return ERROR(CompilationError); } - Pair& first_pair = *first.to(); + Value second = TRY(first.rest()); - Value second = TRY(first_pair.rest()); - - if (second.is() || !second.is()) { + if (!second.is()) { return ERROR(CompilationError); } - Pair& second_pair = *second.to(); + Value third = TRY(second.rest()); - Value third = TRY(second_pair.rest()); - - if (third.is() || !third.is()) { + if (!third.is()) { return ERROR(CompilationError); } - Pair& third_pair = *third.to(); - - auto condition = TRY(first_pair.first()); + auto condition = TRY(first.first()); auto condition_comp = TRY(compile_expr(context, condition)); @@ -411,14 +399,14 @@ Result Compiler::compile_if(Context& context, Symbol& op, uint64_t firstreg = condition_comp.reg; uint64_t reg = firstreg; - auto option1 = TRY(second_pair.first()); + auto option1 = TRY(second.first()); auto option1_comp = TRY(compile_expr(context, option1)); uint64_t option1_reg = option1_comp.reg; context.maxreg = firstreg + 1; - auto option2 = TRY(third_pair.first()); + auto option2 = TRY(third.first()); auto option2_comp = TRY(compile_expr(context, option2)); int64_t true_const = TRY(context.add_const(TRY(Bool::create(true)))); @@ -443,7 +431,7 @@ Result Compiler::compile_if(Context& context, Symbol& op, } Result Compiler::compile_and(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { Value param = TRY(expr.rest()); Expression ex = TRY(Expression::create()); if (param.is() || !param.is()) { @@ -463,9 +451,8 @@ Result Compiler::compile_and(Context& context, Symbol& op, return ERROR(CompilationError); } - Pair& param_pair = *param.to(); - auto rest = TRY(param_pair.rest()); - Value param_val = TRY(param_pair.first()); + auto rest = TRY(param.rest()); + Value param_val = TRY(param.first()); auto param_ex = TRY(compile_expr(context, param_val)); @@ -491,7 +478,7 @@ Result Compiler::compile_and(Context& context, Symbol& op, } Result Compiler::compile_or(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { Value param = TRY(expr.rest()); Expression ex = TRY(Expression::create()); if (param.is() || !param.is()) { @@ -511,9 +498,8 @@ Result Compiler::compile_or(Context& context, Symbol& op, return ERROR(CompilationError); } - Pair& param_pair = *param.to(); - auto rest = TRY(param_pair.rest()); - Value param_val = TRY(param_pair.first()); + auto rest = TRY(param.rest()); + Value param_val = TRY(param.first()); auto param_ex = TRY(compile_expr(context, param_val)); @@ -539,16 +525,14 @@ Result Compiler::compile_or(Context& context, Symbol& op, } Result Compiler::compile_not(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { Value first = TRY(expr.rest()); Expression ex = TRY(Expression::create()); if (first.is() || !first.is()) { return ERROR(CompilationError); } - Pair& first_pair = *first.to(); - - auto first_expr = TRY(first_pair.first()); + auto first_expr = TRY(first.first()); uint64_t result = context.alloc_reg(); int64_t true_c = TRY(context.add_const(TRY(Bool::create(true)))); @@ -568,64 +552,56 @@ Result Compiler::compile_not(Context& context, Symbol& op, } Result Compiler::compile_fn(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { Context ctx = TRY(Context::create(context)); ctx.maxreg = 1; // Reserve the slot for function itself auto first = TRY(expr.rest()); - if (first.is() || !first.is()) { + if (!first.is()) { return ERROR(CompilationError); } - Pair& first_pair = *first.to(); - Value name = TRY(Nil::create()); - auto maybe_name = TRY(first_pair.first()); + auto maybe_name = TRY(first.first()); if (maybe_name.is()) { - name = TRY(maybe_name.to()->copy_value()); + name = TRY(maybe_name.copy()); - first = TRY(first_pair.rest()); + first = TRY(first.rest()); if (first.is() || !first.is()) { return ERROR(CompilationError); } - - Pair& first_pair = *first.to(); } - auto param = TRY(first_pair.first()); + auto param = TRY(first.first()); uint64_t arity = 0; while (!param.is()) { if (!param.is()) { return ERROR(CompilationError); } - Pair& param_pair = *param.to(); + auto param_name = TRY(param.first()); - auto param_first = TRY(param_pair.first()); - - if (!param_first.is()) { + if (!param_name.is()) { return ERROR(CompilationError); } - int64_t reg = TRY(ctx.add_var(param_first)); + int64_t reg = TRY(ctx.add_var(param_name)); - param = TRY(param_pair.rest()); + param = TRY(param.rest()); arity++; } - Value second = TRY(first_pair.rest()); + Value second = TRY(first.rest()); if (second.is() || !second.is()) { return ERROR(CompilationError); } - Pair& second_pair = *second.to(); - - auto ex = TRY(compile_body(ctx, second_pair)); + auto ex = TRY(compile_body(ctx, second)); TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)0}, {0, (int64_t)ex.reg})); TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0})); @@ -642,10 +618,10 @@ Result Compiler::compile_fn(Context& context, Symbol& op, if (ctx.closures.size() == 0) { if (context.toplevel && !name.is()) { - context.add_global(name, TRY(fun.copy_value())); + context.add_global(name, TRY(fun.copy())); } - int64_t c = TRY(context.add_const(TRY(fun.copy_value()))); + int64_t c = TRY(context.add_const(TRY(fun.copy()))); uint64_t reg = context.alloc_reg(); TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c})); @@ -656,7 +632,7 @@ Result Compiler::compile_fn(Context& context, Symbol& op, return ex_res; } - int64_t c = TRY(context.add_const(TRY(fun.copy_value()))); + int64_t c = TRY(context.add_const(TRY(fun.copy()))); uint64_t reg = context.alloc_reg(); TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c})); @@ -678,7 +654,7 @@ Result Compiler::compile_fn(Context& context, Symbol& op, } Result Compiler::compile_let(Context& context, Symbol& op, - Pair& expr) { + const Value& expr) { // Save the variable bindings to restore it later Dict saved_vars = TRY(context.variables_dict.copy()); @@ -718,13 +694,11 @@ Result Compiler::compile_let(Context& context, Symbol& op, Value second = TRY(first.rest()); - if (second.is() || !second.is()) { + if (!second.is()) { return ERROR(CompilationError); } - Pair& second_pair = *second.to(); - - auto ex = TRY(compile_body(context, second_pair)); + auto ex = TRY(compile_body(context, second)); TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)maxreg}, {0, (int64_t)ex.reg})); ex_res.add_code(ex.code); @@ -736,8 +710,8 @@ Result Compiler::compile_let(Context& context, Symbol& op, return std::move(ex_res); } -Result Compiler::compile_body(Context& context, Pair& expr) { - auto cur = TRY(expr.copy_value()); +Result Compiler::compile_body(Context& context, const Value& expr) { + auto cur = TRY(expr.copy()); Expression ex_res = TRY(Expression::create()); @@ -747,8 +721,7 @@ Result Compiler::compile_body(Context& context, Pair& expr) { if (!cur.is()) { return ERROR(CompilationError); } - Pair& cur_pair = *cur.to(); - auto expr_val = TRY(cur_pair.first()); + auto expr_val = TRY(cur.first()); // debug_print(expr_val); @@ -756,7 +729,7 @@ Result Compiler::compile_body(Context& context, Pair& expr) { TRY(ex_res.add_code(expr.code)); - cur = TRY(cur_pair.rest()); + cur = TRY(cur.rest()); if (cur.is()) { ex_res.reg = expr.reg; @@ -769,7 +742,7 @@ Result Compiler::compile_body(Context& context, Pair& expr) { } Result Compiler::compile_function_call(Context& context, - Pair& expr) { + const Value& expr) { auto ex = TRY(Expression::create()); auto first = TRY(expr.first()); @@ -783,13 +756,12 @@ Result Compiler::compile_function_call(Context& context, if (!param.is()) { return ERROR(CompilationError); } - Pair& param_pair = *param.to(); - Value param_val = TRY(param_pair.first()); + Value param_val = TRY(param.first()); auto param_ex = TRY(compile_expr(context, param_val)); TRY(ex.add_code(param_ex.code)); - param = TRY(param_pair.rest()); + param = TRY(param.rest()); } TRY(ex.add_opcode(Oc::Call, {0, (int64_t)fun_ex.reg}, @@ -800,7 +772,7 @@ Result Compiler::compile_function_call(Context& context, return ex; } -Result Compiler::compile_list(Context& context, Pair& expr) { +Result Compiler::compile_list(Context& context, const Value& expr) { auto first = TRY(expr.first()); if (first.is()) { @@ -831,7 +803,8 @@ Result Compiler::compile_list(Context& context, Pair& expr) { return ERROR(TypeMismatch); } -Result Compiler::compile_constant(Context& context, Value& value) { +Result Compiler::compile_constant(Context& context, + const Value& value) { Expression ex = TRY(Expression::create()); uint64_t reg = context.alloc_reg(); @@ -843,10 +816,11 @@ Result Compiler::compile_constant(Context& context, Value& value) { return std::move(ex); } -Result Compiler::compile_symbol(Context& context, Symbol& value) { +Result Compiler::compile_symbol(Context& context, + const Value& value) { Expression ex = TRY(Expression::create()); - auto maybe_reg = context.get_var(TRY(value.copy_value())); + auto maybe_reg = context.get_var(TRY(value.copy())); if (!maybe_reg.has_error()) { auto var_reg = maybe_reg.value(); @@ -858,7 +832,7 @@ Result Compiler::compile_symbol(Context& context, Symbol& value) { return std::move(ex); } - auto maybe_closure = context.get_closure(TRY(value.copy_value())); + auto maybe_closure = context.get_closure(TRY(value.copy())); if (!maybe_closure.has_error()) { auto var_closure = maybe_closure.value(); @@ -871,7 +845,7 @@ Result Compiler::compile_symbol(Context& context, Symbol& value) { return std::move(ex); } - auto maybe_stdlib_fun = get_stdlib_function(value); + auto maybe_stdlib_fun = get_stdlib_function(*value.to()); if (!maybe_stdlib_fun.has_error()) { auto stdlib_fun = TRY(StdlibFunction::create(maybe_stdlib_fun.value())); @@ -887,7 +861,7 @@ Result Compiler::compile_symbol(Context& context, Symbol& value) { // Otherwise treat unknown symbol as a global and try to load it from the // global scope - int64_t c = TRY(context.add_const(TRY(value.copy_value()))); + int64_t c = TRY(context.add_const(TRY(value.copy()))); uint64_t reg = context.alloc_reg(); TRY(ex.add_opcode(Oc::GlobalLoad, {0, (int64_t)reg}, {1, (int64_t)c})); @@ -896,18 +870,6 @@ Result Compiler::compile_symbol(Context& context, Symbol& value) { return std::move(ex); } -Result Compiler::compile_bool(Context& context, Bool& value) { - Expression ex = TRY(Expression::create()); - uint64_t reg = context.alloc_reg(); - - int64_t c = TRY(context.add_const(TRY(value.copy_value()))); - - TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c})); - - ex.reg = reg; - return std::move(ex); -} - Result compile(Value& expr) { Compiler c = Compiler(); return c.compile(expr); diff --git a/src/compiler.hpp b/src/compiler.hpp index 9069db0..bf2dc42 100644 --- a/src/compiler.hpp +++ b/src/compiler.hpp @@ -29,22 +29,28 @@ class Compiler { Compiler() {} Result compile(Value& expr); - Result compile_expr(Context& context, Value& expr); - Result compile_list(Context& context, Pair& expr); - Result compile_primop(Context& context, Symbol& op, Pair& expr); + Result compile_expr(Context& context, const Value& expr); + Result compile_list(Context& context, const Value& expr); + Result compile_primop(Context& context, Symbol& op, + const Value& expr); Result compile_comparison(Context& context, Symbol& op, - Pair& expr); - Result compile_constant(Context& context, Value& value); - Result compile_symbol(Context& context, Symbol& value); - Result compile_bool(Context& context, Bool& value); - Result compile_if(Context& context, Symbol& op, Pair& expr); - Result compile_and(Context& context, Symbol& op, Pair& expr); - Result compile_or(Context& context, Symbol& op, Pair& expr); - Result compile_not(Context& context, Symbol& op, Pair& expr); - Result compile_fn(Context& context, Symbol& op, Pair& expr); - Result compile_let(Context& context, Symbol& op, Pair& expr); - Result compile_body(Context& context, Pair& expr); - Result compile_function_call(Context& context, Pair& expr); + const Value& expr); + Result compile_constant(Context& context, const Value& value); + Result compile_symbol(Context& context, const Value& value); + Result compile_if(Context& context, Symbol& op, + const Value& expr); + Result compile_and(Context& context, Symbol& op, + const Value& expr); + Result compile_or(Context& context, Symbol& op, + const Value& expr); + Result compile_not(Context& context, Symbol& op, + const Value& expr); + Result compile_fn(Context& context, Symbol& op, + const Value& expr); + Result compile_let(Context& context, Symbol& op, + const Value& expr); + Result compile_body(Context& context, const Value& expr); + Result compile_function_call(Context& context, const Value& expr); }; Result compile(Value& expr); diff --git a/test/function.vli b/test/function.vli new file mode 100644 index 0000000..bc0971a --- /dev/null +++ b/test/function.vli @@ -0,0 +1,13 @@ +;; -*- mode: lisp; -*- + +(fn fact (n) + (if (<= n 0) + 1 + (* n (fact (- n 1))))) + +(assert (= (fact 12) 479001600)) + + +(let ((square (fn (x) (* x x)))) + (assert (= (square 4) 16)) + )