Compile comparison operations

This commit is contained in:
Konstantin Nazarov 2024-08-23 21:30:05 +01:00
parent 5948bfa973
commit ecbdc17f2b
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
6 changed files with 139 additions and 23 deletions

View file

@ -134,6 +134,13 @@ Result<bool> is_primitive_op(Symbol& sym) {
return false; return false;
} }
Result<bool> is_comparison_op(Symbol& sym) {
return TRY(sym.cmp("<")) == 0 || TRY(sym.cmp("<=")) == 0 ||
TRY(sym.cmp(">")) == 0 || TRY(sym.cmp(">=")) == 0 ||
TRY(sym.cmp("=")) == 0 || TRY(sym.cmp("!=")) == 0;
return false;
}
Result<Value> Compiler::compile(Value& expr) { Result<Value> Compiler::compile(Value& expr) {
auto context = TRY(Context::create()); auto context = TRY(Context::create());
@ -246,6 +253,80 @@ Result<Expression> Compiler::compile_primop(Context& context, Symbol& op,
return std::move(ex); return std::move(ex);
} }
Result<Expression> Compiler::compile_comparison(Context& context, Symbol& op,
Pair& expr) {
Value cur = TRY(expr.rest());
Expression ex = TRY(Expression::create());
Oc opcode = Oc::Unknown;
int64_t cmp_expected = 0;
if (TRY(op.cmp("<")) == 0) {
opcode = Oc::Less;
cmp_expected = 0;
} else if (TRY(op.cmp("<=")) == 0) {
opcode = Oc::LessEqual;
cmp_expected = 0;
} else if (TRY(op.cmp(">")) == 0) {
opcode = Oc::LessEqual;
cmp_expected = 1;
} else if (TRY(op.cmp(">=")) == 0) {
opcode = Oc::Less;
cmp_expected = 1;
} else if (TRY(op.cmp("=")) == 0) {
opcode = Oc::Equal;
cmp_expected = 0;
} else if (TRY(op.cmp("!=")) == 0) {
opcode = Oc::Equal;
cmp_expected = 1;
} else {
return ERROR(NotImplemented);
}
if (cur.is<Nil>()) {
return ERROR(CompilationError);
}
uint64_t result = context.alloc_reg();
int64_t true_c = TRY(context.add_const(TRY(Bool::create(true))));
int64_t false_c = TRY(context.add_const(TRY(Bool::create(false))));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)result}, {1, true_c}));
Pair& pair = *cur.to<Pair>();
auto subexpr = TRY(pair.first());
auto comp = TRY(compile_expr(context, subexpr));
ex.add_code(comp.code);
uint64_t firstreg = comp.reg;
uint64_t reg = firstreg;
cur = TRY(pair.rest());
while (!cur.is<Nil>()) {
Pair& pair = *cur.to<Pair>();
auto subexpr = TRY(pair.first());
auto comp = TRY(compile_expr(context, subexpr));
ex.add_code(comp.code);
auto rest = TRY(pair.rest());
TRY(ex.add_opcode(opcode, {0, (int64_t)reg}, {0, (int64_t)comp.reg},
{0, (int64_t)cmp_expected}));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)result}, {1, false_c}));
reg = comp.reg;
cur = std::move(rest);
}
context.maxreg = result + 1;
ex.reg = result;
return std::move(ex);
}
Result<Expression> Compiler::compile_if(Context& context, Symbol& op, Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
Pair& expr) { Pair& expr) {
Value first = TRY(expr.rest()); Value first = TRY(expr.rest());
@ -292,9 +373,10 @@ Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
int64_t true_const = TRY(context.add_const(TRY(Bool::create(true)))); int64_t true_const = TRY(context.add_const(TRY(Bool::create(true))));
TRY(ex.add_opcode(Oc::JumpNotEqual, {0, (int64_t)firstreg}, TRY(ex.add_opcode(Oc::Equal, {0, (int64_t)firstreg}, {1, (int64_t)true_const},
{1, (int64_t)true_const}, {0, (int64_t)0}));
{0, (int64_t)option1_comp.code.size() + 2}));
TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option1_comp.code.size() + 2}));
ex.add_code(option1_comp.code); ex.add_code(option1_comp.code);
TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option2_comp.code.size() + 1})); TRY(ex.add_opcode(Oc::Jump, {0, (int64_t)option2_comp.code.size() + 1}));
@ -471,6 +553,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
Symbol& sym = *first.to<Symbol>(); Symbol& sym = *first.to<Symbol>();
if (TRY(is_primitive_op(sym))) { if (TRY(is_primitive_op(sym))) {
return compile_primop(context, sym, expr); return compile_primop(context, sym, expr);
} else if (TRY(is_comparison_op(sym))) {
return compile_comparison(context, sym, expr);
} else if (TRY(sym.cmp("if")) == 0) { } else if (TRY(sym.cmp("if")) == 0) {
return compile_if(context, sym, expr); return compile_if(context, sym, expr);
} else if (TRY(sym.cmp("lambda")) == 0) { } else if (TRY(sym.cmp("lambda")) == 0) {

View file

@ -32,6 +32,8 @@ class Compiler {
Result<Expression> compile_expr(Context& context, Value& expr); Result<Expression> compile_expr(Context& context, Value& expr);
Result<Expression> compile_list(Context& context, Pair& expr); Result<Expression> compile_list(Context& context, Pair& expr);
Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr); Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_comparison(Context& context, Symbol& op,
Pair& expr);
Result<Expression> compile_int64(Context& context, Int64& value); Result<Expression> compile_int64(Context& context, Int64& value);
Result<Expression> compile_symbol(Context& context, Symbol& value); Result<Expression> compile_symbol(Context& context, Symbol& value);
Result<Expression> compile_bool(Context& context, Bool& value); Result<Expression> compile_bool(Context& context, Bool& value);

View file

@ -30,15 +30,13 @@ op_t get_op(Oc op) {
case Oc::LoadStack: case Oc::LoadStack:
return op_t{"lfi", OpcodeType::Reg1I}; return op_t{"lfi", OpcodeType::Reg1I};
case Oc::Jump: case Oc::Jump:
return op_t{"jmp", OpcodeType::Reg0I}; return op_t{"jump", OpcodeType::Reg0I};
case Oc::JumpEqual: case Oc::Equal:
return op_t{"jeq", OpcodeType::Reg2I}; return op_t{"equal", OpcodeType::Reg2I};
case Oc::JumpLess: case Oc::Less:
return op_t{"jlt", OpcodeType::Reg2I}; return op_t{"less", OpcodeType::Reg2I};
case Oc::JumpLessEqual: case Oc::LessEqual:
return op_t{"jle", OpcodeType::Reg2I}; return op_t{"less-equal", OpcodeType::Reg2I};
case Oc::JumpNotEqual:
return op_t{"jne", OpcodeType::Reg2I};
case Oc::Call: case Oc::Call:
return op_t{"call", OpcodeType::Reg2}; return op_t{"call", OpcodeType::Reg2};
case Oc::TailCall: case Oc::TailCall:

View file

@ -28,10 +28,9 @@ enum class Oc : uint8_t {
LoadStack, LoadStack,
// Jumps // Jumps
Jump, Jump,
JumpEqual, Equal,
JumpLess, Less,
JumpLessEqual, LessEqual,
JumpNotEqual,
// Function calls // Function calls
Call, Call,
TailCall, TailCall,

View file

@ -62,13 +62,38 @@ Result<void> VM::vm_div(Opcode& oc) {
return Result<void>(); return Result<void>();
} }
Result<void> VM::vm_jump_not_equal(Opcode& oc) { Result<void> VM::vm_equal(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg)); Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg)); Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
if (TRY(val1.cmp(val2)) != 0) { int64_t expected = oc.arg3().arg;
_pc += oc.arg3().arg; if ((TRY(val1.cmp(val2)) == 0) != expected) {
_pc += 2;
} else { } else {
_pc++; _pc += 1;
}
return Result<void>();
}
Result<void> VM::vm_less(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
int64_t expected = oc.arg3().arg;
if ((TRY(val1.cmp(val2)) == -1) != expected) {
_pc += 2;
} else {
_pc += 1;
}
return Result<void>();
}
Result<void> VM::vm_less_equal(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
int64_t expected = oc.arg3().arg;
if ((TRY(val1.cmp(val2)) <= 0) != expected) {
_pc += 2;
} else {
_pc += 1;
} }
return Result<void>(); return Result<void>();
} }
@ -191,8 +216,14 @@ Result<void> VM::step() {
case Oc::Ret: case Oc::Ret:
TRY(vm_ret(oc)); TRY(vm_ret(oc));
break; break;
case Oc::JumpNotEqual: case Oc::Equal:
TRY(vm_jump_not_equal(oc)); TRY(vm_equal(oc));
break;
case Oc::Less:
TRY(vm_less(oc));
break;
case Oc::LessEqual:
TRY(vm_less_equal(oc));
break; break;
case Oc::Jump: case Oc::Jump:
TRY(vm_jump(oc)); TRY(vm_jump(oc));

View file

@ -28,7 +28,9 @@ class VM {
Result<void> vm_call(Opcode& oc); Result<void> vm_call(Opcode& oc);
Result<void> vm_ret(Opcode& oc); Result<void> vm_ret(Opcode& oc);
Result<void> vm_jump_not_equal(Opcode& oc); Result<void> vm_equal(Opcode& oc);
Result<void> vm_less(Opcode& oc);
Result<void> vm_less_equal(Opcode& oc);
Result<void> vm_jump(Opcode& oc); Result<void> vm_jump(Opcode& oc);
Result<void> vm_make_closure(Opcode& oc); Result<void> vm_make_closure(Opcode& oc);