Commit fc5b07af authored by Jens Behrmann's avatar Jens Behrmann

Delete MSNeuralNetClassifier.m~

parent 0a54b0f9
classdef MSNeuralNetClassifier < MSClassifier
% Neural network classifier
properties
neuralNetFile
% Serialized Neural Net. Also contains class
% dictionaries that can translate between the class
% labels used by Python (0,..,N-1) and the class
% labels transmitted by MSClassifyLib (e.g. 1,3 and 4).
neuralNetArchitecture
% An external network architecture can be supplied.
% If neuralNetArchitecture is 'auto', the standard architecture
% defined in NNInterface.py is used.
dataFile
% File used for data exchange between Python and
% Matlab. Contains spectra (and class labels during
% training).
% Default: 'neuralNetTemp'
resultFile
% File used for data exchange between Python and
% Matlab. Contains predicted class labels and
% predicted class membership probabilites.
% Default: 'neuralNetTemp'
continueTraining
% false if the neuralNetFile should be created anew
% every time the neural network is trained,
% otherwise true.
% Default: false
keepTemporaryFiles
% true if temporary files should be kept, else false
useValidationSet
% If true, 10% of the training data is used for validation.
hyperparameters
% Struct of hyperparameters.
useTimeStamp
% If 1, a timestamp is added to the neuralNetFile-name upon the
% beginning of the training. If for example an
% MSClassificationStudy is performed with cross-validation, the
% file name would be identical for all cross-validation folds, so
% that each neuralNetFile would be overwritten.
saveBestValidModel
% If 1, the best model determined by the lowest valid loss during
% training is saved and the final returned model is this best model
% default: 1
end
methods
function obj = MSNeuralNetClassifier (varargin)
obj@MSClassifier;
parser = inputParser;
% @isstr does not work for input parsers.
checkString = @(x) validateattributes(x,{'char'},{'nonempty'});
addParameter(parser,'neuralNetFile','neuralNetFile.nnet',checkString);
addParameter(parser,'dataFile','dataFile.mat',checkString);
addParameter(parser,'resultFile','resultFile.mat',checkString);
addParameter(parser,'neuralNetArchitecture','auto',checkString);
addParameter(parser,'continueTraining',0,@isnumeric);
addParameter(parser,'keepTemporaryFiles',0,@isnumeric);
addParameter(parser,'useValidationSet',0,@isnumeric);
addParameter(parser,'useTimeStamp',0,@isnumeric);
addParameter(parser,'saveBestValidModel',1,@isnumeric);
isPositiveNumeric = @(x) isnumeric(x) && x>0;
isNonnegativeNumeric = @(x) isnumeric(x) && x>=0;
isRatio = @(x) isnumeric(x) && x>=0 && x<=1;
addParameter(parser,'batchSize',128,isPositiveNumeric);
addParameter(parser,'epochs',100,isPositiveNumeric);
addParameter(parser,'learningRate',0.001,isPositiveNumeric);
addParameter(parser,'l1',0.,isNonnegativeNumeric);
addParameter(parser,'l2',0.,isNonnegativeNumeric);
addParameter(parser,'validationSetRatio',0.1,isRatio);
parse(parser,varargin{:});
obj.neuralNetFile = parser.Results.neuralNetFile;
obj.dataFile = parser.Results.dataFile;
obj.neuralNetArchitecture = parser.Results.neuralNetArchitecture;
obj.resultFile = parser.Results.resultFile;
obj.continueTraining = (parser.Results.continueTraining~=0);
obj.keepTemporaryFiles = (parser.Results.keepTemporaryFiles~=0);
obj.useValidationSet = (parser.Results.useValidationSet~=0);
obj.useTimeStamp = (parser.Results.useTimeStamp~=0);
obj.saveBestValidModel = (parser.Results.saveBestValidModel~=0);
obj.hyperparameters = struct();
obj.hyperparameters.batchSize = ...
int32(ceil(parser.Results.batchSize));
obj.hyperparameters.epochs = ...
int32(ceil(parser.Results.epochs));
obj.hyperparameters.learningRate = ...
parser.Results.learningRate;
obj.hyperparameters.l1 = ...
parser.Results.l1;
obj.hyperparameters.l2 = ...
parser.Results.l2;
obj.hyperparameters.validationSetRatio = ...
parser.Results.validationSetRatio;
end
end
methods (Access = protected)
function model = trainModel_impl (obj, msData, labels, numFeatures)
% Number of features defaults to length of data items
N = numFeatures;
if isempty(N)
N = msData.dataLength;
end
if N < msData.dataLength
disp(['Used only the ',num2str(N),' first features for training!'])
end
nonZeroLabels = (labels > 0);
% If useTimeStamp is true, add a timestamp to the file name.
if obj.useTimeStamp
[pathstr,name,ext] = fileparts(obj.neuralNetFile);
[name,~] = strsplit(name,'@');
name = name{1};
timestamp = ...
char(datetime('now','Format','yyyyMMdd_HHmmss'));
obj.neuralNetFile = ...
fullfile(pathstr,[strcat(name,'@',timestamp) ext]);
end
% Use only labeled data for transfer
data = msData.data(nonZeroLabels, 1:N);
data(isnan(msData.data)) = 0;
% compute standard deviation per input dimension and store as
% property of created NNModel
inputDimScaling = std(data,1);
classes = labels(nonZeroLabels,1);
save(obj.dataFile, 'data', 'classes','-v7.3');
clearvars data classes;
% The following lines call a python program which performs the
% network's training.
% -B prevents the creation of the bytecode-file (.pyc)
pythonPrefix = '';
pyCommand = strcat(pythonPrefix,...
{'python -B '},...
which('NNInterface.py'),...
{' -neuralNetFile '}, obj.neuralNetFile,...
{' -dataFile '}, obj.dataFile,...
{' -resultFile '}, obj.resultFile,...
{' -neuralNetArchitecture '}, obj.neuralNetArchitecture,...
{' -mode '}, 'train');
if ~obj.continueTraining
pyCommand = strcat(pyCommand, {' -overwrite '});
end
if obj.useValidationSet
pyCommand = strcat(pyCommand, {' -useValidationSet'});
end
if obj.saveBestValidModel
pyCommand = strcat(pyCommand, {' -saveBestValidModel'});
end
hyperparameterKeys = fieldnames(obj.hyperparameters)';
for key = hyperparameterKeys
pyCommand = strcat(pyCommand, {[' -',char(key)]});
pyCommand = strcat(pyCommand, ...
{[' ',num2str(getfield(obj.hyperparameters,char(key)))]});
end
system(pyCommand{1});
model = MSNeuralNetModel(numFeatures,...
inputDimScaling, ....
obj.neuralNetFile,obj.neuralNetArchitecture,...
obj.dataFile,obj.resultFile,obj.keepTemporaryFiles,class(obj));
if ~obj.keepTemporaryFiles
delete(obj.dataFile)
end
end
end
end
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment