/*
 * Decompiled with CFR 0.152.
 */
package oracle.pgx.config.mllib;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import oracle.pgx.common.types.PropertyType;
import oracle.pgx.common.util.ErrorMessages;
import oracle.pgx.config.internal.BatchGeneratorDeserializer;
import oracle.pgx.config.internal.GraphWiseBaseConvLayerConfigDeserializer;
import oracle.pgx.config.internal.LabelMapsDeserializer;
import oracle.pgx.config.internal.LossFunctionDeserializer;
import oracle.pgx.config.internal.categorymapping.CategoryMappingConfig;
import oracle.pgx.config.mllib.GraphWiseBaseConvLayerConfig;
import oracle.pgx.config.mllib.GraphWiseBaseModelConfig;
import oracle.pgx.config.mllib.GraphWiseModelConfig;
import oracle.pgx.config.mllib.GraphWisePredictionLayerConfig;
import oracle.pgx.config.mllib.GraphWiseValidationConfig;
import oracle.pgx.config.mllib.LabelMaps;
import oracle.pgx.config.mllib.batchgenerator.BatchGenerator;
import oracle.pgx.config.mllib.batchgenerator.BatchGenerators;
import oracle.pgx.config.mllib.inputconfig.InputPropertyConfig;
import oracle.pgx.config.mllib.loss.LossFunction;
import oracle.pgx.config.mllib.loss.LossFunctions;

public class SupervisedGraphWiseModelConfig
extends GraphWiseModelConfig {
    public static final EnumSet<PropertyType> SUPPORTED_LABEL_TYPES = EnumSet.of(PropertyType.INTEGER, PropertyType.STRING, PropertyType.BOOLEAN, PropertyType.LONG);
    public static final EnumSet<PropertyType> SUPPORTED_REGRESSION_TYPES = EnumSet.of(PropertyType.FLOAT, PropertyType.DOUBLE, PropertyType.INTEGER, PropertyType.LONG);
    public static final LossFunction DEFAULT_LOSS_FUNCTION_CLASS = LossFunctions.SOFTMAX_CROSS_ENTROPY_LOSS;
    public static final BatchGenerator DEFAULT_BATCH_GENERATOR = BatchGenerators.STANDARD;
    public static final GraphWisePredictionLayerConfig[] DEFAULT_PREDICTION_LAYER_CONFIGS = new GraphWisePredictionLayerConfig[]{new GraphWisePredictionLayerConfig()};
    public static final Map<?, Float> DEFAULT_CLASS_MAP = null;
    private LossFunction lossFunction = DEFAULT_LOSS_FUNCTION_CLASS;
    private BatchGenerator batchGenerator = DEFAULT_BATCH_GENERATOR;
    private GraphWisePredictionLayerConfig[] predictionLayerConfigs = DEFAULT_PREDICTION_LAYER_CONFIGS;
    private String vertexTargetPropertyName = null;
    @JsonDeserialize(using=LabelMapsDeserializer.class)
    private LabelMaps labelMaps;

    public SupervisedGraphWiseModelConfig() {
        this.labelMaps = new LabelMaps();
    }

    @JsonCreator
    public SupervisedGraphWiseModelConfig(@JsonProperty(required=true, value="batchSize") int batchSize, @JsonProperty(required=true, value="numEpochs") int numEpochs, @JsonProperty(required=true, value="learningRate") double learningRate, @JsonProperty(required=false, value="weightDecay") double weightDecay, @JsonProperty(required=true, value="embeddingDim") int embeddingDim, @JsonProperty(required=true, value="seed") Integer seed, @JsonDeserialize(contentUsing=GraphWiseBaseConvLayerConfigDeserializer.class) @JsonProperty(required=true, value="convLayerConfigs") GraphWiseBaseConvLayerConfig[] convLayerConfigs, @JsonProperty(required=true, value="standardize") boolean standardize, @JsonProperty(required=true, value="normalize") boolean normalize, @JsonProperty(required=true, value="shuffle") boolean shuffle, @JsonProperty(required=true, value="vertexInputPropertyNames") List<String> vertexInputPropertyNames, @JsonProperty(required=false, value="edgeInputPropertyNames") List<String> edgeInputPropertyNames, @JsonProperty(required=false, value="vertexInputPropertyConfigs") Map<String, InputPropertyConfig> vertexInputPropertyConfigs, @JsonProperty(required=false, value="edgeInputPropertyConfigs") Map<String, InputPropertyConfig> edgeInputPropertyConfigs, @JsonProperty(required=false, value="categoryMappingConfig") CategoryMappingConfig categoryMappingConfig, @JsonProperty(required=false, value="targetVertexLabelSets") List<Set<String>> targetVertexLabelSets, @JsonProperty(required=true, value="fitted") boolean fitted, @JsonProperty(required=true, value="trainingLoss") double trainingLoss, @JsonProperty(required=true, value="inputFeatureDim") int inputFeatureDim, @JsonProperty(required=false, value="edgeInputFeatureDim") int edgeInputFeatureDim, @JsonDeserialize(using=LossFunctionDeserializer.class) @JsonProperty(required=false, value="lossFunctionClass") LossFunction lossFunctionClass, @JsonDeserialize(using=BatchGeneratorDeserializer.class) @JsonProperty(required=false, value="batchGenerator") BatchGenerator batchGenerator, @JsonProperty(required=true, value="predictionLayerConfigs") GraphWisePredictionLayerConfig[] predictionLayerConfigs, @JsonProperty(required=true, value="vertexTargetPropertyName") String vertexTargetPropertyName, @JsonProperty(required=true, value="labelMaps") LabelMaps labelMaps, @JsonProperty(required=true, value="backend") GraphWiseBaseModelConfig.Backend backend, @JsonProperty(required=false, value="variant") GraphWiseModelConfig.GraphConvModelVariant variant, @JsonProperty(required=false, value="enableAccelerator") boolean enableAccelerator, @JsonProperty(required=false, value="validationConfig") GraphWiseValidationConfig validationConfig) {
        super(batchSize, numEpochs, learningRate, weightDecay, embeddingDim, seed, convLayerConfigs, standardize, normalize, shuffle, vertexInputPropertyNames, edgeInputPropertyNames, vertexInputPropertyConfigs, edgeInputPropertyConfigs, categoryMappingConfig, fitted, trainingLoss, inputFeatureDim, edgeInputFeatureDim, targetVertexLabelSets, backend, variant, enableAccelerator, validationConfig);
        if (lossFunctionClass != null) {
            this.lossFunction = lossFunctionClass;
        }
        if (this.lossFunction == null) {
            throw new IllegalArgumentException("Deserializable json files must include a lossFunctionClass");
        }
        this.batchGenerator = batchGenerator != null ? batchGenerator : DEFAULT_BATCH_GENERATOR;
        this.predictionLayerConfigs = predictionLayerConfigs;
        this.vertexTargetPropertyName = vertexTargetPropertyName;
        this.labelMaps = labelMaps;
    }

    public SupervisedGraphWiseModelConfig(SupervisedGraphWiseModelConfig source) {
        super(source);
        if (source.getLossFunctionClass() == null) {
            throw new IllegalArgumentException(ErrorMessages.getMessage((String)"COPY_MODEL_WITHOUT_LOSS", (Object[])new Object[0]));
        }
        this.setLossFunctionClass(source.getLossFunctionClass());
        this.setBatchGenerator(source.getBatchGenerator());
        this.labelMaps = new LabelMaps();
        if (source.getClassMap() != null) {
            this.setClassMap(new HashMap(source.getClassMap()));
        }
        GraphWisePredictionLayerConfig[] targetPredConfigs = source.getPredictionLayerConfigs();
        GraphWisePredictionLayerConfig[] predConfigs = new GraphWisePredictionLayerConfig[targetPredConfigs.length];
        for (int i = 0; i < targetPredConfigs.length; ++i) {
            predConfigs[i] = new GraphWisePredictionLayerConfig();
            predConfigs[i].setActivationFunction(targetPredConfigs[i].getActivationFunction());
            predConfigs[i].setWeightInitScheme(targetPredConfigs[i].getWeightInitScheme());
            predConfigs[i].setHiddenDimension(targetPredConfigs[i].getHiddenDimension());
            predConfigs[i].setDropoutRate(targetPredConfigs[i].getDropoutRate());
        }
        this.setPredictionLayerConfigs(predConfigs);
        this.setVertexTargetPropertyName(source.getVertexTargetPropertyName());
        this.setLabelMaps(source.getLabelMaps());
        if (source.getClassWeights() == null) {
            this.setClassWeights(null);
        } else {
            this.setClassWeights(new HashMap(source.getClassWeights()));
        }
        this.setLabelType(source.getLabelType());
    }

    public SupervisedGraphWiseModelConfig(SupervisedGraphWiseModelConfig source, CategoryMappingConfig categoryMappingConfig) {
        this(source);
        this.categoryMappingConfig = categoryMappingConfig;
    }

    @JsonIgnore
    public int getNumClasses() {
        if (this.labelMaps.getClassMap() == null) {
            throw new IllegalStateException(ErrorMessages.getMessage((String)"NOT_FITTED", (Object[])new Object[0]));
        }
        return this.labelMaps.getClassMap().size();
    }

    public String getVertexTargetPropertyName() {
        return this.vertexTargetPropertyName;
    }

    public final void setVertexTargetPropertyName(String vertexTargetPropertyName) {
        this.vertexTargetPropertyName = vertexTargetPropertyName;
    }

    public GraphWisePredictionLayerConfig[] getPredictionLayerConfigs() {
        return this.predictionLayerConfigs;
    }

    public final void setPredictionLayerConfigs(GraphWisePredictionLayerConfig ... predictionLayerConfigs) {
        this.predictionLayerConfigs = predictionLayerConfigs;
    }

    public LossFunction getLossFunctionClass() {
        return this.lossFunction;
    }

    public final void setLossFunctionClass(LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public BatchGenerator getBatchGenerator() {
        return this.batchGenerator;
    }

    public final void setBatchGenerator(BatchGenerator batchGenerator) {
        this.batchGenerator = batchGenerator;
    }

    @JsonIgnore
    public Map<?, Integer> getClassMap() {
        return this.labelMaps.getClassMap();
    }

    @JsonIgnore
    public final void setClassMap(Map<?, Integer> classMap) {
        this.labelMaps.setClassMap(classMap);
    }

    @JsonIgnore
    public final void setClassWeights(Map<?, Float> classWeights) {
        this.labelMaps.setClassWeights(classWeights);
    }

    @JsonIgnore
    public Map<?, Float> getClassWeights() {
        return this.labelMaps.getClassWeights();
    }

    @JsonIgnore
    public PropertyType getLabelType() {
        return this.labelMaps.getLabelType();
    }

    @JsonIgnore
    public final void setLabelType(PropertyType labelType) {
        this.labelMaps.setLabelType(labelType);
    }

    public LabelMaps getLabelMaps() {
        return this.labelMaps;
    }

    public final void setLabelMaps(LabelMaps labelMaps) {
        this.labelMaps = labelMaps;
    }
}

