#!/usr/bin/env python3 # encoding=utf-8 # ============================================================================ # @brief riscv ROM Patch File # Copyright (c) HiSilicon (Shanghai) Technologies Co., Ltd. 2022-2023. All rights reserved. # ============================================================================ import struct import ctypes import sys import os import shutil import traceback dir_name = os.path.dirname(os.path.realpath(__file__)) info_item = [ "Device_Code_Version", "Patch_Cpu_Core", "Patch_File_Address", "Patch_TBL_Address", "Patch_TBL_Run_Address", "Table_Max_Size", "Table_Reg_Size", "ROM_Address", "ROM_Size", "CMP_Bin_File", "TBL_Bin_File", "RW_Bin_File", "TABLE_REG_CONUT", ] # The default value is 4. # The value will be set based on long jump or short jump for linx131. CMP_HEAD_LEN = 3 g_cmp_total_len = 131 # 128个比较表 + 3个头部信息 CMP_REG_SIZE = 4 PATCH_COUNT_REG_INDEX = 2 DATA_PATCH_COUNT = 2 pid = str(os.getpid()) # 目录转换 def get_dir(dir_in): return os.path.join(dir_name, dir_in) def remove_bin_file(): os.remove(get_dir(pid + "cmp.bin")) if os.path.exists(get_dir(pid + "cmp.bin")) else None os.remove(get_dir(pid + "tbl.bin")) if os.path.exists(get_dir(pid + "tbl.bin")) else None os.remove(get_dir(pid + "rw.bin")) if os.path.exists(get_dir(pid + "rw.bin")) else None # 转换成bin文件 def copy_bin_file(str_dst, str_src): try: with open(str_src, "rb")as file_src: try: with open(str_dst, "wb+")as file_dst: byte = file_src.read(1) while byte: file_dst.write(byte) byte = file_src.read(1) file_dst.close() except Exception as e: print("Error: %s Can't Open!" % file_dst) remove_bin_file() sys.exit(1) file_src.close() except Exception as e: print("Error: %s Can't Open!" % str_src) remove_bin_file() sys.exit(1) # 生成bin文件 def merge_output_file(files): try: reg_size = int(files['Table_Reg_Size']) max_size = int(files['Table_Max_Size']) with open(get_dir(files['RW_Bin_File']), "rb+") as file_rw: try: with open(get_dir(files['TBL_Bin_File']), "rb+")as file_table: data_num = int(files['TABLE_REG_CONUT']) * reg_size data_table = [] for num in range(data_num): data_table.append(0) buff = file_table.read(1) j = 0 while buff: data_table[j] = struct.unpack(' g_cmp_total_len: # 128个比较表 + 头部信息 print("Error: CMP Packet is larger than CMP Reg Capacitance") while cmp_count < g_cmp_total_len - 3: cmp_content.append(0) cmp_count += 1 return cmp_content def get_table_content_for_short_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in): table_content = [] bit0_to6 = 0x6F bit7_to11 = 0x5 << 7 func_num = len(func_addrs) index = 0 while index < func_num: func_addr = int(func_addrs[index], 16) func_patch_addr = int(func_patch_addrs[index], 16) off_addr = func_patch_addr - func_addr off_bit1_to10 = (off_addr & 0x7fe) >> 1 off_bit12_to19 = (off_addr & 0xff000) >> 12 off_bit11 = (off_addr & 0x800) >> 11 off_bit20 = (off_addr & 0x100000) >> 20 bit_code = bit0_to6 + bit7_to11 + (off_bit12_to19 << 12) + (off_bit11 << 20) + (off_bit1_to10 << 21) + (off_bit20 << 31) table_content.append(bit_code) index += 1 table_count = len(table_content) if table_count > int(files['TABLE_REG_CONUT']): # 128个比较表 print("Error: TABLE Packet is larger than CMP Reg Capacitance") sys.exit(1) while table_count < int(files['TABLE_REG_CONUT']): table_content.append(0) table_count += 1 return table_content def get_table_content_for_long_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in): table_content = [] auipc_opt_bits = 0x17 jalr_opt_bits = 0x67 base_addr_bits = 0x6 # x6 jalr_bit12_to14 = 0x0 << 12 jalr_bit7_to11 = 0x0 << 7 func_num = len(func_addrs) index = 0 while index < func_num: func_addr = int(func_addrs[index], 16) func_patch_addr = int(func_patch_addrs[index], 16) off_addr = func_patch_addr - func_addr off_bit12_to31 = off_addr & 0xfffff000 off_bit0_to11 = off_addr & 0xfff if off_bit0_to11 > 0x7FF: off_bit12_to31 = off_bit12_to31 + 0x1000 off_bit0_to11 = 0x1000 - off_bit0_to11 off_bit0_to11 = (~off_bit0_to11 + 1) & 0xfff auipc_bit_code = auipc_opt_bits + (base_addr_bits << 7) + off_bit12_to31 table_content.append(auipc_bit_code) jalr_bit_code = jalr_opt_bits + jalr_bit7_to11 + jalr_bit12_to14 + (base_addr_bits << 15) + (off_bit0_to11 << 20) table_content.append(jalr_bit_code) index += 1 table_count = len(table_content) if table_count > 2 * int(files['TABLE_REG_CONUT']): # 128个比较表 print("Error: TABLE Packet is larger than CMP Reg Capacitance") sys.exit(1) while table_count < 2 * int(files['TABLE_REG_CONUT']): table_content.append(0) table_content.append(0) table_count += 2 return table_content def create_patch(patch_info, nm_file_in, rom_bin_file): global g_cmp_total_len file_all = get_patch_info(patch_info) g_cmp_total_len = int(file_all['TABLE_REG_CONUT']) + 3 core = file_all['Patch_Cpu_Core'] funs = get_func_name(patch_info, 0) funs_patch = get_func_name(patch_info, 1) nm_contents = get_nm_content(nm_file_in) func_addrs = get_func_addr(funs, nm_contents, "GCC") func_patch_addrs = get_func_addr(funs_patch, nm_contents, "GCC") cmp_contents = get_cmp_content(func_addrs, file_all['Patch_TBL_Run_Address'], file_all['Device_Code_Version']) reg_size = int(file_all['Table_Reg_Size']) if reg_size == 4: table_contents = get_table_content_for_short_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file) elif reg_size == 8: table_contents = get_table_content_for_long_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file) else: print("ErrorCore %s for rom patch" % core) creat_bin_file(get_dir(file_all['CMP_Bin_File']), cmp_contents, file_all['CMP_Bin_File']) creat_bin_file(get_dir(file_all['TBL_Bin_File']), table_contents, file_all['TBL_Bin_File']) def output_bin_file(file_all, output_dir_in, ram_bin_file): if os.path.exists(os.path.join(output_dir_in, ram_bin_file)): shutil.move(os.path.join(output_dir_in, ram_bin_file), os.path.join(output_dir_in, "unpatch.bin")) if os.path.exists(os.path.join(output_dir_in, "rw.bin")): os.remove(os.path.join(output_dir_in, "rw.bin")) shutil.move(get_dir(file_all['RW_Bin_File']), os.path.join(output_dir_in, ram_bin_file)) os.remove(get_dir(file_all['CMP_Bin_File'])) os.remove(get_dir(file_all['TBL_Bin_File'])) print("Generating %s..." % ram_bin_file) def get_patch_addr(patch_info, ram_bin_file, output_dir_in): file_all = get_patch_info(patch_info) funs = get_func_name(patch_info, 0) funs_patch = get_func_name(patch_info, 1) copy_bin_file(get_dir(file_all['RW_Bin_File']), ram_bin_file) merge_output_file(file_all) output_bin_file(file_all, output_dir_in, ram_bin_file) if __name__ == "__main__": if(len(sys.argv) == 8): ram_bin_file = sys.argv[1] rom_bin_file = sys.argv[2] nm_file = sys.argv[3] partch_config_dir = sys.argv[4] core = sys.argv[5] target_name = sys.argv[6] output_dir = sys.argv[7] if os.path.exists(os.path.join(partch_config_dir, f'{target_name}.cfg')): patch_info = os.path.join(partch_config_dir, f'{target_name}.cfg') else: patch_info = os.path.join(partch_config_dir, f'{core}.cfg') create_patch(patch_info, nm_file, rom_bin_file) get_patch_addr(patch_info, ram_bin_file, output_dir) else: print( "Usage: %s " "" % os.path.basename(sys.argv[0])) sys.exit(1)