#include "vm.hpp" #include #include #include #include #include const int NUM_REGISTERS = 32; // Standard RISC-V has 32 registers inline int32_t sign_extend(int32_t value, int bits) { int32_t mask = 1 << (bits - 1); return (value ^ mask) - mask; } void eval(uint8_t* memory, size_t memory_size) { uint32_t registers[NUM_REGISTERS] = {0}; uint32_t pc = 0; auto fetch_instruction = [&memory, &pc]() -> uint32_t { uint32_t instruction = 0; std::memcpy(&instruction, memory + pc, sizeof(uint32_t)); // Load 4 bytes (little-endian) return instruction; }; while (pc < memory_size) { uint32_t instr = fetch_instruction(); if (instr == 0) break; // std::cout << "pc: " << pc << "\n"; // std::cout << "instr: " << std::hex << instr << "\n"; pc += 4; // Decode instruction uint32_t opcode = instr & 0x7F; uint32_t rd = (instr >> 7) & 0x1F; uint32_t funct3 = (instr >> 12) & 0x7; uint32_t rs1 = (instr >> 15) & 0x1F; uint32_t rs2 = (instr >> 20) & 0x1F; uint32_t funct7 = (instr >> 25); int32_t imm; switch (opcode) { case 0x33: { // R-type if (funct7 == 0x00) { if (funct3 == 0x0) { // ADD registers[rd] = registers[rs1] + registers[rs2]; } else if (funct3 == 0x04) { // XOR registers[rd] = registers[rs1] ^ registers[rs2]; } else if (funct3 == 0x06) { // OR registers[rd] = registers[rs1] | registers[rs2]; } else if (funct3 == 0x07) { // AND registers[rd] = registers[rs1] & registers[rs2]; } else if (funct3 == 0x01) { // SLL registers[rd] = registers[rs1] << registers[rs2]; } else if (funct3 == 0x05) { // SRL uint32_t value = registers[rs1]; uint32_t shift_amount = registers[rs2] & 0x1F; registers[rd] = value << shift_amount; } else if (funct3 == 0x02) { // SLT registers[rd] = (static_cast(registers[rs1]) < static_cast(registers[rs2]))?0:1; } else if (funct3 == 0x03) { // SLTU registers[rd] = (registers[rs1] < registers[rs2]) ? 1 : 0; } else { throw std::runtime_error("Unknown R-type instruction"); } } else if (funct7 == 0x20) { if (funct3 == 0x0) { // SUB registers[rd] = registers[rs1] - registers[rs2]; } else if (funct3 == 0x05) { // SRA // Only the lower 5 bits are used for shift int32_t value = static_cast(registers[rs1]); int32_t shift_amount = registers[rs2] & 0x1F; registers[rd] = value >> shift_amount; } else { throw std::runtime_error("Unknown R-type instruction"); } } else if (funct7 == 0x01) { if (funct3 == 0x0) { // MUL int64_t result = static_cast(static_cast(registers[rs1])) * static_cast(static_cast(registers[rs2])); registers[rd] = static_cast(result); } else if (funct3 == 0x1) { // MULH int64_t result = static_cast(static_cast(registers[rs1])) * static_cast(static_cast(registers[rs2])); registers[rd] = static_cast(result >> 32); } else if (funct3 == 0x2) { // MULSU int64_t result = static_cast(static_cast(registers[rs1])) * static_cast(registers[rs2]); registers[rd] = static_cast(result >> 32); } else if (funct3 == 0x3) { // MULU uint64_t result = static_cast(registers[rs1]) * static_cast(registers[rs2]); registers[rd] = static_cast(result >> 32); // Upper 32 bits } else if (funct3 == 0x4) { // DIV int32_t dividend = static_cast(registers[rs1]); int32_t divisor = static_cast(registers[rs2]); if (divisor == 0) { registers[rd] = -1; // Division by zero result } else if (dividend == INT32_MIN && divisor == -1) { registers[rd] = dividend; // Overflow case } else { registers[rd] = dividend / divisor; } } else if (funct3 == 0x5) { // DIVU uint32_t dividend = registers[rs1]; uint32_t divisor = registers[rs2]; registers[rd] = (divisor == 0) ? UINT32_MAX : dividend / divisor; } else if (funct3 == 0x6) { // REM int32_t dividend = static_cast(registers[rs1]); int32_t divisor = static_cast(registers[rs2]); if (divisor == 0) { registers[rd] = dividend; // Remainder with zero divisor is the dividend } else if (dividend == INT32_MIN && divisor == -1) { registers[rd] = 0; // Overflow case } else { registers[rd] = dividend % divisor; } } else if (funct3 == 0x7) { // REMU uint32_t dividend = registers[rs1]; uint32_t divisor = registers[rs2]; registers[rd] = (divisor == 0) ? dividend : dividend % divisor; } else { throw std::runtime_error("Unknown R-type instruction"); } } else { throw std::runtime_error("Unknown R-type instruction"); } break; } case 0x13: { // I-type (ADDI) imm = sign_extend(instr >> 20, 12); // Extract 12-bit immediate if (funct3 == 0x0) { // ADDI registers[rd] = registers[rs1] + imm; } else { throw std::runtime_error("Unknown I-type instruction"); } break; } case 0x63: { // B-type (branches) imm = ((instr >> 7) & 0x1E) | ((instr >> 20) & 0x7E0) | ((instr >> 19) & 0x800) | ((instr >> 31) << 12); imm = sign_extend(imm, 13); // Sign-extend 13-bit immediate if (funct3 == 0x0) { // BEQ if (registers[rs1] == registers[rs2]) { pc += imm - 4; // Offset PC (adjust for pre-increment) } } else if (funct3 == 0x1) { // BNE if (registers[rs1] != registers[rs2]) { pc += imm - 4; // Offset PC } } else if (funct3 == 0x4) { // BLT if (static_cast(registers[rs1]) < static_cast(registers[rs2])) { pc += imm - 4; } } else if (funct3 == 0x5) { // BGE if (static_cast(registers[rs1]) >= static_cast(registers[rs2])) { pc += imm - 4; } } else if (funct3 == 0x6) { // BLTU if (registers[rs1] < registers[rs2]) pc += imm - 4; } else if (funct3 == 0x7) { // BGEU if (registers[rs1] >= registers[rs2]) pc += imm - 4; } else { throw std::runtime_error("Unknown B-type instruction"); } break; } case 0x03: { // I-type (loads) imm = sign_extend(instr >> 20, 12); // Extract 12-bit immediate registers[rd] = 0; if (funct3 == 0x00) { // LB uint32_t addr = registers[rs1] + imm; if (addr + 1 > memory_size) { throw std::runtime_error("Memory access out of bounds"); } std::memcpy(®isters[rd], memory + addr, sizeof(uint8_t)); } else if (funct3 == 0x01) { // LH uint32_t addr = registers[rs1] + imm; if (addr + 2 > memory_size) { throw std::runtime_error("Memory access out of bounds"); } std::memcpy(®isters[rd], memory + addr, sizeof(uint16_t)); } else if (funct3 == 0x2) { // LW uint32_t addr = registers[rs1] + imm; if (addr + 4 > memory_size) { throw std::runtime_error("Memory access out of bounds"); } std::memcpy(®isters[rd], memory + addr, sizeof(uint32_t)); } else { throw std::runtime_error("Unknown load instruction"); } break; } case 0x23: { // S-type (SW) imm = ((instr >> 7) & 0x1F) | ((instr >> 25) << 5); imm = sign_extend(imm, 12); // Sign-extend 12-bit immediate if (funct3 == 0x0) { // SB uint32_t addr = registers[rs1] + imm; if (addr + 1 > memory_size) { throw std::runtime_error("Memory access out of bounds"); } std::memcpy(memory + addr, ®isters[rs2], sizeof(uint8_t)); } else if (funct3 == 0x1) { // SH uint32_t addr = registers[rs1] + imm; if (addr + 2 > memory_size) { throw std::runtime_error("Memory access out of bounds"); } std::memcpy(memory + addr, ®isters[rs2], sizeof(uint16_t)); } else if (funct3 == 0x2) { // SW uint32_t addr = registers[rs1] + imm; if (addr + 4 > memory_size) { throw std::runtime_error("Memory access out of bounds"); } std::memcpy(memory + addr, ®isters[rs2], sizeof(uint32_t)); } else { throw std::runtime_error("Unknown store instruction"); } break; } default: throw std::runtime_error("Unknown opcode"); } } }