/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

@Description(name="covariance,covar_pop", value="_FUNC_(x,y) - Returns the population covariance of a set of number pairs", extended="The function takes as arguments any pair of numeric types and returns a double.\nAny pair with a NULL is ignored. If the function is applied to an empty set, NULL\nwill be returned. Otherwise, it computes the following:\n   (SUM(x*y)-SUM(x)*SUM(y)/COUNT(x,y))/COUNT(x,y)\nwhere neither x nor y is null.")
public class GenericUDAFCovariance
extends AbstractGenericUDAFResolver {
    static final Log LOG = LogFactory.getLog((String)GenericUDAFCovariance.class.getName());

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
        if (parameters.length != 2) {
            throw new UDFArgumentTypeException(parameters.length - 1, "Exactly two arguments are expected.");
        }
        if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + parameters[0].getTypeName() + " is passed.");
        }
        if (parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(1, "Only primitive type arguments are accepted but " + parameters[1].getTypeName() + " is passed.");
        }
        switch (((PrimitiveTypeInfo)parameters[0]).getPrimitiveCategory()) {
            case BYTE: 
            case SHORT: 
            case INT: 
            case LONG: 
            case FLOAT: 
            case DOUBLE: 
            case TIMESTAMP: {
                switch (((PrimitiveTypeInfo)parameters[1]).getPrimitiveCategory()) {
                    case BYTE: 
                    case SHORT: 
                    case INT: 
                    case LONG: 
                    case FLOAT: 
                    case DOUBLE: 
                    case TIMESTAMP: {
                        return new GenericUDAFCovarianceEvaluator();
                    }
                }
                throw new UDFArgumentTypeException(1, "Only numeric or string type arguments are accepted but " + parameters[1].getTypeName() + " is passed.");
            }
        }
        throw new UDFArgumentTypeException(0, "Only numeric or string type arguments are accepted but " + parameters[0].getTypeName() + " is passed.");
    }

    public static class GenericUDAFCovarianceEvaluator
    extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector xInputOI;
        private PrimitiveObjectInspector yInputOI;
        private StructObjectInspector soi;
        private StructField countField;
        private StructField xavgField;
        private StructField yavgField;
        private StructField covarField;
        private LongObjectInspector countFieldOI;
        private DoubleObjectInspector xavgFieldOI;
        private DoubleObjectInspector yavgFieldOI;
        private DoubleObjectInspector covarFieldOI;
        private Object[] partialResult;
        private DoubleWritable result;
        private boolean warned = false;

        @Override
        public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);
            if (this.mode == GenericUDAFEvaluator.Mode.PARTIAL1 || this.mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                assert (parameters.length == 2);
                this.xInputOI = (PrimitiveObjectInspector)parameters[0];
                this.yInputOI = (PrimitiveObjectInspector)parameters[1];
            } else {
                assert (parameters.length == 1);
                this.soi = (StructObjectInspector)parameters[0];
                this.countField = this.soi.getStructFieldRef("count");
                this.xavgField = this.soi.getStructFieldRef("xavg");
                this.yavgField = this.soi.getStructFieldRef("yavg");
                this.covarField = this.soi.getStructFieldRef("covar");
                this.countFieldOI = (LongObjectInspector)this.countField.getFieldObjectInspector();
                this.xavgFieldOI = (DoubleObjectInspector)this.xavgField.getFieldObjectInspector();
                this.yavgFieldOI = (DoubleObjectInspector)this.yavgField.getFieldObjectInspector();
                this.covarFieldOI = (DoubleObjectInspector)this.covarField.getFieldObjectInspector();
            }
            if (this.mode == GenericUDAFEvaluator.Mode.PARTIAL1 || this.mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
                foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                ArrayList<String> fname = new ArrayList<String>();
                fname.add("count");
                fname.add("xavg");
                fname.add("yavg");
                fname.add("covar");
                this.partialResult = new Object[4];
                this.partialResult[0] = new LongWritable(0L);
                this.partialResult[1] = new DoubleWritable(0.0);
                this.partialResult[2] = new DoubleWritable(0.0);
                this.partialResult[3] = new DoubleWritable(0.0);
                return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
            }
            this.setResult(new DoubleWritable(0.0));
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        @Override
        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            StdAgg result = new StdAgg();
            this.reset(result);
            return result;
        }

        @Override
        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            StdAgg myagg = (StdAgg)agg;
            myagg.count = 0L;
            myagg.xavg = 0.0;
            myagg.yavg = 0.0;
            myagg.covar = 0.0;
        }

        @Override
        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            assert (parameters.length == 2);
            Object px = parameters[0];
            Object py = parameters[1];
            if (px != null && py != null) {
                StdAgg myagg = (StdAgg)agg;
                double vx = PrimitiveObjectInspectorUtils.getDouble(px, this.xInputOI);
                double vy = PrimitiveObjectInspectorUtils.getDouble(py, this.yInputOI);
                ++myagg.count;
                myagg.yavg += (vy - myagg.yavg) / (double)myagg.count;
                if (myagg.count > 1L) {
                    myagg.covar += (vx - myagg.xavg) * (vy - myagg.yavg);
                }
                myagg.xavg += (vx - myagg.xavg) / (double)myagg.count;
            }
        }

        @Override
        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            StdAgg myagg = (StdAgg)agg;
            ((LongWritable)this.partialResult[0]).set(myagg.count);
            ((DoubleWritable)this.partialResult[1]).set(myagg.xavg);
            ((DoubleWritable)this.partialResult[2]).set(myagg.yavg);
            ((DoubleWritable)this.partialResult[3]).set(myagg.covar);
            return this.partialResult;
        }

        @Override
        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object partial) throws HiveException {
            if (partial != null) {
                StdAgg myagg = (StdAgg)agg;
                Object partialCount = this.soi.getStructFieldData(partial, this.countField);
                Object partialXAvg = this.soi.getStructFieldData(partial, this.xavgField);
                Object partialYAvg = this.soi.getStructFieldData(partial, this.yavgField);
                Object partialCovar = this.soi.getStructFieldData(partial, this.covarField);
                long nA = myagg.count;
                long nB = this.countFieldOI.get(partialCount);
                if (nA == 0L) {
                    myagg.count = this.countFieldOI.get(partialCount);
                    myagg.xavg = this.xavgFieldOI.get(partialXAvg);
                    myagg.yavg = this.yavgFieldOI.get(partialYAvg);
                    myagg.covar = this.covarFieldOI.get(partialCovar);
                }
                if (nA != 0L && nB != 0L) {
                    double xavgA = myagg.xavg;
                    double yavgA = myagg.yavg;
                    double xavgB = this.xavgFieldOI.get(partialXAvg);
                    double yavgB = this.yavgFieldOI.get(partialYAvg);
                    double covarB = this.covarFieldOI.get(partialCovar);
                    myagg.count += nB;
                    myagg.xavg = (xavgA * (double)nA + xavgB * (double)nB) / (double)myagg.count;
                    myagg.yavg = (yavgA * (double)nA + yavgB * (double)nB) / (double)myagg.count;
                    myagg.covar += covarB + (xavgA - xavgB) * (yavgA - yavgB) * ((double)(nA * nB) / (double)myagg.count);
                }
            }
        }

        @Override
        public Object terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            StdAgg myagg = (StdAgg)agg;
            if (myagg.count == 0L) {
                return null;
            }
            this.getResult().set(myagg.covar / (double)myagg.count);
            return this.getResult();
        }

        public void setResult(DoubleWritable result) {
            this.result = result;
        }

        public DoubleWritable getResult() {
            return this.result;
        }

        static class StdAgg
        implements GenericUDAFEvaluator.AggregationBuffer {
            long count;
            double xavg;
            double yavg;
            double covar;

            StdAgg() {
            }
        }
    }
}

