LPT26x-HSF-4MB-Hilink_14.2..../build/script/patch/patch_riscv.py

445 lines
17 KiB
Python
Raw Permalink Normal View History

2025-05-13 22:00:58 +08:00
#!/usr/bin/env python3
# encoding=utf-8
# ============================================================================
# @brief riscv ROM Patch File
# Copyright (c) HiSilicon (Shanghai) Technologies Co., Ltd. 2022-2023. All rights reserved.
# ============================================================================
import struct
import ctypes
import sys
import os
import shutil
import traceback
dir_name = os.path.dirname(os.path.realpath(__file__))
info_item = [
"Device_Code_Version",
"Patch_Cpu_Core",
"Patch_File_Address",
"Patch_TBL_Address",
"Patch_TBL_Run_Address",
"Table_Max_Size",
"Table_Reg_Size",
"ROM_Address",
"ROM_Size",
"CMP_Bin_File",
"TBL_Bin_File",
"RW_Bin_File",
"TABLE_REG_CONUT",
]
# The default value is 4.
# The value will be set based on long jump or short jump for linx131.
CMP_HEAD_LEN = 3
g_cmp_total_len = 131 # 128个比较表 + 3个头部信息
CMP_REG_SIZE = 4
PATCH_COUNT_REG_INDEX = 2
DATA_PATCH_COUNT = 2
pid = str(os.getpid())
# 目录转换
def get_dir(dir_in):
return os.path.join(dir_name, dir_in)
def remove_bin_file():
os.remove(get_dir(pid + "cmp.bin")) if os.path.exists(get_dir(pid + "cmp.bin")) else None
os.remove(get_dir(pid + "tbl.bin")) if os.path.exists(get_dir(pid + "tbl.bin")) else None
os.remove(get_dir(pid + "rw.bin")) if os.path.exists(get_dir(pid + "rw.bin")) else None
# 转换成bin文件
def copy_bin_file(str_dst, str_src):
try:
with open(str_src, "rb")as file_src:
try:
with open(str_dst, "wb+")as file_dst:
byte = file_src.read(1)
while byte:
file_dst.write(byte)
byte = file_src.read(1)
file_dst.close()
except Exception as e:
print("Error: %s Can't Open!" % file_dst)
remove_bin_file()
sys.exit(1)
file_src.close()
except Exception as e:
print("Error: %s Can't Open!" % str_src)
remove_bin_file()
sys.exit(1)
# 生成bin文件
def merge_output_file(files):
try:
reg_size = int(files['Table_Reg_Size'])
max_size = int(files['Table_Max_Size'])
with open(get_dir(files['RW_Bin_File']), "rb+") as file_rw:
try:
with open(get_dir(files['TBL_Bin_File']), "rb+")as file_table:
data_num = int(files['TABLE_REG_CONUT']) * reg_size
data_table = []
for num in range(data_num):
data_table.append(0)
buff = file_table.read(1)
j = 0
while buff:
data_table[j] = struct.unpack('<B', buff)[0]
buff = file_table.read(1)
j += 1
file_table.close()
offset_addr = int(files['Patch_TBL_Address'], 16) - int(files['Patch_File_Address'], 16) +\
DATA_PATCH_COUNT * reg_size
file_rw.seek(offset_addr, 0)
i = 0
while i < len(data_table):
byte1 = struct.pack('B', data_table[i])
file_rw.write(byte1)
i += 1
except Exception as e:
print("Error: %s Can't Open!" % file_table)
remove_bin_file()
sys.exit(1)
try:
with open(get_dir(files['CMP_Bin_File']), "rb+")as file_cmp:
data_num = g_cmp_total_len * CMP_REG_SIZE
data_cmp = []
for num1 in range(data_num):
data_cmp.append(0)
l = 0
buff1 = file_cmp.read(1)
while buff1:
data_cmp[l] = struct.unpack('<B', buff1)[0]
buff1 = file_cmp.read(1)
l += 1
file_cmp.close()
offset_addr = int(files['Patch_TBL_Address'], 16) - int(files['Patch_File_Address'], 16) +\
int(files['TABLE_REG_CONUT']) * max_size
file_rw.seek(offset_addr, 0)
i = 0
while i < len(data_cmp):
byte2 = struct.pack('B', data_cmp[i])
file_rw.write(byte2)
i += 1
file_rw.close()
except Exception as e:
print("Error: %s Can't Open!" % file_cmp)
remove_bin_file()
sys.exit(1)
except Exception as e:
print("Error: %s Can't Open!" % file_rw)
remove_bin_file()
sys.exit(1)
# Table or Cmp
def creat_bin_file(file_name, contents, type_in):
cmp_len = len(contents)
try:
with open(file_name, "wb")as file_o:
for i in range(cmp_len):
byte = struct.pack('I', contents[i])
file_o.write(byte)
file_o.close()
except Exception as e:
print("Error: %s Can't Open!" % file_name)
remove_bin_file()
sys.exit(1)
# 获取函数对应的二进制文件
def get_func_code_in_file(files, func_addr, bt_rom_file_in):
rom_pack = ['', 0xffffffff]
temp_rom_begin = int(files['ROM_Address'], 16)
temp_rom_size = int(files['ROM_Size'], 16)
if (temp_rom_begin <= func_addr) and (
func_addr < (temp_rom_begin + temp_rom_size)):
rom_pack[0] = temp_rom_begin
rom_pack[1] = bt_rom_file_in
return rom_pack
# 获取函数对应的二进制
def get_func_code_in_bin(rom_addr, file_name, func_addr, align):
bin_content = [0, 0, 0, 0]
func_addr_offset = func_addr - rom_addr - align
try:
with open(file_name, "rb") as file_o:
file_o.seek(func_addr_offset, 0)
for i in range(4):
readonce = file_o.read(4)
bin_content[i] = struct.unpack('<I', readonce)[0]
file_o.close()
except Exception as e:
print("Error: %s Can't Open!" % file_name)
remove_bin_file()
sys.exit(1)
return bin_content
# 获取patch相关信息
def get_patch_info(patchinfo):
all_name = {}
try:
with open(str(patchinfo), "r", encoding="utf-8") as file_o:
source_in_lines = file_o.readlines()
for line in source_in_lines:
if '[Function]' in line:
break
if '=' in line:
config_key= line.split('=')[0].strip()
config_value= line.split('=')[1].strip()
if config_key in info_item:
all_name[config_key] = config_value
# 在bin文件名上加上进程id前缀防止多进程时出错
all_name["CMP_Bin_File"] = pid + all_name["CMP_Bin_File"]
all_name["TBL_Bin_File"] = pid + all_name["TBL_Bin_File"]
all_name["RW_Bin_File"] = pid + all_name["RW_Bin_File"]
return all_name
except Exception:
print("Error: %s Can't Open!" % patchinfo)
remove_bin_file()
sys.exit(1)
# 获取patch函数名
def get_func_name(func_file_name, index):
func_names = []
try:
with open(func_file_name, "r", encoding='utf-8') as file_o:
is_func = False
for line in file_o:
line = line.strip() # 去掉每行头尾空白
if '[Function]' in line:
is_func = True
continue
if not len(line) or line.startswith('#') or not is_func:
continue
temp = line.split()
if len(temp) != 2:
print(
"Error format file_o:%s,line:%s," %
(func_file_name, temp))
func_names.append(temp[index])
except UnicodeDecodeError as e:
print("get_func_name catch UnicodeDecodeError: %s" % func_file_name)
remove_bin_file()
sys.exit(1)
except Exception as e:
print(traceback.format_exc())
print("Error: %s Can't Open!" % func_file_name)
remove_bin_file()
sys.exit(1)
return func_names
# gcc编译器nm文件中所有行有效
def get_nm_content(file_name):
nm_ontent = []
try:
with open(file_name, "r", encoding='utf-8') as file_o:
nm_lines = file_o.readlines()
for nm_line in nm_lines:
if nm_line.find(" T ") != -1 or nm_line.find(" t ") != -1 or nm_line.find(" A "):
nm_line = nm_line.strip('\n')
nm_ontent.append(nm_line)
except Exception as e:
print("Error: %s Can't Open!" % file_name)
remove_bin_file()
sys.exit(1)
return nm_ontent
def get_func_addr(func_names, m_contents, compiler_name):
index = 0
func_addrs = []
while index < len(func_names):
index_temp = 0
for m_content in m_contents:
temp = m_content.split('|')
if compiler_name == "GCC" and func_names[index] == temp[0].strip():
func_addrs.append(temp[1].strip())
if int(temp[1], 16) < 4:
print(
"Error: %s Length is %s , it can not be less than 4 "
"bytes!" % (temp[3], temp[1]))
sys.exit(1)
elif int(temp[1], 16) == 4:
print(
"Warning: %s Length is %s bytes, it may be unsafe!" %
(temp[3], temp[1]))
break
index_temp += 1
if index_temp == len(m_contents):
print(
"Error: %s Function Can't Find in Map File !" %
func_names[index])
sys.exit(1)
index += 1
return func_addrs
# Ctrl寄存器 + Remap寄存器 + CMP数量 + CMP表
def get_cmp_content(func_addrs, patch_tbl_addr, version):
cmp_content = []
if version == "Version1":
cmp_content.append(0)
cmp_content.append(int(patch_tbl_addr, 16))
cmp_content.append(0)
index = 0
while index < len(func_addrs):
func_addr = int(func_addrs[index], 16)
func_addr = func_addr & (~0x01) # 地址最后一位为指令标记位,清除
if version == "Version1":
func_addr |= 0x1
cmp_content.append(func_addr)
index += 1
cmp_count = len(cmp_content) - CMP_HEAD_LEN
if version == "Version1" or version == "Version2":
if version == "Version1":
cmp_content[PATCH_COUNT_REG_INDEX] = cmp_count
if len(cmp_content) > g_cmp_total_len: # 128个比较表 + 头部信息
print("Error: CMP Packet is larger than CMP Reg Capacitance")
while cmp_count < g_cmp_total_len - 3:
cmp_content.append(0)
cmp_count += 1
return cmp_content
def get_table_content_for_short_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in):
table_content = []
bit0_to6 = 0x6F
bit7_to11 = 0x5 << 7
func_num = len(func_addrs)
index = 0
while index < func_num:
func_addr = int(func_addrs[index], 16)
func_patch_addr = int(func_patch_addrs[index], 16)
off_addr = func_patch_addr - func_addr
off_bit1_to10 = (off_addr & 0x7fe) >> 1
off_bit12_to19 = (off_addr & 0xff000) >> 12
off_bit11 = (off_addr & 0x800) >> 11
off_bit20 = (off_addr & 0x100000) >> 20
bit_code = bit0_to6 + bit7_to11 + (off_bit12_to19 << 12) + (off_bit11 << 20) + (off_bit1_to10 << 21) + (off_bit20 << 31)
table_content.append(bit_code)
index += 1
table_count = len(table_content)
if table_count > int(files['TABLE_REG_CONUT']): # 128个比较表
print("Error: TABLE Packet is larger than CMP Reg Capacitance")
sys.exit(1)
while table_count < int(files['TABLE_REG_CONUT']):
table_content.append(0)
table_count += 1
return table_content
def get_table_content_for_long_jump(files, func_addrs, func_patch_addrs, version, bt_rom_file_in):
table_content = []
auipc_opt_bits = 0x17
jalr_opt_bits = 0x67
base_addr_bits = 0x6 # x6
jalr_bit12_to14 = 0x0 << 12
jalr_bit7_to11 = 0x0 << 7
func_num = len(func_addrs)
index = 0
while index < func_num:
func_addr = int(func_addrs[index], 16)
func_patch_addr = int(func_patch_addrs[index], 16)
off_addr = func_patch_addr - func_addr
off_bit12_to31 = off_addr & 0xfffff000
off_bit0_to11 = off_addr & 0xfff
if off_bit0_to11 > 0x7FF:
off_bit12_to31 = off_bit12_to31 + 0x1000
off_bit0_to11 = 0x1000 - off_bit0_to11
off_bit0_to11 = (~off_bit0_to11 + 1) & 0xfff
auipc_bit_code = auipc_opt_bits + (base_addr_bits << 7) + off_bit12_to31
table_content.append(auipc_bit_code)
jalr_bit_code = jalr_opt_bits + jalr_bit7_to11 + jalr_bit12_to14 + (base_addr_bits << 15) + (off_bit0_to11 << 20)
table_content.append(jalr_bit_code)
index += 1
table_count = len(table_content)
if table_count > 2 * int(files['TABLE_REG_CONUT']): # 128个比较表
print("Error: TABLE Packet is larger than CMP Reg Capacitance")
sys.exit(1)
while table_count < 2 * int(files['TABLE_REG_CONUT']):
table_content.append(0)
table_content.append(0)
table_count += 2
return table_content
def create_patch(patch_info, nm_file_in, rom_bin_file):
global g_cmp_total_len
file_all = get_patch_info(patch_info)
g_cmp_total_len = int(file_all['TABLE_REG_CONUT']) + 3
core = file_all['Patch_Cpu_Core']
funs = get_func_name(patch_info, 0)
funs_patch = get_func_name(patch_info, 1)
nm_contents = get_nm_content(nm_file_in)
func_addrs = get_func_addr(funs, nm_contents, "GCC")
func_patch_addrs = get_func_addr(funs_patch, nm_contents, "GCC")
cmp_contents = get_cmp_content(func_addrs, file_all['Patch_TBL_Run_Address'], file_all['Device_Code_Version'])
reg_size = int(file_all['Table_Reg_Size'])
if reg_size == 4:
table_contents = get_table_content_for_short_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file)
elif reg_size == 8:
table_contents = get_table_content_for_long_jump(file_all, func_addrs, func_patch_addrs, file_all['Device_Code_Version'], rom_bin_file)
else:
print("ErrorCore %s for rom patch" % core)
creat_bin_file(get_dir(file_all['CMP_Bin_File']), cmp_contents, file_all['CMP_Bin_File'])
creat_bin_file(get_dir(file_all['TBL_Bin_File']), table_contents, file_all['TBL_Bin_File'])
def output_bin_file(file_all, output_dir_in, ram_bin_file):
if os.path.exists(os.path.join(output_dir_in, ram_bin_file)):
shutil.move(os.path.join(output_dir_in, ram_bin_file), os.path.join(output_dir_in, "unpatch.bin"))
if os.path.exists(os.path.join(output_dir_in, "rw.bin")):
os.remove(os.path.join(output_dir_in, "rw.bin"))
shutil.move(get_dir(file_all['RW_Bin_File']), os.path.join(output_dir_in, ram_bin_file))
os.remove(get_dir(file_all['CMP_Bin_File']))
os.remove(get_dir(file_all['TBL_Bin_File']))
print("Generating %s..." % ram_bin_file)
def get_patch_addr(patch_info, ram_bin_file, output_dir_in):
file_all = get_patch_info(patch_info)
funs = get_func_name(patch_info, 0)
funs_patch = get_func_name(patch_info, 1)
copy_bin_file(get_dir(file_all['RW_Bin_File']), ram_bin_file)
merge_output_file(file_all)
output_bin_file(file_all, output_dir_in, ram_bin_file)
if __name__ == "__main__":
if(len(sys.argv) == 8):
ram_bin_file = sys.argv[1]
rom_bin_file = sys.argv[2]
nm_file = sys.argv[3]
partch_config_dir = sys.argv[4]
core = sys.argv[5]
target_name = sys.argv[6]
output_dir = sys.argv[7]
if os.path.exists(os.path.join(partch_config_dir, f'{target_name}.cfg')):
patch_info = os.path.join(partch_config_dir, f'{target_name}.cfg')
else:
patch_info = os.path.join(partch_config_dir, f'{core}.cfg')
create_patch(patch_info, nm_file, rom_bin_file)
get_patch_addr(patch_info, ram_bin_file, output_dir)
else:
print(
"Usage: %s <xx.bin> <xx_rom.bin> <xx.nm> <patch_confi_dir> <core> <target_name>"
"<output_dir>" % os.path.basename(sys.argv[0]))
sys.exit(1)