Implement simple conditionals

This commit is contained in:
Konstantin Nazarov 2024-08-15 01:18:05 +01:00
parent 3a117c9b1f
commit 81eb5e43ce
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
7 changed files with 113 additions and 3 deletions

View file

@ -764,7 +764,10 @@ class Bool : public Object {
virtual Result<short> cmp(const Object& rhs) const final {
return -TRY(rhs.cmp(*this));
}
virtual Result<short> cmp(const Bool& rhs) const final { return 0; }
virtual Result<short> cmp(const Bool& rhs) const final {
return (_value->value > rhs._value->value) -
(_value->value < rhs._value->value);
}
virtual void move(Object* obj) final { new (obj) Bool(std::move(_value)); }

View file

@ -4,6 +4,7 @@
struct Context {
Context() {}
Context(Value&& env, Array&& constants, Dict&& constants_dict)
: env(std::move(env)),
constants(std::move(constants)),
@ -80,8 +81,9 @@ Result<Expression> Compiler::compile_expr(Context& context, Value& expr) {
return compile_list(context, *expr.to<Pair>());
case Tag::Int64:
return compile_int64(context, *expr.to<Int64>());
case Tag::Nil:
case Tag::Bool:
return compile_bool(context, *expr.to<Bool>());
case Tag::Nil:
case Tag::Float:
case Tag::String:
case Tag::Symbol:
@ -163,6 +165,70 @@ Result<Expression> Compiler::compile_primop(Context& context, Symbol& op,
return std::move(ex);
}
Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
Pair& expr) {
Value first = TRY(expr.rest());
Expression ex = TRY(Expression::create());
if (first.is<Nil>() || !first.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& first_pair = *first.to<Pair>();
Value second = TRY(first_pair.rest());
if (second.is<Nil>() || !second.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& second_pair = *second.to<Pair>();
Value third = TRY(second_pair.rest());
if (third.is<Nil>() || !third.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& third_pair = *third.to<Pair>();
auto condition = TRY(first_pair.first());
auto condition_comp = TRY(compile_expr(context, condition));
ex.add_code(condition_comp.code);
uint64_t firstreg = condition_comp.reg;
uint64_t reg = firstreg;
auto option1 = TRY(second_pair.first());
auto option1_comp = TRY(compile_expr(context, option1));
uint64_t option1_reg = option1_comp.reg;
context.maxreg = firstreg + 1;
auto option2 = TRY(third_pair.first());
auto option2_comp = TRY(compile_expr(context, option2));
int64_t true_const = TRY(context.add_const(TRY(Bool::create(true))));
TRY(ex.add_opcode(Oc::JumpNotEqual, {0, (int64_t)firstreg},
{1, (int64_t)true_const},
{0, (int64_t)option1_comp.code.size() + 2}));
ex.add_code(option1_comp.code);
TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option2_comp.code.size() + 1}));
ex.add_code(option2_comp.code);
uint64_t option2_reg = option2_comp.reg;
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)firstreg},
{0, (int64_t)option1_reg}));
context.maxreg = firstreg + 1;
ex.reg = firstreg;
return std::move(ex);
}
Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
auto first = TRY(expr.first());
@ -170,6 +236,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
Symbol& sym = *first.to<Symbol>();
if (TRY(is_primitive_op(sym))) {
return compile_primop(context, sym, expr);
} else if (TRY(sym.cmp("if")) == 0) {
return compile_if(context, sym, expr);
}
}
return ERROR(TypeMismatch);
@ -187,6 +255,18 @@ Result<Expression> Compiler::compile_int64(Context& context, Int64& value) {
return std::move(ex);
}
Result<Expression> Compiler::compile_bool(Context& context, Bool& value) {
Expression ex = TRY(Expression::create());
uint64_t reg = context.alloc_reg();
int64_t c = TRY(context.add_const(TRY(value.copy_value())));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c}));
ex.reg = reg;
return std::move(ex);
}
Result<Value> compile(Value& expr) {
Compiler c = Compiler();
return c.compile(expr);

View file

@ -33,6 +33,8 @@ class Compiler {
Result<Expression> compile_list(Context& context, Pair& expr);
Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_int64(Context& context, Int64& value);
Result<Expression> compile_bool(Context& context, Bool& value);
Result<Expression> compile_if(Context& context, Symbol& op, Pair& expr);
};
Result<Value> compile(Value& expr);

View file

@ -13,6 +13,7 @@ enum class ErrorCode {
MalformedList,
KeyError,
EndOfProgram,
CompilationError,
};
void seterr(const char* err);

View file

@ -10,7 +10,8 @@
StaticArena<64 * 1024 * 1024> arena;
Result<void> run() {
auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))"));
// auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))"));
auto code_str = TRY(String::create("(if true 1 2)"));
auto reader = Reader(code_str);
auto parsed = TRY(reader.read_one());

View file

@ -57,6 +57,20 @@ Result<void> VM::vm_div(Opcode& oc) {
return Result<void>();
}
Result<void> VM::vm_jump_not_equal(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
if (TRY(val1.cmp(val2)) != 0) {
_pc += oc.arg3().arg - 1;
}
return Result<void>();
}
Result<void> VM::vm_jump(Opcode& oc) {
_pc += oc.arg1().arg - 1;
return Result<void>();
}
Result<void> VM::step() {
auto opcode = TRY(_code.get(_pc));
if (!opcode.is<Opcode>()) return ERROR(TypeMismatch);
@ -81,6 +95,12 @@ Result<void> VM::step() {
case Oc::Ret:
_res = TRY(getreg((uint64_t)oc.arg1().arg));
return ERROR(EndOfProgram);
case Oc::JumpNotEqual:
TRY(vm_jump_not_equal(oc));
break;
case Oc::Jump:
TRY(vm_jump(oc));
break;
default:
return ERROR(NotImplemented);
}

View file

@ -23,6 +23,9 @@ class VM {
Result<void> vm_sub(Opcode& oc);
Result<void> vm_div(Opcode& oc);
Result<void> vm_jump_not_equal(Opcode& oc);
Result<void> vm_jump(Opcode& oc);
Result<Value> get(bool is_const, uint64_t idx);
Result<Value> getconst(uint64_t idx);
Result<Value> getreg(uint64_t idx);