/*
 * Decompiled with CFR 0.152.
 */
package org.ojalgo.ann;

import java.util.function.DoubleUnaryOperator;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.function.BinaryFunction;
import org.ojalgo.function.PrimitiveFunction;
import org.ojalgo.function.UnaryFunction;
import org.ojalgo.function.constant.PrimitiveMath;
import org.ojalgo.matrix.store.MatrixStore;
import org.ojalgo.matrix.store.PhysicalStore;
import org.ojalgo.random.Uniform;
import org.ojalgo.structure.Access2D;
import org.ojalgo.structure.Structure2D;

final class CalculationLayer {
    private ArtificialNeuralNetwork.Activator myActivator;
    private final PhysicalStore<Double> myBias;
    private final PhysicalStore<Double> myWeights;

    CalculationLayer(PhysicalStore.Factory<Double, ?> factory, int numberOfInputs, int numberOfOutputs, ArtificialNeuralNetwork.Activator activator) {
        this.myWeights = (PhysicalStore)factory.make(numberOfInputs, numberOfOutputs);
        this.myBias = (PhysicalStore)factory.make(1, numberOfOutputs);
        this.myActivator = activator;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || !(obj instanceof CalculationLayer)) {
            return false;
        }
        CalculationLayer other = (CalculationLayer)obj;
        if (this.myActivator != other.myActivator) {
            return false;
        }
        if (this.myBias == null ? other.myBias != null : !this.myBias.equals(other.myBias)) {
            return false;
        }
        return !(this.myWeights == null ? other.myWeights != null : !this.myWeights.equals(other.myWeights));
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        result = 31 * result + (this.myActivator == null ? 0 : this.myActivator.hashCode());
        result = 31 * result + (this.myBias == null ? 0 : this.myBias.hashCode());
        result = 31 * result + (this.myWeights == null ? 0 : this.myWeights.hashCode());
        return result;
    }

    public String toString() {
        StringBuilder tmpBuilder = new StringBuilder();
        tmpBuilder.append("CalculationLayer [Weights=");
        tmpBuilder.append(this.myWeights);
        tmpBuilder.append(", Bias=");
        tmpBuilder.append(this.myBias);
        tmpBuilder.append(", Activator=");
        tmpBuilder.append((Object)this.myActivator);
        tmpBuilder.append("]");
        return tmpBuilder.toString();
    }

    void adjust(PhysicalStore<Double> input, PhysicalStore<Double> output, PhysicalStore<Double> upstreamGradient, PhysicalStore<Double> downstreamGradient, double learningRate, double dropoutsFactor, DoubleUnaryOperator regularisation) {
        downstreamGradient.modifyMatching(PrimitiveMath.MULTIPLY, output.onAll((UnaryFunction)this.myActivator.getDerivativeInTermsOfOutput()).transpose());
        if (upstreamGradient != null) {
            this.myWeights.multiply(downstreamGradient, upstreamGradient);
        }
        if (regularisation != null) {
            PrimitiveFunction.Unary modifier = arg -> arg + learningRate * regularisation.applyAsDouble(arg);
            this.myWeights.modifyAll(modifier);
        }
        long nbOutput = this.myWeights.countColumns();
        for (long j = 0L; j < nbOutput; ++j) {
            long batchSize = input.countRows();
            for (long b = 0L; b < batchSize; ++b) {
                double gradient = downstreamGradient.doubleValue(j, b);
                double ratedGradient = learningRate * gradient;
                this.myBias.add(j, ratedGradient);
                long nbInput = this.myWeights.countRows();
                for (long i = 0L; i < nbInput; ++i) {
                    this.myWeights.add(i, j, ratedGradient * (input.doubleValue(b, i) / dropoutsFactor));
                }
            }
        }
    }

    int countInputNodes() {
        return Math.toIntExact(this.myWeights.countRows());
    }

    int countOutputNodes() {
        return Math.toIntExact(this.myWeights.countColumns());
    }

    ArtificialNeuralNetwork.Activator getActivator() {
        return this.myActivator;
    }

    double getBias(int output) {
        return this.myBias.doubleValue(output);
    }

    MatrixStore<Double> getLogicalWeights() {
        return this.myWeights.below((Access2D<Double>)this.myBias);
    }

    Structure2D getStructure() {
        return this.myWeights;
    }

    double getWeight(int input, int output) {
        return this.myWeights.doubleValue(input, output);
    }

    PhysicalStore<Double> invoke(PhysicalStore<Double> input, PhysicalStore<Double> output) {
        this.myWeights.premultiply(input).onColumns((BinaryFunction)PrimitiveMath.ADD, this.myBias).supplyTo(output);
        this.myActivator.activate(output);
        return output;
    }

    PhysicalStore<Double> invoke(PhysicalStore<Double> input, PhysicalStore<Double> output, double probabilityToKeep) {
        this.myWeights.premultiply(input).onColumns((BinaryFunction)PrimitiveMath.ADD, this.myBias).supplyTo(output);
        this.myActivator.activate(output, probabilityToKeep);
        return output;
    }

    void randomise() {
        double magnitude = PrimitiveMath.ONE / Math.sqrt(this.countInputNodes());
        Uniform randomiser = new Uniform(-magnitude, 2.0 * magnitude);
        this.myWeights.fillAll(randomiser);
        this.myBias.fillAll(randomiser);
    }

    void scale(double factor) {
        this.myWeights.modifyAll(PrimitiveMath.MULTIPLY.second(factor));
    }

    void setActivator(ArtificialNeuralNetwork.Activator activator) {
        this.myActivator = activator;
    }

    void setBias(int output, double bias) {
        this.myBias.set((long)output, bias);
    }

    void setWeight(int input, int output, double weight) {
        this.myWeights.set((long)input, (long)output, weight);
    }
}

