The current min/max code does both signed and unsigned comparisons against
the input argument "val" which is "u64" and there is explicit type casting
when the comparison is signed.

As we will need slightly more complexer type casting when JMP32 introduced,
it is better to host the signed type casting. This makes the code more
clean with ignorable runtime overhead.

Also, code for J*GE/GT/LT/LE and JEQ/JNE are very similar, this patch
combine them.

The main purpose for this refactor is to make sure the min/max code will
still be readable and with minimum code duplication after JMP32 introduced.

Reviewed-by: Jakub Kicinski <jakub.kicin...@netronome.com>
Signed-off-by: Jiong Wang <jiong.w...@netronome.com>
---
 kernel/bpf/verifier.c | 172 +++++++++++++++++++++++++++++---------------------
 1 file changed, 99 insertions(+), 73 deletions(-)

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index ce87198..53f5135 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -4033,9 +4033,13 @@ static void find_good_pkt_pointers(struct 
bpf_verifier_state *vstate,
  */
 static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
 {
+       s64 sval;
+
        if (__is_pointer_value(false, reg))
                return -1;
 
+       sval = (s64)val;
+
        switch (opcode) {
        case BPF_JEQ:
                if (tnum_is_const(reg->var_off))
@@ -4058,9 +4062,9 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 
val, u8 opcode)
                        return 0;
                break;
        case BPF_JSGT:
-               if (reg->smin_value > (s64)val)
+               if (reg->smin_value > sval)
                        return 1;
-               else if (reg->smax_value < (s64)val)
+               else if (reg->smax_value < sval)
                        return 0;
                break;
        case BPF_JLT:
@@ -4070,9 +4074,9 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 
val, u8 opcode)
                        return 0;
                break;
        case BPF_JSLT:
-               if (reg->smax_value < (s64)val)
+               if (reg->smax_value < sval)
                        return 1;
-               else if (reg->smin_value >= (s64)val)
+               else if (reg->smin_value >= sval)
                        return 0;
                break;
        case BPF_JGE:
@@ -4082,9 +4086,9 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 
val, u8 opcode)
                        return 0;
                break;
        case BPF_JSGE:
-               if (reg->smin_value >= (s64)val)
+               if (reg->smin_value >= sval)
                        return 1;
-               else if (reg->smax_value < (s64)val)
+               else if (reg->smax_value < sval)
                        return 0;
                break;
        case BPF_JLE:
@@ -4094,9 +4098,9 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 
val, u8 opcode)
                        return 0;
                break;
        case BPF_JSLE:
-               if (reg->smax_value <= (s64)val)
+               if (reg->smax_value <= sval)
                        return 1;
-               else if (reg->smin_value > (s64)val)
+               else if (reg->smin_value > sval)
                        return 0;
                break;
        }
@@ -4113,6 +4117,8 @@ static void reg_set_min_max(struct bpf_reg_state 
*true_reg,
                            struct bpf_reg_state *false_reg, u64 val,
                            u8 opcode)
 {
+       s64 sval;
+
        /* If the dst_reg is a pointer, we can't learn anything about its
         * variable offset from the compare (unless src_reg were a pointer into
         * the same object, but we don't bother with that.
@@ -4122,19 +4128,22 @@ static void reg_set_min_max(struct bpf_reg_state 
*true_reg,
        if (__is_pointer_value(false, false_reg))
                return;
 
+       sval = (s64)val;
+
        switch (opcode) {
        case BPF_JEQ:
-               /* If this is false then we know nothing Jon Snow, but if it is
-                * true then we know for sure.
-                */
-               __mark_reg_known(true_reg, val);
-               break;
        case BPF_JNE:
-               /* If this is true we know nothing Jon Snow, but if it is false
-                * we know the value for sure;
+       {
+               struct bpf_reg_state *reg =
+                       opcode == BPF_JEQ ? true_reg : false_reg;
+
+               /* For BPF_JEQ, if this is false we know nothing Jon Snow, but
+                * if it is true we know the value for sure. Likewise for
+                * BPF_JNE.
                 */
-               __mark_reg_known(false_reg, val);
+               __mark_reg_known(reg, val);
                break;
+       }
        case BPF_JSET:
                false_reg->var_off = tnum_and(false_reg->var_off,
                                              tnum_const(~val));
@@ -4142,38 +4151,46 @@ static void reg_set_min_max(struct bpf_reg_state 
*true_reg,
                        true_reg->var_off = tnum_or(true_reg->var_off,
                                                    tnum_const(val));
                break;
-       case BPF_JGT:
-               false_reg->umax_value = min(false_reg->umax_value, val);
-               true_reg->umin_value = max(true_reg->umin_value, val + 1);
-               break;
-       case BPF_JSGT:
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val);
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 
1);
-               break;
-       case BPF_JLT:
-               false_reg->umin_value = max(false_reg->umin_value, val);
-               true_reg->umax_value = min(true_reg->umax_value, val - 1);
-               break;
-       case BPF_JSLT:
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val);
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 
1);
-               break;
        case BPF_JGE:
-               false_reg->umax_value = min(false_reg->umax_value, val - 1);
-               true_reg->umin_value = max(true_reg->umin_value, val);
+       case BPF_JGT:
+       {
+               u64 false_umax = opcode == BPF_JGT ? val    : val - 1;
+               u64 true_umin = opcode == BPF_JGT ? val + 1 : val;
+
+               false_reg->umax_value = min(false_reg->umax_value, false_umax);
+               true_reg->umin_value = max(true_reg->umin_value, true_umin);
                break;
+       }
        case BPF_JSGE:
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 
1);
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val);
+       case BPF_JSGT:
+       {
+               s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
+               s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
+
+               false_reg->smax_value = min(false_reg->smax_value, false_smax);
+               true_reg->smin_value = max(true_reg->smin_value, true_smin);
                break;
+       }
        case BPF_JLE:
-               false_reg->umin_value = max(false_reg->umin_value, val + 1);
-               true_reg->umax_value = min(true_reg->umax_value, val);
+       case BPF_JLT:
+       {
+               u64 false_umin = opcode == BPF_JLT ? val    : val + 1;
+               u64 true_umax = opcode == BPF_JLT ? val - 1 : val;
+
+               false_reg->umin_value = max(false_reg->umin_value, false_umin);
+               true_reg->umax_value = min(true_reg->umax_value, true_umax);
                break;
+       }
        case BPF_JSLE:
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 
1);
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val);
+       case BPF_JSLT:
+       {
+               s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
+               s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
+
+               false_reg->smin_value = max(false_reg->smin_value, false_smin);
+               true_reg->smax_value = min(true_reg->smax_value, true_smax);
                break;
+       }
        default:
                break;
        }
@@ -4198,22 +4215,23 @@ static void reg_set_min_max_inv(struct bpf_reg_state 
*true_reg,
                                struct bpf_reg_state *false_reg, u64 val,
                                u8 opcode)
 {
+       s64 sval;
+
        if (__is_pointer_value(false, false_reg))
                return;
 
+       sval = (s64)val;
+
        switch (opcode) {
        case BPF_JEQ:
-               /* If this is false then we know nothing Jon Snow, but if it is
-                * true then we know for sure.
-                */
-               __mark_reg_known(true_reg, val);
-               break;
        case BPF_JNE:
-               /* If this is true we know nothing Jon Snow, but if it is false
-                * we know the value for sure;
-                */
-               __mark_reg_known(false_reg, val);
+       {
+               struct bpf_reg_state *reg =
+                       opcode == BPF_JEQ ? true_reg : false_reg;
+
+               __mark_reg_known(reg, val);
                break;
+       }
        case BPF_JSET:
                false_reg->var_off = tnum_and(false_reg->var_off,
                                              tnum_const(~val));
@@ -4221,38 +4239,46 @@ static void reg_set_min_max_inv(struct bpf_reg_state 
*true_reg,
                        true_reg->var_off = tnum_or(true_reg->var_off,
                                                    tnum_const(val));
                break;
-       case BPF_JGT:
-               true_reg->umax_value = min(true_reg->umax_value, val - 1);
-               false_reg->umin_value = max(false_reg->umin_value, val);
-               break;
-       case BPF_JSGT:
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 
1);
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val);
-               break;
-       case BPF_JLT:
-               true_reg->umin_value = max(true_reg->umin_value, val + 1);
-               false_reg->umax_value = min(false_reg->umax_value, val);
-               break;
-       case BPF_JSLT:
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 
1);
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val);
-               break;
        case BPF_JGE:
-               true_reg->umax_value = min(true_reg->umax_value, val);
-               false_reg->umin_value = max(false_reg->umin_value, val + 1);
+       case BPF_JGT:
+       {
+               u64 false_umin = opcode == BPF_JGT ? val    : val + 1;
+               u64 true_umax = opcode == BPF_JGT ? val - 1 : val;
+
+               false_reg->umin_value = max(false_reg->umin_value, false_umin);
+               true_reg->umax_value = min(true_reg->umax_value, true_umax);
                break;
+       }
        case BPF_JSGE:
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val);
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 
1);
+       case BPF_JSGT:
+       {
+               s64 false_smin = opcode == BPF_JSGT ? sval    : sval + 1;
+               s64 true_smax = opcode == BPF_JSGT ? sval - 1 : sval;
+
+               false_reg->smin_value = max(false_reg->smin_value, false_smin);
+               true_reg->smax_value = min(true_reg->smax_value, true_smax);
                break;
+       }
        case BPF_JLE:
-               true_reg->umin_value = max(true_reg->umin_value, val);
-               false_reg->umax_value = min(false_reg->umax_value, val - 1);
+       case BPF_JLT:
+       {
+               u64 false_umax = opcode == BPF_JLT ? val    : val - 1;
+               u64 true_umin = opcode == BPF_JLT ? val + 1 : val;
+
+               false_reg->umax_value = min(false_reg->umax_value, false_umax);
+               true_reg->umin_value = max(true_reg->umin_value, true_umin);
                break;
+       }
        case BPF_JSLE:
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val);
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 
1);
+       case BPF_JSLT:
+       {
+               s64 false_smax = opcode == BPF_JSLT ? sval    : sval - 1;
+               s64 true_smin = opcode == BPF_JSLT ? sval + 1 : sval;
+
+               false_reg->smax_value = min(false_reg->smax_value, false_smax);
+               true_reg->smin_value = max(true_reg->smin_value, true_smin);
                break;
+       }
        default:
                break;
        }
-- 
2.7.4

Reply via email to