/*
 * Decompiled with CFR 0.152.
 */
package dev.blaauwendraad.masker.json;

import dev.blaauwendraad.masker.json.DescriptiveValueMasker;
import dev.blaauwendraad.masker.json.ValueMasker;
import dev.blaauwendraad.masker.json.util.Utf8Util;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;

public final class ValueMaskers {
    private ValueMaskers() {
    }

    public static <T extends ValueMasker> T describe(String description, T delegate) {
        return (T)new DescriptiveValueMasker<T>(description, delegate);
    }

    public static ValueMasker.AnyValueMasker with(String value) {
        String replacement = Utf8Util.jsonEncode(value, true);
        byte[] replacementBytes = replacement.getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe(replacement, context -> context.replaceBytes(0, context.byteLength(), replacementBytes, 1));
    }

    public static ValueMasker.AnyValueMasker with(int value) {
        byte[] replacementBytes = String.valueOf(value).getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe(String.valueOf(value), context -> context.replaceBytes(0, context.byteLength(), replacementBytes, 1));
    }

    public static ValueMasker.AnyValueMasker with(boolean value) {
        byte[] replacementBytes = String.valueOf(value).getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe(String.valueOf(value), context -> context.replaceBytes(0, context.byteLength(), replacementBytes, 1));
    }

    public static ValueMasker.AnyValueMasker withNull() {
        byte[] replacementBytes = "null".getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe("null (literal)", context -> context.replaceBytes(0, context.byteLength(), replacementBytes, 1));
    }

    public static ValueMasker.StringMasker eachCharacterWith(String value) {
        String replacement = Utf8Util.jsonEncode(value, false);
        byte[] replacementBytes = replacement.getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe("every character as %s".formatted(replacement), context -> {
            int stringValueStart = 1;
            int stringValueLength = context.byteLength() - 2;
            int nonVisibleCharacters = context.countNonVisibleCharacters(stringValueStart, stringValueLength);
            int maskLength = stringValueLength - nonVisibleCharacters;
            context.replaceBytes(stringValueStart, stringValueLength, replacementBytes, maskLength);
        });
    }

    public static ValueMasker.NumberMasker eachDigitWith(int digit) {
        if (digit < 1 || digit > 9) {
            throw new IllegalArgumentException("Masking digit must be between 1 and 9 to avoid leading zeroes which is invalid in JSON");
        }
        byte[] replacementBytes = String.valueOf(digit).getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe("every digit as integer: %s".formatted(digit), context -> context.replaceBytes(0, context.byteLength(), replacementBytes, context.byteLength()));
    }

    public static ValueMasker.NumberMasker eachDigitWith(String value) {
        String replacement = Utf8Util.jsonEncode(value, false);
        byte[] maskValueBytes = replacement.getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe("every digit as string: %s".formatted(replacement), context -> {
            int originalValueBytesLength = context.byteLength();
            int totalMaskLength = originalValueBytesLength * maskValueBytes.length;
            byte[] mask = new byte[2 + totalMaskLength];
            mask[0] = 34;
            for (int i = 0; i < totalMaskLength; i += maskValueBytes.length) {
                for (int j = 0; j < maskValueBytes.length; ++j) {
                    mask[1 + i + j] = maskValueBytes[j];
                }
            }
            mask[totalMaskLength + 1] = 34;
            context.replaceBytes(0, context.byteLength(), mask, 1);
        });
    }

    public static ValueMasker.AnyValueMasker noop() {
        return ValueMaskers.describe("<no masking>", context -> {});
    }

    public static ValueMasker.StringMasker email(int keepPrefixLength, int keepSuffixLength, boolean keepDomain, String mask) {
        byte[] replacementBytes = mask.getBytes(StandardCharsets.UTF_8);
        return ValueMaskers.describe("email, keep prefix: %s, keep suffix: %s, keep domain: %s".formatted(keepPrefixLength, keepSuffixLength, keepDomain), context -> {
            int maskLength;
            int prefixLength = keepPrefixLength + 1;
            int suffixLength = keepSuffixLength + 1;
            if (keepDomain) {
                for (int i = 0; i < context.byteLength(); ++i) {
                    if (context.getByte(i) != 64) continue;
                    suffixLength = context.byteLength() - i + keepSuffixLength;
                    break;
                }
            }
            if ((maskLength = context.byteLength() - prefixLength - suffixLength) > 0) {
                context.replaceBytes(prefixLength, maskLength, replacementBytes, 1);
            }
        });
    }

    public static ValueMasker.AnyValueMasker withRawValueFunction(Function<String, @Nullable String> masker) {
        return ValueMaskers.describe("withRawValueFunction (%s)".formatted(masker), context -> {
            String value = context.asString(0, context.byteLength());
            String maskedValue = (String)masker.apply(value);
            if (maskedValue == null) {
                maskedValue = "null";
            }
            byte[] replacementBytes = maskedValue.getBytes(StandardCharsets.UTF_8);
            context.replaceBytes(0, context.byteLength(), replacementBytes, 1);
        });
    }

    public static ValueMasker.AnyValueMasker withTextFunction(Function<String, @Nullable String> masker) {
        return ValueMaskers.describe("withTextFunction (%s)".formatted(masker), context -> {
            String decodedValue;
            if (context.getByte(0) != 34) {
                decodedValue = context.asString(0, context.byteLength());
            } else {
                int encodedIndex = 1;
                int valueEndIndex = context.byteLength() - 1;
                int decodedIndex = 0;
                byte[] decodedBytes = new byte[context.byteLength()];
                while (encodedIndex < valueEndIndex) {
                    byte originalByte;
                    if ((originalByte = context.getByte(encodedIndex++)) != 92) {
                        decodedBytes[decodedIndex++] = originalByte;
                        continue;
                    }
                    originalByte = context.getByte(encodedIndex++);
                    switch (originalByte) {
                        case 98: {
                            decodedBytes[decodedIndex++] = 8;
                            break;
                        }
                        case 116: {
                            decodedBytes[decodedIndex++] = 9;
                            break;
                        }
                        case 110: {
                            decodedBytes[decodedIndex++] = 10;
                            break;
                        }
                        case 102: {
                            decodedBytes[decodedIndex++] = 12;
                            break;
                        }
                        case 114: {
                            decodedBytes[decodedIndex++] = 13;
                            break;
                        }
                        case 34: 
                        case 47: 
                        case 92: {
                            decodedBytes[decodedIndex++] = originalByte;
                            break;
                        }
                        case 117: {
                            int valueStartIndex = encodedIndex - 2;
                            try {
                                char unicodeHexBytesAsChar = Utf8Util.unicodeHexToChar(context.getByte(encodedIndex++), context.getByte(encodedIndex++), context.getByte(encodedIndex++), context.getByte(encodedIndex++));
                                if (unicodeHexBytesAsChar < '\u0080') {
                                    decodedBytes[decodedIndex++] = (byte)unicodeHexBytesAsChar;
                                    break;
                                }
                                if (unicodeHexBytesAsChar < '\u0800') {
                                    decodedBytes[decodedIndex++] = (byte)(0xC0 | unicodeHexBytesAsChar >> 6);
                                    decodedBytes[decodedIndex++] = (byte)(0x80 | unicodeHexBytesAsChar & 0x3F);
                                    break;
                                }
                                if (Character.isSurrogate(unicodeHexBytesAsChar)) {
                                    int codePoint = -1;
                                    if (Character.isHighSurrogate(unicodeHexBytesAsChar) && encodedIndex < context.byteLength() - 6 && context.getByte(encodedIndex) == 92 && context.getByte(encodedIndex + 1) == 117) {
                                        char lowSurrogate;
                                        encodedIndex += 2;
                                        if (Character.isLowSurrogate(lowSurrogate = Utf8Util.unicodeHexToChar(context.getByte(encodedIndex++), context.getByte(encodedIndex++), context.getByte(encodedIndex++), context.getByte(encodedIndex++)))) {
                                            codePoint = Character.toCodePoint(unicodeHexBytesAsChar, lowSurrogate);
                                        }
                                    }
                                    if (codePoint < 0) {
                                        throw context.invalidJson("Invalid surrogate pair '%s'".formatted(context.asString(valueStartIndex, encodedIndex - valueStartIndex)), valueStartIndex);
                                    }
                                    decodedBytes[decodedIndex++] = (byte)(0xF0 | codePoint >> 18);
                                    decodedBytes[decodedIndex++] = (byte)(0x80 | codePoint >> 12 & 0x3F);
                                    decodedBytes[decodedIndex++] = (byte)(0x80 | codePoint >> 6 & 0x3F);
                                    decodedBytes[decodedIndex++] = (byte)(0x80 | codePoint & 0x3F);
                                    break;
                                }
                                decodedBytes[decodedIndex++] = (byte)(0xE0 | unicodeHexBytesAsChar >> 12);
                                decodedBytes[decodedIndex++] = (byte)(0x80 | unicodeHexBytesAsChar >> 6 & 0x3F);
                                decodedBytes[decodedIndex++] = (byte)(0x80 | unicodeHexBytesAsChar & 0x3F);
                                break;
                            }
                            catch (IllegalArgumentException e) {
                                throw context.invalidJson(Objects.requireNonNull(e.getMessage()), valueStartIndex);
                            }
                        }
                        default: {
                            throw context.invalidJson("Unexpected character after '\\': '%s'".formatted(Character.valueOf((char)originalByte)), encodedIndex);
                        }
                    }
                }
                decodedValue = new String(decodedBytes, 0, decodedIndex, StandardCharsets.UTF_8);
            }
            String maskedValue = (String)masker.apply(decodedValue);
            maskedValue = maskedValue == null ? "null" : Utf8Util.jsonEncode(maskedValue, true);
            byte[] replacementBytes = maskedValue.getBytes(StandardCharsets.UTF_8);
            context.replaceBytes(0, context.byteLength(), replacementBytes, 1);
        });
    }
}

