diff --git a/solc_json_parser/ast_shared.py b/solc_json_parser/ast_shared.py index ed096ab..8f46fb6 100644 --- a/solc_json_parser/ast_shared.py +++ b/solc_json_parser/ast_shared.py @@ -17,17 +17,87 @@ # Installable version supported on linux-amd64 INSTALLABLE_VERSION = [ - "0.4.10", "0.4.11", "0.4.12", "0.4.13", "0.4.14", "0.4.15", "0.4.16", "0.4.17", "0.4.18", "0.4.19", "0.4.20", "0.4.21", "0.4.22", "0.4.23", "0.4.24", "0.4.25", "0.4.26", - "0.5.0", "0.5.1", "0.5.2", "0.5.3", "0.5.4", "0.5.5", "0.5.6", "0.5.7", "0.5.8", "0.5.9", "0.5.10", "0.5.11", "0.5.12", "0.5.13", "0.5.14", "0.5.15", "0.5.16", "0.5.17", - "0.6.0", "0.6.1", "0.6.2", "0.6.3", "0.6.4", "0.6.5", "0.6.6", "0.6.7", "0.6.8", "0.6.9", "0.6.10", "0.6.11", "0.6.12", - "0.7.0", "0.7.1", "0.7.2", "0.7.3", "0.7.4", "0.7.5", "0.7.6", - "0.8.0", "0.8.1", "0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10", "0.8.11", "0.8.12", "0.8.13", "0.8.14", "0.8.15", "0.8.16", "0.8.17", "0.8.18", "0.8.19" + "0.4.10", + "0.4.11", + "0.4.12", + "0.4.13", + "0.4.14", + "0.4.15", + "0.4.16", + "0.4.17", + "0.4.18", + "0.4.19", + "0.4.20", + "0.4.21", + "0.4.22", + "0.4.23", + "0.4.24", + "0.4.25", + "0.4.26", + "0.5.0", + "0.5.1", + "0.5.2", + "0.5.3", + "0.5.4", + "0.5.5", + "0.5.6", + "0.5.7", + "0.5.8", + "0.5.9", + "0.5.10", + "0.5.11", + "0.5.12", + "0.5.13", + "0.5.14", + "0.5.15", + "0.5.16", + "0.5.17", + "0.6.0", + "0.6.1", + "0.6.2", + "0.6.3", + "0.6.4", + "0.6.5", + "0.6.6", + "0.6.7", + "0.6.8", + "0.6.9", + "0.6.10", + "0.6.11", + "0.6.12", + "0.7.0", + "0.7.1", + "0.7.2", + "0.7.3", + "0.7.4", + "0.7.5", + "0.7.6", + "0.8.0", + "0.8.1", + "0.8.2", + "0.8.3", + "0.8.4", + "0.8.5", + "0.8.6", + "0.8.7", + "0.8.8", + "0.8.9", + "0.8.10", + "0.8.11", + "0.8.12", + "0.8.13", + "0.8.14", + "0.8.15", + "0.8.16", + "0.8.17", + "0.8.18", + "0.8.19", ] INSTALLABLE_VERSION = sorted([Version(v) for v in INSTALLABLE_VERSION]) -INTERFACE_OR_LIB_KIND = set(['interface', 'library']) +INTERFACE_OR_LIB_KIND = set(["interface", "library"]) DEPLOY_START_OPCODES = [ # For solidity 0.4.23 and above @@ -48,21 +118,23 @@ ], ] + def keccak256(s: str) -> str: k = keccak.new(digest_bits=256) k.update(s.encode()) return k.hexdigest() + def get_by_index(lst: Union[List, Tuple], idx: int): - '''Get by index from a list, returns None if the index is out of range ''' + """Get by index from a list, returns None if the index is out of range""" if len(lst) > idx: return lst[idx] return None def get_in(d, key: Any, *nkeys) -> Any: - '''Get in nested datastructure by keys. Only dictionary, tuple and - list are supported''' + """Get in nested datastructure by keys. Only dictionary, tuple and + list are supported""" try: nd = d.get(key) except Exception: @@ -74,6 +146,7 @@ def get_in(d, key: Any, *nkeys) -> Any: return get_in(nd, *nkeys) return nd + def assoc_in(d, keys, value): """Associates a value with a sequence of keys in a nested dictionary""" key = keys[0] @@ -87,36 +160,40 @@ def assoc_in(d, keys, value): def get_all_installable_versions(): - ''' + """ Returns a cached list of solc versions available for install, version list is sorted in ascending order - ''' + """ return INSTALLABLE_VERSION + def version_str_from_line(line) -> Optional[str]: - ''' + """ Extract solc version string from input line - ''' - if line.strip().startswith('pragma') and 'solidity' in line: - ver = line.strip().split(maxsplit=2)[-1].split(';', maxsplit=1)[0] - if 'solidity' in ver: - ver = ver.split('solidity', maxsplit=1)[-1] - ver = re.sub(r'([\^>=<~]+)\s+', r'\1', ver) - return re.sub(r'(\.0+)', '.0', ver) + """ + if line.strip().startswith("pragma") and "solidity" in line: + ver = line.strip().split(maxsplit=2)[-1].split(";", maxsplit=1)[0] + if "solidity" in ver: + ver = ver.split("solidity", maxsplit=1)[-1] + ver = re.sub(r"([\^>=<~]+)\s+", r"\1", ver) + return re.sub(r"(\.0+)", ".0", ver) return None def version_str_from_source(source_or_source_file: str) -> Optional[str]: - inputs = source_or_source_file.split('\n') if '\n' in source_or_source_file else open(source_or_source_file, 'r') + inputs = source_or_source_file.split("\n") if "\n" in source_or_source_file else open(source_or_source_file, "r") # Get version part from `pragma solidity ***;` lines - versions = [version_str_from_line(line) for line in inputs if line.strip().startswith('pragma') and 'solidity' in line] + versions = [ + version_str_from_line(line) for line in inputs if line.strip().startswith("pragma") and "solidity" in line + ] if not versions: - logging.warning('No pragma directive found in source code') + logging.warning("No pragma directive found in source code") return None - return ' '.join(set(versions)) + return " ".join(set(versions)) + def get_solc_candidates(source_or_source_file: str) -> List[str]: merged_version = version_str_from_source(source_or_source_file) @@ -127,29 +204,30 @@ def get_solc_candidates(source_or_source_file: str) -> List[str]: spec = semantic_version.NpmSpec(merged_version) return [str(v) for v in spec.filter(get_all_installable_versions())] + def detect_solc_version(source_or_source_file: str) -> Optional[str]: - ''' + """ Detect solc version from a flatten source. Input can be a single file or source code string - ''' + """ versions = get_solc_candidates(source_or_source_file) return versions[-1] if versions else None def symbols_to_ids_from_ast_v8(ast: dict) -> Dict[str, int]: - syms = [c['ast']['exportedSymbols'] for c in ast.values()] + syms = [c["ast"]["exportedSymbols"] for c in ast.values()] return {k: v[0] for m in syms for k, v in m.items()} def symbols_to_ids_from_ast_v7(ast: Dict[Any, Any]) -> Dict[str, int]: - syms = [c['ast']['attributes']['exportedSymbols'] for c in ast.values()] + syms = [c["ast"]["attributes"]["exportedSymbols"] for c in ast.values()] return {k: v[0] for m in syms for k, v in m.items()} def find_next_version_in_candidates(current_version: str, solc_candidates: List[str]) -> Tuple[str, List[str]]: """Try to get the next version""" ver = Version(current_version) - try_next_version = Version(major=ver.major, minor= ver.minor - 1, patch=0) - print(f'try_next_version: {try_next_version} solc_candidates: {solc_candidates}') + try_next_version = Version(major=ver.major, minor=ver.minor - 1, patch=0) + print(f"try_next_version: {try_next_version} solc_candidates: {solc_candidates}") version = None # print(f'try_next_version: {try_next_version} solc_candidates: {solc_candidates}') if str(try_next_version) in solc_candidates: @@ -160,12 +238,13 @@ def find_next_version_in_candidates(current_version: str, solc_candidates: List[ version = str(solc_candidates[-1]) solc_candidates = solc_candidates[:-1] if not version: - raise ValueError(f'No next solc version available for {current_version}') + raise ValueError(f"No next solc version available for {current_version}") return version, solc_candidates + def skip_deploys(opcodes, deploy_sig_idx=0): if deploy_sig_idx >= len(DEPLOY_START_OPCODES): - raise SolidityAstError(f'Code deploy sequence not found in opcodes: {opcodes}') + raise SolidityAstError(f"Code deploy sequence not found in opcodes: {opcodes}") offset = 1 match_idx = 0 deploy_start_sequence = DEPLOY_START_OPCODES[deploy_sig_idx] @@ -180,8 +259,8 @@ def skip_deploys(opcodes, deploy_sig_idx=0): offset += 1 if offset < len(opcodes): - return opcodes[offset - len(deploy_start_sequence) + 1:] - return skip_deploys(opcodes, deploy_sig_idx+1) + return opcodes[offset - len(deploy_start_sequence) + 1 :] + return skip_deploys(opcodes, deploy_sig_idx + 1) def parse_src_mapping(srcmap: str): @@ -189,23 +268,23 @@ def _reduce_fn(accumulator, current_value): last, *tlist = accumulator return [ { - 's': int(current_value['s'] or last['s']), - 'l': int(current_value['l'] or last['l']), - 'f': int(current_value['f'] or last['f']), + "s": int(current_value["s"] or last["s"]), + "l": int(current_value["l"] or last["l"]), + "f": int(current_value["f"] or last["f"]), }, last, - *tlist + *tlist, ] parsed = srcmap.split(";") - parsed = [l.split(':') for l in parsed] + parsed = [l.split(":") for l in parsed] t = [] for l in parsed: if len(l) >= 3: t.append(l[:3]) else: t.append(l + [None] * (3 - len(l))) - parsed = [{'s': s if s != "" else None, 'l': l, 'f': f} for s, l, f in t] + parsed = [{"s": s if s != "" else None, "l": l, "f": f} for s, l, f in t] parsed = reduce(_reduce_fn, parsed, [{}]) parsed = list(reversed(parsed[:-1])) return parsed @@ -215,41 +294,41 @@ def process_literal_node(literals_nodes, only_value): def _process_other_literal_node(literal_node, literals, only_value): try: if only_value: - literals['other'].add(literal_node.str_value) + literals["other"].add(literal_node.str_value) else: - literals['other'].add(literal_node) + literals["other"].add(literal_node) except AttributeError: pass literals = dict(number=set(), string=set(), address=set(), other=set()) for literal in literals_nodes: try: - if literal.sub_type is None and literal.token_type == 'number': + if literal.sub_type is None and literal.token_type == "number": if only_value and literal.str_value.isdecimal(): - literals['number'].add(int(literal.str_value)) + literals["number"].add(int(literal.str_value)) else: - literals['number'].add(literal) + literals["number"].add(literal) elif literal.sub_type.startswith("address"): if only_value: - literals['address'].add(literal.str_value) + literals["address"].add(literal.str_value) else: - literals['address'].add(literal) + literals["address"].add(literal) elif literal.sub_type.startswith("int"): if only_value: - if literal.str_value.startswith('0x'): - literals['number'].add(int(literal.str_value, 16)) + if literal.str_value.startswith("0x"): + literals["number"].add(int(literal.str_value, 16)) elif literal.sub_type.split()[1].isdecimal(): - literals['number'].add(int(literal.sub_type.split()[1])) + literals["number"].add(int(literal.sub_type.split()[1])) else: - literals['number'].add(int(literal.str_value)) + literals["number"].add(int(literal.str_value)) else: - literals['number'].add(literal) + literals["number"].add(literal) # check if string in token_type, ignore case elif literal.sub_type.startswith("literal_string"): if only_value: - literals['string'].add(literal.str_value) + literals["string"].add(literal.str_value) else: - literals['string'].add(literal) + literals["string"].add(literal) elif literal.sub_type.startswith("bool"): continue else: @@ -260,29 +339,34 @@ def _process_other_literal_node(literal_node, literals, only_value): return literals -def record_jumps(opcode: str, code: list[Dict[str, Any]], idx: int, pc: int, seen_targets: set[int], pc2opcode: Dict[int, str]) -> set[int]: +def record_jumps( + opcode: str, code: list[Dict[str, Any]], idx: int, pc: int, seen_targets: set[int], pc2opcode: Dict[int, str] +) -> set[int]: pc2opcode[pc] = opcode - if opcode == 'JUMPI': - seen_targets.add(int(code[idx-1].get('value'))) + if opcode == "JUMPI": + seen_targets.add(int(code[idx - 1].get("value"))) seen_targets.add(int(pc + 1)) return seen_targets + def solc_bin(ver: str): - ''' + """ Get solc bin full path by version. By default it checks the solcx installion path. You can also override this function to use solc from https://github.com/ethereum/solc-bin/tree/gh-pages/linux-amd64 - ''' - return os.path.expanduser(f'~/.solcx/solc-v{ver}') + """ + return os.path.expanduser(f"~/.solcx/solc-v{ver}") + + +version_pattern = r"v(\d+\.\d+\.\d+)" -version_pattern = r'v(\d+\.\d+\.\d+)' def simplify_version(s): - ''' + """ Convert a version with sha to a simple version Example: v0.8.13+commit.abaa5c0e -> 0.8.13 - ''' - match = re.search(version_pattern, s or '') + """ + match = re.search(version_pattern, s or "") if match: extracted_version = match.group(1) return extracted_version diff --git a/solc_json_parser/standard_json_parser.py b/solc_json_parser/standard_json_parser.py index 3709993..8b4dea4 100644 --- a/solc_json_parser/standard_json_parser.py +++ b/solc_json_parser/standard_json_parser.py @@ -1,6 +1,7 @@ import subprocess import json import os +import re from typing import Tuple, Callable, List, Union, Optional, Dict from functools import cached_property, cache @@ -12,56 +13,130 @@ from .fields import Function import sys + +def _preprocess_etherscan_json(etherscan_json_path: str) -> dict: + """ + Preprocess etherscan JSON file and return standard JSON dict. + + Args: + etherscan_json_path: Path to etherscan JSON file + + Returns: + dict: Standard JSON input dict with optimizer settings + """ + with open(etherscan_json_path, "r", encoding="utf-8") as f: + etherscan_data = json.load(f) + + source_code = etherscan_data["SourceCode"] + OUTPUT_SELECT_ALL = {"*": {"*": ["*"], "": ["ast"]}} + optimizer_enabled = etherscan_data.get("OptimizationUsed", "0") == "1" + optimize_runs = int(etherscan_data.get("Runs", 200)) + + # Handle different etherscan formats + try: + # Try parsing as standard JSON (with double braces {{ }}) + if source_code.startswith("{{") and source_code.endswith("}}"): + dict_source_code = json.loads(source_code[1:-1]) + else: + # Try parsing as standard JSON (with single braces { }) + dict_source_code = json.loads(source_code) + + # If it's a multi-file project with standard JSON format + if isinstance(dict_source_code, dict) and "sources" in dict_source_code: + # Create standard JSON input for multi-file project + solc_input = dict_source_code.copy() + solc_input["settings"] = {"outputSelection": OUTPUT_SELECT_ALL} + solc_input["settings"]["optimizer"] = {"enabled": optimizer_enabled, "runs": optimize_runs} + return solc_input + else: + # Single source code parsed as JSON, but not in standard format + # Wrap it in standard JSON structure + contract_name = etherscan_data.get("ContractName", "Contract") + filename = f"{contract_name}.sol" + solc_input = { + "language": "Solidity", + "sources": {filename: {"content": source_code}}, + "settings": { + "outputSelection": OUTPUT_SELECT_ALL, + "optimizer": {"enabled": optimizer_enabled, "runs": optimize_runs}, + }, + } + return solc_input + + except json.JSONDecodeError: + # If JSON parsing fails, treat as single source code + contract_name = etherscan_data.get("ContractName", "Contract") + filename = f"{contract_name}.sol" + solc_input = { + "language": "Solidity", + "sources": {filename: {"content": source_code}}, + "settings": { + "outputSelection": OUTPUT_SELECT_ALL, + "optimizer": {"enabled": optimizer_enabled, "runs": optimize_runs}, + }, + } + return solc_input + + def node_contains(src_str: str, pc_source: dict) -> bool: """ Check if the source code contains the given pc_source """ if not src_str: return False - offset, length, _fidx = list(map(int, src_str.split(':'))) - return offset <= pc_source['begin'] and offset + length >= pc_source['end'] + offset, length, _fidx = list(map(int, src_str.split(":"))) + return offset <= pc_source["begin"] and offset + length >= pc_source["end"] + -def compile_standard(version: str, input_json: dict, solc_bin_resolver: Callable[[str], str] = solc_bin, cwd: Optional[str]=None): - ''' +def compile_standard( + version: str, input_json: dict, solc_bin_resolver: Callable[[str], str] = solc_bin, cwd: Optional[str] = None +): + """ Compile standard input json and parse output as json. Parameters: version: solc version. Example: 0.8.13 input_json: standard json input solc_bin_resolver: a function takes a solc version string and returns a full path to solc executable - ''' - print(f'Compiling with solc version: {version}') + """ + print(f"Compiling with solc version: {version}") solc = solc_bin_resolver(version) if not os.path.exists(solc): - raise Exception(f'solc not found at: {solc}, please download all solc binaries first or provide your `solc_bin_resolver` function') - + raise Exception( + f"solc not found at: {solc}, please download all solc binaries first or provide your `solc_bin_resolver` function" + ) solc_output = subprocess.check_output( - [solc, "--standard-json",], + [ + solc, + "--standard-json", + ], input=json.dumps(input_json), text=True, stderr=subprocess.PIPE, - cwd=cwd + cwd=cwd, ) return json.loads(solc_output) + def build_pc2idx(evm: dict, deploy: bool = False) -> Tuple[list, dict, dict]: - ''' + """ Build pc2idx map from one evm dictionary. If deploy is True, build it using deployment code. Returns a tuple: (code, pc2idx, pc2opcode) - ''' - evm_key = 'bytecode' if deploy else 'deployedBytecode' + """ + evm_key = "bytecode" if deploy else "deployedBytecode" # opcodes list (including operand datasize information for the opcode) # Example path in standard json: '.contracts."FILE_PATH.SOL"."CONTRACT_NAME".evm.deployedBytecode.opcodes' - opcodes = evm[evm_key]['opcodes'].split() + opcodes = evm[evm_key]["opcodes"].split() + # source code mapping blocks # Example path in standard json: '.contracts."FILE_PATH.SOL"."CONTRACT_NAME".evm.legacyAssembly.".data"."0".".code"' - code = evm['legacyAssembly']['.code'] if deploy else evm['legacyAssembly']['.data']['0']['.code'] + code = evm["legacyAssembly"][".code"] if deploy else evm["legacyAssembly"][".data"]["0"][".code"] offset = 0 # program counter: byte offset - idx = 0 # index of source code mapping blocks - idx2pc = {} # dict: index -> pc + idx = 0 # index of source code mapping blocks + idx2pc = {} # dict: index -> pc op_idx = 0 # idx value in contract opcodes list i = 0 @@ -73,28 +148,27 @@ def build_pc2idx(evm: dict, deploy: bool = False) -> Tuple[list, dict, dict]: size = 2 # opcode size: one byte as hex takes two chars datasize = 0 - opcode = c.get('name').split()[0] + opcode = c.get("name").split()[0] pc2opcode[offset] = opcode - - if opcode == 'PUSHDEPLOYADDRESS': + if opcode == "PUSHDEPLOYADDRESS": i += 2 continue - if (not opcode.isupper()): + if not opcode.isupper(): idx += 1 continue - if opcode.startswith('PUSH'): + if opcode.startswith("PUSH"): op = opcodes[op_idx] try: datasize = int(op[4:]) * 2 if len(op) > 4 else 2 except Exception as e: - print(f'error: {e}') + print(f"error: {e}") continue - op_idx += 1 + if datasize != 0: + op_idx += 1 size += datasize - # print(f'PC {offset:4} IDX: {idx:4} datasize: {datasize:2} {c}') idx += 1 offset += int(size / 2) op_idx += 1 @@ -102,26 +176,32 @@ def build_pc2idx(evm: dict, deploy: bool = False) -> Tuple[list, dict, dict]: pc2idx = {v: k for k, v in idx2pc.items()} return code, pc2idx, pc2opcode + def source_content_by_file_key(input_json: dict, filename: str): - ''' + """ Get source code content by unique filename - ''' - return s.get_in(input_json, 'sources', filename, 'content') + """ + return s.get_in(input_json, "sources", filename, "content") + def filename_by_fid(output_json: dict, fid: int) -> str: - for k, source in output_json['sources'].items(): - if fid == source['id']: + for k, source in output_json["sources"].items(): + if fid == source["id"]: filename = k break return filename + def source_content_by_fid(input_json: dict, output_json: dict, fid: int): filename = filename_by_fid(output_json, fid) return source_content_by_file_key(input_json, filename) -def source_by_pc(code, pc2idx, input_json: dict, output_json: dict, pc: int, resolve_yul_block: Optional[Callable]=None): + +def source_by_pc( + code, pc2idx, input_json: dict, output_json: dict, pc: int, resolve_yul_block: Optional[Callable] = None +): # code, pc2idx, *_ = build_pc2idx(evm, deploy) code_len = len(code) @@ -129,7 +209,7 @@ def source_by_pc(code, pc2idx, input_json: dict, output_json: dict, pc: int, res for k in range(pc, -1, -1): idx = pc2idx.get(k, None) if idx is not None: - if idx >= code_len: # code index is outside code list + if idx >= code_len: # code index is outside code list continue block = code[idx] break @@ -137,54 +217,63 @@ def source_by_pc(code, pc2idx, input_json: dict, output_json: dict, pc: int, res if block is None: return None - fid = block.get('source', 0) # some times there is no `source` field. - begin = block.get('begin') - end = block.get('end') + fid = block.get("source", 0) # some times there is no `source` field. + begin = block.get("begin") + end = block.get("end") # name = block.get('name') file_key = None - for k, source in output_json['sources'].items(): - if fid == source['id']: + for k, source in output_json["sources"].items(): + if fid == source["id"]: file_key = k break if not file_key and resolve_yul_block is not None: r = resolve_yul_block(block) if r: - r['pc'] = pc + r["pc"] = pc return r return None content = source_content_by_file_key(input_json, file_key) highlight = content.encode()[begin:end].decode() - line_start = content.encode()[:begin].decode().count('\n') + 1 - line_end = content.encode()[:end].decode().count('\n') + 1 - return dict(pc=pc, linenums = [line_start, line_end], fragment=highlight, fid=file_key, begin=begin, end=end, source_idx = fid, source_path = file_key) + line_start = content.encode()[:begin].decode().count("\n") + 1 + line_end = content.encode()[:end].decode().count("\n") + 1 + return dict( + pc=pc, + linenums=[line_start, line_end], + fragment=highlight, + fid=file_key, + begin=begin, + end=end, + source_idx=fid, + source_path=file_key, + ) def evms_by_contract_name(output_json: dict, contract_name: str) -> List[Tuple[str, dict]]: - ''' + """ Get evm json by contract name, returns a list of dict. Each dict is a evm json. A list is returned because there may be multiple contracts with the same name. - ''' + """ result = [] - for filename, v in output_json['contracts'].items(): + for filename, v in output_json["contracts"].items(): for name, c in v.items(): if name == contract_name: - result.append((filename, c.get('evm'))) + result.append((filename, c.get("evm"))) return result def has_compilation_error(output_json: dict) -> bool: - errors_t = {t.get('type') for t in output_json.get('errors', [])} + errors_t = {t.get("type") for t in output_json.get("errors", [])} for e in errors_t: - if 'Error' in e: + if "Error" in e: return True return False -def override_settings(input_json): +def override_settings(input_json, etherscan: bool = False): """ Override settings: - Disable optimization which could confuse source mapping @@ -192,31 +281,62 @@ def override_settings(input_json): https://docs.soliditylang.org/en/latest/using-the-compiler.html#input-description """ - s.assoc_in(input_json, ['settings', 'optimizer', 'enabled'], False) - s.assoc_in(input_json, ['settings', 'outputSelection'], {'*': {'*': [ '*' ], '': ['ast']}}) - s.assoc_in(input_json, ['settings', 'metadata'], {'bytecodeHash': 'none'}) # equiv. of solc --metadata=none + if not etherscan: + s.assoc_in(input_json, ["settings", "optimizer", "enabled"], False) + s.assoc_in(input_json, ["settings", "outputSelection"], {"*": {"*": ["*"], "": ["ast"]}}) + if "metadata" in input_json.get("settings", {}): + s.assoc_in(input_json, ["settings", "metadata"], {"bytecodeHash": "none"}) # equiv. of solc --metadata=none - input_json['language']= input_json.get('language', 'Solidity') + input_json["language"] = input_json.get("language", "Solidity") return input_json class StandardJsonParser(BaseParser): - def __init__(self, input_json: Union[dict, str], version: str, solc_bin_resolver: Callable[[str], str] = solc_bin, cwd: Optional[str] = None, - retry_num: Optional[int]=0, - try_install_solc: Optional[bool]=False, - solc_options: Optional[Dict] = {}): + def __init__( + self, + input_json: Union[dict, str], + version: str = None, + solc_bin_resolver: Callable[[str], str] = solc_bin, + cwd: Optional[str] = None, + retry_num: Optional[int] = 0, + try_install_solc: Optional[bool] = False, + solc_options: Optional[Dict] = {}, + etherscan: bool = False, + ): if retry_num is not None and retry_num > 0: - raise Exception('StandardJsonParser does not support retry') + raise Exception("StandardJsonParser does not support retry") if try_install_solc: - print('StandardJsonParser does not support try_install_solc, option will be ignored', file=sys.stderr) + print("StandardJsonParser does not support try_install_solc, option will be ignored", file=sys.stderr) if solc_options: - print('StandardJsonParser does not support solc_options, please set extra parameters to input_json instead', file=sys.stderr) + print( + "StandardJsonParser does not support solc_options, please set extra parameters to input_json instead", + file=sys.stderr, + ) super().__init__() self.file_path = None - self.solc_version: str = version + + # Handle etherscan JSON preprocessing + if etherscan and isinstance(input_json, str) and not input_json.startswith("{"): + # input_json is a file path to etherscan JSON + with open(input_json, "r", encoding="utf-8") as f: + etherscan_data = json.load(f) + + # Extract version from etherscan data if not provided + if version is None: + version = etherscan_data.get("CompilerVersion", "0.8.0") + # Clean up version string (remove commit info) + version_match = re.search(r"(\d+\.\d+\.\d+)", version) + if version_match: + version = version_match.group(1) + + # Preprocess the etherscan JSON + input_json = _preprocess_etherscan_json(input_json) + + self.solc_version: str = version or "0.8.0" + try: # try parse as json self.input_json: dict = input_json if isinstance(input_json, dict) else json.loads(input_json) @@ -224,12 +344,11 @@ def __init__(self, input_json: Union[dict, str], version: str, solc_bin_resolver # try use input as a plain source file self.input_json = StandardJsonParser.__prepare_standard_input(input_json) - - self.input_json = override_settings(self.input_json) + self.input_json = override_settings(self.input_json, etherscan) # https://soliditylang.org/blog/2023/02/01/solidity-0.8.18-release-announcement - support_cbor = Version(version) >= Version('0.8.18') + support_cbor = Version(version) >= Version("0.8.18") if support_cbor: - s.assoc_in(self.input_json, ['settings', 'metadata', 'appendCBOR'], False) + s.assoc_in(self.input_json, ["settings", "metadata", "appendCBOR"], False) self.solc_json_ast: Dict[int, dict] = {} self.is_standard_json = True @@ -239,44 +358,33 @@ def __init__(self, input_json: Union[dict, str], version: str, solc_bin_resolver self.output_json = compile_standard(version, self.input_json, solc_bin_resolver, cwd) if has_compilation_error(self.output_json): - raise SolidityAstError(f"Compile failed: {self.output_json.get('errors')}" ) + raise SolidityAstError(f"Compile failed: {self.output_json.get('errors')}") self.post_configure_compatible_fields() @staticmethod def __prepare_standard_input(source: str) -> Dict: - if '\n' not in source: - with open(source, 'r') as f: + if "\n" not in source: + with open(source, "r") as f: source = f.read() input_json = { - 'language': 'Solidity', - 'sources': { - 'source.sol': { - 'content': source - } - }, - 'settings': { - 'optimizer': { - 'enabled': False, + "language": "Solidity", + "sources": {"source.sol": {"content": source}}, + "settings": { + "optimizer": { + "enabled": False, }, # 'evmVersion': 'istanbul', - 'outputSelection': { - '*': { - '*': [ '*' ], - '': ['ast'] - } - } - } + "outputSelection": {"*": {"*": ["*"], "": ["ast"]}}, + }, } return input_json - def prepare_by_version(self): super().prepare_by_version() # NOTE the whole v_keys seems unneccessary when using standard json input, all format follows v8 version of combined json outputs - self.keys = v_keys['v8'] - + self.keys = v_keys["v8"] def pre_configure_compatible_fields(self): """ @@ -288,51 +396,47 @@ def pre_configure_compatible_fields(self): def __build_ast(self): ast_dict = {} - for filename, source in self.output_json.get('sources').items(): + for filename, source in self.output_json.get("sources").items(): # key = source['id'] ast_dict.update({filename: source}) return ast_dict - def get_line_number_range_and_source(self, slf): start, length, fid = slf content = source_content_by_fid(self.input_json, self.output_json, fid) if not content: return (0, 0), "" source_code_bytes = content.encode() - start_line = source_code_bytes[:start].decode().count('\n') + 1 - end_line = start_line + source_code_bytes[start:start + length].decode().count('\n') + start_line = source_code_bytes[:start].decode().count("\n") + 1 + end_line = start_line + source_code_bytes[start : start + length].decode().count("\n") return (start_line, end_line), source_code_bytes.decode() - def _get_contract_meta_data(self, node: Dict) -> tuple: # line number range is the same for all versions - line_number_range_raw = list(map(int, node.get('src').split(':'))) + line_number_range_raw = list(map(int, node.get("src").split(":"))) line_number_range, _ = self.get_line_number_range_and_source(line_number_range_raw) - contract_id = node.get('id') + contract_id = node.get("id") # assert node.get('name') is not None # assert node.get('abstract') is not None # assert node.get('baseContracts') is not None - contract_kind = node.get('contractKind') + contract_kind = node.get("contractKind") - is_abstract = node.get('abstract') + is_abstract = node.get("abstract") - if node.get('baseContracts') is not None: - base_contracts = self._get_base_contracts(node.get('baseContracts')) + if node.get("baseContracts") is not None: + base_contracts = self._get_base_contracts(node.get("baseContracts")) else: - base_contracts = node.get('contractDependencies') - contract_name = node.get('name') + base_contracts = node.get("contractDependencies") + contract_name = node.get("name") return contract_id, contract_kind, is_abstract, contract_name, base_contracts, line_number_range - @cached_property def exported_symbols(self) -> Dict[str, int]: return s.symbols_to_ids_from_ast_v8(self.solc_json_ast) - def post_configure_compatible_fields(self): """ Configure the fields to maintain backward compatibility with the CombinedJsonParser, called after compilation @@ -344,23 +448,21 @@ def source_by_yul_block(self, block: Dict): """ Get source code by Yul block """ - fid = block.get('source') - begin = block.get('begin') - end = block.get('end') - pred = lambda node: node and node.get('language') == 'Yul' and node.get('id') == fid + fid = block.get("source") + begin = block.get("begin") + end = block.get("end") + pred = lambda node: node and node.get("language") == "Yul" and node.get("id") == fid # this does not consider deployment code or not, might be a bug - yul_source = self.extract_node(pred, self.output_json['contracts'], first_only=True)[0] + yul_source = self.extract_node(pred, self.output_json["contracts"], first_only=True)[0] if not yul_source: return None - source_as_bytes = yul_source['contents'].encode() + source_as_bytes = yul_source["contents"].encode() fragment = source_as_bytes[begin:end].decode() - linenums = (source_as_bytes[:begin].decode().count('\n') + 1, - source_as_bytes[:end].decode().count('\n') + 1) - - return dict(fragment=fragment, begin=begin, end=end, linenums=linenums, fid=fid, source_path=yul_source['name']) + linenums = (source_as_bytes[:begin].decode().count("\n") + 1, source_as_bytes[:end].decode().count("\n") + 1) + return dict(fragment=fragment, begin=begin, end=end, linenums=linenums, fid=fid, source_path=yul_source["name"]) def source_by_pc(self, contract_name: str, pc: int, deploy=False) -> Optional[dict]: """ @@ -372,7 +474,10 @@ def source_by_pc(self, contract_name: str, pc: int, deploy=False) -> Optional[di evms = evms_by_contract_name(self.output_json, contract_name) for _, evm in evms: code, pc2idx, *_ = self.__build_pc2idx(evm, deploy) - result = source_by_pc(code, pc2idx, self.input_json, self.output_json, pc, resolve_yul_block=self.source_by_yul_block) + + result = source_by_pc( + code, pc2idx, self.input_json, self.output_json, pc, resolve_yul_block=self.source_by_yul_block + ) if result: return result return None @@ -404,21 +509,27 @@ def extract_node(self, pred: Callable, root_node: List[Dict], first_only=True) - return found - def ast_units_by_pc(self, contract_name: str, pc: int, node_type: Optional[str], deploy=False, first_only=False) -> List[Dict]: + def ast_units_by_pc( + self, contract_name: str, pc: int, node_type: Optional[str], deploy=False, first_only=False + ) -> List[Dict]: """ Get all AST units by PC """ pc_source = self.source_by_pc(contract_name, pc, deploy) if not pc_source: return [] - pred = lambda node: node and (node_type is None or node.get('nodeType') == node_type) and node_contains(node.get('src'), pc_source) - return self.extract_node(pred, self.output_json['sources'][pc_source['fid']]['ast'], first_only=first_only) + pred = ( + lambda node: node + and (node_type is None or node.get("nodeType") == node_type) + and node_contains(node.get("src"), pc_source) + ) + return self.extract_node(pred, self.output_json["sources"][pc_source["fid"]]["ast"], first_only=first_only) def function_unit_by_pc(self, contract_name: str, pc: int, deploy=False) -> Optional[Dict]: """ Get the function AST unit containing the PC """ - units = self.ast_units_by_pc(contract_name, pc, 'FunctionDefinition', deploy, first_only=True) + units = self.ast_units_by_pc(contract_name, pc, "FunctionDefinition", deploy, first_only=True) return units[0] if units else None def ast_unit_by_pc(self, contract_name: str, pc: int, deploy=False) -> Optional[Dict]: @@ -428,7 +539,6 @@ def ast_unit_by_pc(self, contract_name: str, pc: int, deploy=False) -> Optional[ units = self.ast_units_by_pc(contract_name, pc, node_type=None, deploy=deploy, first_only=False) return units[-1] if units else None - def all_pcs(self, contract: str, deploy: Optional[bool] = False) -> List[int]: """ Returns a list of PCs inside the contract @@ -444,7 +554,7 @@ def __build_pc2idx(self, evm: dict, deploy: bool = False) -> Tuple[list, dict, d @cache def pc2opcode_by_contract(self, contract_name: str, deploy: bool) -> Dict[int, str]: evms = evms_by_contract_name(self.output_json, contract_name) - for _, evm in evms: # if same contract existsin in multiple files, there could be a problem + for _, evm in evms: # if same contract existsin in multiple files, there could be a problem _, _, pc2opcode = self.__build_pc2idx(evm, deploy) return pc2opcode return {} @@ -452,19 +562,18 @@ def pc2opcode_by_contract(self, contract_name: str, deploy: bool) -> Dict[int, s def function_by_name(self, contract_name: str, function_name: str) -> Function: """Return a function for a given contract name and function name""" contract = self.contract_by_name(contract_name) - funcs = self.functions_in_contract(contract) + funcs = self.functions_in_contract(contract) return next(fn for fn in funcs if fn.name == function_name) - def __get_binary(self, contract_name: str, filename: Optional[str], deploy=False) -> List[Tuple[str, str, str]]: """ Returns a list of tuples, each tuple is: `(filename, contract_name, binary)` """ bins = [] evms = evms_by_contract_name(self.output_json, contract_name) - bytecode_key = 'bytecode' if deploy else 'deployedBytecode' + bytecode_key = "bytecode" if deploy else "deployedBytecode" for _filename, evm in evms: - bin = evm.get(bytecode_key, {}).get('object') + bin = evm.get(bytecode_key, {}).get("object") if bin and ((not filename) or _filename == filename): bins.append((filename, contract_name, bin)) return bins @@ -481,25 +590,24 @@ def get_deployment_binary(self, contract_name: str) -> List[Tuple[str, str, str] """ return self.__get_binary(contract_name, None, deploy=True) - def qualified_name_from_hash(self, hsh: str)->Optional[Tuple[str, str]]: - '''Get fully qualified contract name from 34 character hash''' - for filename, m_contract in self.output_json.get('contracts').items(): + def qualified_name_from_hash(self, hsh: str) -> Optional[Tuple[str, str]]: + """Get fully qualified contract name from 34 character hash""" + for filename, m_contract in self.output_json.get("contracts").items(): for contract_name, contract in m_contract.items(): - full_name = f'{filename}:{contract_name}' + full_name = f"{filename}:{contract_name}" if hsh == s.keccak256(full_name)[:34]: return (filename, contract_name) return None def get_deploy_bin_by_hash(self, hsh: str) -> Optional[str]: - '''Get deployment binary by hash of fully qualified contract / library name''' + """Get deployment binary by hash of fully qualified contract / library name""" r = self.qualified_name_from_hash(hsh) if not r: return None filename, contract_name = r return self.__get_binary(contract_name, filename, deploy=True)[0][2] - def get_literals(self, contract_name: str, only_value=False) -> dict: """ Get all literals(number, address, string, other) in the contract. @@ -511,11 +619,11 @@ def get_literals(self, contract_name: str, only_value=False) -> dict: literals_nodes = set() contract_node = None for filename, unit in self.solc_json_ast.items(): - root_node = unit.get('ast') + root_node = unit.get("ast") for i, node in enumerate(root_node[self.keys.children]): if node[self.keys.name] == "ContractDefinition": - info_node = node if self.v8 else node.get('attributes') - if info_node['name'] == contract_name: + info_node = node if self.v8 else node.get("attributes") + if info_node["name"] == contract_name: contract_node = node break @@ -533,27 +641,26 @@ def source_path_by_contract(self, contract_name: str) -> str: - May throw exception if no source file contains the contract. - May return unexpected result when the contract appears in multiple source files. """ - pred = lambda node: node and node.get('nodeType') == 'ContractDefinition' and node.get('name') == contract_name - contract = self.extract_node(pred, self.output_json['sources'], first_only=True)[0] - return contract['source_id'] + pred = lambda node: node and node.get("nodeType") == "ContractDefinition" and node.get("name") == contract_name + contract = self.extract_node(pred, self.output_json["sources"], first_only=True)[0] + return contract["source_id"] def all_source_path_by_contract(self, contract_name: str) -> Optional[List[str]]: """ Get source path by contract name. """ - pred = lambda node: node and node.get('nodeType') == 'ContractDefinition' and node.get('name') == contract_name - contracts = self.extract_node(pred, self.output_json['sources'], first_only=False) - return [c['source_id'] for c in contracts] if contracts else [] + pred = lambda node: node and node.get("nodeType") == "ContractDefinition" and node.get("name") == contract_name + contracts = self.extract_node(pred, self.output_json["sources"], first_only=False) + return [c["source_id"] for c in contracts] if contracts else [] def source_by_lines(self, contract_name: str, line_start: int, line_end: int) -> str: """ Get source code by contract name and line numbers, line numbers are zero indexed """ source_path = self.source_path_by_contract(contract_name) - content = self.input_json['sources'][source_path]['content'] - lines = content.split('\n')[line_start:line_end] - return '\n'.join(lines) - + content = self.input_json["sources"][source_path]["content"] + lines = content.split("\n")[line_start:line_end] + return "\n".join(lines) def all_available_modifiers_by_contract_name(self) -> Dict[str, list]: """Return all available modifiers by contract name.""" @@ -567,36 +674,34 @@ def all_available_modifiers_by_contract_name(self) -> Dict[str, list]: def source_by_fid(self, fid: int) -> Tuple[Optional[str], Optional[str]]: """Get source code by file id. Returns error message and source code.""" - pred = lambda node: node and node.get('id') == fid - source = self.extract_node(pred, self.output_json, first_only=True) + pred = lambda node: node and node.get("id") == fid + source = self.extract_node(pred, self.output_json, first_only=True) if not source: - return 'no source found', None + return "no source found", None source = source[0] file_key = None - for k, source in self.output_json['sources'].items(): - if fid == source['id']: + for k, source in self.output_json["sources"].items(): + if fid == source["id"]: file_key = k break if not file_key: - return 'no file_key', None - return None, self.input_json['sources'][file_key]['content'] - - + return "no file_key", None + return None, self.input_json["sources"][file_key]["content"] def source_by_pred(self, pred: Callable) -> Tuple[Optional[str], Optional[str]]: """Get source code by unit name. Returns error message and source code.""" unit = self.extract_node(pred, self.output_json, first_only=True) if not unit: - return 'no unit found', None + return "no unit found", None unit = unit[0] - (start, size, fid) = [int(i) for i in unit['src'].split(':')] + (start, size, fid) = [int(i) for i in unit["src"].split(":")] err, content = self.source_by_fid(fid) if err: return err, None - content = content.encode() - return None, content[start:start+size].decode() + content = content.encode() + return None, content[start : start + size].decode()