diff --git a/runtime/strings/bytes.cpp b/runtime/strings/bytes.cpp index ada8f46d8..5433325c2 100644 --- a/runtime/strings/bytes.cpp +++ b/runtime/strings/bytes.cpp @@ -182,6 +182,22 @@ SortBytes hook_BYTES_replaceAt(SortBytes b, SortInt start, SortBytes b2) { return b; } +SortBytes hook_BYTES_memset(SortBytes b, SortInt start, SortInt count, SortInt value) { + uint64_t ustart = get_ui(start); + uint64_t ucount = get_ui(count); + uint64_t uend = ustart + ucount; + uint64_t input_len = len(b); + if (uend > input_len) { + KLLVM_HOOK_INVALID_ARGUMENT( + "Buffer overflow on memset: start {} plus count {} is greater " + "than buffer length {}", + ustart, ucount, input_len); + } + int v = mpz_get_si(value); + memset(b->data + ustart, v, ucount); + return b; +} + SortInt hook_BYTES_length(SortBytes a) { mpz_t result; mpz_init_set_ui(result, len(a)); diff --git a/unittests/runtime-strings/bytestest.cpp b/unittests/runtime-strings/bytestest.cpp index 9c1d1aae0..fd4da14b9 100644 --- a/unittests/runtime-strings/bytestest.cpp +++ b/unittests/runtime-strings/bytestest.cpp @@ -23,6 +23,7 @@ string *hook_BYTES_string2bytes(string *s); string *hook_BYTES_substr(string *b, mpz_t start, mpz_t end); string *hook_BYTES_replaceAt(string *b, mpz_t start, string *b2); string *hook_BYTES_update(string *b, mpz_t off, mpz_t val); +string *hook_BYTES_memset(string *b, mpz_t start, mpz_t cnt, mpz_t val); mpz_ptr hook_BYTES_get(string *b, mpz_t off); mpz_ptr hook_BYTES_length(string *b); string *hook_BYTES_padRight(string *b, mpz_t len, mpz_t v); @@ -227,6 +228,28 @@ BOOST_AUTO_TEST_CASE(update) { BOOST_CHECK_EQUAL(0, memcmp(res->data, "1204", 4)); } +BOOST_AUTO_TEST_CASE(memset) { + auto _12345 = makeString("12345"); + mpz_t _0, _1, _3; + mpz_init_set_si(_0, '0'); + mpz_init_set_ui(_1, 1); + mpz_init_set_ui(_3, 3); + + auto res = hook_BYTES_memset(_12345, _1, _3, _0); + BOOST_CHECK_EQUAL(_12345, res); + BOOST_CHECK_EQUAL(5, len(res)); + BOOST_CHECK_EQUAL(0, memcmp(res->data, "10005", 5)); + + mpz_t neg1; + mpz_init_set_si(neg1, -1); + res = hook_BYTES_memset(_12345, _1, _1, neg1); + BOOST_CHECK_EQUAL(_12345, res); + BOOST_CHECK_EQUAL(5, len(res)); + BOOST_CHECK_EQUAL((unsigned char)_12345->data[1], 255); + + BOOST_CHECK_THROW(hook_BYTES_memset(_12345, _3, _3, _0), std::invalid_argument); +} + BOOST_AUTO_TEST_CASE(get) { auto _1234 = makeString("1234"); mpz_t _0;