// codecvt facet for UTF-8 multibyte code, UTF-16 wide-character code

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once
#ifndef _CVT_UTF8_UTF16_
#define _CVT_UTF8_UTF16_
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR
#include <codecvt>
#include <cwchar>
#include <locale>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

namespace stdext {
    namespace cvt {

        using _Mbstatype = _CSTD mbstate_t;

        _STL_DISABLE_DEPRECATED_WARNING
        template <class _Elem, unsigned long _Maxcode = 0x10ffff, _STD codecvt_mode _Mode = _STD codecvt_mode{},
            class _Statype = _Mbstatype>
        class codecvt_utf8_utf16 : public _STD codecvt<_Elem, char, _Statype> {
            // facet for converting between UTF-16 _Elem and UTF-8 byte sequences
        public:
            using _Mybase     = _STD codecvt<_Elem, char, _Statype>;
            using result      = typename _Mybase::result;
            using _Byte       = char;
            using intern_type = _Elem;
            using extern_type = _Byte;
            using state_type  = _Statype;

            static_assert(sizeof(unsigned short) <= sizeof(state_type), "state_type too small");

            explicit codecvt_utf8_utf16(size_t _Refs = 0) : _Mybase(_Refs) {}

            ~codecvt_utf8_utf16() noexcept override {}

        protected:
            result do_in(_Statype& _State, const _Byte* _First1, const _Byte* _Last1, const _Byte*& _Mid1,
                _Elem* _First2, _Elem* _Last2, _Elem*& _Mid2) const override {
                // convert bytes [_First1, _Last1) to [_First2, _Last2)
                unsigned short* _Pstate = reinterpret_cast<unsigned short*>(&_State);
                _Mid1                   = _First1;
                _Mid2                   = _First2;

                while (_Mid1 != _Last1 && _Mid2 != _Last2) { // convert a multibyte sequence
                    unsigned long _By = static_cast<unsigned char>(*_Mid1);
                    unsigned long _Ch;
                    int _Nextra;
                    int _Nskip;

                    if (*_Pstate > 1) { // leftover word
                        if (_By < 0x80 || 0xc0 <= _By) {
                            return _Mybase::error; // not continuation byte
                        }

                        // deliver second half of two-word value
                        ++_Mid1;
                        *_Mid2++ = static_cast<_Elem>(*_Pstate | (_By & 0x3f));
                        *_Pstate = 1;
                        continue;
                    }

                    if (_By < 0x80) {
                        _Ch     = _By;
                        _Nextra = 0;
                    } else if (_By < 0xc0) { // 0x80-0xbf not first byte
                        ++_Mid1;
                        return _Mybase::error;
                    } else if (_By < 0xe0) {
                        _Ch     = _By & 0x1f;
                        _Nextra = 1;
                    } else if (_By < 0xf0) {
                        _Ch     = _By & 0x0f;
                        _Nextra = 2;
                    } else if (_By < 0xf8) {
                        _Ch     = _By & 0x07;
                        _Nextra = 3;
                    } else {
                        _Ch     = _By & 0x03;
                        _Nextra = _By < 0xfc ? 4 : 5;
                    }

                    _Nskip  = _Nextra < 3 ? 0 : 1; // leave a byte for 2nd word
                    _First1 = _Mid1; // roll back point

                    if (_Nextra == 0) {
                        ++_Mid1;
                    } else if (_Last1 - _Mid1 < _Nextra + 1 - _Nskip) {
                        break; // not enough input
                    } else {
                        for (++_Mid1; _Nskip < _Nextra; --_Nextra, ++_Mid1) {
                            if ((_By = static_cast<unsigned char>(*_Mid1)) < 0x80 || 0xc0 <= _By) {
                                return _Mybase::error; // not continuation byte
                            } else {
                                _Ch = _Ch << 6 | (_By & 0x3f);
                            }
                        }
                    }

                    if (0 < _Nskip) {
                        _Ch <<= 6; // get last byte on next call
                    }

                    if ((_Maxcode < 0x10ffff ? _Maxcode : 0x10ffff) < _Ch) {
                        return _Mybase::error; // value too large
                    } else if (0xffff < _Ch) { // deliver first half of two-word value, save second word
                        unsigned short _Ch0 = static_cast<unsigned short>(0xd800 | (_Ch >> 10) - 0x0040);

                        *_Mid2++ = static_cast<_Elem>(_Ch0);
                        *_Pstate = static_cast<unsigned short>(0xdc00 | (_Ch & 0x03ff));
                        continue;
                    }

                    if (_Nskip != 0) {
                        if (_Mid1 == _Last1) { // not enough bytes, noncanonical value
                            _Mid1 = _First1;
                            break;
                        }

                        if ((_By = static_cast<unsigned char>(*_Mid1++)) < 0x80 || 0xc0 <= _By) {
                            return _Mybase::error; // not continuation byte
                        }

                        _Ch |= _By & 0x3f; // complete noncanonical value
                    }

                    if (*_Pstate == 0) { // first time, maybe look for and consume header
                        *_Pstate = 1;

                        if constexpr ((_Mode & _STD consume_header) != 0) {
                            if (_Ch == 0xfeff) { // drop header and retry
                                result _Ans = do_in(_State, _Mid1, _Last1, _Mid1, _First2, _Last2, _Mid2);

                                if (_Ans == _Mybase::partial) { // roll back header determination
                                    *_Pstate = 0;
                                    _Mid1    = _First1;
                                }
                                return _Ans;
                            }
                        }
                    }

                    *_Mid2++ = static_cast<_Elem>(_Ch);
                }

                return _First1 == _Mid1 ? _Mybase::partial : _Mybase::ok;
            }

            result do_out(_Statype& _State, const _Elem* _First1, const _Elem* _Last1, const _Elem*& _Mid1,
                _Byte* _First2, _Byte* _Last2, _Byte*& _Mid2) const override {
                // convert [_First1, _Last1) to bytes [_First2, _Last)
                unsigned short* _Pstate = reinterpret_cast<unsigned short*>(&_State);
                _Mid1                   = _First1;
                _Mid2                   = _First2;

                while (_Mid1 != _Last1 && _Mid2 != _Last2) { // convert and put a wide char
                    unsigned long _Ch;
                    unsigned short _Ch1 = static_cast<unsigned short>(*_Mid1);
                    bool _Save          = false;

                    if (1 < *_Pstate) { // get saved MS 11 bits from *_Pstate
                        if (_Ch1 < 0xdc00 || 0xe000 <= _Ch1) {
                            return _Mybase::error; // bad second word
                        }
                        _Ch = static_cast<unsigned long>((*_Pstate << 10) | (_Ch1 - 0xdc00));
                    } else if (0xd800 <= _Ch1 && _Ch1 < 0xdc00) { // get new first word
                        _Ch   = static_cast<unsigned long>((_Ch1 - 0xd800 + 0x0040) << 10);
                        _Save = true; // put only first byte, rest with second word
                    } else {
                        _Ch = _Ch1; // not first word, just put it
                    }

                    _Byte _By;
                    int _Nextra;

                    if (_Ch < 0x0080) {
                        _By     = static_cast<_Byte>(_Ch);
                        _Nextra = 0;
                    } else if (_Ch < 0x0800) {
                        _By     = static_cast<_Byte>(0xc0 | _Ch >> 6);
                        _Nextra = 1;
                    } else if (_Ch < 0x10000) {
                        _By     = static_cast<_Byte>(0xe0 | _Ch >> 12);
                        _Nextra = 2;
                    } else {
                        _By     = static_cast<_Byte>(0xf0 | _Ch >> 18);
                        _Nextra = 3;
                    }

                    int _Nput = _Nextra < 3 ? _Nextra + 1 : _Save ? 1 : 3;

                    if (_Last2 - _Mid2 < _Nput) {
                        break; // not enough room, even without header
                    }

                    if constexpr ((_Mode & _STD generate_header) != 0) { // maybe header to put
                        if (*_Pstate == 0) { // header to put
                            if (_Last2 - _Mid2 < 3 + _Nput) {
                                break; // not enough room for both
                            }

                            // prepend header
                            *_Mid2++ = '\xef';
                            *_Mid2++ = '\xbb';
                            *_Mid2++ = '\xbf';
                        }
                    }

                    ++_Mid1;
                    if (_Save || _Nextra < 3) { // put first byte of sequence, if not already put
                        *_Mid2++ = _By;
                        --_Nput;
                    }

                    for (; 0 < _Nput; --_Nput) {
                        *_Mid2++ = static_cast<_Byte>((_Ch >> 6 * --_Nextra & 0x3f) | 0x80);
                    }

                    *_Pstate = static_cast<unsigned short>(_Save ? _Ch >> 10 : 1);
                }

                return _First1 == _Mid1 ? _Mybase::partial : _Mybase::ok;
            }

            result do_unshift(_Statype& _State, _Byte* _First2, _Byte*, _Byte*& _Mid2) const override {
                // generate bytes to return to default shift state
                unsigned short* _Pstate = reinterpret_cast<unsigned short*>(&_State);
                _Mid2                   = _First2;

                return 1 < *_Pstate ? _Mybase::error : _Mybase::ok; // fail if trailing first word
            }

            int do_length(
                _Statype& _State, const _Byte* _First1, const _Byte* _Last1, size_t _Count) const noexcept override {
                // return min(_Count, converted length of bytes [_First1, _Last1))
                size_t _Wchars    = 0;
                _Statype _Mystate = _State;

                while (_Wchars < _Count && _First1 != _Last1) { // convert another wide character
                    const _Byte* _Mid1;
                    _Elem* _Mid2;
                    _Elem _Ch;

                    // test result of single wide-char conversion
                    switch (do_in(_Mystate, _First1, _Last1, _Mid1, &_Ch, &_Ch + 1, _Mid2)) {
                    case _Mybase::noconv:
                        return static_cast<int>(_Wchars + (_Last1 - _First1));

                    case _Mybase::ok:
                        if (_Mid2 == &_Ch + 1) {
                            ++_Wchars; // replacement do_in might not convert one
                        }

                        _First1 = _Mid1;
                        break;

                    default:
                        return static_cast<int>(_Wchars); // error or partial
                    }
                }

                return static_cast<int>(_Wchars);
            }

            bool do_always_noconv() const noexcept override { // return true if conversions never change input
                return false;
            }

            int do_max_length() const noexcept override { // return maximum length required for a conversion
                if constexpr ((_Mode & _STD consume_header) != 0) {
                    return 9; // header + max input
                } else if constexpr ((_Mode & _STD generate_header) != 0) {
                    return 7; // header + max output
                } else {
                    return 6; // 6-byte max input sequence, no 3-byte header
                }
            }

            int do_encoding() const noexcept override { // return length of code sequence (from codecvt)
                return 0; // 0 => varying length
            }
        };
        _STL_RESTORE_DEPRECATED_WARNING
    } // namespace cvt
} // namespace stdext

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)

#endif // _STL_COMPILER_PREPROCESSOR
#endif // _CVT_UTF8_UTF16_
