Implement simple conditionals

This commit is contained in:
Konstantin Nazarov 2024-08-15 01:18:05 +01:00
parent 3a117c9b1f
commit 81eb5e43ce
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
7 changed files with 113 additions and 3 deletions

View file

@ -764,7 +764,10 @@ class Bool : public Object {
virtual Result<short> cmp(const Object& rhs) const final { virtual Result<short> cmp(const Object& rhs) const final {
return -TRY(rhs.cmp(*this)); return -TRY(rhs.cmp(*this));
} }
virtual Result<short> cmp(const Bool& rhs) const final { return 0; } virtual Result<short> cmp(const Bool& rhs) const final {
return (_value->value > rhs._value->value) -
(_value->value < rhs._value->value);
}
virtual void move(Object* obj) final { new (obj) Bool(std::move(_value)); } virtual void move(Object* obj) final { new (obj) Bool(std::move(_value)); }

View file

@ -4,6 +4,7 @@
struct Context { struct Context {
Context() {} Context() {}
Context(Value&& env, Array&& constants, Dict&& constants_dict) Context(Value&& env, Array&& constants, Dict&& constants_dict)
: env(std::move(env)), : env(std::move(env)),
constants(std::move(constants)), constants(std::move(constants)),
@ -80,8 +81,9 @@ Result<Expression> Compiler::compile_expr(Context& context, Value& expr) {
return compile_list(context, *expr.to<Pair>()); return compile_list(context, *expr.to<Pair>());
case Tag::Int64: case Tag::Int64:
return compile_int64(context, *expr.to<Int64>()); return compile_int64(context, *expr.to<Int64>());
case Tag::Nil:
case Tag::Bool: case Tag::Bool:
return compile_bool(context, *expr.to<Bool>());
case Tag::Nil:
case Tag::Float: case Tag::Float:
case Tag::String: case Tag::String:
case Tag::Symbol: case Tag::Symbol:
@ -163,6 +165,70 @@ Result<Expression> Compiler::compile_primop(Context& context, Symbol& op,
return std::move(ex); return std::move(ex);
} }
Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
Pair& expr) {
Value first = TRY(expr.rest());
Expression ex = TRY(Expression::create());
if (first.is<Nil>() || !first.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& first_pair = *first.to<Pair>();
Value second = TRY(first_pair.rest());
if (second.is<Nil>() || !second.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& second_pair = *second.to<Pair>();
Value third = TRY(second_pair.rest());
if (third.is<Nil>() || !third.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& third_pair = *third.to<Pair>();
auto condition = TRY(first_pair.first());
auto condition_comp = TRY(compile_expr(context, condition));
ex.add_code(condition_comp.code);
uint64_t firstreg = condition_comp.reg;
uint64_t reg = firstreg;
auto option1 = TRY(second_pair.first());
auto option1_comp = TRY(compile_expr(context, option1));
uint64_t option1_reg = option1_comp.reg;
context.maxreg = firstreg + 1;
auto option2 = TRY(third_pair.first());
auto option2_comp = TRY(compile_expr(context, option2));
int64_t true_const = TRY(context.add_const(TRY(Bool::create(true))));
TRY(ex.add_opcode(Oc::JumpNotEqual, {0, (int64_t)firstreg},
{1, (int64_t)true_const},
{0, (int64_t)option1_comp.code.size() + 2}));
ex.add_code(option1_comp.code);
TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option2_comp.code.size() + 1}));
ex.add_code(option2_comp.code);
uint64_t option2_reg = option2_comp.reg;
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)firstreg},
{0, (int64_t)option1_reg}));
context.maxreg = firstreg + 1;
ex.reg = firstreg;
return std::move(ex);
}
Result<Expression> Compiler::compile_list(Context& context, Pair& expr) { Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
auto first = TRY(expr.first()); auto first = TRY(expr.first());
@ -170,6 +236,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
Symbol& sym = *first.to<Symbol>(); Symbol& sym = *first.to<Symbol>();
if (TRY(is_primitive_op(sym))) { if (TRY(is_primitive_op(sym))) {
return compile_primop(context, sym, expr); return compile_primop(context, sym, expr);
} else if (TRY(sym.cmp("if")) == 0) {
return compile_if(context, sym, expr);
} }
} }
return ERROR(TypeMismatch); return ERROR(TypeMismatch);
@ -187,6 +255,18 @@ Result<Expression> Compiler::compile_int64(Context& context, Int64& value) {
return std::move(ex); return std::move(ex);
} }
Result<Expression> Compiler::compile_bool(Context& context, Bool& value) {
Expression ex = TRY(Expression::create());
uint64_t reg = context.alloc_reg();
int64_t c = TRY(context.add_const(TRY(value.copy_value())));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c}));
ex.reg = reg;
return std::move(ex);
}
Result<Value> compile(Value& expr) { Result<Value> compile(Value& expr) {
Compiler c = Compiler(); Compiler c = Compiler();
return c.compile(expr); return c.compile(expr);

View file

@ -33,6 +33,8 @@ class Compiler {
Result<Expression> compile_list(Context& context, Pair& expr); Result<Expression> compile_list(Context& context, Pair& expr);
Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr); Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_int64(Context& context, Int64& value); Result<Expression> compile_int64(Context& context, Int64& value);
Result<Expression> compile_bool(Context& context, Bool& value);
Result<Expression> compile_if(Context& context, Symbol& op, Pair& expr);
}; };
Result<Value> compile(Value& expr); Result<Value> compile(Value& expr);

View file

@ -13,6 +13,7 @@ enum class ErrorCode {
MalformedList, MalformedList,
KeyError, KeyError,
EndOfProgram, EndOfProgram,
CompilationError,
}; };
void seterr(const char* err); void seterr(const char* err);

View file

@ -10,7 +10,8 @@
StaticArena<64 * 1024 * 1024> arena; StaticArena<64 * 1024 * 1024> arena;
Result<void> run() { Result<void> run() {
auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))")); // auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))"));
auto code_str = TRY(String::create("(if true 1 2)"));
auto reader = Reader(code_str); auto reader = Reader(code_str);
auto parsed = TRY(reader.read_one()); auto parsed = TRY(reader.read_one());

View file

@ -57,6 +57,20 @@ Result<void> VM::vm_div(Opcode& oc) {
return Result<void>(); return Result<void>();
} }
Result<void> VM::vm_jump_not_equal(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
if (TRY(val1.cmp(val2)) != 0) {
_pc += oc.arg3().arg - 1;
}
return Result<void>();
}
Result<void> VM::vm_jump(Opcode& oc) {
_pc += oc.arg1().arg - 1;
return Result<void>();
}
Result<void> VM::step() { Result<void> VM::step() {
auto opcode = TRY(_code.get(_pc)); auto opcode = TRY(_code.get(_pc));
if (!opcode.is<Opcode>()) return ERROR(TypeMismatch); if (!opcode.is<Opcode>()) return ERROR(TypeMismatch);
@ -81,6 +95,12 @@ Result<void> VM::step() {
case Oc::Ret: case Oc::Ret:
_res = TRY(getreg((uint64_t)oc.arg1().arg)); _res = TRY(getreg((uint64_t)oc.arg1().arg));
return ERROR(EndOfProgram); return ERROR(EndOfProgram);
case Oc::JumpNotEqual:
TRY(vm_jump_not_equal(oc));
break;
case Oc::Jump:
TRY(vm_jump(oc));
break;
default: default:
return ERROR(NotImplemented); return ERROR(NotImplemented);
} }

View file

@ -23,6 +23,9 @@ class VM {
Result<void> vm_sub(Opcode& oc); Result<void> vm_sub(Opcode& oc);
Result<void> vm_div(Opcode& oc); Result<void> vm_div(Opcode& oc);
Result<void> vm_jump_not_equal(Opcode& oc);
Result<void> vm_jump(Opcode& oc);
Result<Value> get(bool is_const, uint64_t idx); Result<Value> get(bool is_const, uint64_t idx);
Result<Value> getconst(uint64_t idx); Result<Value> getconst(uint64_t idx);
Result<Value> getreg(uint64_t idx); Result<Value> getreg(uint64_t idx);