package org.qsardb.conversion.regression;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.xpath.compiler.PsuedoNames;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PredictorTerm;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionTable;
import org.dmg.pmml.TransformationDictionary;
import org.qsardb.cargo.pmml.FieldNameUtil;
import org.qsardb.conversion.regression.Equation;
import org.qsardb.model.Descriptor;
import org.qsardb.model.Parameter;
import org.qsardb.model.Property;
import org.qsardb.model.Qdb;

/* loaded from: input_file:org/qsardb/conversion/regression/RegressionUtil.class */
public class RegressionUtil {
    private RegressionUtil() {
    }

    public static PMML parse(Qdb qdb, String str) throws ParseException {
        return parse(qdb, new EquationParser().parseEquation(str));
    }

    public static PMML parse(Qdb qdb, Equation equation) {
        DataDictionary dataDictionary = new DataDictionary();
        MiningSchema miningSchema = new MiningSchema();
        RegressionModel regressionModel = new RegressionModel(miningSchema, MiningFunctionType.REGRESSION);
        Property property = qdb.getProperty(equation.getIdentifier());
        if (property == null) {
            throw new IllegalArgumentException("Property '" + equation.getIdentifier() + "' not found");
        }
        addDataField(dataDictionary, property);
        addMiningField(miningSchema, property);
        regressionModel.setTargetFieldName(getFieldName(property));
        RegressionTable regressionTable = new RegressionTable(Double.NaN);
        regressionModel.getRegressionTables().add(regressionTable);
        for (Equation.Term term : equation.getTerms()) {
            Double valueOf = Double.valueOf(term.getCoefficient());
            if (term.isIntercept()) {
                regressionTable.setIntercept(valueOf.doubleValue());
            } else {
                Descriptor descriptor = qdb.getDescriptor(term.getIdentifier());
                if (descriptor == null) {
                    throw new IllegalArgumentException("Descriptor '" + term.getIdentifier() + "' not found");
                }
                addDataField(dataDictionary, descriptor);
                addMiningField(miningSchema, descriptor);
                regressionTable.getNumericPredictors().add(new NumericPredictor(getFieldName(descriptor), valueOf.doubleValue()));
            }
        }
        return new PMML(null, dataDictionary, "4.1").withModels(regressionModel);
    }

    private static void addDataField(DataDictionary dataDictionary, Parameter parameter) {
        DataField dataField = new DataField(getFieldName(parameter), OpType.CONTINUOUS, DataType.DOUBLE);
        dataField.setDisplayName(parameter.getName());
        dataDictionary.getDataFields().add(dataField);
    }

    private static void addMiningField(MiningSchema miningSchema, Parameter parameter) {
        MiningField miningField = new MiningField(getFieldName(parameter));
        if (parameter instanceof Property) {
            miningField.setUsageType(FieldUsageType.PREDICTED);
        } else {
            miningField.setUsageType(FieldUsageType.ACTIVE);
        }
        miningSchema.getMiningFields().add(miningField);
    }

    private static FieldName getFieldName(Parameter parameter) {
        if (parameter instanceof Property) {
            return FieldNameUtil.encodeProperty((Property) parameter);
        }
        if (parameter instanceof Descriptor) {
            return FieldNameUtil.encodeDescriptor((Descriptor) parameter);
        }
        throw new IllegalArgumentException();
    }

    public static Equation format(Qdb qdb, PMML pmml) {
        Equation equation = new Equation();
        RegressionModel regressionModel = (RegressionModel) pmml.getModels().get(0);
        FieldName propertyName = getPropertyName(regressionModel);
        Property decodeProperty = FieldNameUtil.decodeProperty(qdb, propertyName);
        if (decodeProperty == null) {
            throw new IllegalArgumentException("Property '" + propertyName.getValue() + "' not found");
        }
        equation.setIdentifier(decodeProperty.getId());
        ArrayList arrayList = new ArrayList();
        RegressionTable regressionTable = regressionModel.getRegressionTables().get(0);
        Iterator<NumericPredictor> it = regressionTable.getNumericPredictors().iterator();
        while (it.hasNext()) {
            arrayList.add(formatNumericPredictor(qdb, pmml, regressionModel, it.next()));
        }
        Iterator<PredictorTerm> it2 = regressionTable.getPredictorTerms().iterator();
        while (it2.hasNext()) {
            arrayList.add(formatPredictorTerm(qdb, pmml, regressionModel, it2.next()));
        }
        Equation.Term term = new Equation.Term();
        term.setCoefficient(Double.valueOf(regressionTable.getIntercept()).toString());
        arrayList.add(term);
        equation.setTerms(arrayList);
        return equation;
    }

    private static FieldName getPropertyName(RegressionModel regressionModel) {
        for (MiningField miningField : regressionModel.getMiningSchema().getMiningFields()) {
            if (miningField.getUsageType().equals(FieldUsageType.PREDICTED)) {
                return miningField.getName();
            }
        }
        throw new IllegalArgumentException("MiningSchema without predicted field");
    }

    private static Equation.Term formatNumericPredictor(Qdb qdb, PMML pmml, RegressionModel regressionModel, NumericPredictor numericPredictor) {
        Equation.Term createTerm = createTerm(qdb, pmml, regressionModel, numericPredictor.getName());
        createTerm.setCoefficient(Double.valueOf(numericPredictor.getCoefficient()).toString());
        createTerm.setExponent(Integer.toString(numericPredictor.getExponent()));
        return createTerm;
    }

    private static Equation.Term formatPredictorTerm(Qdb qdb, PMML pmml, RegressionModel regressionModel, PredictorTerm predictorTerm) {
        Equation.Term term = new Equation.Term();
        term.setCoefficient(Double.valueOf(predictorTerm.getCoefficient()).toString());
        ArrayList arrayList = new ArrayList();
        Iterator<FieldRef> it = predictorTerm.getFieldRefs().iterator();
        while (it.hasNext()) {
            arrayList.add(createTerm(qdb, pmml, regressionModel, it.next().getField()));
        }
        term.setFunction("*");
        term.setArguments(arrayList);
        return term;
    }

    private static Equation.Term createTerm(Qdb qdb, PMML pmml, RegressionModel regressionModel, FieldName fieldName) {
        DerivedField findDerivedField = findDerivedField(regressionModel.getLocalTransformations(), pmml.getTransformationDictionary(), fieldName);
        if (findDerivedField != null && (findDerivedField.getExpression() instanceof Apply)) {
            return createTerm(qdb, pmml, regressionModel, (Apply) findDerivedField.getExpression());
        }
        if (findDerivedField == null || !(findDerivedField.getExpression() instanceof NormContinuous)) {
            Equation.Term term = new Equation.Term();
            term.setIdentifier(resolveDescriptor(qdb, fieldName).getId());
            return term;
        }
        Equation.Term createTerm = createTerm(qdb, pmml, regressionModel, ((NormContinuous) findDerivedField.getExpression()).getField());
        createTerm.setFunction("norm");
        return createTerm;
    }

    private static Equation.Term createTerm(Qdb qdb, PMML pmml, RegressionModel regressionModel, Apply apply) {
        Equation.Term term = new Equation.Term();
        term.setFunction(formatFunction(apply));
        ArrayList arrayList = new ArrayList();
        for (Expression expression : apply.getExpressions()) {
            if (expression instanceof FieldRef) {
                arrayList.add(createTerm(qdb, pmml, regressionModel, ((FieldRef) expression).getField()));
            } else if (expression instanceof Constant) {
                Equation.Term term2 = new Equation.Term();
                term2.setCoefficient(((Constant) expression).getValue());
                arrayList.add(term2);
            } else {
                if (!(expression instanceof Apply)) {
                    throw new IllegalArgumentException(expression.toString());
                }
                arrayList.add(createTerm(qdb, pmml, regressionModel, (Apply) expression));
            }
        }
        term.setArguments(arrayList);
        return term;
    }

    private static Descriptor resolveDescriptor(Qdb qdb, FieldName fieldName) {
        Descriptor decodeDescriptor = FieldNameUtil.decodeDescriptor(qdb, fieldName);
        if (decodeDescriptor == null) {
            throw new IllegalArgumentException("Descriptor '" + fieldName.getValue() + "' not found");
        }
        return decodeDescriptor;
    }

    private static DerivedField findDerivedField(LocalTransformations localTransformations, TransformationDictionary transformationDictionary, FieldName fieldName) {
        DerivedField derivedField = null;
        if (localTransformations != null) {
            derivedField = findDerivedField(localTransformations.getDerivedFields(), fieldName);
        }
        if (derivedField == null && transformationDictionary != null) {
            derivedField = findDerivedField(transformationDictionary.getDerivedFields(), fieldName);
        }
        return derivedField;
    }

    private static DerivedField findDerivedField(List<DerivedField> list, FieldName fieldName) {
        for (DerivedField derivedField : list) {
            if (fieldName.equals(derivedField.getName())) {
                return derivedField;
            }
        }
        return null;
    }

    private static String formatFunction(Apply apply) {
        String[] strArr = {"log10", "ln", "sqrt", "abs", "exp", "pow", "+", "-", "*", PsuedoNames.PSEUDONAME_ROOT};
        String function = apply.getFunction();
        for (String str : strArr) {
            if ("log10".equals(function)) {
                return "log";
            }
            if (str.equals(function)) {
                return str;
            }
        }
        throw new IllegalArgumentException("Apply function: " + function);
    }
}
