diff --git a/include/linux/tnum.h b/include/linux/tnum.h index c52b862dad45b..e5d87853892fa 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -125,5 +125,6 @@ static inline bool tnum_subreg_is_const(struct tnum a) { return !(tnum_subreg(a)).mask; } - +/* Returns smallest member of t > z */ +u64 tnum_step(struct tnum t, u64 z); #endif /* _LINUX_TNUM_H */ diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index f8e70e9c3998d..cca2b5675a1d4 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -253,3 +253,42 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value) { return tnum_with_subreg(a, tnum_const(value)); } + +u64 tnum_step(struct tnum t, u64 z) +{ + u64 tmax, j, p, q, r, s, v, u, w, res; + u8 k; + + tmax = t.value | t.mask; + + /* if z >= largest member of t, return largest member of t */ + if (z >= tmax) + return tmax; + + /* keep t's known bits, and match all unknown bits to z */ + j = t.value | (z & t.mask); + + if (j > z) { + p = ~z & t.value & ~t.mask; + k = fls64(p); /* k is the most-significant 0-to-1 flip */ + q = U64_MAX << k; + r = q & z; /* positions > k matched to z */ + s = ~q & t.value; /* positions <= k matched to t.value */ + v = r | s; + res = v; + } else { + p = z & ~t.value & ~t.mask; + k = fls64(p); /* k is the most-significant 1-to-0 flip */ + q = U64_MAX << k; + r = q & t.mask & z; /* unknown positions > k, matched to z */ + s = q & ~t.mask; /* known positions > k, set to 1 */ + v = r | s; + /* add 1 to unknown positions > k to make value greater than z */ + u = v + (1ULL << k); + /* extract bits in unknown positions > k from u, rest from t.value */ + w = u & (t.mask | t.value); + res = w; + } + return res; +} + diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index d6b8a77fbe3bf..3d983fa49836c 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -16059,11 +16059,165 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate, })); } +static bool intersection_u64_s64(u64 umin, u64 umax, s64 smin, s64 smax) +{ + if ((u64)smin <= (u64)smax) + return !(((u64)smax < umin) || (umax < (u64)smin)); + else + return !(((u64)smin > umax) && ((u64)smax < umin)); +} + +static bool intersection_u64_tnum(u64 umin, u64 umax, struct tnum t) +{ + u64 tmin = t.value; + u64 tmax = t.value | t.mask; + + return !((tmin > umax) || (tmax < umin) || + ((t.value != (umin & ~t.mask)) && (tnum_step(t, umin) > umax))); +} + +static bool intersection_s64_tnum(s64 smin, s64 smax, struct tnum t) +{ + if ((u64)smin <= (u64)smax) + return intersection_u64_tnum((u64)smin, (u64)smax, t); + + return (intersection_u64_tnum((u64)smin, U64_MAX, t) || + intersection_u64_tnum(0, (u64)smax, t)); +} + +static bool intersection_u32_s32(u32 u32_min, u32 u32_max, s32 s32_min, s32 s32_max) +{ + if ((u32)s32_min <= (u32)s32_max) + return !(((u32)s32_max < u32_min) || (u32_max < (u32)s32_min)); + else + return !(((u32)s32_min > u32_max) && ((u32)s32_max < u32_min)); +} + +static bool intersection_u32_tnum(u32 u32_min, u32 u32_max, struct tnum t) +{ + struct tnum t32 = tnum_subreg(t); + u32 t32_min = t32.value; + u32 t32_max = t32.value | t32.mask; + + return !((t32_min > u32_max) || + (t32_max < u32_min) || + ((t32.value != (u32_min & ~t32.mask)) && + (tnum_step(t32, u32_min) > u32_max))); +} + +static bool intersection_s32_tnum(s32 s32_min, s32 s32_max, struct tnum t) +{ + if ((u32)s32_min <= (u32)s32_max) + return intersection_u32_tnum((u32)s32_min, (u32)s32_max, t); + + return (intersection_u32_tnum((u32)s32_min, U32_MAX, t) || + intersection_u32_tnum(0, (u32)s32_max, t)); +} + +static bool check_intersection_all(struct bpf_verifier_env *env, + struct bpf_reg_state *reg, + const char *ctx) +{ + + bool intersection_exists; + + if (reg->umin_value > reg->umax_value || + reg->smin_value > reg->smax_value || + reg->u32_min_value > reg->u32_max_value || + reg->s32_min_value > reg->s32_max_value) { + intersection_exists = false; + } else if ((reg->var_off.value & reg->var_off.mask) != 0) { + intersection_exists = false; + } else if (!intersection_u64_s64(reg->umin_value, reg->umax_value, + reg->smin_value, reg->smax_value)) { + intersection_exists = false; + } else if (!intersection_u64_tnum(reg->umin_value, reg->umax_value, reg->var_off)) { + intersection_exists = false; + } else if (!intersection_s64_tnum(reg->smin_value, reg->smax_value, reg->var_off)) { + intersection_exists = false; + } else if (!intersection_u32_s32(reg->u32_min_value, reg->u32_max_value, + reg->s32_min_value, reg->s32_max_value)) { + intersection_exists = false; + } else if (!intersection_u32_tnum(reg->u32_min_value, reg->u32_max_value, reg->var_off)) { + intersection_exists = false; + } else if (!intersection_s32_tnum(reg->s32_min_value, reg->s32_max_value, reg->var_off)) { + intersection_exists = false; + } else { + intersection_exists = true; + } + + if (!intersection_exists) { + return false; + } else + return true; + +} + +static void regs_refine_cond_op(struct bpf_reg_state *reg1, + struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32); +static u8 rev_opcode(u8 opcode); + +static int simulate_both_branches_taken(struct bpf_verifier_env *env, + struct bpf_reg_state *false_reg1, + struct bpf_reg_state *false_reg2, + u8 opcode, bool is_jmp32) +{ + + struct bpf_reg_state false_reg1_c, false_reg2_c, true_reg1, true_reg2; + bool t1, t2, f1, f2; + + memcpy(&false_reg1_c, false_reg1, sizeof(struct bpf_reg_state)); + memcpy(&false_reg2_c, false_reg2, sizeof(struct bpf_reg_state)); + memcpy(&true_reg1, false_reg1, sizeof(struct bpf_reg_state)); + memcpy(&true_reg2, false_reg2, sizeof(struct bpf_reg_state)); + + /* fallthrough (FALSE) branch */ + check_intersection_all(env, &false_reg1_c, "BR_false_reg1_c"); + check_intersection_all(env, &false_reg2_c, "BR_false_reg2_c"); + regs_refine_cond_op(&false_reg1_c, &false_reg2_c, rev_opcode(opcode), is_jmp32); + check_intersection_all(env, &false_reg1_c, "BS_false_reg1_c"); + check_intersection_all(env, &false_reg2_c, "BS_false_reg2_c"); + reg_bounds_sync(&false_reg1_c); + reg_bounds_sync(&false_reg2_c); + f1 = check_intersection_all(env, &false_reg1_c, "AS_false_reg1_c"); + f2 = check_intersection_all(env, &false_reg2_c, "AS_false_reg2_c"); + + /* jump (TRUE) branch */ + check_intersection_all(env, &true_reg1, "BR_true_reg1"); + check_intersection_all(env, &true_reg2, "BR_true_reg2"); + regs_refine_cond_op(&true_reg1, &true_reg2, opcode, is_jmp32); + check_intersection_all(env, &true_reg1, "BS_true_reg1"); + check_intersection_all(env, &true_reg2, "BS_true_reg2"); + reg_bounds_sync(&true_reg1); + reg_bounds_sync(&true_reg2); + t1 = check_intersection_all(env, &true_reg1, "AS_true_reg1"); + t2 = check_intersection_all(env, &true_reg2, "AS_true_reg2"); + + if (!f1 || !f2) { + /* If there is no intersection among *any pair* of abstract values in + * either reg_states in the FALSE branch (i.e. false_reg1, false_reg2), + * the FALSE branch must be dead. Only TRUE branch will be taken. + */ + return 1; + } else if (!t1 || !t2) { + /* If there is no intersection among *any pair* of abstract values in + * either reg_states in the TRUE branch (i.e. true_reg1, true_reg2), + * the TRUE branch must be dead. Only FALSE branch will be taken. + */ + return 0; + } + + return -1; +} + /* * , currently assuming reg2 is a constant */ -static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, - u8 opcode, bool is_jmp32) +static int is_scalar_branch_taken(struct bpf_verifier_env *env, + struct bpf_reg_state *reg1, + struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) { struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off; struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off; @@ -16215,7 +16369,7 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta break; } - return -1; + return simulate_both_branches_taken(env, reg1, reg2, opcode, is_jmp32); } static int flip_opcode(u32 opcode) @@ -16286,8 +16440,9 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg, * -1 - unknown. Example: "if (reg1 < 5)" is unknown when register value * range [0,10] */ -static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, - u8 opcode, bool is_jmp32) +static int is_branch_taken(struct bpf_verifier_env *env, + struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) { if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32) return is_pkt_ptr_branch_taken(reg1, reg2, opcode); @@ -16325,7 +16480,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg } /* now deal with two scalars, but not necessarily constants */ - return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); + return is_scalar_branch_taken(env, reg1, reg2, opcode, is_jmp32); } /* Opcode that corresponds to a *false* branch condition. @@ -16933,7 +17088,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, } is_jmp32 = BPF_CLASS(insn->code) == BPF_JMP32; - pred = is_branch_taken(dst_reg, src_reg, opcode, is_jmp32); + pred = is_branch_taken(env, dst_reg, src_reg, opcode, is_jmp32); if (pred >= 0) { /* If we get here with a dst_reg pointer type it is because * above is_branch_taken() special cased the 0 comparison.