36 #ifndef VIGRA_RANDOM_FOREST_DEPREC_HXX
37 #define VIGRA_RANDOM_FOREST_DEPREC_HXX
45 #include "vigra/mathutil.hxx"
46 #include "vigra/array_vector.hxx"
47 #include "vigra/sized_int.hxx"
48 #include "vigra/matrix.hxx"
49 #include "vigra/random.hxx"
50 #include "vigra/functorexpression.hxx"
63 template<
class DataMatrix>
64 class RandomForestDeprecFeatureSorter
66 DataMatrix
const & data_;
71 RandomForestDeprecFeatureSorter(DataMatrix
const & data,
MultiArrayIndex sortColumn)
73 sortColumn_(sortColumn)
78 sortColumn_ = sortColumn;
83 return data_(l, sortColumn_) < data_(r, sortColumn_);
87 template<
class LabelArray>
88 class RandomForestDeprecLabelSorter
90 LabelArray
const & labels_;
94 RandomForestDeprecLabelSorter(LabelArray
const & labels)
100 return labels_[l] < labels_[r];
104 template <
class CountArray>
105 class RandomForestDeprecClassCounter
107 ArrayVector<int>
const & labels_;
108 CountArray & counts_;
112 RandomForestDeprecClassCounter(ArrayVector<int>
const & labels, CountArray & counts)
126 ++counts_[labels_[l]];
130 struct DecisionTreeDeprecCountNonzeroFunctor
132 double operator()(
double old,
double other)
const
140 struct DecisionTreeDeprecNode
143 : thresholdIndex(t), splitColumn(bestColumn)
152 struct DecisionTreeDeprecNodeProxy
154 DecisionTreeDeprecNodeProxy(ArrayVector<INT>
const & tree, INT n)
155 : node(const_cast<ArrayVector<INT> &>(tree).begin()+n)
158 INT & child(INT l)
const
163 INT & decisionWeightsIndex()
const
168 typename ArrayVector<INT>::iterator decisionColumns()
const
173 mutable typename ArrayVector<INT>::iterator node;
176 struct DecisionTreeDeprecAxisSplitFunctor
178 ArrayVector<Int32> splitColumns;
179 ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights;
181 double totalCounts[2], bestTotalCounts[2];
182 int mtry, classCount, bestSplitColumn;
183 bool pure[2], isWeighted;
185 void init(
int mtry,
int cols,
int classCount, ArrayVector<double>
const & weights)
188 splitColumns.resize(cols);
189 for(
int k=0; k<cols; ++k)
192 this->classCount = classCount;
193 classCounts.resize(classCount);
194 currentCounts[0].resize(classCount);
195 currentCounts[1].resize(classCount);
196 bestCounts[0].resize(classCount);
197 bestCounts[1].resize(classCount);
199 isWeighted = weights.size() > 0;
201 classWeights = weights;
203 classWeights.resize(classCount, 1.0);
206 bool isPure(
int k)
const
211 unsigned int totalCount(
int k)
const
213 return (
unsigned int)bestTotalCounts[k];
216 int sizeofNode()
const {
return 4; }
218 int writeSplitParameters(ArrayVector<Int32> & tree,
219 ArrayVector<double> &terminalWeights)
221 int currentWeightIndex = terminalWeights.size();
222 terminalWeights.push_back(threshold);
224 int currentNodeIndex = tree.size();
227 tree.push_back(currentWeightIndex);
228 tree.push_back(bestSplitColumn);
230 return currentNodeIndex;
233 void writeWeights(
int l, ArrayVector<double> &terminalWeights)
235 for(
int k=0; k<classCount; ++k)
236 terminalWeights.push_back(isWeighted
238 : bestCounts[l][k] / totalCount(l));
241 template <
class U,
class C,
class AxesIterator,
class WeightIterator>
242 bool decideAtNode(MultiArrayView<2, U, C>
const & features,
243 AxesIterator a, WeightIterator w)
const
245 return (features(0, *a) < *w);
248 template <
class U,
class C,
class IndexIterator,
class Random>
249 IndexIterator findBestSplit(MultiArrayView<2, U, C>
const & features,
250 ArrayVector<int>
const & labels,
251 IndexIterator indices,
int exampleCount,
257 template <
class U,
class C,
class IndexIterator,
class Random>
259 DecisionTreeDeprecAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C>
const & features,
260 ArrayVector<int>
const & labels,
261 IndexIterator indices,
int exampleCount,
265 for(
int k=0; k<mtry; ++k)
266 std::swap(splitColumns[k], splitColumns[k+randint(
columnCount(features)-k)]);
268 RandomForestDeprecFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0);
269 RandomForestDeprecClassCounter<ArrayVector<double> > counter(labels, classCounts);
270 std::for_each(indices, indices+exampleCount, counter);
273 double minGini = NumericTraits<double>::max();
274 IndexIterator bestSplit = indices;
275 for(
int k=0; k<mtry; ++k)
277 sorter.setColumn(splitColumns[k]);
278 std::sort(indices, indices+exampleCount, sorter);
280 currentCounts[0].init(0);
281 std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(),
282 currentCounts[1].begin(), std::multiplies<double>());
284 totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0);
285 for(
int m = 0; m < exampleCount-1; ++m)
287 int label = labels[indices[m]];
288 double w = classWeights[label];
289 currentCounts[0][label] += w;
291 currentCounts[1][label] -= w;
294 if (m < exampleCount-2 &&
295 features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k]))
301 gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] +
302 currentCounts[1][0]*currentCounts[1][1] / totalCounts[1];
306 for(
int l=0; l<classCount; ++l)
307 gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) +
308 currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]);
313 bestSplit = indices+m;
314 bestSplitColumn = splitColumns[k];
315 bestCounts[0] = currentCounts[0];
316 bestCounts[1] = currentCounts[1];
325 sorter.setColumn(bestSplitColumn);
326 std::sort(indices, indices+exampleCount, sorter);
328 for(
int k=0; k<2; ++k)
330 bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0);
333 threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0;
337 std::for_each(indices, bestSplit, counter);
338 pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
340 std::for_each(bestSplit, indices+exampleCount, counter);
341 pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
346 enum { DecisionTreeDeprecNoParent = -1 };
348 template <
class Iterator>
349 struct DecisionTreeDeprecStackEntry
351 DecisionTreeDeprecStackEntry(Iterator i,
int c,
352 int lp = DecisionTreeDeprecNoParent,
int rp = DecisionTreeDeprecNoParent)
353 : indices(i), exampleCount(c),
354 leftParent(lp), rightParent(rp)
358 int exampleCount, leftParent, rightParent;
361 class DecisionTreeDeprec
364 typedef Int32 TreeInt;
365 ArrayVector<TreeInt> tree_;
366 ArrayVector<double> terminalWeights_;
367 unsigned int classCount_;
368 DecisionTreeDeprecAxisSplitFunctor split;
373 DecisionTreeDeprec(
unsigned int classCount)
374 : classCount_(classCount)
377 void reset(
unsigned int classCount = 0)
380 classCount_ = classCount;
382 terminalWeights_.clear();
385 template <
class U,
class C,
class Iterator,
class Options,
class Random>
386 void learn(MultiArrayView<2, U, C>
const & features,
387 ArrayVector<int>
const & labels,
388 Iterator indices,
int exampleCount,
389 Options
const & options,
392 template <
class U,
class C>
393 ArrayVector<double>::const_iterator
394 predict(MultiArrayView<2, U, C>
const & features)
const
399 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
400 nodeindex = split.decideAtNode(features, node.decisionColumns(),
401 terminalWeights_.begin() + node.decisionWeightsIndex())
405 return terminalWeights_.begin() + (-nodeindex);
409 template <
class U,
class C>
411 predictLabel(MultiArrayView<2, U, C>
const & features)
const
413 ArrayVector<double>::const_iterator weights = predict(features);
414 return argMax(weights, weights+classCount_) - weights;
417 template <
class U,
class C>
419 leafID(MultiArrayView<2, U, C>
const & features)
const
424 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
425 nodeindex = split.decideAtNode(features, node.decisionColumns(),
426 terminalWeights_.begin() + node.decisionWeightsIndex())
434 void depth(
int & maxDep,
int & interiorCount,
int & leafCount,
int k = 0,
int d = 1)
const
436 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
439 for(
int l=0; l<2; ++l)
441 int child = node.child(l);
443 depth(maxDep, interiorCount, leafCount, child, d);
453 void printStatistics(std::ostream & o)
const
455 int maxDep = 0, interiorCount = 0, leafCount = 0;
456 depth(maxDep, interiorCount, leafCount);
458 o <<
"interior nodes: " << interiorCount <<
459 ", terminal nodes: " << leafCount <<
460 ", depth: " << maxDep <<
"\n";
463 void print(std::ostream & o,
int k = 0, std::string s =
"")
const
465 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
466 o << s << (*node.decisionColumns()) <<
" " << terminalWeights_[node.decisionWeightsIndex()] <<
"\n";
468 for(
int l=0; l<2; ++l)
470 int child = node.child(l);
472 o << s <<
" weights " << terminalWeights_[-child] <<
" "
473 << terminalWeights_[-child+1] <<
"\n";
475 print(o, child, s+
" ");
481 template <
class U,
class C,
class Iterator,
class Options,
class Random>
482 void DecisionTreeDeprec::learn(MultiArrayView<2, U, C>
const & features,
483 ArrayVector<int>
const & labels,
484 Iterator indices,
int exampleCount,
485 Options
const & options,
488 ArrayVector<double>
const & classLoss = options.class_weights;
490 vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_,
491 "DecisionTreeDeprec2::learn(): class weights array has wrong size.");
495 unsigned int mtry = options.mtry;
498 split.init(mtry, cols, classCount_, classLoss);
500 typedef DecisionTreeDeprecStackEntry<Iterator> Entry;
501 ArrayVector<Entry> stack;
502 stack.push_back(Entry(indices, exampleCount));
504 while(!stack.empty())
507 indices = stack.back().indices;
508 exampleCount = stack.back().exampleCount;
509 int leftParent = stack.back().leftParent,
510 rightParent = stack.back().rightParent;
514 Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint);
517 int currentNode = split.writeSplitParameters(tree_, terminalWeights_);
519 if(leftParent != DecisionTreeDeprecNoParent)
520 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode;
521 if(rightParent != DecisionTreeDeprecNoParent)
522 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode;
523 leftParent = currentNode;
524 rightParent = DecisionTreeDeprecNoParent;
526 for(
int l=0; l<2; ++l)
529 if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size)
532 stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent));
536 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size();
538 split.writeWeights(l, terminalWeights_);
540 std::swap(leftParent, rightParent);
549 class RandomForestOptionsDeprec
554 RandomForestOptionsDeprec()
555 : training_set_proportion(1.0),
557 min_split_node_size(1),
558 training_set_size(0),
559 sample_with_replacement(true),
560 sample_classes_individually(false),
572 RandomForestOptionsDeprec & featuresPerNode(
unsigned int n)
585 RandomForestOptionsDeprec & sampleWithReplacement(
bool r)
587 sample_with_replacement = r;
591 RandomForestOptionsDeprec & setTreeCount(
unsigned int cnt)
607 RandomForestOptionsDeprec & trainingSetSizeProportional(
double p)
609 vigra_precondition(p >= 0.0 && p <= 1.0,
610 "RandomForestOptionsDeprec::trainingSetSizeProportional(): proportion must be in [0, 1].");
611 if(training_set_size == 0)
612 training_set_proportion = p;
624 RandomForestOptionsDeprec & trainingSetSizeAbsolute(
unsigned int s)
626 training_set_size = s;
628 training_set_proportion = 0.0;
642 RandomForestOptionsDeprec & sampleClassesIndividually(
bool s)
644 sample_classes_individually = s;
656 RandomForestOptionsDeprec & minSplitNodeSize(
unsigned int n)
660 min_split_node_size = n;
671 template <
class WeightIterator>
672 RandomForestOptionsDeprec & weights(WeightIterator weights,
unsigned int classCount)
674 class_weights.clear();
676 class_weights.insert(weights, classCount);
680 RandomForestOptionsDeprec & oobData(MultiArrayView<2, UInt8>& data)
686 MultiArrayView<2, UInt8> oob_data;
687 ArrayVector<double> class_weights;
688 double training_set_proportion;
689 unsigned int mtry, min_split_node_size, training_set_size;
690 bool sample_with_replacement, sample_classes_individually;
691 unsigned int treeCount;
700 template <
class ClassLabelType>
701 class RandomForestDeprec
704 ArrayVector<ClassLabelType> classes_;
705 ArrayVector<detail::DecisionTreeDeprec> trees_;
707 RandomForestOptionsDeprec options_;
713 template<
class ClassLabelIterator>
714 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
715 unsigned int treeCount = 255,
716 RandomForestOptionsDeprec
const & options = RandomForestOptionsDeprec())
717 : classes_(cl, cend),
718 trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
722 vigra_precondition(options.training_set_proportion == 0.0 ||
723 options.training_set_size == 0,
724 "RandomForestOptionsDeprec: absolute and proportional training set sizes "
725 "cannot be specified at the same time.");
726 vigra_precondition(classes_.size() > 1,
727 "RandomForestOptionsDeprec::weights(): need at least two classes.");
728 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
729 "RandomForestOptionsDeprec::weights(): wrong number of classes.");
732 RandomForestDeprec(ClassLabelType
const & c1, ClassLabelType
const & c2,
733 unsigned int treeCount = 255,
734 RandomForestOptionsDeprec
const & options = RandomForestOptionsDeprec())
736 trees_(treeCount, detail::DecisionTreeDeprec(2)),
740 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2,
741 "RandomForestOptionsDeprec::weights(): wrong number of classes.");
746 template<
class ClassLabelIterator>
747 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
748 RandomForestOptionsDeprec
const & options )
749 : classes_(cl, cend),
750 trees_(options.treeCount , detail::DecisionTreeDeprec(classes_.size())),
755 vigra_precondition(options.training_set_proportion == 0.0 ||
756 options.training_set_size == 0,
757 "RandomForestOptionsDeprec: absolute and proportional training set sizes "
758 "cannot be specified at the same time.");
759 vigra_precondition(classes_.size() > 1,
760 "RandomForestOptionsDeprec::weights(): need at least two classes.");
761 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
762 "RandomForestOptionsDeprec::weights(): wrong number of classes.");
767 template<
class ClassLabelIterator,
class TreeIterator,
class WeightIterator>
768 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
770 TreeIterator trees, WeightIterator weights)
771 : classes_(cl, cend),
772 trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
773 columnCount_(columnCount)
775 for(
unsigned int k=0; k<treeCount; ++k, ++trees, ++weights)
777 trees_[k].tree_ = *trees;
778 trees_[k].terminalWeights_ = *weights;
782 int featureCount()
const
784 vigra_precondition(columnCount_ > 0,
785 "RandomForestDeprec::featureCount(): Random forest has not been trained yet.");
789 int labelCount()
const
791 return classes_.size();
794 int treeCount()
const
796 return trees_.size();
800 template <
class U,
class C,
class Array,
class Random>
801 double learn(MultiArrayView<2, U, C>
const & features, Array
const & labels,
802 Random
const& random);
804 template <
class U,
class C,
class Array>
805 double learn(MultiArrayView<2, U, C>
const & features, Array
const & labels)
807 RandomNumberGenerator<> generator(RandomSeed);
808 return learn(features, labels, generator);
811 template <
class U,
class C>
812 ClassLabelType predictLabel(MultiArrayView<2, U, C>
const & features)
const;
814 template <
class U,
class C1,
class T,
class C2>
815 void predictLabels(MultiArrayView<2, U, C1>
const & features,
816 MultiArrayView<2, T, C2> & labels)
const
818 vigra_precondition(features.shape(0) == labels.shape(0),
819 "RandomForestDeprec::predictLabels(): Label array has wrong size.");
820 for(
int k=0; k<features.shape(0); ++k)
821 labels(k,0) = predictLabel(
rowVector(features, k));
824 template <
class U,
class C,
class Iterator>
825 ClassLabelType predictLabel(MultiArrayView<2, U, C>
const & features,
826 Iterator priors)
const;
828 template <
class U,
class C1,
class T,
class C2>
829 void predictProbabilities(MultiArrayView<2, U, C1>
const & features,
830 MultiArrayView<2, T, C2> & prob)
const;
832 template <
class U,
class C1,
class T,
class C2>
833 void predictNodes(MultiArrayView<2, U, C1>
const & features,
834 MultiArrayView<2, T, C2> & NodeIDs)
const;
837 template <
class ClassLabelType>
838 template <
class U,
class C1,
class Array,
class Random>
840 RandomForestDeprec<ClassLabelType>::learn(MultiArrayView<2, U, C1>
const & features,
841 Array
const & labels,
842 Random
const& random)
844 unsigned int classCount = classes_.size();
845 unsigned int m =
rowCount(features);
847 vigra_precondition((
unsigned int)(m) == (
unsigned int)labels.size(),
848 "RandomForestDeprec::learn(): Label array has wrong size.");
850 vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement,
851 "RandomForestDeprec::learn(): Requested training set size exceeds total number of examples.");
858 "RandomForestDeprec::learn(): mtry must be less than number of features.");
861 if(options_.sample_classes_individually)
862 msamples = int(
std::ceil(
double(msamples) / classCount));
864 ArrayVector<int> intLabels(m), classExampleCounts(classCount);
869 typedef std::map<ClassLabelType, int > LabelChecker;
870 typedef typename LabelChecker::iterator LabelCheckerIterator;
871 LabelChecker labelChecker;
872 for(
unsigned int k=0; k<classCount; ++k)
873 labelChecker[classes_[k]] = k;
875 for(
unsigned int k=0; k<m; ++k)
877 LabelCheckerIterator found = labelChecker.find(labels[k]);
878 vigra_precondition(found != labelChecker.end(),
879 "RandomForestDeprec::learn(): Unknown class label encountered.");
880 intLabels[k] = found->second;
881 ++classExampleCounts[intLabels[k]];
883 minClassCount = *
argMin(classExampleCounts.begin(), classExampleCounts.end());
884 vigra_precondition(minClassCount > 0,
885 "RandomForestDeprec::learn(): At least one class is missing in the training set.");
886 if(msamples > 0 && options_.sample_classes_individually &&
887 !options_.sample_with_replacement)
889 vigra_precondition(msamples <= minClassCount,
890 "RandomForestDeprec::learn(): Too few examples in smallest class to reach "
891 "requested training set size.");
895 ArrayVector<int> indices(m);
896 for(
unsigned int k=0; k<m; ++k)
899 if(options_.sample_classes_individually)
901 detail::RandomForestDeprecLabelSorter<ArrayVector<int> > sorter(intLabels);
902 std::sort(indices.begin(), indices.end(), sorter);
905 ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m);
907 UniformIntRandomFunctor<Random> randint(0, m-1, random);
909 for(
unsigned int k=0; k<trees_.size(); ++k)
913 ArrayVector<int> trainingSet;
916 if(options_.sample_classes_individually)
919 for(
unsigned int l=0; l<classCount; ++l)
921 int lc = classExampleCounts[l];
922 int lsamples = (msamples == 0)
923 ?
int(
std::ceil(options_.training_set_proportion*lc))
926 if(options_.sample_with_replacement)
928 for(
int ll=0; ll<lsamples; ++ll)
930 trainingSet.push_back(indices[first+randint(lc)]);
931 ++usedIndices[trainingSet.back()];
936 for(
int ll=0; ll<lsamples; ++ll)
938 std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]);
939 trainingSet.push_back(indices[first+ll]);
940 ++usedIndices[trainingSet.back()];
950 msamples = int(
std::ceil(options_.training_set_proportion*m));
952 if(options_.sample_with_replacement)
954 for(
int l=0; l<msamples; ++l)
956 trainingSet.push_back(indices[randint(m)]);
957 ++usedIndices[trainingSet.back()];
962 for(
int l=0; l<msamples; ++l)
964 std::swap(indices[l], indices[l+randint(m-l)]);
965 trainingSet.push_back(indices[l]);
966 ++usedIndices[trainingSet.back()];
973 trees_[k].learn(features, intLabels,
974 trainingSet.begin(), trainingSet.size(),
975 options_.featuresPerNode(mtry), randint);
986 for(
unsigned int l=0; l<m; ++l)
991 if(trees_[k].predictLabel(
rowVector(features, l)) != intLabels[l])
994 if(options_.oob_data.data() != 0)
995 options_.oob_data(l, k) = 2;
997 else if(options_.oob_data.data() != 0)
999 options_.oob_data(l, k) = 1;
1008 #ifdef VIGRA_RF_VERBOSE
1009 trees_[k].printStatistics(std::cerr);
1012 double oobError = 0.0;
1013 int totalOobCount = 0;
1014 for(
unsigned int l=0; l<m; ++l)
1017 oobError += double(oobErrorCount[l]) / oobCount[l];
1020 return oobError / totalOobCount;
1023 template <
class ClassLabelType>
1024 template <
class U,
class C>
1026 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C>
const & features)
const
1028 vigra_precondition(
columnCount(features) >= featureCount(),
1029 "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
1030 vigra_precondition(
rowCount(features) == 1,
1031 "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
1032 Matrix<double> prob(1, classes_.size());
1033 predictProbabilities(features, prob);
1034 return classes_[
argMax(prob)];
1039 template <
class ClassLabelType>
1040 template <
class U,
class C,
class Iterator>
1042 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C>
const & features,
1043 Iterator priors)
const
1045 using namespace functor;
1046 vigra_precondition(
columnCount(features) >= featureCount(),
1047 "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
1048 vigra_precondition(
rowCount(features) == 1,
1049 "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
1050 Matrix<double> prob(1,classes_.size());
1051 predictProbabilities(features, prob);
1052 std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2());
1053 return classes_[
argMax(prob)];
1056 template <
class ClassLabelType>
1057 template <
class U,
class C1,
class T,
class C2>
1059 RandomForestDeprec<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1>
const & features,
1060 MultiArrayView<2, T, C2> & prob)
const
1067 "RandomForestDeprec::predictProbabilities(): Feature matrix and probability matrix size mismatch.");
1071 vigra_precondition(
columnCount(features) >= featureCount(),
1072 "RandomForestDeprec::predictProbabilities(): Too few columns in feature matrix.");
1074 "RandomForestDeprec::predictProbabilities(): Probability matrix must have as many columns as there are classes.");
1077 for(
int row=0; row <
rowCount(features); ++row)
1082 ArrayVector<double>::const_iterator weights;
1085 double totalWeight = 0.0;
1089 for(
unsigned int l=0; l<classes_.size(); ++l)
1093 for(
unsigned int k=0; k<trees_.size(); ++k)
1096 weights = trees_[k].predict(
rowVector(features, row));
1099 for(
unsigned int l=0; l<classes_.size(); ++l)
1101 prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]);
1103 totalWeight += weights[l];
1108 for(
unsigned int l=0; l<classes_.size(); ++l)
1109 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1114 template <
class ClassLabelType>
1115 template <
class U,
class C1,
class T,
class C2>
1117 RandomForestDeprec<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1>
const & features,
1118 MultiArrayView<2, T, C2> & NodeIDs)
const
1120 vigra_precondition(
columnCount(features) >= featureCount(),
1121 "RandomForestDeprec::getNodesRF(): Too few columns in feature matrix.");
1123 "RandomForestDeprec::getNodesRF(): Too few rows in NodeIds matrix");
1124 vigra_precondition(
columnCount(NodeIDs) >= treeCount(),
1125 "RandomForestDeprec::getNodesRF(): Too few columns in NodeIds matrix.");
1127 for(
unsigned int k=0; k<trees_.size(); ++k)
1129 for(
int row=0; row <
rowCount(features); ++row)
1131 NodeIDs(row,k) = trees_[k].leafID(
rowVector(features, row));
1141 #endif // VIGRA_RANDOM_FOREST_HXX
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
Iterator argMin(Iterator first, Iterator last)
Find the minimum element in a sequence.
Definition: algorithm.hxx:68
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition: fixedpoint.hxx:667
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616