diff --git a/test/scan.h b/test/scan.h index 016c3b42..dd6751d8 100644 --- a/test/scan.h +++ b/test/scan.h @@ -114,19 +114,20 @@ class scan_buffer { auto base() const -> const char* { return buf_->ptr_; } friend auto to_contiguous(iterator it) -> maybe_contiguous_range; - friend void advance(iterator& it, size_t n); + friend auto advance(iterator it, size_t n) -> iterator; }; friend auto to_contiguous(iterator it) -> maybe_contiguous_range { if (it.buf_->is_contiguous()) return {it.buf_->ptr_, it.buf_->end_}; return {nullptr, nullptr}; } - friend void advance(iterator& it, size_t n) { + friend auto advance(iterator it, size_t n) -> iterator { FMT_ASSERT(it.buf_->is_contiguous(), ""); const char*& ptr = it.buf_->ptr_; ptr += n; it.value_ = *ptr; if (ptr == it.buf_->end_) it.ptr_ = iterator::sentinel(); + return it; } auto begin() -> iterator { return this; } @@ -320,10 +321,10 @@ struct custom_scan_arg { template class basic_scan_arg { private: using scan_type = detail::scan_type; + scan_type type_; public: // TODO: make private - scan_type type; union { int* int_value; unsigned* uint_value; @@ -336,46 +337,48 @@ template class basic_scan_arg { }; FMT_CONSTEXPR basic_scan_arg() - : type(scan_type::none_type), int_value(nullptr) {} + : type_(scan_type::none_type), int_value(nullptr) {} FMT_CONSTEXPR basic_scan_arg(int& value) - : type(scan_type::int_type), int_value(&value) {} + : type_(scan_type::int_type), int_value(&value) {} FMT_CONSTEXPR basic_scan_arg(unsigned& value) - : type(scan_type::uint_type), uint_value(&value) {} + : type_(scan_type::uint_type), uint_value(&value) {} FMT_CONSTEXPR basic_scan_arg(long long& value) - : type(scan_type::long_long_type), long_long_value(&value) {} + : type_(scan_type::long_long_type), long_long_value(&value) {} FMT_CONSTEXPR basic_scan_arg(unsigned long long& value) - : type(scan_type::ulong_long_type), ulong_long_value(&value) {} + : type_(scan_type::ulong_long_type), ulong_long_value(&value) {} FMT_CONSTEXPR basic_scan_arg(std::string& value) - : type(scan_type::string_type), string(&value) {} + : type_(scan_type::string_type), string(&value) {} FMT_CONSTEXPR basic_scan_arg(fmt::string_view& value) - : type(scan_type::string_view_type), string_view(&value) {} + : type_(scan_type::string_view_type), string_view(&value) {} template - FMT_CONSTEXPR basic_scan_arg(T& value) : type(scan_type::custom_type) { + FMT_CONSTEXPR basic_scan_arg(T& value) : type_(scan_type::custom_type) { custom.value = &value; custom.scan = scan_custom_arg; } constexpr explicit operator bool() const noexcept { - return type != scan_type::none_type; + return type_ != scan_type::none_type; } + auto type() const -> detail::scan_type { return type_; } + template - auto visit(Visitor&& vis) -> decltype(vis(std::declval())) { - switch (type) { + auto visit(Visitor&& vis) -> decltype(vis(monostate())) { + switch (type_) { case scan_type::none_type: break; case scan_type::int_type: - return vis(int_value); + return vis(*int_value); case scan_type::uint_type: - return vis(uint_value); + return vis(*uint_value); case scan_type::long_long_type: - return vis(long_long_value); + return vis(*long_long_value); case scan_type::ulong_long_type: - return vis(ulong_long_value); + return vis(*ulong_long_value); case scan_type::string_type: - return vis(string); + return vis(*string); case scan_type::string_view_type: - return vis(string_view); + return vis(*string_view); case scan_type::custom_type: // TODO: implement break; @@ -445,6 +448,91 @@ const char* parse_scan_specs(const char* begin, const char* end, return begin; } +struct default_arg_scanner { + using iterator = scan_buffer::iterator; + iterator begin; + iterator end; + + template + auto read_uint(iterator it, T& value) -> iterator { + if (it == end) return it; + char c = *it; + if (c < '0' || c > '9') throw_format_error("invalid input"); + + int num_digits = 0; + T n = 0, prev = 0; + char prev_digit = c; + do { + prev = n; + n = n * 10 + static_cast(c - '0'); + prev_digit = c; + c = *++it; + ++num_digits; + if (c < '0' || c > '9') break; + } while (it != end); + + // Check overflow. + if (num_digits <= std::numeric_limits::digits10) { + value = n; + return it; + } + unsigned max = to_unsigned((std::numeric_limits::max)()); + if (num_digits == std::numeric_limits::digits10 + 1 && + prev * 10ull + unsigned(prev_digit - '0') <= max) { + value = n; + } else { + throw_format_error("number is too big"); + } + return it; + } + + template + auto read_int(iterator it, T& value) -> iterator { + bool negative = it != end && *it == '-'; + if (negative) { + ++it; + if (it == end) throw_format_error("invalid input"); + } + using unsigned_type = typename std::make_unsigned::type; + unsigned_type abs_value = 0; + it = read_uint(it, abs_value); + auto n = static_cast(abs_value); + value = negative ? -n : n; + return it; + } + + auto operator()(int& value) -> iterator { + return read_int(begin, value); + } + auto operator()(unsigned& value) -> iterator { + return read_uint(begin, value); + } + auto operator()(long long& value) -> iterator { + return read_int(begin, value); + } + auto operator()(unsigned long long& value) -> iterator { + return read_uint(begin, value); + } + auto operator()(std::string& value) -> iterator { + iterator it = begin; + while (it != end && *it != ' ') value.push_back(*it++); + return it; + } + auto operator()(fmt::string_view& value) -> iterator { + auto range = to_contiguous(begin); + // This could also be checked at compile time in scan. + if (!range) throw_format_error("string_view requires contiguous input"); + auto p = range.begin; + while (p != range.end && *p != ' ') ++p; + size_t size = to_unsigned(p - range.begin); + value = {range.begin, size}; + return advance(begin, size); + } + auto operator()(monostate) -> iterator { + return begin; + } +}; + struct arg_scanner { using iterator = scan_buffer::iterator; @@ -465,48 +553,6 @@ struct scan_handler : error_handler { scan_context scan_ctx_; int next_arg_id_; - using iterator = scan_buffer::iterator; - - template auto read_uint(iterator& it) -> optional { - auto end = scan_ctx_.end(); - if (it == end) return {}; - char c = *it; - if (c < '0' || c > '9') on_error("invalid input"); - - int num_digits = 0; - T value = 0, prev = 0; - char prev_digit = c; - do { - prev = value; - value = value * 10 + static_cast(c - '0'); - prev_digit = c; - c = *++it; - ++num_digits; - if (c < '0' || c > '9') break; - } while (it != end); - - // Check overflow. - if (num_digits <= std::numeric_limits::digits10) return value; - const unsigned max = to_unsigned((std::numeric_limits::max)()); - if (num_digits == std::numeric_limits::digits10 + 1 && - prev * 10ull + unsigned(prev_digit - '0') <= max) { - return value; - } - throw format_error("number is too big"); - } - - template auto read_int(iterator& it) -> optional { - auto end = scan_ctx_.end(); - bool negative = it != end && *it == '-'; - if (negative) ++it; - if (auto abs_value = read_uint::type>(it)) { - auto value = static_cast(*abs_value); - return negative ? -value : value; - } - if (negative) on_error("invalid input"); - return {}; - } - public: FMT_CONSTEXPR scan_handler(string_view format, scan_buffer& buf, scan_args args) @@ -537,51 +583,20 @@ struct scan_handler : error_handler { scan_arg arg = scan_ctx_.arg(arg_id); auto it = scan_ctx_.begin(), end = scan_ctx_.end(); while (it != end && is_whitespace(*it)) ++it; - switch (arg.type) { - case scan_type::int_type: - if (auto value = read_int(it)) *arg.int_value = *value; - break; - case scan_type::uint_type: - if (auto value = read_uint(it)) *arg.uint_value = *value; - break; - case scan_type::long_long_type: - if (auto value = read_int(it)) *arg.long_long_value = *value; - break; - case scan_type::ulong_long_type: - if (auto value = read_uint(it)) - *arg.ulong_long_value = *value; - break; - case scan_type::string_type: - while (it != end && *it != ' ') arg.string->push_back(*it++); - break; - case scan_type::string_view_type: { - auto range = to_contiguous(it); - // This could also be checked at compile time in scan. - if (!range) on_error("string_view requires contiguous input"); - auto p = range.begin; - while (p != range.end && *p != ' ') ++p; - size_t size = to_unsigned(p - range.begin); - *arg.string_view = {range.begin, size}; - advance(it, size); - break; - } - case scan_type::none_type: - case scan_type::custom_type: - assert(false); - } + arg.visit(default_arg_scanner{it, end}); scan_ctx_.advance_to(it); } auto on_format_specs(int arg_id, const char* begin, const char* end) -> const char* { scan_arg arg = scan_ctx_.arg(arg_id); - if (arg.type == scan_type::custom_type) { + if (arg.type() == scan_type::custom_type) { parse_ctx_.advance_to(begin); arg.custom.scan(arg.custom.value, parse_ctx_, scan_ctx_); return parse_ctx_.begin(); } auto specs = format_specs<>(); - begin = parse_scan_specs(begin, end, specs, arg.type); + begin = parse_scan_specs(begin, end, specs, arg.type()); if (begin == end || *begin != '}') on_error("missing '}' in format string"); auto s = arg_scanner{scan_ctx_.begin(), scan_ctx_.end(), specs}; // TODO: scan argument according to specs