diff --git a/src/common.cpp b/src/common.cpp index 87605d2..676a0d9 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -135,6 +135,26 @@ Result Function::copy_value() const { } Result Function::copy() const { return Function(TRY(_value.copy())); } +Result Function::cmp(const Function& rhs) const { + short res = (arity() > rhs.arity()) - (arity() < rhs.arity()); + if (res != 0) return res; + + auto lhs_name = TRY(name()); + auto rhs_name = TRY(rhs.name()); + res = TRY(lhs_name.cmp(rhs_name)); + if (res != 0) return res; + + auto lhs_code = TRY(code()); + auto rhs_code = TRY(rhs.code()); + res = TRY(lhs_code.cmp(rhs_code)); + if (res != 0) return res; + + auto lhs_constants = TRY(constants()); + auto rhs_constants = TRY(rhs.constants()); + res = TRY(lhs_constants.cmp(rhs_constants)); + return res; +} + Result Stack::copy_value() const { return Value(Stack(TRY(_value.copy()))); } @@ -161,6 +181,7 @@ Result Stack::settop(uint64_t idx) { for (uint64_t i = std::min(top, idx); i < std::max(top, idx); i++) { _value->data[i] = 0; } + _value->top = idx; return Result(); } diff --git a/src/common.hpp b/src/common.hpp index 1c300cf..22b3ed3 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -28,6 +28,7 @@ class ByteArray; class Writer; class Opcode; class Stack; +class Function; short cmp_tag(Tag lhs, Tag rhs); @@ -73,6 +74,9 @@ class Object { virtual Result cmp(const Opcode& rhs) const { return cmp_tag(tag(), Tag::Opcode); } + virtual Result cmp(const Function& rhs) const { + return cmp_tag(tag(), Tag::Opcode); + } virtual Result cmp(const char* rhs) const { return ERROR(NotImplemented); } @@ -852,7 +856,7 @@ class Function : public Object { virtual Result cmp(const Object& rhs) const final { return -TRY(rhs.cmp(*this)); } - virtual Result cmp(const Opcode& rhs) const final { return 0; } + virtual Result cmp(const Function& rhs) const final; virtual void move(Object* obj) final { new (obj) Function(std::move(_value)); @@ -865,7 +869,7 @@ class Function : public Object { static Result create(const Value& name, uint64_t arity, const Array& constants, const Array& code); - uint64_t arity() { return _value->arity; } + uint64_t arity() const { return _value->arity; } Result name() const; Result code() const; @@ -901,12 +905,14 @@ class Stack : public Object { Result get(uint64_t idx) const; Result set(uint64_t idx, const Value& val); Result settop(uint64_t idx); + uint64_t gettop() { return _value->top; } static Result create(uint64_t size) { auto pod = TRY(arena_alloc(sizeof(OffPtr) * size)); pod->header.tag = Tag::Stack; pod->size = size; + pod->top = 0; for (uint64_t i = 0; i < size; i++) { pod->data[i] = 0; diff --git a/src/compiler.cpp b/src/compiler.cpp index 4f458c0..77290b7 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -264,6 +264,7 @@ Result Compiler::compile_if(Context& context, Symbol& op, Result Compiler::compile_lambda(Context& context, Symbol& op, Pair& expr) { Context ctx = TRY(Context::create()); + ctx.maxreg = 1; // Reserve the slot for function itself auto first = TRY(expr.rest()); @@ -292,7 +293,6 @@ Result Compiler::compile_lambda(Context& context, Symbol& op, } int64_t reg = TRY(ctx.add_var(param_first)); - std::cout << "reg: " << reg << "\n"; param = TRY(param_pair.rest()); arity++; @@ -308,8 +308,13 @@ Result Compiler::compile_lambda(Context& context, Symbol& op, auto ex = TRY(compile_body(ctx, second_pair)); + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)0}, {0, (int64_t)ex.reg})); + TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0})); + Value name = TRY(Nil::create()); - auto fun = TRY(Function::create(name, arity, context.constants, ex.code)); + auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code)); + + // TRY(debug_print(TRY(ex.code.copy()))); Expression ex_res = TRY(Expression::create()); @@ -337,7 +342,7 @@ Result Compiler::compile_body(Context& context, Pair& expr) { Pair& cur_pair = *cur.to(); auto expr_val = TRY(cur_pair.first()); - debug_print(expr_val); + // debug_print(expr_val); auto expr = TRY(compile_expr(context, expr_val)); @@ -355,6 +360,37 @@ Result Compiler::compile_body(Context& context, Pair& expr) { return ex_res; } +Result Compiler::compile_function_call(Context& context, + Pair& expr) { + auto ex = TRY(Expression::create()); + + auto first = TRY(expr.first()); + auto param = TRY(expr.rest()); + + auto fun_ex = TRY(compile_expr(context, first)); + TRY(ex.add_code(fun_ex.code)); + + int64_t maxreg = context.maxreg; + while (!param.is()) { + if (!param.is()) { + return ERROR(CompilationError); + } + Pair& param_pair = *param.to(); + Value param_val = TRY(param_pair.first()); + + auto param_ex = TRY(compile_expr(context, param_val)); + TRY(ex.add_code(param_ex.code)); + + param = TRY(param_pair.rest()); + } + + TRY(ex.add_opcode(Oc::Call, {0, (int64_t)fun_ex.reg}, + {0, (int64_t)context.maxreg})); + + ex.reg = fun_ex.reg; + return ex; +} + Result Compiler::compile_list(Context& context, Pair& expr) { auto first = TRY(expr.first()); @@ -367,6 +403,8 @@ Result Compiler::compile_list(Context& context, Pair& expr) { } else if (TRY(sym.cmp("lambda")) == 0) { return compile_lambda(context, sym, expr); } + } else if (first.is()) { + return compile_function_call(context, expr); } return ERROR(TypeMismatch); } @@ -395,7 +433,7 @@ Result Compiler::compile_symbol(Context& context, Symbol& value) { auto var_reg = maybe_reg.value(); uint64_t reg = context.alloc_reg(); - TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)var_reg})); + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {0, (int64_t)var_reg})); ex.reg = reg; return std::move(ex); diff --git a/src/compiler.hpp b/src/compiler.hpp index 2878a41..5cfeb59 100644 --- a/src/compiler.hpp +++ b/src/compiler.hpp @@ -38,6 +38,7 @@ class Compiler { Result compile_if(Context& context, Symbol& op, Pair& expr); Result compile_lambda(Context& context, Symbol& op, Pair& expr); Result compile_body(Context& context, Pair& expr); + Result compile_function_call(Context& context, Pair& expr); }; Result compile(Value& expr); diff --git a/src/error.hpp b/src/error.hpp index 02f9fb3..f88fa4e 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -14,6 +14,7 @@ enum class ErrorCode { KeyError, EndOfProgram, CompilationError, + ArgumentCountMismatch, }; void seterr(const char* err); diff --git a/src/opcode.cpp b/src/opcode.cpp index 5f1c6c6..ad51006 100644 --- a/src/opcode.cpp +++ b/src/opcode.cpp @@ -40,7 +40,7 @@ op_t get_op(Oc op) { case Oc::JumpNotEqual: return op_t{"jne", OpcodeType::Reg2I}; case Oc::Call: - return op_t{"call", OpcodeType::Reg1I}; + return op_t{"call", OpcodeType::Reg2}; case Oc::TailCall: return op_t{"tailcall", OpcodeType::Reg1I}; case Oc::Ret: diff --git a/src/vli.cpp b/src/vli.cpp index b95e124..3e76d11 100644 --- a/src/vli.cpp +++ b/src/vli.cpp @@ -11,7 +11,7 @@ StaticArena<64 * 1024 * 1024> arena; Result run() { // auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))")); - auto code_str = TRY(String::create("(lambda (x) (* x x))")); + auto code_str = TRY(String::create("((lambda (x) (* x x)) 2)")); auto reader = Reader(code_str); auto parsed = TRY(reader.read_one()); diff --git a/src/vm.cpp b/src/vm.cpp index 01424e5..aa34128 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -18,6 +18,7 @@ Result VM::vm_mov(Opcode& oc) { uint64_t acc = (uint64_t)oc.arg1().arg; Value val = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg)); setreg(acc, val); + _pc++; return Result(); } @@ -27,6 +28,7 @@ Result VM::vm_add(Opcode& oc) { Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg)); Value res = TRY(val1.add(val2)); setreg(acc, res); + _pc++; return Result(); } @@ -36,6 +38,7 @@ Result VM::vm_mul(Opcode& oc) { Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg)); Value res = TRY(val1.mul(val2)); setreg(acc, res); + _pc++; return Result(); } @@ -45,6 +48,7 @@ Result VM::vm_sub(Opcode& oc) { Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg)); Value res = TRY(val1.sub(val2)); setreg(acc, res); + _pc++; return Result(); } @@ -54,6 +58,7 @@ Result VM::vm_div(Opcode& oc) { Value val2 = TRY(get(oc.arg3().is_const, (uint64_t)oc.arg3().arg)); Value res = TRY(val1.div(val2)); setreg(acc, res); + _pc++; return Result(); } @@ -61,13 +66,74 @@ Result VM::vm_jump_not_equal(Opcode& oc) { Value val1 = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg)); Value val2 = TRY(get(oc.arg2().is_const, (uint64_t)oc.arg2().arg)); if (TRY(val1.cmp(val2)) != 0) { - _pc += oc.arg3().arg - 1; + _pc += oc.arg3().arg; + } else { + _pc++; } return Result(); } +Result VM::vm_call(Opcode& oc) { + Value fun_value = TRY(get(oc.arg1().is_const, oc.arg1().arg)); + if (!fun_value.is()) return ERROR(TypeMismatch); + Function& fun = *fun_value.to(); + + uint64_t reg_start = (uint64_t)oc.arg1().arg; + uint64_t reg_end = (uint64_t)oc.arg2().arg; + + if (fun.arity() != (reg_end - reg_start - 1)) { + return ERROR(ArgumentCountMismatch); + } + + uint64_t old_base = _base; + + Value fun_val = TRY(fun.copy()); + Value oldbase_val = TRY(Int64::create(old_base)); + Value pc_val = TRY(Int64::create(_pc)); + _callstack.set(_callstack.gettop(), fun_val); + _callstack.set(_callstack.gettop(), oldbase_val); + _callstack.set(_callstack.gettop(), pc_val); + + _code = TRY(fun.code()); + _constants = TRY(fun.constants()); + _fun = std::move(fun); + _pc = 0; + _base = reg_start; + + return Result(); +} + +Result VM::vm_ret(Opcode& oc) { + if (_callstack.gettop() == 0) { + _res = TRY(getreg((uint64_t)oc.arg1().arg)); + return ERROR(EndOfProgram); + } + + Value fun_val = TRY(_callstack.get(_callstack.gettop() - 3)); + Value oldbase_val = TRY(_callstack.get(_callstack.gettop() - 2)); + Value pc_val = TRY(_callstack.get(_callstack.gettop() - 1)); + + if (!oldbase_val.is() || !pc_val.is() || + !fun_val.is()) + return ERROR(TypeMismatch); + + Function& fun = *fun_val.to(); + uint64_t oldbase = oldbase_val.to()->value(); + uint64_t pc = pc_val.to()->value(); + + _code = TRY(fun.code()); + _constants = TRY(fun.constants()); + _fun = std::move(fun); + _pc = pc + 1; + _base = oldbase; + + TRY(_callstack.settop(_callstack.gettop() - 3)); + + return Result(); +} + Result VM::vm_jump(Opcode& oc) { - _pc += oc.arg1().arg - 1; + _pc += oc.arg1().arg; return Result(); } @@ -93,20 +159,21 @@ Result VM::step() { TRY(vm_div(oc)); break; case Oc::Ret: - _res = TRY(getreg((uint64_t)oc.arg1().arg)); - return ERROR(EndOfProgram); + TRY(vm_ret(oc)); + break; case Oc::JumpNotEqual: TRY(vm_jump_not_equal(oc)); break; case Oc::Jump: TRY(vm_jump(oc)); break; + case Oc::Call: + TRY(vm_call(oc)); + break; default: return ERROR(NotImplemented); } - _pc++; - return Result(); } diff --git a/src/vm.hpp b/src/vm.hpp index 368a833..4d35f6e 100644 --- a/src/vm.hpp +++ b/src/vm.hpp @@ -5,13 +5,16 @@ class VM { public: VM() {} - VM(Stack&& stack) : _stack(std::move(stack)) {} - VM(VM&& vm) : _stack(std::move(vm._stack)) {} + VM(Stack&& stack, Stack&& callstack) + : _stack(std::move(stack)), _callstack(std::move(callstack)) {} + VM(VM&& vm) + : _stack(std::move(vm._stack)), _callstack(std::move(vm._callstack)) {} VM(const VM&) = delete; static Result create() { auto stack = TRY(Stack::create(16 * 1024)); - return VM(std::move(stack)); + auto callstack = TRY(Stack::create(16 * 1024)); + return VM(std::move(stack), std::move(callstack)); } Result run(const Function& fun); @@ -22,6 +25,8 @@ class VM { Result vm_mul(Opcode& oc); Result vm_sub(Opcode& oc); Result vm_div(Opcode& oc); + Result vm_call(Opcode& oc); + Result vm_ret(Opcode& oc); Result vm_jump_not_equal(Opcode& oc); Result vm_jump(Opcode& oc); @@ -33,6 +38,7 @@ class VM { private: Stack _stack; + Stack _callstack; Function _fun; Array _code; Array _constants;