// xone_byte -- convert one-byte code to 16-bit wide

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once
#ifndef _CVT_XONE_BYTE_
#define _CVT_XONE_BYTE_
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR
#include <cwchar>
#include <locale>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

namespace stdext {
    namespace cvt {

        using _Statype = _CSTD mbstate_t;

        template <class _Elem, class _Table, unsigned long _Maxcode = 0xffff>
        class _Cvt_one_byte
            : public _STD codecvt<_Elem, char, _Statype> { // facet for converting between _Elem and one-byte sequences
        public:
            using _Mybase     = _STD codecvt<_Elem, char, _Statype>;
            using result      = typename _Mybase::result;
            using _Byte       = char;
            using intern_type = _Elem;
            using extern_type = _Byte;
            using state_type  = _Statype;

            explicit _Cvt_one_byte(size_t _Refs = 0) : _Mybase(_Refs) {}

            ~_Cvt_one_byte() noexcept override {}

        protected:
            result do_in(_Statype&, const _Byte* _First1, const _Byte* _Last1, const _Byte*& _Mid1, _Elem* _First2,
                _Elem* _Last2, _Elem*& _Mid2) const override {
                // convert bytes [_First1, _Last1) to [_First2, _Last)
                _Mid1 = _First1;
                _Mid2 = _First2;
                for (; _Mid1 != _Last1 && _Mid2 != _Last2; ++_Mid1) { // get a byte and map it
                    unsigned short _Bytecode = static_cast<unsigned char>(*_Mid1);

                    if ((_Table::_Nlow <= _Bytecode && (_Bytecode = _Table::_Btw[_Bytecode - _Table::_Nlow]) == 0)
                        || _Maxcode < _Bytecode) {
                        return _Mybase::error;
                    }

                    *_Mid2++ = static_cast<_Elem>(_Bytecode);
                }

                return _First1 == _Mid1 ? _Mybase::partial : _Mybase::ok;
            }

            result do_out(_Statype&, const _Elem* _First1, const _Elem* _Last1, const _Elem*& _Mid1, _Byte* _First2,
                _Byte* _Last2, _Byte*& _Mid2) const override {
                // convert [_First1, _Last1) to bytes [_First2, _Last)
                _Mid1 = _First1;
                _Mid2 = _First2;
                for (; _Mid1 != _Last1 && _Mid2 != _Last2; ++_Mid1) { // get an element and map it
                    unsigned long _Hugecode  = static_cast<unsigned long>(*_Mid1);
                    unsigned short _Widecode = static_cast<unsigned short>(_Hugecode);

                    if (_Maxcode < _Widecode) {
                        return _Mybase::error; // value too large
                    }

                    if (_Widecode >= _Table::_Nlow
                        && (_Widecode >= 0x100
                            || _Table::_Btw[_Widecode - _Table::_Nlow] != _Widecode)) { // search for a valid mapping
                        unsigned long _Lo = 0;
                        unsigned long _Hi = static_cast<unsigned long>(_STD size(_Table::_Wvalid));

                        while (_Lo < _Hi) { // test midpoint of remaining interval
                            unsigned long _Mid = (_Hi + _Lo) / 2;

                            if (_Widecode < _Table::_Wvalid[_Mid]) {
                                _Hi = _Mid;
                            } else if (_Widecode == _Table::_Wvalid[_Mid]) { // found the code, map it
                                _Widecode = _Table::_Wtb[_Mid];
                                break;
                            } else {
                                _Lo = _Mid + 1;
                            }
                        }

                        if (_Hi <= _Lo) {
                            return _Mybase::error;
                        }
                    }

                    *_Mid2++ = static_cast<_Byte>(_Widecode);
                }

                return _First1 == _Mid1 ? _Mybase::partial : _Mybase::ok;
            }

            result do_unshift(_Statype&, _Byte* _First2, _Byte*, _Byte*& _Mid2) const override {
                // generate bytes to return to default shift state
                _Mid2 = _First2;
                return _Mybase::ok;
            }

            int do_length(_Statype&, const _Byte* _First1, const _Byte* _Last1, size_t _Count) const noexcept override {
                // return min(_Count, converted length of bytes [_First1, _Last1))
                const auto _Wchars = static_cast<size_t>(_Last1 - _First1);

                return static_cast<int>(_Wchars < _Count ? _Wchars : _Count);
            }

            bool do_always_noconv() const noexcept override { // return true if conversions never change input
                return false;
            }

            int do_max_length() const noexcept override { // return maximum length required for a conversion
                return 1; // length of sequence
            }

            int do_encoding() const noexcept override { // return length of code sequence (from codecvt)
                return 1; // >0 => fixed sequence length
            }
        };
    } // namespace cvt
} // namespace stdext

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)

#endif // _STL_COMPILER_PREPROCESSOR
#endif // _CVT_XONE_BYTE_
