/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;

public final class CombineProjections
extends OptimizerRules.OptimizerRule<UnaryPlan> {
    public CombineProjections() {
        super(OptimizerRules.TransformDirection.UP);
    }

    @Override
    protected LogicalPlan rule(UnaryPlan plan) {
        LogicalPlan child = plan.child();
        if (plan instanceof Project) {
            Project project = (Project)plan;
            if (child instanceof Project) {
                Project p = (Project)child;
                project = p.withProjections(CombineProjections.combineProjections(project.projections(), p.projections()));
                child = project.child();
                plan = project;
            }
            if (child instanceof Aggregate) {
                Aggregate a = (Aggregate)child;
                List<? extends NamedExpression> aggs = a.aggregates();
                List<? extends NamedExpression> newAggs = CombineProjections.projectAggregations(project.projections(), aggs);
                if (newAggs != null) {
                    List<Expression> newGroups = this.replacePrunedAliasesUsedInGroupBy(a.groupings(), aggs, newAggs);
                    plan = new Aggregate(a.source(), a.child(), a.aggregateType(), newGroups, newAggs);
                }
            }
            return plan;
        }
        if (plan instanceof Aggregate) {
            Aggregate a = (Aggregate)plan;
            if (child instanceof Project) {
                Project p = (Project)child;
                List<Expression> groupings = a.groupings();
                ArrayList<Attribute> groupingAttrs = new ArrayList<Attribute>(a.groupings().size());
                for (Expression grouping : groupings) {
                    if (grouping instanceof Attribute) {
                        Attribute attribute = (Attribute)grouping;
                        groupingAttrs.add(attribute);
                        continue;
                    }
                    throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
                }
                plan = new Aggregate(a.source(), p.child(), a.aggregateType(), CombineProjections.combineUpperGroupingsAndLowerProjections(groupingAttrs, p.projections()), CombineProjections.combineProjections(a.aggregates(), p.projections()));
            }
        }
        return plan;
    }

    private static List<? extends NamedExpression> projectAggregations(List<? extends NamedExpression> upperProjection, List<? extends NamedExpression> lowerAggregations) {
        AttributeSet seen = new AttributeSet();
        for (NamedExpression namedExpression : upperProjection) {
            Expression unwrapped = Alias.unwrap((Expression)namedExpression);
            if (seen.contains((Object)unwrapped)) {
                return null;
            }
            seen.add(Expressions.attribute((Expression)unwrapped));
        }
        lowerAggregations = CombineProjections.combineProjections(upperProjection, lowerAggregations);
        return lowerAggregations;
    }

    private static List<NamedExpression> combineProjections(List<? extends NamedExpression> upper, List<? extends NamedExpression> lower) {
        AttributeMap namedExpressions = new AttributeMap();
        AttributeMap aliases = new AttributeMap();
        for (NamedExpression namedExpression : lower) {
            aliases.put(namedExpression.toAttribute(), (Object)Alias.unwrap((Expression)namedExpression));
            if (!(namedExpression instanceof Alias)) continue;
            Alias alias = (Alias)namedExpression;
            Expression child = alias.child();
            namedExpressions.put(namedExpression.toAttribute(), (Object)alias.replaceChild((Expression)aliases.resolve((Object)child, (Object)child)));
        }
        ArrayList<NamedExpression> replaced = new ArrayList<NamedExpression>();
        for (NamedExpression namedExpression : upper) {
            NamedExpression replacedExp = (NamedExpression)namedExpression.transformUp(Attribute.class, a -> (Expression)namedExpressions.resolve(a, a));
            replaced.add((NamedExpression)CombineProjections.trimNonTopLevelAliases((Expression)replacedExp));
        }
        return replaced;
    }

    private static List<Expression> combineUpperGroupingsAndLowerProjections(List<? extends Attribute> upperGroupings, List<? extends NamedExpression> lowerProjections) {
        AttributeMap aliases = new AttributeMap();
        for (NamedExpression namedExpression : lowerProjections) {
            aliases.put(namedExpression.toAttribute(), (Object)((Attribute)Alias.unwrap((Expression)namedExpression)));
        }
        AttributeSet replaced = new AttributeSet();
        for (Attribute attribute : upperGroupings) {
            replaced.add((Attribute)aliases.resolve((Object)attribute, (Object)attribute));
        }
        return new ArrayList<Expression>((Collection<Expression>)replaced);
    }

    private List<Expression> replacePrunedAliasesUsedInGroupBy(List<Expression> groupings, List<? extends NamedExpression> oldAggs, List<? extends NamedExpression> newAggs) {
        AttributeMap removedAliases = new AttributeMap();
        AttributeSet currentAliases = new AttributeSet((Collection)Expressions.asAttributes(newAggs));
        for (NamedExpression namedExpression : oldAggs) {
            if (!(namedExpression instanceof Alias)) continue;
            Alias alias = (Alias)namedExpression;
            Attribute attr = namedExpression.toAttribute();
            if (currentAliases.contains((Object)attr)) continue;
            removedAliases.put(attr, (Object)alias.child());
        }
        if (removedAliases.isEmpty()) {
            return groupings;
        }
        ArrayList<Expression> newGroupings = new ArrayList<Expression>(groupings.size());
        for (Expression group : groupings) {
            Expression transformed = (Expression)group.transformUp(Attribute.class, a -> (Expression)removedAliases.resolve(a, a));
            if (Expressions.anyMatch(newGroupings, g -> Expressions.equalsAsAttribute((Expression)g, (Expression)transformed))) continue;
            newGroupings.add(transformed);
        }
        return newGroupings;
    }

    public static Expression trimNonTopLevelAliases(Expression e) {
        Expression expression;
        if (e instanceof Alias) {
            Alias a = (Alias)e;
            expression = a.replaceChild(CombineProjections.trimAliases(a.child()));
        } else {
            expression = CombineProjections.trimAliases(e);
        }
        return expression;
    }

    private static Expression trimAliases(Expression e) {
        return (Expression)e.transformDown(Alias.class, Alias::child);
    }
}

