Commit f49ac2be authored by Jens Behrmann's avatar Jens Behrmann

Delete MSNeuralNetModel.m~

parent fc5b07af
classdef MSNeuralNetModel < MSClassificationModel
% Neural network classification model
properties (SetAccess = immutable)
numFeatures % Number of used features.
neuralNetFile
% Serialized Neural Net. Also contains class
% dictionaries that can translate between the class
% labels used by Python (0,..,N-1) and the class
% labels transmitted by MSClassifyLib (e.g. 1,3 and 4).
neuralNetArchitecture
% An external network architecture can be supplied.
% If neuralNetArchitecture is 'auto', the standard architecture
% defined in NNInterface.py is used.
dataFile
% File used for data exchange between Python and
% Matlab. Contains spectra (and class labels during
% training).
resultFile
% File used for data exchange between Python and
% Matlab. Contains predicted class labels and
% predicted class membership probabilites.
keepTemporaryFiles
% true if temporary files should be kept, else false
inputDimScaling
% scling of input dimension via standard deviation of each input
% dimension of the used training data. Scaling parameter is used to
% scale sensitivity output
end
methods
function obj = MSNeuralNetModel (numFeatures,...
inputDimScaling, ...
neuralNetFile,neuralNetArchitecture,...
dataFile,resultFile,keepTemporaryFiles,...
creator)
% Constructor
obj@MSClassificationModel(creator);
obj.numFeatures = numFeatures;
obj.inputDimScaling = inputDimScaling;
obj.neuralNetFile = neuralNetFile;
obj.neuralNetArchitecture = neuralNetArchitecture;
obj.dataFile = dataFile;
obj.resultFile = resultFile;
obj.keepTemporaryFiles = keepTemporaryFiles;
end
end
methods (Access = protected)
function [prediction, scores] = classify_impl (obj, msData, itemMask)
% prepare data file which is load by python scripts
prepareDataFile(obj, msData, itemMask, false)
% The following lines call a python program which performs the
% network's predicting.
% Modify the next line if you want to use a different
% python environment, e.g. pythonPrefix = '~/NeuralNets/env/bin/'
pythonPrefix = '';
pyCommand = strcat(pythonPrefix,...
{'python -B '},...
which('NNInterface.py'),...
{' -neuralNetFile '}, ...
obj.neuralNetFile,...
{' -dataFile '}, ...
obj.dataFile,...
{' -resultFile '}, ...
obj.resultFile,...
{' -mode '}, {'predict'});
system(pyCommand{1});
% Compute prediction for items in itemMask
if nargout > 1
scores = h5read(obj.resultFile,'/scores');
end
prediction = h5read(obj.resultFile,'/prediction');
%[~,prediction]=max(prediction_temp.probabilities,[],2);
if ~obj.keepTemporaryFiles
delete(obj.dataFile,...
obj.resultFile)
end
end
function [data] = prepareDataFile(obj, msData, itemMask, nll)
if isempty(obj.numFeatures)
N = msData.dataLength;
else
N = obj.numFeatures;
end
% get mask if itemMask was provided as label data
if isa(itemMask, 'MSLabelData')
itemMask = itemMask.data;
end
% reduce data to mask
data = msData.data(itemMask>0,1:N);
data(isnan(msData.data)) = 0;
if nll
% add labels if used for NLL sensitivty analysis
classes = itemMask(itemMask > 0, 1);
save(obj.dataFile, 'data', 'classes','-v7.3');
clearvars data classes
else
save(obj.dataFile, 'data','-v7.3');
clearvars data
end
end
end
methods (Access = public)
function [sensitivity] = sensitivityAnalysis(obj, msdata,...
itemMask, type)
% select type of sensitivity and prepare data file
if strcmp(type, 'NLL') % sensitivity of nonnegative log-likelihood
prepareDataFile(obj, msdata, itemMask, true)
else if strcmp(type, 'perClass') % sensitivity per class
prepareDataFile(obj, msdata, itemMask, false)
else % sensitivity of binary crossentropy
type = 'binary';
prepareDataFile(obj, msdata, itemMask, false)
end
end
% The following lines call a python program which performs the
% network's sensitivity analysis.
pythonPrefix = '';
pyCommand = strcat(pythonPrefix,...
{'python -B '},...
which('NNInterface.py'),...
{' -neuralNetFile '}, ...
obj.neuralNetFile,...
{' -dataFile '}, ...
obj.dataFile,...
{' -resultFile '}, ...
obj.resultFile,...
{' -mode '}, {'sensitivity'}, ...
{' -sensType '}, {type}); % Compute sensitivity for items in itemMask
% call python code from console
system(pyCommand{1});
% Load result file with sensitivity computations
sensitivity = squeeze(h5read(obj.resultFile,'/sensitivity'));
% rearange dimension
if strcmp(type,'perClass')
sensitivity = permute(sensitivity, [3,2,1]);
end
% scale sensitivity by property inputDimScaling
if length(size(sensitivity)) > 2 % if sensitivity per class
for i = 1: size(sensitivity, 1)
sensitivity(i, :, :) = bsxfun(@times,squeeze(sensitivity(i, :, :)), obj.inputDimScaling);
end
else
sensitivity = bsxfun(@times,sensitivity, obj.inputDimScaling);
end
% delete temporary result file if desired
if ~obj.keepTemporaryFiles
delete(obj.dataFile,...
obj.resultFile)
end
end
end
end
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment