package org.jpmml.manager;

import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralInputs;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.NeuralOutputs;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;

/* loaded from: input_file:org/jpmml/manager/NeuralNetworkManager.class */
public class NeuralNetworkManager extends ModelManager<NeuralNetwork> implements HasEntityRegistry<Entity> {
    private NeuralNetwork neuralNetwork;

    public NeuralNetworkManager() {
        this.neuralNetwork = null;
    }

    public NeuralNetworkManager(PMML pmml) {
        this(pmml, (NeuralNetwork) find(pmml.getContent(), NeuralNetwork.class));
    }

    public NeuralNetworkManager(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml);
        this.neuralNetwork = null;
        this.neuralNetwork = neuralNetwork;
    }

    @Override // org.jpmml.manager.Consumer
    public String getSummary() {
        return "Neural network";
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.jpmml.manager.ModelManager
    public NeuralNetwork getModel() {
        Preconditions.checkState(this.neuralNetwork != null);
        return this.neuralNetwork;
    }

    public NeuralNetwork createModel(MiningFunctionType miningFunctionType, ActivationFunctionType activationFunctionType) {
        Preconditions.checkState(this.neuralNetwork == null);
        this.neuralNetwork = new NeuralNetwork(new MiningSchema(), new NeuralInputs(), miningFunctionType, activationFunctionType);
        getModels().add(this.neuralNetwork);
        return this.neuralNetwork;
    }

    public List<NeuralInput> getNeuralInputs() {
        return getModel().getNeuralInputs().getNeuralInputs();
    }

    public NeuralInput addNeuralInput(String str, NormContinuous normContinuous) {
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
        derivedField.setExpression(normContinuous);
        NeuralInput neuralInput = new NeuralInput(derivedField, str);
        getNeuralInputs().add(neuralInput);
        return neuralInput;
    }

    public List<NeuralLayer> getNeuralLayers() {
        return getModel().getNeuralLayers();
    }

    public NeuralLayer addNeuralLayer() {
        NeuralLayer neuralLayer = new NeuralLayer();
        getNeuralLayers().add(neuralLayer);
        return neuralLayer;
    }

    public BiMap<String, Entity> getEntityRegistry() {
        HashBiMap create = HashBiMap.create();
        Iterator<NeuralInput> it2 = getNeuralInputs().iterator();
        while (it2.hasNext()) {
            EntityUtil.put(it2.next(), create);
        }
        Iterator<NeuralLayer> it3 = getNeuralLayers().iterator();
        while (it3.hasNext()) {
            Iterator<Neuron> it4 = it3.next().getNeurons().iterator();
            while (it4.hasNext()) {
                EntityUtil.put(it4.next(), create);
            }
        }
        return create;
    }

    public static Neuron addNeuron(NeuralLayer neuralLayer, String str, Double d) {
        Neuron neuron = new Neuron(str);
        neuron.setBias(d);
        neuralLayer.getNeurons().add(neuron);
        return neuron;
    }

    public static void addConnection(NeuralInput neuralInput, Neuron neuron, double d) {
        neuron.getConnections().add(new Connection(neuralInput.getId(), d));
    }

    public static void addConnection(Neuron neuron, Neuron neuron2, double d) {
        neuron2.getConnections().add(new Connection(neuron.getId(), d));
    }

    public List<NeuralOutput> getOrCreateNeuralOutputs() {
        NeuralNetwork model = getModel();
        NeuralOutputs neuralOutputs = model.getNeuralOutputs();
        if (neuralOutputs == null) {
            neuralOutputs = new NeuralOutputs();
            model.setNeuralOutputs(neuralOutputs);
        }
        return neuralOutputs.getNeuralOutputs();
    }

    public NeuralOutput addNeuralOutput(Neuron neuron, NormContinuous normContinuous) {
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE);
        derivedField.setExpression(normContinuous);
        NeuralOutput neuralOutput = new NeuralOutput(derivedField, neuron.getId());
        getOrCreateNeuralOutputs().add(neuralOutput);
        return neuralOutput;
    }
}
