/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.UninitializedParameterException;
import ai.djl.training.GradientCollector;
import ai.djl.training.ParameterServer;
import ai.djl.training.ParameterStore;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.EpochTrainingListener;
import ai.djl.training.listener.EvaluatorTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Trainer
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(Trainer.class);
    private Model model;
    private NDManager manager;
    private Metrics metrics;
    private List<TrainingListener> listeners;
    private Device[] devices;
    private ParameterStore parameterStore;
    private List<Evaluator> evaluators;
    private Loss loss;
    private ExecutorService executorService;
    private boolean gradientsChecked;

    public Trainer(Model model, TrainingConfig trainingConfig) {
        this.model = model;
        this.manager = model.getNDManager().newSubManager();
        this.manager.setName("trainer");
        this.devices = trainingConfig.getDevices();
        this.loss = trainingConfig.getLossFunction();
        Objects.requireNonNull(this.loss, "You must specify a loss for the trainer");
        this.evaluators = new ArrayList<Evaluator>(trainingConfig.getEvaluators());
        this.evaluators.add(this.loss);
        this.executorService = trainingConfig.getExecutorService();
        ParameterServer parameterServer = this.manager.getEngine().newParameterServer(trainingConfig.getOptimizer());
        this.parameterStore = new ParameterStore(this.manager, false);
        this.parameterStore.setParameterServer(parameterServer, this.devices);
        this.listeners = trainingConfig.getTrainingListeners();
        this.notifyListeners(listener -> listener.onTrainingBegin(this));
    }

    public void initialize(Shape ... shapes) {
        this.model.getBlock().initialize(this.model.getNDManager(), this.model.getDataType(), shapes);
        this.model.getBlock().getParameters().forEach(pair -> {
            for (Device device : this.devices) {
                try {
                    this.parameterStore.getValue((Parameter)pair.getValue(), device, true);
                }
                catch (UninitializedParameterException e) {
                    throw new IllegalStateException("Failed to initialize parameter: " + (String)pair.getKey() + ".\nIf you are defining a Block extending AbstractBlock, check that you are initializing all child blocks as part of the overload for AbstractBlock.initializeChildBlocks().", e);
                }
            }
        });
    }

    public Iterable<Batch> iterateDataset(Dataset dataset) throws IOException, TranslateException {
        return dataset.getData(this.getManager(), this.executorService);
    }

    public GradientCollector newGradientCollector() {
        return this.manager.getEngine().newGradientCollector();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public NDList forward(NDList input) {
        long begin = System.nanoTime();
        try {
            NDList nDList = this.model.getBlock().forward(this.parameterStore, input, true);
            return nDList;
        }
        finally {
            this.addMetric("forward", begin);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public NDList forward(NDList data, NDList labels) {
        long begin = System.nanoTime();
        try {
            NDList nDList = this.model.getBlock().forward(this.parameterStore, data, labels, null);
            return nDList;
        }
        finally {
            this.addMetric("forward", begin);
        }
    }

    public NDList evaluate(NDList input) {
        return this.model.getBlock().forward(this.parameterStore, input, false, null);
    }

    public void step() {
        if (!this.gradientsChecked) {
            this.checkGradients();
        }
        long begin = System.nanoTime();
        this.parameterStore.updateAllParameters();
        this.addMetric("step", begin);
    }

    public Metrics getMetrics() {
        return this.metrics;
    }

    public void setMetrics(Metrics metrics) {
        this.metrics = metrics;
    }

    public Device[] getDevices() {
        return this.devices;
    }

    public Loss getLoss() {
        return this.loss;
    }

    public Model getModel() {
        return this.model;
    }

    public Optional<ExecutorService> getExecutorService() {
        return Optional.ofNullable(this.executorService);
    }

    public List<Evaluator> getEvaluators() {
        return this.evaluators;
    }

    public final void notifyListeners(Consumer<TrainingListener> listenerConsumer) {
        this.listeners.forEach(listenerConsumer);
    }

    public TrainingResult getTrainingResult() {
        TrainingResult result = new TrainingResult();
        for (TrainingListener listener : this.listeners) {
            if (listener instanceof EpochTrainingListener) {
                result.setEpoch(((EpochTrainingListener)listener).getNumEpochs());
                continue;
            }
            if (!(listener instanceof EvaluatorTrainingListener)) continue;
            EvaluatorTrainingListener l = (EvaluatorTrainingListener)listener;
            result.setEvaluations(l.getLatestEvaluations());
        }
        return result;
    }

    public NDManager getManager() {
        return this.manager;
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            if (logger.isDebugEnabled()) {
                logger.warn("Trainer for {} was not closed explicitly.", (Object)this.model.getName());
            }
            this.close();
        }
        super.finalize();
    }

    @Override
    public void close() {
        this.notifyListeners(listener -> listener.onTrainingEnd(this));
        this.parameterStore.sync();
        this.manager.close();
    }

    private void checkGradients() {
        ArrayList<NDArray> grads = new ArrayList<NDArray>();
        this.model.getBlock().getParameters().values().stream().filter(Parameter::requiresGradient).forEach(param -> grads.add(this.parameterStore.getValue((Parameter)param, this.devices[0], true).getGradient()));
        try (NDManager scoped = this.manager.newSubManager();){
            scoped.tempAttachAll(new NDList((Collection<NDArray>)grads));
            NDList list = new NDList((NDArray[])grads.stream().map(NDArray::sum).toArray(NDArray[]::new));
            float gradSum = NDArrays.stack(list).sum().getFloat(new long[0]);
            if (gradSum == 0.0f) {
                throw new IllegalStateException("Gradient values are all zeros, please call gradientCollector.backward() onyour target NDArray (usually loss), before calling step() ");
            }
            this.gradientsChecked = true;
        }
    }

    public void addMetric(String metricName, long begin) {
        if (this.metrics != null && begin > 0L) {
            this.metrics.addMetric(metricName, System.nanoTime() - begin);
        }
    }
}

