From 81eb5e43cec1dff898e34f501cfa2a4e3bb5b61e Mon Sep 17 00:00:00 2001 From: Konstantin Nazarov Date: Thu, 15 Aug 2024 01:18:05 +0100 Subject: [PATCH] Implement simple conditionals --- src/common.hpp | 5 ++- src/compiler.cpp | 82 +++++++++++++++++++++++++++++++++++++++++++++++- src/compiler.hpp | 2 ++ src/error.hpp | 1 + src/vli.cpp | 3 +- src/vm.cpp | 20 ++++++++++++ src/vm.hpp | 3 ++ 7 files changed, 113 insertions(+), 3 deletions(-) diff --git a/src/common.hpp b/src/common.hpp index d116dda..1c300cf 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -764,7 +764,10 @@ class Bool : public Object { virtual Result cmp(const Object& rhs) const final { return -TRY(rhs.cmp(*this)); } - virtual Result cmp(const Bool& rhs) const final { return 0; } + virtual Result cmp(const Bool& rhs) const final { + return (_value->value > rhs._value->value) - + (_value->value < rhs._value->value); + } virtual void move(Object* obj) final { new (obj) Bool(std::move(_value)); } diff --git a/src/compiler.cpp b/src/compiler.cpp index 788ea41..e38a51c 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -4,6 +4,7 @@ struct Context { Context() {} + Context(Value&& env, Array&& constants, Dict&& constants_dict) : env(std::move(env)), constants(std::move(constants)), @@ -80,8 +81,9 @@ Result Compiler::compile_expr(Context& context, Value& expr) { return compile_list(context, *expr.to()); case Tag::Int64: return compile_int64(context, *expr.to()); - case Tag::Nil: case Tag::Bool: + return compile_bool(context, *expr.to()); + case Tag::Nil: case Tag::Float: case Tag::String: case Tag::Symbol: @@ -163,6 +165,70 @@ Result Compiler::compile_primop(Context& context, Symbol& op, return std::move(ex); } +Result Compiler::compile_if(Context& context, Symbol& op, + Pair& expr) { + Value first = TRY(expr.rest()); + Expression ex = TRY(Expression::create()); + if (first.is() || !first.is()) { + return ERROR(CompilationError); + } + + Pair& first_pair = *first.to(); + + Value second = TRY(first_pair.rest()); + + if (second.is() || !second.is()) { + return ERROR(CompilationError); + } + + Pair& second_pair = *second.to(); + + Value third = TRY(second_pair.rest()); + + if (third.is() || !third.is()) { + return ERROR(CompilationError); + } + + Pair& third_pair = *third.to(); + + auto condition = TRY(first_pair.first()); + + auto condition_comp = TRY(compile_expr(context, condition)); + + ex.add_code(condition_comp.code); + uint64_t firstreg = condition_comp.reg; + uint64_t reg = firstreg; + + auto option1 = TRY(second_pair.first()); + auto option1_comp = TRY(compile_expr(context, option1)); + + uint64_t option1_reg = option1_comp.reg; + + context.maxreg = firstreg + 1; + + auto option2 = TRY(third_pair.first()); + auto option2_comp = TRY(compile_expr(context, option2)); + + int64_t true_const = TRY(context.add_const(TRY(Bool::create(true)))); + + TRY(ex.add_opcode(Oc::JumpNotEqual, {0, (int64_t)firstreg}, + {1, (int64_t)true_const}, + {0, (int64_t)option1_comp.code.size() + 2})); + + ex.add_code(option1_comp.code); + TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option2_comp.code.size() + 1})); + ex.add_code(option2_comp.code); + + uint64_t option2_reg = option2_comp.reg; + + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)firstreg}, + {0, (int64_t)option1_reg})); + + context.maxreg = firstreg + 1; + ex.reg = firstreg; + return std::move(ex); +} + Result Compiler::compile_list(Context& context, Pair& expr) { auto first = TRY(expr.first()); @@ -170,6 +236,8 @@ Result Compiler::compile_list(Context& context, Pair& expr) { Symbol& sym = *first.to(); if (TRY(is_primitive_op(sym))) { return compile_primop(context, sym, expr); + } else if (TRY(sym.cmp("if")) == 0) { + return compile_if(context, sym, expr); } } return ERROR(TypeMismatch); @@ -187,6 +255,18 @@ Result Compiler::compile_int64(Context& context, Int64& value) { return std::move(ex); } +Result Compiler::compile_bool(Context& context, Bool& value) { + Expression ex = TRY(Expression::create()); + uint64_t reg = context.alloc_reg(); + + int64_t c = TRY(context.add_const(TRY(value.copy_value()))); + + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c})); + + ex.reg = reg; + return std::move(ex); +} + Result compile(Value& expr) { Compiler c = Compiler(); return c.compile(expr); diff --git a/src/compiler.hpp b/src/compiler.hpp index 2d0115a..d09d698 100644 --- a/src/compiler.hpp +++ b/src/compiler.hpp @@ -33,6 +33,8 @@ class Compiler { Result compile_list(Context& context, Pair& expr); Result compile_primop(Context& context, Symbol& op, Pair& expr); Result compile_int64(Context& context, Int64& value); + Result compile_bool(Context& context, Bool& value); + Result compile_if(Context& context, Symbol& op, Pair& expr); }; Result compile(Value& expr); diff --git a/src/error.hpp b/src/error.hpp index b8601b8..02f9fb3 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -13,6 +13,7 @@ enum class ErrorCode { MalformedList, KeyError, EndOfProgram, + CompilationError, }; void seterr(const char* err); diff --git a/src/vli.cpp b/src/vli.cpp index ae84e4d..c67310a 100644 --- a/src/vli.cpp +++ b/src/vli.cpp @@ -10,7 +10,8 @@ StaticArena<64 * 1024 * 1024> arena; Result run() { - auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))")); + // auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))")); + auto code_str = TRY(String::create("(if true 1 2)")); auto reader = Reader(code_str); auto parsed = TRY(reader.read_one()); diff --git a/src/vm.cpp b/src/vm.cpp index 29a6ecc..01424e5 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -57,6 +57,20 @@ Result VM::vm_div(Opcode& oc) { return Result(); } +Result VM::vm_jump_not_equal(Opcode& oc) { + Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg)); + Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg)); + if (TRY(val1.cmp(val2)) != 0) { + _pc += oc.arg3().arg - 1; + } + return Result(); +} + +Result VM::vm_jump(Opcode& oc) { + _pc += oc.arg1().arg - 1; + return Result(); +} + Result VM::step() { auto opcode = TRY(_code.get(_pc)); if (!opcode.is()) return ERROR(TypeMismatch); @@ -81,6 +95,12 @@ Result VM::step() { case Oc::Ret: _res = TRY(getreg((uint64_t)oc.arg1().arg)); return ERROR(EndOfProgram); + case Oc::JumpNotEqual: + TRY(vm_jump_not_equal(oc)); + break; + case Oc::Jump: + TRY(vm_jump(oc)); + break; default: return ERROR(NotImplemented); } diff --git a/src/vm.hpp b/src/vm.hpp index 6b66797..368a833 100644 --- a/src/vm.hpp +++ b/src/vm.hpp @@ -23,6 +23,9 @@ class VM { Result vm_sub(Opcode& oc); Result vm_div(Opcode& oc); + Result vm_jump_not_equal(Opcode& oc); + Result vm_jump(Opcode& oc); + Result get(bool is_const, uint64_t idx); Result getconst(uint64_t idx); Result getreg(uint64_t idx);