/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup.mapping;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.BitSet;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToInt;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToZero;
import org.apache.sysds.utils.MemoryEstimates;

public class MapToBit
extends AMapToData {
    private static final long serialVersionUID = -8065234231282619923L;
    private final BitSet _data;
    private final int _size;

    protected MapToBit(int size) {
        this(2, size);
    }

    public MapToBit(int unique, int size) {
        super(Math.min(unique, 2));
        this._data = new BitSet(size);
        this._size = size;
    }

    private MapToBit(int unique, BitSet d, int size) {
        super(unique);
        this._data = d;
        this._size = size;
        if (this._data.isEmpty()) {
            throw new DMLRuntimeException("Empty BitSet should not happen it should return MapToZero");
        }
    }

    protected BitSet getData() {
        return this._data;
    }

    @Override
    public MapToFactory.MAP_TYPE getType() {
        return MapToFactory.MAP_TYPE.BIT;
    }

    @Override
    public int getIndex(int n) {
        return this._data.get(n) ? 1 : 0;
    }

    @Override
    public void fill(int v) {
        this._data.set(0, this._size, true);
    }

    @Override
    public long getInMemorySize() {
        return MapToBit.getInMemorySize(this._data.size() - 1);
    }

    public static long getInMemorySize(int dataLength) {
        long size = 28L;
        size = (long)((double)size + MemoryEstimates.bitSetCost(dataLength));
        return size;
    }

    @Override
    public void set(int n, int v) {
        this._data.set(n, v == 1);
    }

    @Override
    public int setAndGet(int n, int v) {
        this._data.set(n, v == 1);
        return 1;
    }

    @Override
    public int size() {
        return this._size;
    }

    @Override
    public void replace(int v, int r) {
        if (v == 0) {
            this._data.set(0, this.size(), true);
        } else {
            this._data.clear();
        }
    }

    @Override
    public long getExactSizeOnDisk() {
        long size = 9L;
        return size += (long)(this._data.toLongArray().length * 8);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        long[] internals = this._data.toLongArray();
        out.writeByte(MapToFactory.MAP_TYPE.BIT.ordinal());
        out.writeInt(this._size);
        out.writeInt(internals.length);
        for (int i = 0; i < internals.length; ++i) {
            out.writeLong(internals[i]);
        }
    }

    protected static MapToBit readFields(DataInput in) throws IOException {
        int size = in.readInt();
        long[] internalLong = new long[in.readInt()];
        for (int i = 0; i < internalLong.length; ++i) {
            internalLong[i] = in.readLong();
        }
        BitSet ret = BitSet.valueOf(internalLong);
        return new MapToBit(2, ret, size);
    }

    @Override
    public int getUpperBoundValue() {
        return 1;
    }

    @Override
    public int[] getCounts(int[] ret) {
        int sz = this.size();
        ret[1] = this._data.cardinality();
        ret[0] = sz - ret[1];
        return ret;
    }

    @Override
    public void preAggregateDDC_DDCSingleCol(AMapToData tm, double[] td, double[] v) {
        if (tm instanceof MapToBit) {
            this.preAggregateDDCSingleColBitBit((MapToBit)tm, td, v);
        } else {
            super.preAggregateDDC_DDCSingleCol(tm, td, v);
        }
    }

    private void preAggregateDDCSingleColBitBit(MapToBit tmb, double[] td, double[] v) {
        JoinBitSets j = new JoinBitSets(tmb._data, this._data, this._size);
        v[1] = v[1] + td[1] * (double)j.tt;
        v[0] = v[0] + td[1] * (double)j.ft;
        v[1] = v[1] + td[0] * (double)j.tf;
        v[0] = v[0] + td[0] * (double)j.ff;
    }

    @Override
    public void preAggregateDDC_DDCMultiCol(AMapToData tm, ADictionary td, double[] v, int nCol) {
        if (tm instanceof MapToBit) {
            this.preAggregateDDCMultiColBitBit((MapToBit)tm, td, v, nCol);
        } else {
            super.preAggregateDDC_DDCMultiCol(tm, td, v, nCol);
        }
    }

    private void preAggregateDDCMultiColBitBit(MapToBit tmb, ADictionary td, double[] v, int nCol) {
        JoinBitSets j = new JoinBitSets(tmb._data, this._data, this._size);
        double[] tv = td.getValues();
        int i = 0;
        while (i < nCol) {
            int off = nCol + i;
            int n = i;
            v[n] = v[n] + tv[i] * (double)j.ff;
            int n2 = off;
            v[n2] = v[n2] + tv[i] * (double)j.tf;
            int n3 = off;
            v[n3] = v[n3] + tv[off] * (double)j.tt;
            int n4 = i++;
            v[n4] = v[n4] + tv[off] * (double)j.ft;
        }
    }

    public boolean isEmpty() {
        return this._data.isEmpty();
    }

    @Override
    public void copy(AMapToData d) {
        if (d instanceof MapToBit) {
            this.copyBit((MapToBit)d);
        } else if (d instanceof MapToInt) {
            this.copyInt((MapToInt)d);
        } else {
            int sz = this.size();
            for (int i = 0; i < sz; ++i) {
                if (d.getIndex(i) == 0) continue;
                this._data.set(i);
            }
        }
    }

    @Override
    public void copyInt(int[] d) {
        for (int i = d.length - 1; i > -1; --i) {
            if (d[i] == 0) continue;
            this._data.set(i);
        }
    }

    @Override
    public void copyBit(BitSet d) {
        this._data.clear();
        this._data.or(d);
    }

    @Override
    public AMapToData resize(int unique) {
        if (unique <= 1) {
            return new MapToZero(this.size());
        }
        return this;
    }

    @Override
    public int countRuns() {
        if (this._size <= 64) {
            long l = this._data.toLongArray()[0];
            if (this._size != 64 && this._data.get(this._size - 1)) {
                long mask = 0xFFFFFFFFFFFFFFFFL ^ -1L << this._size - 64 ^ 0xFFFFFFFFFFFFFFFFL;
                l |= mask;
            }
            long shift1 = l << 1 | l & 1L;
            long j = l ^ shift1;
            return 1 + Long.bitCount(j);
        }
        long[] _longs = this._data.toLongArray();
        long lastMask = Long.MIN_VALUE;
        long l = _longs[0];
        long shift1 = l << 1 | l & 1L;
        long j = l ^ shift1;
        int c = 1 + Long.bitCount(j);
        for (int i = 1; i < _longs.length - 1; ++i) {
            shift1 = _longs[i] << 1 | (_longs[i - 1] & Long.MIN_VALUE) >>> 63;
            c += Long.bitCount(_longs[i] ^ shift1);
        }
        int idx = _longs.length - 1;
        l = this._size % 64 != 0 && this._data.get(this._size - 1) ? _longs[idx] | 0xFFFFFFFFFFFFFFFFL ^ -1L << this._size - 64 ^ 0xFFFFFFFFFFFFFFFFL : _longs[idx];
        shift1 = l << 1 | (_longs[idx - 1] & Long.MIN_VALUE) >>> 63;
        return c += Long.bitCount(l ^ shift1);
    }

    @Override
    public AMapToData slice(int l, int u) {
        BitSet s = this._data.get(l, u);
        if (s.isEmpty()) {
            return new MapToZero(u - l);
        }
        return new MapToBit(this.getUnique(), s, u - l);
    }

    @Override
    public AMapToData append(AMapToData t) {
        if (t instanceof MapToBit) {
            MapToBit tb = (MapToBit)t;
            BitSet tbb = tb._data;
            int newSize = this._size + t.size();
            BitSet ret = new BitSet(newSize);
            ret.xor(this._data);
            tbb.stream().forEach(x -> ret.set(x + this._size, true));
            return new MapToBit(2, ret, newSize);
        }
        throw new NotImplementedException("Not implemented append on Bit map different type");
    }

    @Override
    public AMapToData appendN(AMapToDataGroup[] d) {
        int p = 0;
        for (AMapToDataGroup gd : d) {
            p += gd.getMapToData().size();
        }
        long[] ret = new long[(p - 1) / 64 + 1];
        long[] or = this._data.toLongArray();
        System.arraycopy(or, 0, ret, 0, or.length);
        p = this.size();
        for (int i = 1; i < d.length; ++i) {
            MapToBit mm = (MapToBit)d[i].getMapToData();
            int ms = mm.size();
            or = mm._data.toLongArray();
            int remainder = p % 64;
            int retLp = p / 64;
            if (remainder == 0) {
                System.arraycopy(or, 0, ret, retLp, or.length);
            } else {
                for (int j = 0; j < or.length - 1; ++j) {
                    long v = or[j];
                    ret[retLp] = ret[retLp] ^ v << remainder;
                    ret[++retLp] = v >>> 64 - remainder;
                }
                long v = or[or.length - 1];
                ret[retLp] = ret[retLp] ^ v << remainder;
                if (++retLp < ret.length) {
                    ret[retLp] = v >>> 64 - remainder;
                }
            }
            p += ms;
        }
        BitSet retBS = BitSet.valueOf(ret);
        return new MapToBit(this.getUnique(), retBS, p);
    }

    private static class JoinBitSets {
        int tt = 0;
        int ft = 0;
        int tf = 0;
        int ff = 0;

        protected JoinBitSets(BitSet t_data, BitSet o_data, int size) {
            int i;
            long[] t_longs = t_data.toLongArray();
            long[] _longs = o_data.toLongArray();
            int common = Math.min(t_longs.length, _longs.length);
            for (i = 0; i < common; ++i) {
                long t = t_longs[i];
                long v = _longs[i];
                this.tt += Long.bitCount(t & v);
                this.ft += Long.bitCount(t & (v ^ 0xFFFFFFFFFFFFFFFFL));
                this.tf += Long.bitCount((t ^ 0xFFFFFFFFFFFFFFFFL) & v);
                this.ff += Long.bitCount((t ^ 0xFFFFFFFFFFFFFFFFL) & (v ^ 0xFFFFFFFFFFFFFFFFL));
            }
            if (t_longs.length > common) {
                for (i = common; i < t_longs.length; ++i) {
                    int v = Long.bitCount(t_longs[i]);
                    this.ft += v;
                    this.ff += 64 - v;
                }
            } else if (_longs.length > common) {
                for (i = common; i < _longs.length; ++i) {
                    int v = Long.bitCount(_longs[i]);
                    this.tf += v;
                    this.ff += 64 - v;
                }
            }
            int longest = Math.max(t_longs.length, _longs.length);
            this.ff += size - longest * 64;
        }
    }
}

