From 85fc6526fc5c89f9511bc4526c85c1fba541ffe0 Mon Sep 17 00:00:00 2001 From: "Bernhard J. Berger" <bernhard.berger@uni-bremen.de> Date: Sat, 11 Feb 2023 02:47:13 +0100 Subject: [PATCH] RMSE and RRSE --- .../main/gof/rmse/RMSECalculator.java | 108 ++++++++++++++++++ .../main/gof/rmse/RRSECalculator.java | 103 +++++++++++++++++ .../src/main/java/module-info.java | 1 + .../META-INF/specifications.surrogate/api.dl | 0 .../META-INF/specifications/ea/api.dl | 0 5 files changed, 212 insertions(+) create mode 100644 src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RMSECalculator.java create mode 100644 src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RRSECalculator.java create mode 100644 src/core/de.evoal.surrogate.api/src/main/resources/META-INF/specifications.surrogate/api.dl create mode 100644 src/core/de.evoal.surrogate.api/src/main/resources/META-INF/specifications/ea/api.dl diff --git a/src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RMSECalculator.java b/src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RMSECalculator.java new file mode 100644 index 00000000..420d6063 --- /dev/null +++ b/src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RMSECalculator.java @@ -0,0 +1,108 @@ +package de.evoal.surrogate.main.gof.rmse; + +import de.evoal.core.api.properties.Properties; +import de.evoal.core.api.properties.stream.PropertiesBasedPropertiesPairStreamSupplier; +import de.evoal.core.api.properties.stream.PropertiesPairStreamSupplier; +import de.evoal.core.api.properties.stream.PropertiesStreamSupplier; +import de.evoal.core.api.utils.Requirements; +import de.evoal.surrogate.api.SurrogateInformationCalculator; +import de.evoal.surrogate.api.configuration.Parameter; +import de.evoal.surrogate.api.configuration.SurrogateConfiguration; +import de.evoal.surrogate.api.function.SurrogateFunction; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; +import org.apache.commons.math3.util.Pair; + +import javax.enterprise.context.Dependent; +import javax.inject.Named; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; +import java.util.stream.DoubleStream; + +/** + * Calculates the root mean squared error. + */ +@Dependent +@Named("rmse") +@Slf4j +public class RMSECalculator implements SurrogateInformationCalculator { + + /** + * Surrogate configuration for attaching the calculated r² value. + */ + private SurrogateConfiguration config; + + /** + * The actual surrogate function. + */ + private SurrogateFunction function; + + /** + * Supplier for the trainings data. + */ + private PropertiesStreamSupplier trainingData; + + @Override + public void execute() { + log.info("calculating rmse of surrogate function."); + + final PropertiesPairStreamSupplier pairStream = new PropertiesBasedPropertiesPairStreamSupplier(trainingData, function.getInputSpecification(), function.getOutputSpecification()); + + final List<List<Double>> data = new ArrayList<>(); + for(int i = 0; i < function.getOutputSpecification().size(); ++i) { + data.add(new ArrayList<>()); + } + + pairStream.get() + .forEach(p -> { + final Properties source = p.getFirst(); + final Properties expected = p.getSecond(); + + final Properties actual = function.apply(source); + + for(int i = 0; i < expected.size(); ++i) { + data.get(i).add(Math.pow(expected.getAsDouble(i) - actual.getAsDouble(i), 2.0)); + } + }); + + + + for(int i = 0; i < data.size(); ++i) { + final double average = data.get(i) + .stream() + .mapToDouble(Double.class::cast) + .summaryStatistics() + .getAverage(); + + final double rmse = Math.pow(average, 0.5); + + log.info("RMSE of {}: '{}'", function.getOutputSpecification().get(i), rmse); + + final Parameter goodnessOfFit = Parameter.builder() + .name("rmse") + .value(rmse) + .build(); + + config.addOutputParameter(function.getOutputSpecification().get(i).name(), goodnessOfFit); + } + } + + @Override + public String toString() { + return "rmse"; + } + + @Override + public void configure(final SurrogateFunction function, final SurrogateConfiguration config, final List<Object> parameters, final PropertiesStreamSupplier trainingData) { + Requirements.requireEmpty(parameters); + Requirements.requireNotNull(function); + Requirements.requireNotNull(config); + + this.function = function; + this.config = config; + this.trainingData = trainingData; + } +} diff --git a/src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RRSECalculator.java b/src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RRSECalculator.java new file mode 100644 index 00000000..9df7d648 --- /dev/null +++ b/src/core/de.evoal.surrogate.api/src/main/java/de/evoal/surrogate/main/gof/rmse/RRSECalculator.java @@ -0,0 +1,103 @@ +package de.evoal.surrogate.main.gof.rmse; + +import de.evoal.core.api.properties.Properties; +import de.evoal.core.api.properties.stream.PropertiesBasedPropertiesPairStreamSupplier; +import de.evoal.core.api.properties.stream.PropertiesPairStreamSupplier; +import de.evoal.core.api.properties.stream.PropertiesStreamSupplier; +import de.evoal.core.api.utils.Requirements; +import de.evoal.surrogate.api.SurrogateInformationCalculator; +import de.evoal.surrogate.api.configuration.Parameter; +import de.evoal.surrogate.api.configuration.SurrogateConfiguration; +import de.evoal.surrogate.api.function.SurrogateFunction; +import lombok.extern.slf4j.Slf4j; + +import javax.enterprise.context.Dependent; +import javax.inject.Named; +import java.util.ArrayList; +import java.util.List; + +/** + * Calculates the root mean squared error. + */ +@Dependent +@Named("rrse") +@Slf4j +public class RRSECalculator implements SurrogateInformationCalculator { + + /** + * Surrogate configuration for attaching the calculated r² value. + */ + private SurrogateConfiguration config; + + /** + * The actual surrogate function. + */ + private SurrogateFunction function; + + /** + * Supplier for the trainings data. + */ + private PropertiesStreamSupplier trainingData; + + @Override + public void execute() { + log.info("calculating rmse of surrogate function."); + + final PropertiesPairStreamSupplier pairStream = new PropertiesBasedPropertiesPairStreamSupplier(trainingData, function.getInputSpecification(), function.getOutputSpecification()); + + final List<List<Double>> data = new ArrayList<>(); + for(int i = 0; i < function.getOutputSpecification().size(); ++i) { + data.add(new ArrayList<>()); + } + + pairStream.get() + .forEach(p -> { + final Properties source = p.getFirst(); + final Properties expected = p.getSecond(); + + final Properties actual = function.apply(source); + + for(int i = 0; i < expected.size(); ++i) { + System.err.println(" " + expected.getAsDouble(i) + " -- " + actual.getAsDouble(i)); + data.get(i).add(Math.pow((expected.getAsDouble(i) - actual.getAsDouble(i)) / expected.getAsDouble(i), 2.0)); + } + }); + + + + for(int i = 0; i < data.size(); ++i) { + final double average = data.get(i) + .stream() + .mapToDouble(Double.class::cast) + .summaryStatistics() + .getAverage(); + + final double rrse = Math.pow(average, 0.5); + + log.info("RRSE of {}: '{}'", function.getOutputSpecification().get(i), rrse); + + final Parameter goodnessOfFit = Parameter.builder() + .name("rrse") + .value(rrse) + .build(); + + config.addOutputParameter(function.getOutputSpecification().get(i).name(), goodnessOfFit); + } + } + + @Override + public String toString() { + return "rrse"; + } + + @Override + public void configure(final SurrogateFunction function, final SurrogateConfiguration config, final List<Object> parameters, final PropertiesStreamSupplier trainingData) { + Requirements.requireEmpty(parameters); + Requirements.requireNotNull(function); + Requirements.requireNotNull(config); + + this.function = function; + this.config = config; + this.trainingData = trainingData; + } +} diff --git a/src/core/de.evoal.surrogate.api/src/main/java/module-info.java b/src/core/de.evoal.surrogate.api/src/main/java/module-info.java index 2cf1213f..a4a1a87c 100644 --- a/src/core/de.evoal.surrogate.api/src/main/java/module-info.java +++ b/src/core/de.evoal.surrogate.api/src/main/java/module-info.java @@ -43,6 +43,7 @@ module de.evoal.surrogate.api { opens de.evoal.surrogate.main.internal to weld.core.impl; opens de.evoal.surrogate.main.jackson to weld.core.impl, com.fasterxml.jackson.databind; opens de.evoal.surrogate.main.gof.cross to weld.core.impl; + opens de.evoal.surrogate.main.gof.rmse to weld.core.impl; opens de.evoal.surrogate.main.gof.rsquare to weld.core.impl; opens de.evoal.surrogate.main.statistics.constraint to weld.core.impl; opens de.evoal.surrogate.main.statistics.correlated to weld.core.impl; diff --git a/src/core/de.evoal.surrogate.api/src/main/resources/META-INF/specifications.surrogate/api.dl b/src/core/de.evoal.surrogate.api/src/main/resources/META-INF/specifications.surrogate/api.dl new file mode 100644 index 00000000..e69de29b diff --git a/src/core/de.evoal.surrogate.api/src/main/resources/META-INF/specifications/ea/api.dl b/src/core/de.evoal.surrogate.api/src/main/resources/META-INF/specifications/ea/api.dl new file mode 100644 index 00000000..e69de29b -- GitLab