// codecvt facet for Shift JIS multibyte code, JIS-X0208 wide-character code

// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once
#ifndef _CVT_SJIS_0208_
#define _CVT_SJIS_0208_
#include <yvals_core.h>
#if _STL_COMPILER_PREPROCESSOR
#include <cwchar>
#include <locale>

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
_STL_DISABLE_CLANG_WARNINGS
#pragma push_macro("new")
#undef new

namespace stdext {
    namespace cvt {

        using _Statype = _CSTD mbstate_t;

        template <class _Elem, unsigned long _Maxcode = 0x7e7e>
        class codecvt_sjis_0208
            : public _STD codecvt<_Elem, char, _Statype> { // facet for converting between JIS-X0208 _Elem and EUC bytes
        public:
            using _Mybase     = _STD codecvt<_Elem, char, _Statype>;
            using result      = typename _Mybase::result;
            using _Byte       = char;
            using intern_type = _Elem;
            using extern_type = _Byte;
            using state_type  = _Statype;

            explicit codecvt_sjis_0208(size_t _Refs = 0) : _Mybase(_Refs) {}

            ~codecvt_sjis_0208() noexcept override {}

        protected:
            result do_in(_Statype&, const _Byte* _First1, const _Byte* _Last1, const _Byte*& _Mid1, _Elem* _First2,
                _Elem* _Last2, _Elem*& _Mid2) const override {
                // convert bytes [_First1, _Last1) to [_First2, _Last)
                _Mid1 = _First1;
                _Mid2 = _First2;
                while (_Mid1 != _Last1 && _Mid2 != _Last2) { // convert a multibyte sequence
                    unsigned char _By  = static_cast<unsigned char>(*_Mid1);
                    unsigned short _Ch = _By;

                    if (_By > 0x80 && (0xa0 > _By || _By > 0xdf) && 0xf0 > _By) {
                        if (_Last1 - _Mid1 < 2) {
                            break;
                        }

                        // convert 2-byte code
                        unsigned char _By2 = static_cast<unsigned char>(*++_Mid1);

                        if (_By2 < 0x40 || _By2 == 0x7f || _By2 > 0xfc) {
                            return _Mybase::error; // bad second byte
                        }

                        if (0xe0 <= _Ch) {
                            _Ch -= 0xb0; // MS [e0, ef] => [30, 3f]
                        } else {
                            _Ch -= 0x70; // MS [81, 9f] => [11, 2f]
                        }

                        _Ch <<= 9; // MS [11, 3f] => [22, 7e] in place
                        if (0x9f <= _By2) {
                            _Ch = static_cast<unsigned short>(_Ch + (_By2 - 0x7e)); // LS [9f, fc] => [21, 7e]
                        } else if (_By2 <= 0x7e) {
                            _Ch = static_cast<unsigned short>(_Ch + (_By2 - 0x11f)); // LS [40, 7e] => [21, 5f] - 100
                        } else {
                            _Ch = static_cast<unsigned short>(_Ch + (_By2 - 0x120)); // LS [80, 9e] => [60, 7e] - 100
                        }
                    }

                    if (_Maxcode < _Ch) {
                        return _Mybase::error; // code too large
                    }

                    ++_Mid1;
                    *_Mid2++ = static_cast<_Elem>(_Ch);
                }

                return _First1 == _Mid1 ? _Mybase::partial : _Mybase::ok;
            }

            result do_out(_Statype&, const _Elem* _First1, const _Elem* _Last1, const _Elem*& _Mid1, _Byte* _First2,
                _Byte* _Last2, _Byte*& _Mid2) const override {
                // convert [_First1, _Last1) to bytes [_First2, _Last)
                _Mid1 = _First1;
                _Mid2 = _First2;
                while (_Mid1 != _Last1 && _Mid2 != _Last2) { // convert and put a wide char
                    unsigned long _Ch = static_cast<unsigned long>(*_Mid1);

                    if (_Maxcode < _Ch) {
                        return _Mybase::error;
                    }

                    if (_Ch <= 0x80 || (0xa0 <= _Ch && _Ch <= 0xdf) || (0xf0 <= _Ch && _Ch <= 0xff)) {
                        *_Mid2++ = static_cast<_Byte>(_Ch);
                    } else if (_Last2 - _Mid2 < 2) {
                        break;
                    } else { // convert 2-byte code
                        unsigned char _By0 = static_cast<unsigned char>(_Ch >> 8);
                        unsigned char _By1 = static_cast<unsigned char>(_Ch);
                        if (_By0 < 0x21 || 0x7e < _By0 || _By1 < 0x21 || 0x7e < _By1) {
                            return _Mybase::error;
                        }

                        *_Mid2++ = static_cast<_Byte>((_By0 + 1) / 2 + (_By0 < 0x5f ? 0x70 : 0xb0));
                        if ((_By0 & 1) != 0) {
                            *_Mid2++ = static_cast<_Byte>(_By1 + (0x5f <= _By1 ? 0x20 : 0x1f));
                        } else {
                            *_Mid2++ = static_cast<_Byte>(_By1 + 0x7e);
                        }
                    }

                    ++_Mid1;
                }

                return _First1 == _Mid1 ? _Mybase::partial : _Mybase::ok;
            }

            result do_unshift(_Statype&, _Byte* _First2, _Byte*, _Byte*& _Mid2) const override {
                // generate bytes to return to default shift state
                _Mid2 = _First2;

                return _Mybase::ok;
            }

            int do_length(
                _Statype& _State, const _Byte* _First1, const _Byte* _Last1, size_t _Count) const noexcept override {
                // return min(_Count, converted length of bytes [_First1, _Last1))
                size_t _Wchars    = 0;
                _Statype _Mystate = _State;

                while (_Wchars < _Count && _First1 != _Last1) { // convert another wide character
                    const _Byte* _Mid1;
                    _Elem* _Mid2;
                    _Elem _Ch;

                    // test result of single wide-char conversion
                    switch (do_in(_Mystate, _First1, _Last1, _Mid1, &_Ch, &_Ch + 1, _Mid2)) {
                    case _Mybase::noconv:
                        return static_cast<int>(_Wchars + (_Last1 - _First1));

                    case _Mybase::ok:
                        if (_Mid2 == &_Ch + 1) {
                            ++_Wchars; // replacement do_in might not convert one
                        }

                        _First1 = _Mid1;
                        break;

                    default:
                        return static_cast<int>(_Wchars); // error or partial
                    }
                }

                return static_cast<int>(_Wchars);
            }

            bool do_always_noconv() const noexcept override { // return true if conversions never change input
                return false;
            }

            int do_max_length() const noexcept override { // return maximum length required for a conversion
                return 3;
            }

            int do_encoding() const noexcept override { // return length of code sequence (from codecvt)
                return 0; // 0 => varying length
            }
        };
    } // namespace cvt
} // namespace stdext

#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS
#pragma warning(pop)
#pragma pack(pop)

#endif // _STL_COMPILER_PREPROCESSOR
#endif // _CVT_SJIS_0208_
