Skip to content

Commit 8fcd9b1

Browse files
authored
Merge pull request #8442 from diffblue/zero_extend
zero extension expression
2 parents 20a1ecf + 5420b97 commit 8fcd9b1

13 files changed

+137
-17
lines changed

src/solvers/flattening/boolbv.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ bvt boolbvt::convert_bitvector(const exprt &expr)
165165
return convert_replication(to_replication_expr(expr));
166166
else if(expr.id()==ID_extractbits)
167167
return convert_extractbits(to_extractbits_expr(expr));
168+
else if(expr.id() == ID_zero_extend)
169+
return convert_bitvector(to_zero_extend_expr(expr).lower());
168170
else if(expr.id()==ID_bitnot || expr.id()==ID_bitand ||
169171
expr.id()==ID_bitor || expr.id()==ID_bitxor ||
170172
expr.id()==ID_bitxnor || expr.id()==ID_bitnor ||

src/solvers/floatbv/float_bv.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,10 @@ exprt float_bvt::mul(
692692

693693
// zero-extend the fractions (unpacked fraction has the hidden bit)
694694
typet new_fraction_type=unsignedbv_typet((spec.f+1)*2);
695-
const exprt fraction1=typecast_exprt(unpacked1.fraction, new_fraction_type);
696-
const exprt fraction2=typecast_exprt(unpacked2.fraction, new_fraction_type);
695+
const exprt fraction1 =
696+
zero_extend_exprt{unpacked1.fraction, new_fraction_type};
697+
const exprt fraction2 =
698+
zero_extend_exprt{unpacked2.fraction, new_fraction_type};
697699

698700
// multiply the fractions
699701
unbiased_floatt result;
@@ -750,7 +752,7 @@ exprt float_bvt::div(
750752
unsignedbv_typet(div_width));
751753

752754
// zero-extend fraction2 to match fraction1
753-
const typecast_exprt fraction2(unpacked2.fraction, fraction1.type());
755+
const zero_extend_exprt fraction2{unpacked2.fraction, fraction1.type()};
754756

755757
// divide fractions
756758
unbiased_floatt result;

src/solvers/smt2/smt2_conv.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -2456,6 +2456,10 @@ void smt2_convt::convert_expr(const exprt &expr)
24562456
{
24572457
convert_expr(simplify_expr(to_bitreverse_expr(expr).lower(), ns));
24582458
}
2459+
else if(expr.id() == ID_zero_extend)
2460+
{
2461+
convert_expr(to_zero_extend_expr(expr).lower());
2462+
}
24592463
else if(expr.id() == ID_function_application)
24602464
{
24612465
const auto &function_application_expr = to_function_application_expr(expr);

src/solvers/smt2_incremental/convert_expr_to_smt.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,15 @@ static smt_termt convert_expr_to_smt(
14691469
count_trailing_zeros.pretty());
14701470
}
14711471

1472+
static smt_termt convert_expr_to_smt(
1473+
const zero_extend_exprt &zero_extend,
1474+
const sub_expression_mapt &converted)
1475+
{
1476+
UNREACHABLE_BECAUSE(
1477+
"zero_extend expression should have been lowered by the decision "
1478+
"procedure before conversion to smt terms");
1479+
}
1480+
14721481
static smt_termt convert_expr_to_smt(
14731482
const prophecy_r_or_w_ok_exprt &prophecy_r_or_w_ok,
14741483
const sub_expression_mapt &converted)
@@ -1822,6 +1831,10 @@ static smt_termt dispatch_expr_to_smt_conversion(
18221831
{
18231832
return convert_expr_to_smt(*count_trailing_zeros, converted);
18241833
}
1834+
if(const auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
1835+
{
1836+
return convert_expr_to_smt(*zero_extend, converted);
1837+
}
18251838
if(
18261839
const auto prophecy_r_or_w_ok =
18271840
expr_try_dynamic_cast<prophecy_r_or_w_ok_exprt>(expr))

src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "smt2_incremental_decision_procedure.h"
44

55
#include <util/arith_tools.h>
6+
#include <util/bitvector_expr.h>
67
#include <util/byte_operators.h>
78
#include <util/c_types.h>
89
#include <util/range.h>
@@ -296,6 +297,17 @@ static exprt lower_rw_ok_pointer_in_range(exprt expr, const namespacet &ns)
296297
return expr;
297298
}
298299

300+
static exprt lower_zero_extend(exprt expr, const namespacet &ns)
301+
{
302+
expr.visit_pre([](exprt &expr) {
303+
if(auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
304+
{
305+
expr = zero_extend->lower();
306+
}
307+
});
308+
return expr;
309+
}
310+
299311
void smt2_incremental_decision_proceduret::ensure_handle_for_expr_defined(
300312
const exprt &in_expr)
301313
{
@@ -677,8 +689,10 @@ void smt2_incremental_decision_proceduret::define_object_properties()
677689

678690
exprt smt2_incremental_decision_proceduret::lower(exprt expression) const
679691
{
680-
const exprt lowered = struct_encoding.encode(lower_enum(
681-
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
692+
const exprt lowered = struct_encoding.encode(lower_zero_extend(
693+
lower_enum(
694+
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
695+
ns),
682696
ns));
683697
log.conditional_output(log.debug(), [&](messaget::mstreamt &debug) {
684698
if(lowered != expression)

src/util/bitvector_expr.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ exprt update_bit_exprt::lower() const
5454
typecast_exprt(src(), src_bv_type), bitnot_exprt(mask_shifted));
5555

5656
// zero-extend the replacement bit to match src
57-
auto new_value_casted = typecast_exprt(
58-
typecast_exprt(new_value(), unsignedbv_typet(width)), src_bv_type);
57+
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};
5958

6059
// shift the replacement bits
6160
auto new_value_shifted = shl_exprt(new_value_casted, index());
@@ -85,7 +84,7 @@ exprt update_bits_exprt::lower() const
8584
bitand_exprt(typecast_exprt(src(), src_bv_type), mask_shifted);
8685

8786
// zero-extend or shrink the replacement bits to match src
88-
auto new_value_casted = typecast_exprt(new_value(), src_bv_type);
87+
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};
8988

9089
// shift the replacement bits
9190
auto new_value_shifted = shl_exprt(new_value_casted, index());
@@ -279,3 +278,19 @@ exprt find_first_set_exprt::lower() const
279278

280279
return typecast_exprt::conditional_cast(result, type());
281280
}
281+
282+
exprt zero_extend_exprt::lower() const
283+
{
284+
const auto old_width = to_bitvector_type(op().type()).get_width();
285+
const auto new_width = to_bitvector_type(type()).get_width();
286+
287+
if(new_width > old_width)
288+
{
289+
return concatenation_exprt{
290+
bv_typet{new_width - old_width}.all_zeros_expr(), op(), type()};
291+
}
292+
else // new_width <= old_width
293+
{
294+
return extractbits_exprt{op(), 0, type()};
295+
}
296+
}

src/util/bitvector_expr.h

+44
Original file line numberDiff line numberDiff line change
@@ -1663,4 +1663,48 @@ inline find_first_set_exprt &to_find_first_set_expr(exprt &expr)
16631663
return ret;
16641664
}
16651665

1666+
/// \brief zero extension
1667+
/// The operand is converted to the given type by either
1668+
/// a) truncating if the new type is shorter, or
1669+
/// b) padding with most-significant zero bits if the new type is larger, or
1670+
/// c) reinterprets the operand as the given type if their widths match.
1671+
class zero_extend_exprt : public unary_exprt
1672+
{
1673+
public:
1674+
zero_extend_exprt(exprt _op, typet _type)
1675+
: unary_exprt(ID_zero_extend, std::move(_op), std::move(_type))
1676+
{
1677+
}
1678+
1679+
// a lowering to extraction or concatenation
1680+
exprt lower() const;
1681+
};
1682+
1683+
template <>
1684+
inline bool can_cast_expr<zero_extend_exprt>(const exprt &base)
1685+
{
1686+
return base.id() == ID_zero_extend;
1687+
}
1688+
1689+
/// \brief Cast an exprt to a \ref zero_extend_exprt
1690+
///
1691+
/// \a expr must be known to be \ref zero_extend_exprt.
1692+
///
1693+
/// \param expr: Source expression
1694+
/// \return Object of type \ref zero_extend_exprt
1695+
inline const zero_extend_exprt &to_zero_extend_expr(const exprt &expr)
1696+
{
1697+
PRECONDITION(expr.id() == ID_zero_extend);
1698+
zero_extend_exprt::check(expr);
1699+
return static_cast<const zero_extend_exprt &>(expr);
1700+
}
1701+
1702+
/// \copydoc to_zero_extend_expr(const exprt &)
1703+
inline zero_extend_exprt &to_zero_extend_expr(exprt &expr)
1704+
{
1705+
PRECONDITION(expr.id() == ID_zero_extend);
1706+
zero_extend_exprt::check(expr);
1707+
return static_cast<zero_extend_exprt &>(expr);
1708+
}
1709+
16661710
#endif // CPROVER_UTIL_BITVECTOR_EXPR_H

src/util/format_expr.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,12 @@ void format_expr_configt::setup()
376376
<< format(expr.type()) << ')';
377377
};
378378

379+
expr_map[ID_zero_extend] =
380+
[](std::ostream &os, const exprt &expr) -> std::ostream & {
381+
return os << "zero_extend(" << format(to_zero_extend_expr(expr).op())
382+
<< ", " << format(expr.type()) << ')';
383+
};
384+
379385
expr_map[ID_floatbv_typecast] =
380386
[](std::ostream &os, const exprt &expr) -> std::ostream & {
381387
const auto &floatbv_typecast_expr = to_floatbv_typecast_expr(expr);

src/util/irep_ids.def

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ IREP_ID_ONE(extractbit)
188188
IREP_ID_ONE(extractbits)
189189
IREP_ID_ONE(update_bit)
190190
IREP_ID_ONE(update_bits)
191+
IREP_ID_ONE(zero_extend)
191192
IREP_ID_TWO(C_reference, #reference)
192193
IREP_ID_TWO(C_rvalue_reference, #rvalue_reference)
193194
IREP_ID_ONE(true)

src/util/lower_byte_operators.cpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -2491,15 +2491,16 @@ static exprt lower_byte_update(
24912491
exprt zero_extended;
24922492
if(bit_width > update_size_bits)
24932493
{
2494-
zero_extended = concatenation_exprt{
2495-
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
2496-
value,
2497-
bv_typet{bit_width}};
2498-
2499-
if(!is_little_endian)
2500-
to_concatenation_expr(zero_extended)
2501-
.op0()
2502-
.swap(to_concatenation_expr(zero_extended).op1());
2494+
if(is_little_endian)
2495+
zero_extended = zero_extend_exprt{value, bv_typet{bit_width}};
2496+
else
2497+
{
2498+
// Big endian -- the zero is added as LSB.
2499+
zero_extended = concatenation_exprt{
2500+
value,
2501+
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
2502+
bv_typet{bit_width}};
2503+
}
25032504
}
25042505
else
25052506
zero_extended = value;

src/util/simplify_expr.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -3028,6 +3028,10 @@ simplify_exprt::resultt<> simplify_exprt::simplify_node(const exprt &node)
30283028
{
30293029
r = simplify_extractbits(to_extractbits_expr(expr));
30303030
}
3031+
else if(expr.id() == ID_zero_extend)
3032+
{
3033+
r = simplify_zero_extend(to_zero_extend_expr(expr));
3034+
}
30313035
else if(expr.id()==ID_ieee_float_equal ||
30323036
expr.id()==ID_ieee_float_notequal)
30333037
{

src/util/simplify_expr_class.h

+2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class unary_overflow_exprt;
7676
class unary_plus_exprt;
7777
class update_exprt;
7878
class with_exprt;
79+
class zero_extend_exprt;
7980

8081
class simplify_exprt
8182
{
@@ -152,6 +153,7 @@ class simplify_exprt
152153
[[nodiscard]] resultt<> simplify_extractbit(const extractbit_exprt &);
153154
[[nodiscard]] resultt<> simplify_extractbits(const extractbits_exprt &);
154155
[[nodiscard]] resultt<> simplify_concatenation(const concatenation_exprt &);
156+
[[nodiscard]] resultt<> simplify_zero_extend(const zero_extend_exprt &);
155157
[[nodiscard]] resultt<> simplify_mult(const mult_exprt &);
156158
[[nodiscard]] resultt<> simplify_div(const div_exprt &);
157159
[[nodiscard]] resultt<> simplify_mod(const mod_exprt &);

src/util/simplify_expr_int.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,18 @@ simplify_exprt::simplify_concatenation(const concatenation_exprt &expr)
997997
return std::move(new_expr);
998998
}
999999

1000+
simplify_exprt::resultt<>
1001+
simplify_exprt::simplify_zero_extend(const zero_extend_exprt &expr)
1002+
{
1003+
if(!can_cast_type<bitvector_typet>(expr.type()))
1004+
return unchanged(expr);
1005+
1006+
if(!can_cast_type<bitvector_typet>(expr.op().type()))
1007+
return unchanged(expr);
1008+
1009+
return changed(simplify_node(expr.lower()));
1010+
}
1011+
10001012
simplify_exprt::resultt<>
10011013
simplify_exprt::simplify_shifts(const shift_exprt &expr)
10021014
{

0 commit comments

Comments
 (0)