Implement compiling lambda functions and function calls

This commit is contained in:
Konstantin Nazarov 2024-08-17 23:22:21 +01:00
parent 9a0f6d2264
commit fce8b84276
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
9 changed files with 157 additions and 17 deletions

View file

@ -135,6 +135,26 @@ Result<Value> Function::copy_value() const {
}
Result<Function> Function::copy() const { return Function(TRY(_value.copy())); }
Result<short> Function::cmp(const Function& rhs) const {
short res = (arity() > rhs.arity()) - (arity() < rhs.arity());
if (res != 0) return res;
auto lhs_name = TRY(name());
auto rhs_name = TRY(rhs.name());
res = TRY(lhs_name.cmp(rhs_name));
if (res != 0) return res;
auto lhs_code = TRY(code());
auto rhs_code = TRY(rhs.code());
res = TRY(lhs_code.cmp(rhs_code));
if (res != 0) return res;
auto lhs_constants = TRY(constants());
auto rhs_constants = TRY(rhs.constants());
res = TRY(lhs_constants.cmp(rhs_constants));
return res;
}
Result<Value> Stack::copy_value() const {
return Value(Stack(TRY(_value.copy())));
}
@ -161,6 +181,7 @@ Result<void> Stack::settop(uint64_t idx) {
for (uint64_t i = std::min(top, idx); i < std::max(top, idx); i++) {
_value->data[i] = 0;
}
_value->top = idx;
return Result<void>();
}

View file

@ -28,6 +28,7 @@ class ByteArray;
class Writer;
class Opcode;
class Stack;
class Function;
short cmp_tag(Tag lhs, Tag rhs);
@ -73,6 +74,9 @@ class Object {
virtual Result<short> cmp(const Opcode& rhs) const {
return cmp_tag(tag(), Tag::Opcode);
}
virtual Result<short> cmp(const Function& rhs) const {
return cmp_tag(tag(), Tag::Opcode);
}
virtual Result<short> cmp(const char* rhs) const {
return ERROR(NotImplemented);
}
@ -852,7 +856,7 @@ class Function : public Object {
virtual Result<short> cmp(const Object& rhs) const final {
return -TRY(rhs.cmp(*this));
}
virtual Result<short> cmp(const Opcode& rhs) const final { return 0; }
virtual Result<short> cmp(const Function& rhs) const final;
virtual void move(Object* obj) final {
new (obj) Function(std::move(_value));
@ -865,7 +869,7 @@ class Function : public Object {
static Result<Function> create(const Value& name, uint64_t arity,
const Array& constants, const Array& code);
uint64_t arity() { return _value->arity; }
uint64_t arity() const { return _value->arity; }
Result<Value> name() const;
Result<Array> code() const;
@ -901,12 +905,14 @@ class Stack : public Object {
Result<Value> get(uint64_t idx) const;
Result<void> set(uint64_t idx, const Value& val);
Result<void> settop(uint64_t idx);
uint64_t gettop() { return _value->top; }
static Result<Stack> create(uint64_t size) {
auto pod = TRY(arena_alloc<PodStack>(sizeof(OffPtr<PodObject>) * size));
pod->header.tag = Tag::Stack;
pod->size = size;
pod->top = 0;
for (uint64_t i = 0; i < size; i++) {
pod->data[i] = 0;

View file

@ -264,6 +264,7 @@ Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
Pair& expr) {
Context ctx = TRY(Context::create());
ctx.maxreg = 1; // Reserve the slot for function itself
auto first = TRY(expr.rest());
@ -292,7 +293,6 @@ Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
}
int64_t reg = TRY(ctx.add_var(param_first));
std::cout << "reg: " << reg << "\n";
param = TRY(param_pair.rest());
arity++;
@ -308,8 +308,13 @@ Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
auto ex = TRY(compile_body(ctx, second_pair));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)0}, {0, (int64_t)ex.reg}));
TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0}));
Value name = TRY(Nil::create());
auto fun = TRY(Function::create(name, arity, context.constants, ex.code));
auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code));
// TRY(debug_print(TRY(ex.code.copy())));
Expression ex_res = TRY(Expression::create());
@ -337,7 +342,7 @@ Result<Expression> Compiler::compile_body(Context& context, Pair& expr) {
Pair& cur_pair = *cur.to<Pair>();
auto expr_val = TRY(cur_pair.first());
debug_print(expr_val);
// debug_print(expr_val);
auto expr = TRY(compile_expr(context, expr_val));
@ -355,6 +360,37 @@ Result<Expression> Compiler::compile_body(Context& context, Pair& expr) {
return ex_res;
}
Result<Expression> Compiler::compile_function_call(Context& context,
Pair& expr) {
auto ex = TRY(Expression::create());
auto first = TRY(expr.first());
auto param = TRY(expr.rest());
auto fun_ex = TRY(compile_expr(context, first));
TRY(ex.add_code(fun_ex.code));
int64_t maxreg = context.maxreg;
while (!param.is<Nil>()) {
if (!param.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& param_pair = *param.to<Pair>();
Value param_val = TRY(param_pair.first());
auto param_ex = TRY(compile_expr(context, param_val));
TRY(ex.add_code(param_ex.code));
param = TRY(param_pair.rest());
}
TRY(ex.add_opcode(Oc::Call, {0, (int64_t)fun_ex.reg},
{0, (int64_t)context.maxreg}));
ex.reg = fun_ex.reg;
return ex;
}
Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
auto first = TRY(expr.first());
@ -367,6 +403,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
} else if (TRY(sym.cmp("lambda")) == 0) {
return compile_lambda(context, sym, expr);
}
} else if (first.is<Pair>()) {
return compile_function_call(context, expr);
}
return ERROR(TypeMismatch);
}
@ -395,7 +433,7 @@ Result<Expression> Compiler::compile_symbol(Context& context, Symbol& value) {
auto var_reg = maybe_reg.value();
uint64_t reg = context.alloc_reg();
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)var_reg}));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {0, (int64_t)var_reg}));
ex.reg = reg;
return std::move(ex);

View file

@ -38,6 +38,7 @@ class Compiler {
Result<Expression> compile_if(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_lambda(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_body(Context& context, Pair& expr);
Result<Expression> compile_function_call(Context& context, Pair& expr);
};
Result<Value> compile(Value& expr);

View file

@ -14,6 +14,7 @@ enum class ErrorCode {
KeyError,
EndOfProgram,
CompilationError,
ArgumentCountMismatch,
};
void seterr(const char* err);

View file

@ -40,7 +40,7 @@ op_t get_op(Oc op) {
case Oc::JumpNotEqual:
return op_t{"jne", OpcodeType::Reg2I};
case Oc::Call:
return op_t{"call", OpcodeType::Reg1I};
return op_t{"call", OpcodeType::Reg2};
case Oc::TailCall:
return op_t{"tailcall", OpcodeType::Reg1I};
case Oc::Ret:

View file

@ -11,7 +11,7 @@ StaticArena<64 * 1024 * 1024> arena;
Result<void> run() {
// auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))"));
auto code_str = TRY(String::create("(lambda (x) (* x x))"));
auto code_str = TRY(String::create("((lambda (x) (* x x)) 2)"));
auto reader = Reader(code_str);
auto parsed = TRY(reader.read_one());

View file

@ -18,6 +18,7 @@ Result<void> VM::vm_mov(Opcode& oc) {
uint64_t acc = (uint64_t)oc.arg1().arg;
Value val = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
setreg(acc, val);
_pc++;
return Result<void>();
}
@ -27,6 +28,7 @@ Result<void> VM::vm_add(Opcode& oc) {
Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg));
Value res = TRY(val1.add(val2));
setreg(acc, res);
_pc++;
return Result<void>();
}
@ -36,6 +38,7 @@ Result<void> VM::vm_mul(Opcode& oc) {
Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg));
Value res = TRY(val1.mul(val2));
setreg(acc, res);
_pc++;
return Result<void>();
}
@ -45,6 +48,7 @@ Result<void> VM::vm_sub(Opcode& oc) {
Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg));
Value res = TRY(val1.sub(val2));
setreg(acc, res);
_pc++;
return Result<void>();
}
@ -54,6 +58,7 @@ Result<void> VM::vm_div(Opcode& oc) {
Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg));
Value res = TRY(val1.div(val2));
setreg(acc, res);
_pc++;
return Result<void>();
}
@ -61,13 +66,74 @@ Result<void> VM::vm_jump_not_equal(Opcode& oc) {
Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg));
if (TRY(val1.cmp(val2)) != 0) {
_pc += oc.arg3().arg - 1;
_pc += oc.arg3().arg;
} else {
_pc++;
}
return Result<void>();
}
Result<void> VM::vm_call(Opcode& oc) {
Value fun_value = TRY(get(oc.arg1().is_const, oc.arg1().arg));
if (!fun_value.is<Function>()) return ERROR(TypeMismatch);
Function& fun = *fun_value.to<Function>();
uint64_t reg_start = (uint64_t)oc.arg1().arg;
uint64_t reg_end = (uint64_t)oc.arg2().arg;
if (fun.arity() != (reg_end - reg_start - 1)) {
return ERROR(ArgumentCountMismatch);
}
uint64_t old_base = _base;
Value fun_val = TRY(fun.copy());
Value oldbase_val = TRY(Int64::create(old_base));
Value pc_val = TRY(Int64::create(_pc));
_callstack.set(_callstack.gettop(), fun_val);
_callstack.set(_callstack.gettop(), oldbase_val);
_callstack.set(_callstack.gettop(), pc_val);
_code = TRY(fun.code());
_constants = TRY(fun.constants());
_fun = std::move(fun);
_pc = 0;
_base = reg_start;
return Result<void>();
}
Result<void> VM::vm_ret(Opcode& oc) {
if (_callstack.gettop() == 0) {
_res = TRY(getreg((uint64_t)oc.arg1().arg));
return ERROR(EndOfProgram);
}
Value fun_val = TRY(_callstack.get(_callstack.gettop() - 3));
Value oldbase_val = TRY(_callstack.get(_callstack.gettop() - 2));
Value pc_val = TRY(_callstack.get(_callstack.gettop() - 1));
if (!oldbase_val.is<Int64>() || !pc_val.is<Int64>() ||
!fun_val.is<Function>())
return ERROR(TypeMismatch);
Function& fun = *fun_val.to<Function>();
uint64_t oldbase = oldbase_val.to<Int64>()->value();
uint64_t pc = pc_val.to<Int64>()->value();
_code = TRY(fun.code());
_constants = TRY(fun.constants());
_fun = std::move(fun);
_pc = pc + 1;
_base = oldbase;
TRY(_callstack.settop(_callstack.gettop() - 3));
return Result<void>();
}
Result<void> VM::vm_jump(Opcode& oc) {
_pc += oc.arg1().arg - 1;
_pc += oc.arg1().arg;
return Result<void>();
}
@ -93,20 +159,21 @@ Result<void> VM::step() {
TRY(vm_div(oc));
break;
case Oc::Ret:
_res = TRY(getreg((uint64_t)oc.arg1().arg));
return ERROR(EndOfProgram);
TRY(vm_ret(oc));
break;
case Oc::JumpNotEqual:
TRY(vm_jump_not_equal(oc));
break;
case Oc::Jump:
TRY(vm_jump(oc));
break;
case Oc::Call:
TRY(vm_call(oc));
break;
default:
return ERROR(NotImplemented);
}
_pc++;
return Result<void>();
}

View file

@ -5,13 +5,16 @@
class VM {
public:
VM() {}
VM(Stack&& stack) : _stack(std::move(stack)) {}
VM(VM&& vm) : _stack(std::move(vm._stack)) {}
VM(Stack&& stack, Stack&& callstack)
: _stack(std::move(stack)), _callstack(std::move(callstack)) {}
VM(VM&& vm)
: _stack(std::move(vm._stack)), _callstack(std::move(vm._callstack)) {}
VM(const VM&) = delete;
static Result<VM> create() {
auto stack = TRY(Stack::create(16 * 1024));
return VM(std::move(stack));
auto callstack = TRY(Stack::create(16 * 1024));
return VM(std::move(stack), std::move(callstack));
}
Result<Value> run(const Function& fun);
@ -22,6 +25,8 @@ class VM {
Result<void> vm_mul(Opcode& oc);
Result<void> vm_sub(Opcode& oc);
Result<void> vm_div(Opcode& oc);
Result<void> vm_call(Opcode& oc);
Result<void> vm_ret(Opcode& oc);
Result<void> vm_jump_not_equal(Opcode& oc);
Result<void> vm_jump(Opcode& oc);
@ -33,6 +38,7 @@ class VM {
private:
Stack _stack;
Stack _callstack;
Function _fun;
Array _code;
Array _constants;