Compile comparison operations

This commit is contained in:
Konstantin Nazarov 2024-08-23 21:30:05 +01:00
parent 5948bfa973
commit ecbdc17f2b
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
6 changed files with 139 additions and 23 deletions

View file

@ -134,6 +134,13 @@ Result<bool> is_primitive_op(Symbol& sym) {
return false;
}
Result<bool> is_comparison_op(Symbol& sym) {
return TRY(sym.cmp("<")) == 0 || TRY(sym.cmp("<=")) == 0 ||
TRY(sym.cmp(">")) == 0 || TRY(sym.cmp(">=")) == 0 ||
TRY(sym.cmp("=")) == 0 || TRY(sym.cmp("!=")) == 0;
return false;
}
Result<Value> Compiler::compile(Value& expr) {
auto context = TRY(Context::create());
@ -246,6 +253,80 @@ Result<Expression> Compiler::compile_primop(Context& context, Symbol& op,
return std::move(ex);
}
Result<Expression> Compiler::compile_comparison(Context& context, Symbol& op,
Pair& expr) {
Value cur = TRY(expr.rest());
Expression ex = TRY(Expression::create());
Oc opcode = Oc::Unknown;
int64_t cmp_expected = 0;
if (TRY(op.cmp("<")) == 0) {
opcode = Oc::Less;
cmp_expected = 0;
} else if (TRY(op.cmp("<=")) == 0) {
opcode = Oc::LessEqual;
cmp_expected = 0;
} else if (TRY(op.cmp(">")) == 0) {
opcode = Oc::LessEqual;
cmp_expected = 1;
} else if (TRY(op.cmp(">=")) == 0) {
opcode = Oc::Less;
cmp_expected = 1;
} else if (TRY(op.cmp("=")) == 0) {
opcode = Oc::Equal;
cmp_expected = 0;
} else if (TRY(op.cmp("!=")) == 0) {
opcode = Oc::Equal;
cmp_expected = 1;
} else {
return ERROR(NotImplemented);
}
if (cur.is<Nil>()) {
return ERROR(CompilationError);
}
uint64_t result = context.alloc_reg();
int64_t true_c = TRY(context.add_const(TRY(Bool::create(true))));
int64_t false_c = TRY(context.add_const(TRY(Bool::create(false))));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)result}, {1, true_c}));
Pair& pair = *cur.to<Pair>();
auto subexpr = TRY(pair.first());
auto comp = TRY(compile_expr(context, subexpr));
ex.add_code(comp.code);
uint64_t firstreg = comp.reg;
uint64_t reg = firstreg;
cur = TRY(pair.rest());
while (!cur.is<Nil>()) {
Pair& pair = *cur.to<Pair>();
auto subexpr = TRY(pair.first());
auto comp = TRY(compile_expr(context, subexpr));
ex.add_code(comp.code);
auto rest = TRY(pair.rest());
TRY(ex.add_opcode(opcode, {0, (int64_t)reg}, {0, (int64_t)comp.reg},
{0, (int64_t)cmp_expected}));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)result}, {1, false_c}));
reg = comp.reg;
cur = std::move(rest);
}
context.maxreg = result + 1;
ex.reg = result;
return std::move(ex);
}
Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
Pair& expr) {
Value first = TRY(expr.rest());
@ -292,9 +373,10 @@ Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
int64_t true_const = TRY(context.add_const(TRY(Bool::create(true))));
TRY(ex.add_opcode(Oc::JumpNotEqual, {0, (int64_t)firstreg},
{1, (int64_t)true_const},
{0, (int64_t)option1_comp.code.size() + 2}));
TRY(ex.add_opcode(Oc::Equal, {0, (int64_t)firstreg}, {1, (int64_t)true_const},
{0, (int64_t)0}));
TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option1_comp.code.size() + 2}));
ex.add_code(option1_comp.code);
TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option2_comp.code.size() + 1}));
@ -471,6 +553,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
Symbol& sym = *first.to<Symbol>();
if (TRY(is_primitive_op(sym))) {
return compile_primop(context, sym, expr);
} else if (TRY(is_comparison_op(sym))) {
return compile_comparison(context, sym, expr);
} else if (TRY(sym.cmp("if")) == 0) {
return compile_if(context, sym, expr);
} else if (TRY(sym.cmp("lambda")) == 0) {

View file

@ -32,6 +32,8 @@ class Compiler {
Result<Expression> compile_expr(Context& context, Value& expr);
Result<Expression> compile_list(Context& context, Pair& expr);
Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_comparison(Context& context, Symbol& op,
Pair& expr);
Result<Expression> compile_int64(Context& context, Int64& value);
Result<Expression> compile_symbol(Context& context, Symbol& value);
Result<Expression> compile_bool(Context& context, Bool& value);

View file

@ -30,15 +30,13 @@ op_t get_op(Oc op) {
case Oc::LoadStack:
return op_t{"lfi", OpcodeType::Reg1I};
case Oc::Jump:
return op_t{"jmp", OpcodeType::Reg0I};
case Oc::JumpEqual:
return op_t{"jeq", OpcodeType::Reg2I};
case Oc::JumpLess:
return op_t{"jlt", OpcodeType::Reg2I};
case Oc::JumpLessEqual:
return op_t{"jle", OpcodeType::Reg2I};
case Oc::JumpNotEqual:
return op_t{"jne", OpcodeType::Reg2I};
return op_t{"jump", OpcodeType::Reg0I};
case Oc::Equal:
return op_t{"equal", OpcodeType::Reg2I};
case Oc::Less:
return op_t{"less", OpcodeType::Reg2I};
case Oc::LessEqual:
return op_t{"less-equal", OpcodeType::Reg2I};
case Oc::Call:
return op_t{"call", OpcodeType::Reg2};
case Oc::TailCall:

View file

@ -28,10 +28,9 @@ enum class Oc : uint8_t {
LoadStack,
// Jumps
Jump,
JumpEqual,
JumpLess,
JumpLessEqual,
JumpNotEqual,
Equal,
Less,
LessEqual,
// Function calls
Call,
TailCall,

View file

@ -62,13 +62,38 @@ Result<void> VM::vm_div(Opcode& oc) {
return Result<void>();
}
Result<void> VM::vm_jump_not_equal(Opcode& oc) {
Result<void> VM::vm_equal(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
if (TRY(val1.cmp(val2)) != 0) {
_pc += oc.arg3().arg;
int64_t expected = oc.arg3().arg;
if ((TRY(val1.cmp(val2)) == 0) != expected) {
_pc += 2;
} else {
_pc++;
_pc += 1;
}
return Result<void>();
}
Result<void> VM::vm_less(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
int64_t expected = oc.arg3().arg;
if ((TRY(val1.cmp(val2)) == -1) != expected) {
_pc += 2;
} else {
_pc += 1;
}
return Result<void>();
}
Result<void> VM::vm_less_equal(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
int64_t expected = oc.arg3().arg;
if ((TRY(val1.cmp(val2)) <= 0) != expected) {
_pc += 2;
} else {
_pc += 1;
}
return Result<void>();
}
@ -191,8 +216,14 @@ Result<void> VM::step() {
case Oc::Ret:
TRY(vm_ret(oc));
break;
case Oc::JumpNotEqual:
TRY(vm_jump_not_equal(oc));
case Oc::Equal:
TRY(vm_equal(oc));
break;
case Oc::Less:
TRY(vm_less(oc));
break;
case Oc::LessEqual:
TRY(vm_less_equal(oc));
break;
case Oc::Jump:
TRY(vm_jump(oc));

View file

@ -28,7 +28,9 @@ class VM {
Result<void> vm_call(Opcode& oc);
Result<void> vm_ret(Opcode& oc);
Result<void> vm_jump_not_equal(Opcode& oc);
Result<void> vm_equal(Opcode& oc);
Result<void> vm_less(Opcode& oc);
Result<void> vm_less_equal(Opcode& oc);
Result<void> vm_jump(Opcode& oc);
Result<void> vm_make_closure(Opcode& oc);