/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.internal.vectorization;

import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;

final class DefaultVectorUtilSupport
implements VectorUtilSupport {
    DefaultVectorUtilSupport() {
    }

    @SuppressForbidden(reason="Uses FMA only where fast and carefully contained")
    private static float fma(float a, float b, float c) {
        if (Constants.HAS_FAST_SCALAR_FMA) {
            return Math.fma(a, b, c);
        }
        return a * b + c;
    }

    @Override
    public float dotProduct(float[] a, float[] b) {
        int i;
        float res = 0.0f;
        if (a.length > 32) {
            float acc1 = 0.0f;
            float acc2 = 0.0f;
            float acc3 = 0.0f;
            float acc4 = 0.0f;
            int upperBound = a.length & 0xFFFFFFFC;
            for (i = 0; i < upperBound; i += 4) {
                acc1 = DefaultVectorUtilSupport.fma(a[i], b[i], acc1);
                acc2 = DefaultVectorUtilSupport.fma(a[i + 1], b[i + 1], acc2);
                acc3 = DefaultVectorUtilSupport.fma(a[i + 2], b[i + 2], acc3);
                acc4 = DefaultVectorUtilSupport.fma(a[i + 3], b[i + 3], acc4);
            }
            res += acc1 + acc2 + acc3 + acc4;
        }
        while (i < a.length) {
            res = DefaultVectorUtilSupport.fma(a[i], b[i], res);
            ++i;
        }
        return res;
    }

    @Override
    public float cosine(float[] a, float[] b) {
        int i;
        float sum = 0.0f;
        float norm1 = 0.0f;
        float norm2 = 0.0f;
        if (a.length > 32) {
            float sum1 = 0.0f;
            float sum2 = 0.0f;
            float norm1_1 = 0.0f;
            float norm1_2 = 0.0f;
            float norm2_1 = 0.0f;
            float norm2_2 = 0.0f;
            int upperBound = a.length & 0xFFFFFFFE;
            for (i = 0; i < upperBound; i += 2) {
                sum1 = DefaultVectorUtilSupport.fma(a[i], b[i], sum1);
                norm1_1 = DefaultVectorUtilSupport.fma(a[i], a[i], norm1_1);
                norm2_1 = DefaultVectorUtilSupport.fma(b[i], b[i], norm2_1);
                sum2 = DefaultVectorUtilSupport.fma(a[i + 1], b[i + 1], sum2);
                norm1_2 = DefaultVectorUtilSupport.fma(a[i + 1], a[i + 1], norm1_2);
                norm2_2 = DefaultVectorUtilSupport.fma(b[i + 1], b[i + 1], norm2_2);
            }
            sum += sum1 + sum2;
            norm1 += norm1_1 + norm1_2;
            norm2 += norm2_1 + norm2_2;
        }
        while (i < a.length) {
            sum = DefaultVectorUtilSupport.fma(a[i], b[i], sum);
            norm1 = DefaultVectorUtilSupport.fma(a[i], a[i], norm1);
            norm2 = DefaultVectorUtilSupport.fma(b[i], b[i], norm2);
            ++i;
        }
        return (float)((double)sum / Math.sqrt((double)norm1 * (double)norm2));
    }

    @Override
    public float squareDistance(float[] a, float[] b) {
        int i;
        float res = 0.0f;
        if (a.length > 32) {
            float acc1 = 0.0f;
            float acc2 = 0.0f;
            float acc3 = 0.0f;
            float acc4 = 0.0f;
            int upperBound = a.length & 0xFFFFFFFC;
            for (i = 0; i < upperBound; i += 4) {
                float diff1 = a[i] - b[i];
                acc1 = DefaultVectorUtilSupport.fma(diff1, diff1, acc1);
                float diff2 = a[i + 1] - b[i + 1];
                acc2 = DefaultVectorUtilSupport.fma(diff2, diff2, acc2);
                float diff3 = a[i + 2] - b[i + 2];
                acc3 = DefaultVectorUtilSupport.fma(diff3, diff3, acc3);
                float diff4 = a[i + 3] - b[i + 3];
                acc4 = DefaultVectorUtilSupport.fma(diff4, diff4, acc4);
            }
            res += acc1 + acc2 + acc3 + acc4;
        }
        while (i < a.length) {
            float diff = a[i] - b[i];
            res = DefaultVectorUtilSupport.fma(diff, diff, res);
            ++i;
        }
        return res;
    }

    @Override
    public int dotProduct(byte[] a, byte[] b) {
        int total = 0;
        for (int i = 0; i < a.length; ++i) {
            total += a[i] * b[i];
        }
        return total;
    }

    @Override
    public int int4DotProduct(byte[] a, boolean apacked, byte[] b, boolean bpacked) {
        assert (!(apacked && bpacked));
        if (apacked || bpacked) {
            byte[] packed = apacked ? a : b;
            byte[] unpacked = apacked ? b : a;
            int total = 0;
            for (int i = 0; i < packed.length; ++i) {
                byte packedByte = packed[i];
                byte unpacked1 = unpacked[i];
                byte unpacked2 = unpacked[i + packed.length];
                total += (packedByte & 0xF) * unpacked2;
                total += ((packedByte & 0xFF) >> 4) * unpacked1;
            }
            return total;
        }
        return this.dotProduct(a, b);
    }

    @Override
    public float cosine(byte[] a, byte[] b) {
        int sum = 0;
        int norm1 = 0;
        int norm2 = 0;
        for (int i = 0; i < a.length; ++i) {
            byte elem1 = a[i];
            byte elem2 = b[i];
            sum += elem1 * elem2;
            norm1 += elem1 * elem1;
            norm2 += elem2 * elem2;
        }
        return (float)((double)sum / Math.sqrt((double)norm1 * (double)norm2));
    }

    @Override
    public int squareDistance(byte[] a, byte[] b) {
        int squareSum = 0;
        for (int i = 0; i < a.length; ++i) {
            int diff = a[i] - b[i];
            squareSum += diff * diff;
        }
        return squareSum;
    }

    @Override
    public int findNextGEQ(int[] buffer, int target, int from, int to) {
        for (int i = from; i < to; ++i) {
            if (buffer[i] < target) continue;
            return i;
        }
        return to;
    }

    @Override
    public long int4BitDotProduct(byte[] int4Quantized, byte[] binaryQuantized) {
        return DefaultVectorUtilSupport.int4BitDotProductImpl(int4Quantized, binaryQuantized);
    }

    public static long int4BitDotProductImpl(byte[] q, byte[] d) {
        assert (q.length == d.length * 4);
        long ret = 0L;
        int size = d.length;
        for (int i = 0; i < 4; ++i) {
            int r;
            long subRet = 0L;
            int upperBound = d.length & 0xFFFFFFFC;
            for (r = 0; r < upperBound; r += 4) {
                subRet += (long)Integer.bitCount(BitUtil.VH_NATIVE_INT.get(q, i * size + r) & BitUtil.VH_NATIVE_INT.get(d, r));
            }
            while (r < d.length) {
                subRet += (long)Integer.bitCount(q[i * size + r] & d[r] & 0xFF);
                ++r;
            }
            ret += subRet << i;
        }
        return ret;
    }

    @Override
    public float minMaxScalarQuantize(float[] vector, byte[] dest, float scale, float alpha2, float minQuantile, float maxQuantile) {
        return new ScalarQuantizer(alpha2, scale, minQuantile, maxQuantile).quantize(vector, dest, 0);
    }

    @Override
    public float recalculateScalarQuantizationOffset(byte[] vector, float oldAlpha, float oldMinQuantile, float scale, float alpha2, float minQuantile, float maxQuantile) {
        return new ScalarQuantizer(alpha2, scale, minQuantile, maxQuantile).recalculateOffset(vector, 0, oldAlpha, oldMinQuantile);
    }

    @Override
    public int filterByScore(int[] docBuffer, double[] scoreBuffer, double minScoreInclusive, int upTo) {
        int newSize = 0;
        for (int i = 0; i < upTo; ++i) {
            int doc = docBuffer[i];
            double score = scoreBuffer[i];
            docBuffer[newSize] = doc;
            scoreBuffer[newSize] = score;
            if (!(score >= minScoreInclusive)) continue;
            ++newSize;
        }
        return newSize;
    }

    static class ScalarQuantizer {
        private final float alpha;
        private final float scale;
        private final float minQuantile;
        private final float maxQuantile;

        ScalarQuantizer(float alpha2, float scale, float minQuantile, float maxQuantile) {
            this.alpha = alpha2;
            this.scale = scale;
            this.minQuantile = minQuantile;
            this.maxQuantile = maxQuantile;
        }

        float quantize(float[] vector, byte[] dest, int start) {
            assert (vector.length == dest.length);
            float correction = 0.0f;
            for (int i = start; i < vector.length; ++i) {
                correction += this.quantizeFloat(vector[i], dest, i);
            }
            return correction;
        }

        float recalculateOffset(byte[] vector, int start, float oldAlpha, float oldMinQuantile) {
            float correction = 0.0f;
            for (int i = start; i < vector.length; ++i) {
                float v = oldAlpha * (float)vector[i] + oldMinQuantile;
                correction += this.quantizeFloat(v, null, 0);
            }
            return correction;
        }

        private float quantizeFloat(float v, byte[] dest, int destIndex) {
            assert (dest == null || destIndex < dest.length);
            float dx = v - this.minQuantile;
            float dxc = Math.max(this.minQuantile, Math.min(this.maxQuantile, v)) - this.minQuantile;
            int roundedDxs = Math.round(this.scale * dxc);
            float dxq = (float)roundedDxs * this.alpha;
            if (dest != null) {
                dest[destIndex] = (byte)roundedDxs;
            }
            return this.minQuantile * (v - this.minQuantile / 2.0f) + (dx - dxq) * dxq;
        }
    }
}

