From ecbdc17f2b48613ba94d0bb25d63bf754646bc58 Mon Sep 17 00:00:00 2001 From: Konstantin Nazarov Date: Fri, 23 Aug 2024 21:30:05 +0100 Subject: [PATCH] Compile comparison operations --- src/compiler.cpp | 90 ++++++++++++++++++++++++++++++++++++++++++++++-- src/compiler.hpp | 2 ++ src/opcode.cpp | 16 ++++----- src/opcode.hpp | 7 ++-- src/vm.cpp | 43 +++++++++++++++++++---- src/vm.hpp | 4 ++- 6 files changed, 139 insertions(+), 23 deletions(-) diff --git a/src/compiler.cpp b/src/compiler.cpp index 58c28ca..6495f96 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -134,6 +134,13 @@ Result is_primitive_op(Symbol& sym) { return false; } +Result is_comparison_op(Symbol& sym) { + return TRY(sym.cmp("<")) == 0 || TRY(sym.cmp("<=")) == 0 || + TRY(sym.cmp(">")) == 0 || TRY(sym.cmp(">=")) == 0 || + TRY(sym.cmp("=")) == 0 || TRY(sym.cmp("!=")) == 0; + return false; +} + Result Compiler::compile(Value& expr) { auto context = TRY(Context::create()); @@ -246,6 +253,80 @@ Result Compiler::compile_primop(Context& context, Symbol& op, return std::move(ex); } +Result Compiler::compile_comparison(Context& context, Symbol& op, + Pair& expr) { + Value cur = TRY(expr.rest()); + Expression ex = TRY(Expression::create()); + + Oc opcode = Oc::Unknown; + int64_t cmp_expected = 0; + + if (TRY(op.cmp("<")) == 0) { + opcode = Oc::Less; + cmp_expected = 0; + } else if (TRY(op.cmp("<=")) == 0) { + opcode = Oc::LessEqual; + cmp_expected = 0; + } else if (TRY(op.cmp(">")) == 0) { + opcode = Oc::LessEqual; + cmp_expected = 1; + } else if (TRY(op.cmp(">=")) == 0) { + opcode = Oc::Less; + cmp_expected = 1; + } else if (TRY(op.cmp("=")) == 0) { + opcode = Oc::Equal; + cmp_expected = 0; + } else if (TRY(op.cmp("!=")) == 0) { + opcode = Oc::Equal; + cmp_expected = 1; + } else { + return ERROR(NotImplemented); + } + + if (cur.is()) { + return ERROR(CompilationError); + } + + uint64_t result = context.alloc_reg(); + int64_t true_c = TRY(context.add_const(TRY(Bool::create(true)))); + int64_t false_c = TRY(context.add_const(TRY(Bool::create(false)))); + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)result}, {1, true_c})); + + Pair& pair = *cur.to(); + auto subexpr = TRY(pair.first()); + + auto comp = TRY(compile_expr(context, subexpr)); + + ex.add_code(comp.code); + uint64_t firstreg = comp.reg; + uint64_t reg = firstreg; + + cur = TRY(pair.rest()); + + while (!cur.is()) { + Pair& pair = *cur.to(); + auto subexpr = TRY(pair.first()); + + auto comp = TRY(compile_expr(context, subexpr)); + + ex.add_code(comp.code); + + auto rest = TRY(pair.rest()); + + TRY(ex.add_opcode(opcode, {0, (int64_t)reg}, {0, (int64_t)comp.reg}, + {0, (int64_t)cmp_expected})); + + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)result}, {1, false_c})); + + reg = comp.reg; + cur = std::move(rest); + } + + context.maxreg = result + 1; + ex.reg = result; + return std::move(ex); +} + Result Compiler::compile_if(Context& context, Symbol& op, Pair& expr) { Value first = TRY(expr.rest()); @@ -292,9 +373,10 @@ Result Compiler::compile_if(Context& context, Symbol& op, int64_t true_const = TRY(context.add_const(TRY(Bool::create(true)))); - TRY(ex.add_opcode(Oc::JumpNotEqual, {0, (int64_t)firstreg}, - {1, (int64_t)true_const}, - {0, (int64_t)option1_comp.code.size() + 2})); + TRY(ex.add_opcode(Oc::Equal, {0, (int64_t)firstreg}, {1, (int64_t)true_const}, + {0, (int64_t)0})); + + TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option1_comp.code.size() + 2})); ex.add_code(option1_comp.code); TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option2_comp.code.size() + 1})); @@ -471,6 +553,8 @@ Result Compiler::compile_list(Context& context, Pair& expr) { Symbol& sym = *first.to(); if (TRY(is_primitive_op(sym))) { return compile_primop(context, sym, expr); + } else if (TRY(is_comparison_op(sym))) { + return compile_comparison(context, sym, expr); } else if (TRY(sym.cmp("if")) == 0) { return compile_if(context, sym, expr); } else if (TRY(sym.cmp("lambda")) == 0) { diff --git a/src/compiler.hpp b/src/compiler.hpp index 5cfeb59..0ef0527 100644 --- a/src/compiler.hpp +++ b/src/compiler.hpp @@ -32,6 +32,8 @@ class Compiler { Result compile_expr(Context& context, Value& expr); Result compile_list(Context& context, Pair& expr); Result compile_primop(Context& context, Symbol& op, Pair& expr); + Result compile_comparison(Context& context, Symbol& op, + Pair& expr); Result compile_int64(Context& context, Int64& value); Result compile_symbol(Context& context, Symbol& value); Result compile_bool(Context& context, Bool& value); diff --git a/src/opcode.cpp b/src/opcode.cpp index 62289f3..1757c9b 100644 --- a/src/opcode.cpp +++ b/src/opcode.cpp @@ -30,15 +30,13 @@ op_t get_op(Oc op) { case Oc::LoadStack: return op_t{"lfi", OpcodeType::Reg1I}; case Oc::Jump: - return op_t{"jmp", OpcodeType::Reg0I}; - case Oc::JumpEqual: - return op_t{"jeq", OpcodeType::Reg2I}; - case Oc::JumpLess: - return op_t{"jlt", OpcodeType::Reg2I}; - case Oc::JumpLessEqual: - return op_t{"jle", OpcodeType::Reg2I}; - case Oc::JumpNotEqual: - return op_t{"jne", OpcodeType::Reg2I}; + return op_t{"jump", OpcodeType::Reg0I}; + case Oc::Equal: + return op_t{"equal", OpcodeType::Reg2I}; + case Oc::Less: + return op_t{"less", OpcodeType::Reg2I}; + case Oc::LessEqual: + return op_t{"less-equal", OpcodeType::Reg2I}; case Oc::Call: return op_t{"call", OpcodeType::Reg2}; case Oc::TailCall: diff --git a/src/opcode.hpp b/src/opcode.hpp index e5c6b9f..96a3c41 100644 --- a/src/opcode.hpp +++ b/src/opcode.hpp @@ -28,10 +28,9 @@ enum class Oc : uint8_t { LoadStack, // Jumps Jump, - JumpEqual, - JumpLess, - JumpLessEqual, - JumpNotEqual, + Equal, + Less, + LessEqual, // Function calls Call, TailCall, diff --git a/src/vm.cpp b/src/vm.cpp index 09166a4..3233285 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -62,13 +62,38 @@ Result VM::vm_div(Opcode& oc) { return Result(); } -Result VM::vm_jump_not_equal(Opcode& oc) { +Result VM::vm_equal(Opcode& oc) { Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg)); Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg)); - if (TRY(val1.cmp(val2)) != 0) { - _pc += oc.arg3().arg; + int64_t expected = oc.arg3().arg; + if ((TRY(val1.cmp(val2)) == 0) != expected) { + _pc += 2; } else { - _pc++; + _pc += 1; + } + return Result(); +} + +Result VM::vm_less(Opcode& oc) { + Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg)); + Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg)); + int64_t expected = oc.arg3().arg; + if ((TRY(val1.cmp(val2)) == -1) != expected) { + _pc += 2; + } else { + _pc += 1; + } + return Result(); +} + +Result VM::vm_less_equal(Opcode& oc) { + Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg)); + Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg)); + int64_t expected = oc.arg3().arg; + if ((TRY(val1.cmp(val2)) <= 0) != expected) { + _pc += 2; + } else { + _pc += 1; } return Result(); } @@ -191,8 +216,14 @@ Result VM::step() { case Oc::Ret: TRY(vm_ret(oc)); break; - case Oc::JumpNotEqual: - TRY(vm_jump_not_equal(oc)); + case Oc::Equal: + TRY(vm_equal(oc)); + break; + case Oc::Less: + TRY(vm_less(oc)); + break; + case Oc::LessEqual: + TRY(vm_less_equal(oc)); break; case Oc::Jump: TRY(vm_jump(oc)); diff --git a/src/vm.hpp b/src/vm.hpp index 1cd82f4..db17ea4 100644 --- a/src/vm.hpp +++ b/src/vm.hpp @@ -28,7 +28,9 @@ class VM { Result vm_call(Opcode& oc); Result vm_ret(Opcode& oc); - Result vm_jump_not_equal(Opcode& oc); + Result vm_equal(Opcode& oc); + Result vm_less(Opcode& oc); + Result vm_less_equal(Opcode& oc); Result vm_jump(Opcode& oc); Result vm_make_closure(Opcode& oc);