/*
 * Decompiled with CFR 0.152.
 */
package org.apache.drill.exec.expr;

import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.drill.common.exceptions.DrillRuntimeException;
import org.apache.drill.common.expression.BooleanOperator;
import org.apache.drill.common.expression.CastExpression;
import org.apache.drill.common.expression.ConvertExpression;
import org.apache.drill.common.expression.ErrorCollector;
import org.apache.drill.common.expression.ExpressionPosition;
import org.apache.drill.common.expression.FunctionCall;
import org.apache.drill.common.expression.FunctionHolderExpression;
import org.apache.drill.common.expression.IfExpression;
import org.apache.drill.common.expression.LogicalExpression;
import org.apache.drill.common.expression.NullExpression;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.common.expression.TypedNullConstant;
import org.apache.drill.common.expression.ValueExpressions;
import org.apache.drill.common.expression.fn.CastFunctions;
import org.apache.drill.common.expression.visitors.AbstractExprVisitor;
import org.apache.drill.common.expression.visitors.ConditionalExprOptimizer;
import org.apache.drill.common.expression.visitors.ExpressionValidator;
import org.apache.drill.common.types.TypeProtos;
import org.apache.drill.common.types.Types;
import org.apache.drill.exec.expr.ValueVectorReadExpression;
import org.apache.drill.exec.expr.annotations.FunctionTemplate;
import org.apache.drill.exec.expr.fn.AbstractFuncHolder;
import org.apache.drill.exec.expr.fn.DrillComplexWriterFuncHolder;
import org.apache.drill.exec.expr.fn.DrillFuncHolder;
import org.apache.drill.exec.expr.fn.FunctionImplementationRegistry;
import org.apache.drill.exec.record.TypedFieldId;
import org.apache.drill.exec.record.VectorAccessible;
import org.apache.drill.exec.resolver.FunctionResolver;
import org.apache.drill.exec.resolver.FunctionResolverFactory;
import org.apache.drill.exec.resolver.TypeCastRules;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExpressionTreeMaterializer {
    static final Logger logger = LoggerFactory.getLogger(ExpressionTreeMaterializer.class);

    private ExpressionTreeMaterializer() {
    }

    public static LogicalExpression materialize(LogicalExpression expr, VectorAccessible batch, ErrorCollector errorCollector, FunctionImplementationRegistry registry) {
        return ExpressionTreeMaterializer.materialize(expr, batch, errorCollector, registry, false);
    }

    public static LogicalExpression materialize(LogicalExpression expr, VectorAccessible batch, ErrorCollector errorCollector, FunctionImplementationRegistry registry, boolean allowComplexWriterExpr) {
        LogicalExpression out = expr.accept(new MaterializeVisitor(batch, errorCollector, allowComplexWriterExpr), registry);
        if (!errorCollector.hasErrors()) {
            out = out.accept(ConditionalExprOptimizer.INSTANCE, null);
        }
        if (out instanceof NullExpression) {
            return new TypedNullConstant(Types.optional(TypeProtos.MinorType.INT));
        }
        return out;
    }

    public static LogicalExpression addCastExpression(LogicalExpression fromExpr, TypeProtos.MajorType toType, FunctionImplementationRegistry registry, ErrorCollector errorCollector) {
        String castFuncName = CastFunctions.getCastFunc(toType.getMinorType());
        ArrayList<LogicalExpression> castArgs = Lists.newArrayList();
        castArgs.add(fromExpr);
        if (!Types.isFixedWidthType(toType)) {
            castArgs.add(new ValueExpressions.LongExpression(65536L, null));
        } else if (toType.getMinorType().name().startsWith("DECIMAL")) {
            castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getPrecision(), null));
            castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getScale(), null));
        }
        FunctionCall castCall = new FunctionCall(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
        FunctionResolver resolver = FunctionResolverFactory.getResolver(castCall);
        DrillFuncHolder matchedCastFuncHolder = registry.findDrillFunction(resolver, castCall);
        if (matchedCastFuncHolder == null) {
            ExpressionTreeMaterializer.logFunctionResolutionError(errorCollector, castCall);
            return NullExpression.INSTANCE;
        }
        return matchedCastFuncHolder.getExpr(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
    }

    private static void logFunctionResolutionError(ErrorCollector errorCollector, FunctionCall call) {
        StringBuilder sb = new StringBuilder();
        sb.append("Missing function implementation: ");
        sb.append("[");
        sb.append(call.getName());
        sb.append("(");
        boolean first = true;
        for (LogicalExpression e : call.args) {
            TypeProtos.MajorType mt = e.getMajorType();
            if (first) {
                first = false;
            } else {
                sb.append(", ");
            }
            sb.append(mt.getMinorType().name());
            sb.append("-");
            sb.append(mt.getMode().name());
        }
        sb.append(")");
        sb.append("]");
        errorCollector.addGeneralError(call.getPosition(), sb.toString());
    }

    private static class MaterializeVisitor
    extends AbstractExprVisitor<LogicalExpression, FunctionImplementationRegistry, RuntimeException> {
        private ExpressionValidator validator = new ExpressionValidator();
        private final ErrorCollector errorCollector;
        private final VectorAccessible batch;
        private final boolean allowComplexWriter;

        public MaterializeVisitor(VectorAccessible batch, ErrorCollector errorCollector, boolean allowComplexWriter) {
            this.batch = batch;
            this.errorCollector = errorCollector;
            this.allowComplexWriter = allowComplexWriter;
        }

        private LogicalExpression validateNewExpr(LogicalExpression newExpr) {
            newExpr.accept(this.validator, this.errorCollector);
            return newExpr;
        }

        @Override
        public LogicalExpression visitUnknown(LogicalExpression e, FunctionImplementationRegistry registry) throws RuntimeException {
            return e;
        }

        @Override
        public LogicalExpression visitFunctionHolderExpression(FunctionHolderExpression holder, FunctionImplementationRegistry value) throws RuntimeException {
            return holder;
        }

        @Override
        public LogicalExpression visitBooleanOperator(BooleanOperator op, FunctionImplementationRegistry registry) {
            ArrayList<LogicalExpression> args = Lists.newArrayList();
            for (int i = 0; i < op.args.size(); ++i) {
                LogicalExpression newExpr = ((LogicalExpression)op.args.get(i)).accept(this, registry);
                assert (newExpr != null) : String.format("Materialization of %s return a null expression.", op.args.get(i));
                args.add(newExpr);
            }
            return new BooleanOperator(op.getName(), args, op.getPosition());
        }

        @Override
        public LogicalExpression visitFunctionCall(FunctionCall call, FunctionImplementationRegistry registry) {
            ArrayList<LogicalExpression> args = Lists.newArrayList();
            for (int i = 0; i < call.args.size(); ++i) {
                LogicalExpression newExpr = ((LogicalExpression)call.args.get(i)).accept(this, registry);
                assert (newExpr != null) : String.format("Materialization of %s return a null expression.", call.args.get(i));
                args.add(newExpr);
            }
            FunctionResolver resolver = FunctionResolverFactory.getResolver(call = new FunctionCall(call.getName(), args, call.getPosition()));
            DrillFuncHolder matchedFuncHolder = registry.findDrillFunction(resolver, call);
            if (matchedFuncHolder instanceof DrillComplexWriterFuncHolder && !this.allowComplexWriter) {
                this.errorCollector.addGeneralError(call.getPosition(), "Only ProjectRecordBatch could have complex writer function. You are using complex writer function " + call.getName() + " in a non-project operation!");
            }
            ArrayList<LogicalExpression> argsWithCast = Lists.newArrayList();
            if (matchedFuncHolder != null) {
                for (int i = 0; i < call.args.size(); ++i) {
                    LogicalExpression currentArg = (LogicalExpression)call.args.get(i);
                    TypeProtos.MajorType parmType = matchedFuncHolder.getParmMajorType(i);
                    if (currentArg.equals(NullExpression.INSTANCE) && (parmType.getMode().equals(TypeProtos.DataMode.OPTIONAL) || matchedFuncHolder.getNullHandling() == FunctionTemplate.NullHandling.NULL_IF_NULL)) {
                        argsWithCast.add(new TypedNullConstant(parmType));
                        continue;
                    }
                    if (Types.softEquals(parmType, currentArg.getMajorType(), matchedFuncHolder.getNullHandling() == FunctionTemplate.NullHandling.NULL_IF_NULL) || matchedFuncHolder.isFieldReader(i)) {
                        argsWithCast.add(currentArg);
                        continue;
                    }
                    argsWithCast.add(MaterializeVisitor.addCastExpression((LogicalExpression)call.args.get(i), parmType, registry, this.errorCollector));
                }
                return matchedFuncHolder.getExpr(call.getName(), argsWithCast, call.getPosition());
            }
            AbstractFuncHolder matchedNonDrillFuncHolder = registry.findNonDrillFunction(call);
            if (matchedNonDrillFuncHolder != null) {
                ArrayList<LogicalExpression> extArgsWithCast = Lists.newArrayList();
                for (int i = 0; i < call.args.size(); ++i) {
                    LogicalExpression currentArg = (LogicalExpression)call.args.get(i);
                    TypeProtos.MajorType parmType = matchedNonDrillFuncHolder.getParmMajorType(i);
                    if (Types.softEquals(parmType, currentArg.getMajorType(), true)) {
                        extArgsWithCast.add(currentArg);
                        continue;
                    }
                    extArgsWithCast.add(MaterializeVisitor.addCastExpression((LogicalExpression)call.args.get(i), parmType, registry, this.errorCollector));
                }
                return matchedNonDrillFuncHolder.getExpr(call.getName(), extArgsWithCast, call.getPosition());
            }
            ExpressionTreeMaterializer.logFunctionResolutionError(this.errorCollector, call);
            return NullExpression.INSTANCE;
        }

        public static LogicalExpression addCastExpression(LogicalExpression fromExpr, TypeProtos.MajorType toType, FunctionImplementationRegistry registry, ErrorCollector errorCollector) {
            String castFuncName = CastFunctions.getCastFunc(toType.getMinorType());
            ArrayList<LogicalExpression> castArgs = Lists.newArrayList();
            castArgs.add(fromExpr);
            if (!Types.isFixedWidthType(toType)) {
                castArgs.add(new ValueExpressions.LongExpression(65536L, null));
            } else if (toType.getMinorType().name().startsWith("DECIMAL")) {
                castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getPrecision(), null));
                castArgs.add(new ValueExpressions.LongExpression(fromExpr.getMajorType().getScale(), null));
            }
            FunctionCall castCall = new FunctionCall(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
            FunctionResolver resolver = FunctionResolverFactory.getResolver(castCall);
            DrillFuncHolder matchedCastFuncHolder = registry.findDrillFunction(resolver, castCall);
            if (matchedCastFuncHolder == null) {
                ExpressionTreeMaterializer.logFunctionResolutionError(errorCollector, castCall);
                return NullExpression.INSTANCE;
            }
            return matchedCastFuncHolder.getExpr(castFuncName, castArgs, ExpressionPosition.UNKNOWN);
        }

        @Override
        public LogicalExpression visitIfExpression(IfExpression ifExpr, FunctionImplementationRegistry registry) {
            Optional<LogicalExpression> nonNullExpr;
            TypeProtos.MinorType elseType;
            IfExpression.IfCondition conditions = ifExpr.ifCondition;
            LogicalExpression newElseExpr = ifExpr.elseExpression.accept(this, registry);
            LogicalExpression newCondition = conditions.condition.accept(this, registry);
            LogicalExpression newExpr = conditions.expression.accept(this, registry);
            conditions = new IfExpression.IfCondition(newCondition, newExpr);
            TypeProtos.MinorType thenType = conditions.expression.getMajorType().getMinorType();
            if (thenType != (elseType = newElseExpr.getMajorType().getMinorType()) && thenType != TypeProtos.MinorType.NULL && elseType != TypeProtos.MinorType.NULL) {
                TypeProtos.MinorType leastRestrictive = TypeCastRules.getLeastRestrictiveType(Arrays.asList(thenType, elseType));
                if (leastRestrictive != thenType) {
                    conditions = new IfExpression.IfCondition(newCondition, MaterializeVisitor.addCastExpression(conditions.expression, newElseExpr.getMajorType(), registry, this.errorCollector));
                } else if (leastRestrictive != elseType) {
                    newElseExpr = MaterializeVisitor.addCastExpression(newElseExpr, conditions.expression.getMajorType(), registry, this.errorCollector);
                } else {
                    throw new DrillRuntimeException("Case expression should have similar output type on all its branches");
                }
            }
            ArrayList<LogicalExpression> allExpressions = Lists.newArrayList();
            allExpressions.add(conditions.expression);
            allExpressions.add(newElseExpr);
            boolean containsNullExpr = Iterables.any(allExpressions, new Predicate<LogicalExpression>(){

                public boolean apply(LogicalExpression input) {
                    return input instanceof NullExpression;
                }
            });
            if (containsNullExpr && (nonNullExpr = Iterables.tryFind(allExpressions, new Predicate<LogicalExpression>(){

                public boolean apply(LogicalExpression input) {
                    return !input.getMajorType().getMinorType().equals(TypeProtos.MinorType.NULL);
                }
            })).isPresent()) {
                TypeProtos.MajorType type = ((LogicalExpression)nonNullExpr.get()).getMajorType();
                conditions = new IfExpression.IfCondition(conditions.condition, this.rewriteNullExpression(conditions.expression, type));
                newElseExpr = this.rewriteNullExpression(newElseExpr, type);
            }
            if (IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build().getMajorType().getMode() == TypeProtos.DataMode.OPTIONAL) {
                IfExpression.IfCondition condition = conditions;
                if (condition.expression.getMajorType().getMode() != TypeProtos.DataMode.OPTIONAL) {
                    conditions = new IfExpression.IfCondition(condition.condition, this.getConvertToNullableExpr(ImmutableList.of(condition.expression), condition.expression.getMajorType().getMinorType(), registry));
                }
                if (newElseExpr.getMajorType().getMode() != TypeProtos.DataMode.OPTIONAL) {
                    newElseExpr = this.getConvertToNullableExpr(ImmutableList.of(newElseExpr), newElseExpr.getMajorType().getMinorType(), registry);
                }
            }
            return this.validateNewExpr(IfExpression.newBuilder().setElse(newElseExpr).setIfCondition(conditions).build());
        }

        private LogicalExpression getConvertToNullableExpr(List<LogicalExpression> args, TypeProtos.MinorType minorType, FunctionImplementationRegistry registry) {
            String funcName = "convertToNullable" + minorType.toString();
            FunctionCall funcCall = new FunctionCall(funcName, args, ExpressionPosition.UNKNOWN);
            FunctionResolver resolver = FunctionResolverFactory.getResolver(funcCall);
            DrillFuncHolder matchedConvertToNullableFuncHolder = registry.findDrillFunction(resolver, funcCall);
            if (matchedConvertToNullableFuncHolder == null) {
                ExpressionTreeMaterializer.logFunctionResolutionError(this.errorCollector, funcCall);
                return NullExpression.INSTANCE;
            }
            return matchedConvertToNullableFuncHolder.getExpr(funcName, args, ExpressionPosition.UNKNOWN);
        }

        private LogicalExpression rewriteNullExpression(LogicalExpression expr, TypeProtos.MajorType type) {
            if (expr instanceof NullExpression) {
                return new TypedNullConstant(type);
            }
            return expr;
        }

        @Override
        public LogicalExpression visitSchemaPath(SchemaPath path, FunctionImplementationRegistry value) {
            TypedFieldId tfId = this.batch.getValueVectorId(path);
            if (tfId == null) {
                logger.warn("Unable to find value vector of path {}, returning null instance.", (Object)path);
                return NullExpression.INSTANCE;
            }
            ValueVectorReadExpression e = new ValueVectorReadExpression(tfId);
            return e;
        }

        @Override
        public LogicalExpression visitIntConstant(ValueExpressions.IntExpression intExpr, FunctionImplementationRegistry value) {
            return intExpr;
        }

        @Override
        public LogicalExpression visitLongConstant(ValueExpressions.LongExpression intExpr, FunctionImplementationRegistry registry) {
            return intExpr;
        }

        @Override
        public LogicalExpression visitDateConstant(ValueExpressions.DateExpression intExpr, FunctionImplementationRegistry registry) {
            return intExpr;
        }

        @Override
        public LogicalExpression visitTimeConstant(ValueExpressions.TimeExpression intExpr, FunctionImplementationRegistry registry) {
            return intExpr;
        }

        @Override
        public LogicalExpression visitTimeStampConstant(ValueExpressions.TimeStampExpression intExpr, FunctionImplementationRegistry registry) {
            return intExpr;
        }

        @Override
        public LogicalExpression visitNullConstant(TypedNullConstant nullConstant, FunctionImplementationRegistry value) throws RuntimeException {
            return nullConstant;
        }

        @Override
        public LogicalExpression visitIntervalYearConstant(ValueExpressions.IntervalYearExpression intExpr, FunctionImplementationRegistry registry) {
            return intExpr;
        }

        @Override
        public LogicalExpression visitIntervalDayConstant(ValueExpressions.IntervalDayExpression intExpr, FunctionImplementationRegistry registry) {
            return intExpr;
        }

        @Override
        public LogicalExpression visitDoubleConstant(ValueExpressions.DoubleExpression dExpr, FunctionImplementationRegistry registry) {
            return dExpr;
        }

        @Override
        public LogicalExpression visitBooleanConstant(ValueExpressions.BooleanExpression e, FunctionImplementationRegistry registry) {
            return e;
        }

        @Override
        public LogicalExpression visitQuotedStringConstant(ValueExpressions.QuotedString e, FunctionImplementationRegistry registry) {
            return e;
        }

        @Override
        public LogicalExpression visitConvertExpression(ConvertExpression e, FunctionImplementationRegistry value) {
            String convertFunctionName = e.getConvertFunction() + e.getEncodingType();
            ArrayList<LogicalExpression> newArgs = Lists.newArrayList();
            newArgs.add(e.getInput());
            FunctionCall fc = new FunctionCall(convertFunctionName, newArgs, e.getPosition());
            return fc.accept(this, value);
        }

        @Override
        public LogicalExpression visitCastExpression(CastExpression e, FunctionImplementationRegistry value) {
            LogicalExpression input = e.getInput().accept(this, value);
            TypeProtos.MajorType newMajor = e.getMajorType();
            TypeProtos.MinorType newMinor = input.getMajorType().getMinorType();
            if (this.castEqual(e.getPosition(), input.getMajorType(), newMajor)) {
                return input;
            }
            if (newMinor == TypeProtos.MinorType.LATE) {
                return new CastExpression(input, e.getMajorType(), e.getPosition());
            }
            if (newMinor == TypeProtos.MinorType.NULL) {
                return new TypedNullConstant(Types.optional(e.getMajorType().getMinorType()));
            }
            TypeProtos.MajorType type = e.getMajorType();
            String castFuncWithType = CastFunctions.getCastFunc(type.getMinorType());
            ArrayList<LogicalExpression> newArgs = Lists.newArrayList();
            newArgs.add(e.getInput());
            if (!Types.isFixedWidthType(type)) {
                newArgs.add(new ValueExpressions.LongExpression(type.getWidth(), null));
            } else if (type.getMinorType().name().startsWith("DECIMAL")) {
                newArgs.add(new ValueExpressions.LongExpression(type.getPrecision(), null));
                newArgs.add(new ValueExpressions.LongExpression(type.getScale(), null));
            }
            FunctionCall fc = new FunctionCall(castFuncWithType, newArgs, e.getPosition());
            return fc.accept(this, value);
        }

        private boolean castEqual(ExpressionPosition pos, TypeProtos.MajorType from, TypeProtos.MajorType to) {
            if (!from.getMinorType().equals(to.getMinorType())) {
                return false;
            }
            switch (from.getMinorType()) {
                case FLOAT4: 
                case FLOAT8: 
                case INT: 
                case BIGINT: 
                case BIT: 
                case TINYINT: 
                case SMALLINT: 
                case UINT1: 
                case UINT2: 
                case UINT4: 
                case UINT8: 
                case TIME: 
                case TIMESTAMP: 
                case TIMESTAMPTZ: 
                case DATE: 
                case INTERVAL: 
                case INTERVALDAY: 
                case INTERVALYEAR: {
                    return true;
                }
                case DECIMAL9: 
                case DECIMAL18: 
                case DECIMAL28DENSE: 
                case DECIMAL28SPARSE: 
                case DECIMAL38DENSE: 
                case DECIMAL38SPARSE: {
                    return to.getScale() == from.getScale() && to.getPrecision() == from.getPrecision();
                }
                case FIXED16CHAR: 
                case FIXEDBINARY: 
                case FIXEDCHAR: {
                    this.errorCollector.addGeneralError(pos, "Casting fixed width types are not yet supported..");
                    return false;
                }
                case VAR16CHAR: 
                case VARBINARY: 
                case VARCHAR: {
                    return to.getWidth() >= from.getWidth() && from.getWidth() > 0 || to.getWidth() == 0;
                }
            }
            this.errorCollector.addGeneralError(pos, String.format("Casting rules are unknown for type %s.", from));
            return false;
        }
    }
}

