[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_deprec.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008 by Ullrich Koethe */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 #ifndef VIGRA_RANDOM_FOREST_DEPREC_HXX
37 #define VIGRA_RANDOM_FOREST_DEPREC_HXX
38 
39 #include <algorithm>
40 #include <map>
41 #include <numeric>
42 #include <iostream>
43 #include <ctime>
44 #include <cstdlib>
45 #include "vigra/mathutil.hxx"
46 #include "vigra/array_vector.hxx"
47 #include "vigra/sized_int.hxx"
48 #include "vigra/matrix.hxx"
49 #include "vigra/random.hxx"
50 #include "vigra/functorexpression.hxx"
51 
52 
53 namespace vigra
54 {
55 
56 /** \addtogroup MachineLearning
57 **/
58 //@{
59 
60 namespace detail
61 {
62 
63 template<class DataMatrix>
64 class RandomForestDeprecFeatureSorter
65 {
66  DataMatrix const & data_;
67  MultiArrayIndex sortColumn_;
68 
69  public:
70 
71  RandomForestDeprecFeatureSorter(DataMatrix const & data, MultiArrayIndex sortColumn)
72  : data_(data),
73  sortColumn_(sortColumn)
74  {}
75 
76  void setColumn(MultiArrayIndex sortColumn)
77  {
78  sortColumn_ = sortColumn;
79  }
80 
81  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
82  {
83  return data_(l, sortColumn_) < data_(r, sortColumn_);
84  }
85 };
86 
87 template<class LabelArray>
88 class RandomForestDeprecLabelSorter
89 {
90  LabelArray const & labels_;
91 
92  public:
93 
94  RandomForestDeprecLabelSorter(LabelArray const & labels)
95  : labels_(labels)
96  {}
97 
98  bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
99  {
100  return labels_[l] < labels_[r];
101  }
102 };
103 
104 template <class CountArray>
105 class RandomForestDeprecClassCounter
106 {
107  ArrayVector<int> const & labels_;
108  CountArray & counts_;
109 
110  public:
111 
112  RandomForestDeprecClassCounter(ArrayVector<int> const & labels, CountArray & counts)
113  : labels_(labels),
114  counts_(counts)
115  {
116  reset();
117  }
118 
119  void reset()
120  {
121  counts_.init(0);
122  }
123 
124  void operator()(MultiArrayIndex l) const
125  {
126  ++counts_[labels_[l]];
127  }
128 };
129 
130 struct DecisionTreeDeprecCountNonzeroFunctor
131 {
132  double operator()(double old, double other) const
133  {
134  if(other != 0.0)
135  ++old;
136  return old;
137  }
138 };
139 
140 struct DecisionTreeDeprecNode
141 {
142  DecisionTreeDeprecNode(int t, MultiArrayIndex bestColumn)
143  : thresholdIndex(t), splitColumn(bestColumn)
144  {}
145 
146  int children[2];
147  int thresholdIndex;
148  Int32 splitColumn;
149 };
150 
151 template <class INT>
152 struct DecisionTreeDeprecNodeProxy
153 {
154  DecisionTreeDeprecNodeProxy(ArrayVector<INT> const & tree, INT n)
155  : node(const_cast<ArrayVector<INT> &>(tree).begin()+n)
156  {}
157 
158  INT & child(INT l) const
159  {
160  return node[l];
161  }
162 
163  INT & decisionWeightsIndex() const
164  {
165  return node[2];
166  }
167 
168  typename ArrayVector<INT>::iterator decisionColumns() const
169  {
170  return node+3;
171  }
172 
173  mutable typename ArrayVector<INT>::iterator node;
174 };
175 
176 struct DecisionTreeDeprecAxisSplitFunctor
177 {
178  ArrayVector<Int32> splitColumns;
179  ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights;
180  double threshold;
181  double totalCounts[2], bestTotalCounts[2];
182  int mtry, classCount, bestSplitColumn;
183  bool pure[2], isWeighted;
184 
185  void init(int mtry, int cols, int classCount, ArrayVector<double> const & weights)
186  {
187  this->mtry = mtry;
188  splitColumns.resize(cols);
189  for(int k=0; k<cols; ++k)
190  splitColumns[k] = k;
191 
192  this->classCount = classCount;
193  classCounts.resize(classCount);
194  currentCounts[0].resize(classCount);
195  currentCounts[1].resize(classCount);
196  bestCounts[0].resize(classCount);
197  bestCounts[1].resize(classCount);
198 
199  isWeighted = weights.size() > 0;
200  if(isWeighted)
201  classWeights = weights;
202  else
203  classWeights.resize(classCount, 1.0);
204  }
205 
206  bool isPure(int k) const
207  {
208  return pure[k];
209  }
210 
211  unsigned int totalCount(int k) const
212  {
213  return (unsigned int)bestTotalCounts[k];
214  }
215 
216  int sizeofNode() const { return 4; }
217 
218  int writeSplitParameters(ArrayVector<Int32> & tree,
219  ArrayVector<double> &terminalWeights)
220  {
221  int currentWeightIndex = terminalWeights.size();
222  terminalWeights.push_back(threshold);
223 
224  int currentNodeIndex = tree.size();
225  tree.push_back(-1); // left child
226  tree.push_back(-1); // right child
227  tree.push_back(currentWeightIndex);
228  tree.push_back(bestSplitColumn);
229 
230  return currentNodeIndex;
231  }
232 
233  void writeWeights(int l, ArrayVector<double> &terminalWeights)
234  {
235  for(int k=0; k<classCount; ++k)
236  terminalWeights.push_back(isWeighted
237  ? bestCounts[l][k]
238  : bestCounts[l][k] / totalCount(l));
239  }
240 
241  template <class U, class C, class AxesIterator, class WeightIterator>
242  bool decideAtNode(MultiArrayView<2, U, C> const & features,
243  AxesIterator a, WeightIterator w) const
244  {
245  return (features(0, *a) < *w);
246  }
247 
248  template <class U, class C, class IndexIterator, class Random>
249  IndexIterator findBestSplit(MultiArrayView<2, U, C> const & features,
250  ArrayVector<int> const & labels,
251  IndexIterator indices, int exampleCount,
252  Random & randint);
253 
254 };
255 
256 
257 template <class U, class C, class IndexIterator, class Random>
258 IndexIterator
259 DecisionTreeDeprecAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C> const & features,
260  ArrayVector<int> const & labels,
261  IndexIterator indices, int exampleCount,
262  Random & randint)
263 {
264  // select columns to be tried for split
265  for(int k=0; k<mtry; ++k)
266  std::swap(splitColumns[k], splitColumns[k+randint(columnCount(features)-k)]);
267 
268  RandomForestDeprecFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0);
269  RandomForestDeprecClassCounter<ArrayVector<double> > counter(labels, classCounts);
270  std::for_each(indices, indices+exampleCount, counter);
271 
272  // find the best gini index
273  double minGini = NumericTraits<double>::max();
274  IndexIterator bestSplit = indices;
275  for(int k=0; k<mtry; ++k)
276  {
277  sorter.setColumn(splitColumns[k]);
278  std::sort(indices, indices+exampleCount, sorter);
279 
280  currentCounts[0].init(0);
281  std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(),
282  currentCounts[1].begin(), std::multiplies<double>());
283  totalCounts[0] = 0;
284  totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0);
285  for(int m = 0; m < exampleCount-1; ++m)
286  {
287  int label = labels[indices[m]];
288  double w = classWeights[label];
289  currentCounts[0][label] += w;
290  totalCounts[0] += w;
291  currentCounts[1][label] -= w;
292  totalCounts[1] -= w;
293 
294  if (m < exampleCount-2 &&
295  features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k]))
296  continue ;
297 
298  double gini = 0.0;
299  if(classCount == 2)
300  {
301  gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] +
302  currentCounts[1][0]*currentCounts[1][1] / totalCounts[1];
303  }
304  else
305  {
306  for(int l=0; l<classCount; ++l)
307  gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) +
308  currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]);
309  }
310  if(gini < minGini)
311  {
312  minGini = gini;
313  bestSplit = indices+m;
314  bestSplitColumn = splitColumns[k];
315  bestCounts[0] = currentCounts[0];
316  bestCounts[1] = currentCounts[1];
317  }
318  }
319 
320 
321 
322  }
323  //std::cerr << minGini << " " << bestSplitColumn << std::endl;
324  // split using the best feature
325  sorter.setColumn(bestSplitColumn);
326  std::sort(indices, indices+exampleCount, sorter);
327 
328  for(int k=0; k<2; ++k)
329  {
330  bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0);
331  }
332 
333  threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0;
334  ++bestSplit;
335 
336  counter.reset();
337  std::for_each(indices, bestSplit, counter);
338  pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
339  counter.reset();
340  std::for_each(bestSplit, indices+exampleCount, counter);
341  pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
342 
343  return bestSplit;
344 }
345 
346 enum { DecisionTreeDeprecNoParent = -1 };
347 
348 template <class Iterator>
349 struct DecisionTreeDeprecStackEntry
350 {
351  DecisionTreeDeprecStackEntry(Iterator i, int c,
352  int lp = DecisionTreeDeprecNoParent, int rp = DecisionTreeDeprecNoParent)
353  : indices(i), exampleCount(c),
354  leftParent(lp), rightParent(rp)
355  {}
356 
357  Iterator indices;
358  int exampleCount, leftParent, rightParent;
359 };
360 
361 class DecisionTreeDeprec
362 {
363  public:
364  typedef Int32 TreeInt;
365  ArrayVector<TreeInt> tree_;
366  ArrayVector<double> terminalWeights_;
367  unsigned int classCount_;
368  DecisionTreeDeprecAxisSplitFunctor split;
369 
370  public:
371 
372 
373  DecisionTreeDeprec(unsigned int classCount)
374  : classCount_(classCount)
375  {}
376 
377  void reset(unsigned int classCount = 0)
378  {
379  if(classCount)
380  classCount_ = classCount;
381  tree_.clear();
382  terminalWeights_.clear();
383  }
384 
385  template <class U, class C, class Iterator, class Options, class Random>
386  void learn(MultiArrayView<2, U, C> const & features,
387  ArrayVector<int> const & labels,
388  Iterator indices, int exampleCount,
389  Options const & options,
390  Random & randint);
391 
392  template <class U, class C>
393  ArrayVector<double>::const_iterator
394  predict(MultiArrayView<2, U, C> const & features) const
395  {
396  int nodeindex = 0;
397  for(;;)
398  {
399  DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
400  nodeindex = split.decideAtNode(features, node.decisionColumns(),
401  terminalWeights_.begin() + node.decisionWeightsIndex())
402  ? node.child(0)
403  : node.child(1);
404  if(nodeindex <= 0)
405  return terminalWeights_.begin() + (-nodeindex);
406  }
407  }
408 
409  template <class U, class C>
410  int
411  predictLabel(MultiArrayView<2, U, C> const & features) const
412  {
413  ArrayVector<double>::const_iterator weights = predict(features);
414  return argMax(weights, weights+classCount_) - weights;
415  }
416 
417  template <class U, class C>
418  int
419  leafID(MultiArrayView<2, U, C> const & features) const
420  {
421  int nodeindex = 0;
422  for(;;)
423  {
424  DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
425  nodeindex = split.decideAtNode(features, node.decisionColumns(),
426  terminalWeights_.begin() + node.decisionWeightsIndex())
427  ? node.child(0)
428  : node.child(1);
429  if(nodeindex <= 0)
430  return -nodeindex;
431  }
432  }
433 
434  void depth(int & maxDep, int & interiorCount, int & leafCount, int k = 0, int d = 1) const
435  {
436  DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
437  ++interiorCount;
438  ++d;
439  for(int l=0; l<2; ++l)
440  {
441  int child = node.child(l);
442  if(child > 0)
443  depth(maxDep, interiorCount, leafCount, child, d);
444  else
445  {
446  ++leafCount;
447  if(maxDep < d)
448  maxDep = d;
449  }
450  }
451  }
452 
453  void printStatistics(std::ostream & o) const
454  {
455  int maxDep = 0, interiorCount = 0, leafCount = 0;
456  depth(maxDep, interiorCount, leafCount);
457 
458  o << "interior nodes: " << interiorCount <<
459  ", terminal nodes: " << leafCount <<
460  ", depth: " << maxDep << "\n";
461  }
462 
463  void print(std::ostream & o, int k = 0, std::string s = "") const
464  {
465  DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
466  o << s << (*node.decisionColumns()) << " " << terminalWeights_[node.decisionWeightsIndex()] << "\n";
467 
468  for(int l=0; l<2; ++l)
469  {
470  int child = node.child(l);
471  if(child <= 0)
472  o << s << " weights " << terminalWeights_[-child] << " "
473  << terminalWeights_[-child+1] << "\n";
474  else
475  print(o, child, s+" ");
476  }
477  }
478 };
479 
480 
481 template <class U, class C, class Iterator, class Options, class Random>
482 void DecisionTreeDeprec::learn(MultiArrayView<2, U, C> const & features,
483  ArrayVector<int> const & labels,
484  Iterator indices, int exampleCount,
485  Options const & options,
486  Random & randint)
487 {
488  ArrayVector<double> const & classLoss = options.class_weights;
489 
490  vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_,
491  "DecisionTreeDeprec2::learn(): class weights array has wrong size.");
492 
493  reset();
494 
495  unsigned int mtry = options.mtry;
496  MultiArrayIndex cols = columnCount(features);
497 
498  split.init(mtry, cols, classCount_, classLoss);
499 
500  typedef DecisionTreeDeprecStackEntry<Iterator> Entry;
501  ArrayVector<Entry> stack;
502  stack.push_back(Entry(indices, exampleCount));
503 
504  while(!stack.empty())
505  {
506 // std::cerr << "*";
507  indices = stack.back().indices;
508  exampleCount = stack.back().exampleCount;
509  int leftParent = stack.back().leftParent,
510  rightParent = stack.back().rightParent;
511 
512  stack.pop_back();
513 
514  Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint);
515 
516 
517  int currentNode = split.writeSplitParameters(tree_, terminalWeights_);
518 
519  if(leftParent != DecisionTreeDeprecNoParent)
520  DecisionTreeDeprecNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode;
521  if(rightParent != DecisionTreeDeprecNoParent)
522  DecisionTreeDeprecNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode;
523  leftParent = currentNode;
524  rightParent = DecisionTreeDeprecNoParent;
525 
526  for(int l=0; l<2; ++l)
527  {
528 
529  if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size)
530  {
531  // sample is still large enough and not yet perfectly separated => split
532  stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent));
533  }
534  else
535  {
536  DecisionTreeDeprecNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size();
537 
538  split.writeWeights(l, terminalWeights_);
539  }
540  std::swap(leftParent, rightParent);
541  indices = bestSplit;
542  }
543  }
544 // std::cerr << "\n";
545 }
546 
547 } // namespace detail
548 
549 class RandomForestOptionsDeprec
550 {
551  public:
552  /** Initialize all options with default values.
553  */
554  RandomForestOptionsDeprec()
555  : training_set_proportion(1.0),
556  mtry(0),
557  min_split_node_size(1),
558  training_set_size(0),
559  sample_with_replacement(true),
560  sample_classes_individually(false),
561  treeCount(255)
562  {}
563 
564  /** Number of features considered in each node.
565 
566  If \a n is 0 (the default), the number of features tried in every node
567  is determined by the square root of the total number of features.
568  According to Breiman, this quantity should always be optimized by means
569  of the out-of-bag error.<br>
570  Default: 0 (use <tt>sqrt(columnCount(featureMatrix))</tt>)
571  */
572  RandomForestOptionsDeprec & featuresPerNode(unsigned int n)
573  {
574  mtry = n;
575  return *this;
576  }
577 
578  /** How to sample the subset of the training data for each tree.
579 
580  Each tree is only trained with a subset of the entire training data.
581  If \a r is <tt>true</tt>, this subset is sampled from the entire training set with
582  replacement.<br>
583  Default: <tt>true</tt> (use sampling with replacement))
584  */
585  RandomForestOptionsDeprec & sampleWithReplacement(bool r)
586  {
587  sample_with_replacement = r;
588  return *this;
589  }
590 
591  RandomForestOptionsDeprec & setTreeCount(unsigned int cnt)
592  {
593  treeCount = cnt;
594  return *this;
595  }
596  /** Proportion of training examples used for each tree.
597 
598  If \a p is 1.0 (the default), and samples are drawn with replacement,
599  the training set of each tree will contain as many examples as the entire
600  training set, but some are drawn multiply and others not at all. On average,
601  each tree is actually trained on about 65% of the examples in the full
602  training set. Changing the proportion makes mainly sense when
603  sampleWithReplacement() is set to <tt>false</tt>. trainingSetSizeProportional() gets
604  overridden by trainingSetSizeAbsolute().<br>
605  Default: 1.0
606  */
607  RandomForestOptionsDeprec & trainingSetSizeProportional(double p)
608  {
609  vigra_precondition(p >= 0.0 && p <= 1.0,
610  "RandomForestOptionsDeprec::trainingSetSizeProportional(): proportion must be in [0, 1].");
611  if(training_set_size == 0) // otherwise, absolute size gets priority
612  training_set_proportion = p;
613  return *this;
614  }
615 
616  /** Size of the training set for each tree.
617 
618  If this option is set, it overrides the proportion set by
619  trainingSetSizeProportional(). When classes are sampled individually,
620  the number of examples is divided by the number of classes (rounded upwards)
621  to determine the number of examples drawn from every class.<br>
622  Default: <tt>0</tt> (determine size by proportion)
623  */
624  RandomForestOptionsDeprec & trainingSetSizeAbsolute(unsigned int s)
625  {
626  training_set_size = s;
627  if(s > 0)
628  training_set_proportion = 0.0;
629  return *this;
630  }
631 
632  /** Are the classes sampled individually?
633 
634  If \a s is <tt>false</tt> (the default), the training set for each tree is sampled
635  without considering class labels. Otherwise, samples are drawn from each
636  class independently. The latter is especially useful in connection
637  with the specification of an absolute training set size: then, the same number of
638  examples is drawn from every class. This can be used as a counter-measure when the
639  classes are very unbalanced in size.<br>
640  Default: <tt>false</tt>
641  */
642  RandomForestOptionsDeprec & sampleClassesIndividually(bool s)
643  {
644  sample_classes_individually = s;
645  return *this;
646  }
647 
648  /** Number of examples required for a node to be split.
649 
650  When the number of examples in a node is below this number, the node is not
651  split even if class separation is not yet perfect. Instead, the node returns
652  the proportion of each class (among the remaining examples) during the
653  prediction phase.<br>
654  Default: 1 (complete growing)
655  */
656  RandomForestOptionsDeprec & minSplitNodeSize(unsigned int n)
657  {
658  if(n == 0)
659  n = 1;
660  min_split_node_size = n;
661  return *this;
662  }
663 
664  /** Use a weighted random forest.
665 
666  This is usually used to penalize the errors for the minority class.
667  Weights must be convertible to <tt>double</tt>, and the array of weights
668  must contain as many entries as there are classes.<br>
669  Default: do not use weights
670  */
671  template <class WeightIterator>
672  RandomForestOptionsDeprec & weights(WeightIterator weights, unsigned int classCount)
673  {
674  class_weights.clear();
675  if(weights != 0)
676  class_weights.insert(weights, classCount);
677  return *this;
678  }
679 
680  RandomForestOptionsDeprec & oobData(MultiArrayView<2, UInt8>& data)
681  {
682  oob_data =data;
683  return *this;
684  }
685 
686  MultiArrayView<2, UInt8> oob_data;
687  ArrayVector<double> class_weights;
688  double training_set_proportion;
689  unsigned int mtry, min_split_node_size, training_set_size;
690  bool sample_with_replacement, sample_classes_individually;
691  unsigned int treeCount;
692 };
693 
694 /*****************************************************************/
695 /* */
696 /* RandomForestDeprec */
697 /* */
698 /*****************************************************************/
699 
700 template <class ClassLabelType>
701 class RandomForestDeprec
702 {
703  public:
704  ArrayVector<ClassLabelType> classes_;
705  ArrayVector<detail::DecisionTreeDeprec> trees_;
706  MultiArrayIndex columnCount_;
707  RandomForestOptionsDeprec options_;
708 
709  public:
710 
711  //First two constructors are straight forward.
712  //they take either the iterators to an Array of Classlabels or the values
713  template<class ClassLabelIterator>
714  RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
715  unsigned int treeCount = 255,
716  RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec())
717  : classes_(cl, cend),
718  trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
719  columnCount_(0),
720  options_(options)
721  {
722  vigra_precondition(options.training_set_proportion == 0.0 ||
723  options.training_set_size == 0,
724  "RandomForestOptionsDeprec: absolute and proportional training set sizes "
725  "cannot be specified at the same time.");
726  vigra_precondition(classes_.size() > 1,
727  "RandomForestOptionsDeprec::weights(): need at least two classes.");
728  vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
729  "RandomForestOptionsDeprec::weights(): wrong number of classes.");
730  }
731 
732  RandomForestDeprec(ClassLabelType const & c1, ClassLabelType const & c2,
733  unsigned int treeCount = 255,
734  RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec())
735  : classes_(2),
736  trees_(treeCount, detail::DecisionTreeDeprec(2)),
737  columnCount_(0),
738  options_(options)
739  {
740  vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2,
741  "RandomForestOptionsDeprec::weights(): wrong number of classes.");
742  classes_[0] = c1;
743  classes_[1] = c2;
744  }
745  //This is esp. For the CrosValidator Class
746  template<class ClassLabelIterator>
747  RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
748  RandomForestOptionsDeprec const & options )
749  : classes_(cl, cend),
750  trees_(options.treeCount , detail::DecisionTreeDeprec(classes_.size())),
751  columnCount_(0),
752  options_(options)
753  {
754 
755  vigra_precondition(options.training_set_proportion == 0.0 ||
756  options.training_set_size == 0,
757  "RandomForestOptionsDeprec: absolute and proportional training set sizes "
758  "cannot be specified at the same time.");
759  vigra_precondition(classes_.size() > 1,
760  "RandomForestOptionsDeprec::weights(): need at least two classes.");
761  vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
762  "RandomForestOptionsDeprec::weights(): wrong number of classes.");
763  }
764 
765  //Not understood yet
766  //Does not use the options object but the columnCount object.
767  template<class ClassLabelIterator, class TreeIterator, class WeightIterator>
768  RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
769  unsigned int treeCount, unsigned int columnCount,
770  TreeIterator trees, WeightIterator weights)
771  : classes_(cl, cend),
772  trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
773  columnCount_(columnCount)
774  {
775  for(unsigned int k=0; k<treeCount; ++k, ++trees, ++weights)
776  {
777  trees_[k].tree_ = *trees;
778  trees_[k].terminalWeights_ = *weights;
779  }
780  }
781 
782  int featureCount() const
783  {
784  vigra_precondition(columnCount_ > 0,
785  "RandomForestDeprec::featureCount(): Random forest has not been trained yet.");
786  return columnCount_;
787  }
788 
789  int labelCount() const
790  {
791  return classes_.size();
792  }
793 
794  int treeCount() const
795  {
796  return trees_.size();
797  }
798 
799  // loss == 0.0 means unweighted random forest
800  template <class U, class C, class Array, class Random>
801  double learn(MultiArrayView<2, U, C> const & features, Array const & labels,
802  Random const& random);
803 
804  template <class U, class C, class Array>
805  double learn(MultiArrayView<2, U, C> const & features, Array const & labels)
806  {
807  RandomNumberGenerator<> generator(RandomSeed);
808  return learn(features, labels, generator);
809  }
810 
811  template <class U, class C>
812  ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features) const;
813 
814  template <class U, class C1, class T, class C2>
815  void predictLabels(MultiArrayView<2, U, C1> const & features,
816  MultiArrayView<2, T, C2> & labels) const
817  {
818  vigra_precondition(features.shape(0) == labels.shape(0),
819  "RandomForestDeprec::predictLabels(): Label array has wrong size.");
820  for(int k=0; k<features.shape(0); ++k)
821  labels(k,0) = predictLabel(rowVector(features, k));
822  }
823 
824  template <class U, class C, class Iterator>
825  ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features,
826  Iterator priors) const;
827 
828  template <class U, class C1, class T, class C2>
829  void predictProbabilities(MultiArrayView<2, U, C1> const & features,
830  MultiArrayView<2, T, C2> & prob) const;
831 
832  template <class U, class C1, class T, class C2>
833  void predictNodes(MultiArrayView<2, U, C1> const & features,
834  MultiArrayView<2, T, C2> & NodeIDs) const;
835 };
836 
837 template <class ClassLabelType>
838 template <class U, class C1, class Array, class Random>
839 double
840 RandomForestDeprec<ClassLabelType>::learn(MultiArrayView<2, U, C1> const & features,
841  Array const & labels,
842  Random const& random)
843 {
844  unsigned int classCount = classes_.size();
845  unsigned int m = rowCount(features);
846  unsigned int n = columnCount(features);
847  vigra_precondition((unsigned int)(m) == (unsigned int)labels.size(),
848  "RandomForestDeprec::learn(): Label array has wrong size.");
849 
850  vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement,
851  "RandomForestDeprec::learn(): Requested training set size exceeds total number of examples.");
852 
853  MultiArrayIndex mtry = (options_.mtry == 0)
854  ? int(std::floor(std::sqrt(double(n)) + 0.5))
855  : options_.mtry;
856 
857  vigra_precondition(mtry <= (MultiArrayIndex)n,
858  "RandomForestDeprec::learn(): mtry must be less than number of features.");
859 
860  MultiArrayIndex msamples = options_.training_set_size;
861  if(options_.sample_classes_individually)
862  msamples = int(std::ceil(double(msamples) / classCount));
863 
864  ArrayVector<int> intLabels(m), classExampleCounts(classCount);
865 
866  // verify the input labels
867  int minClassCount;
868  {
869  typedef std::map<ClassLabelType, int > LabelChecker;
870  typedef typename LabelChecker::iterator LabelCheckerIterator;
871  LabelChecker labelChecker;
872  for(unsigned int k=0; k<classCount; ++k)
873  labelChecker[classes_[k]] = k;
874 
875  for(unsigned int k=0; k<m; ++k)
876  {
877  LabelCheckerIterator found = labelChecker.find(labels[k]);
878  vigra_precondition(found != labelChecker.end(),
879  "RandomForestDeprec::learn(): Unknown class label encountered.");
880  intLabels[k] = found->second;
881  ++classExampleCounts[intLabels[k]];
882  }
883  minClassCount = *argMin(classExampleCounts.begin(), classExampleCounts.end());
884  vigra_precondition(minClassCount > 0,
885  "RandomForestDeprec::learn(): At least one class is missing in the training set.");
886  if(msamples > 0 && options_.sample_classes_individually &&
887  !options_.sample_with_replacement)
888  {
889  vigra_precondition(msamples <= minClassCount,
890  "RandomForestDeprec::learn(): Too few examples in smallest class to reach "
891  "requested training set size.");
892  }
893  }
894  columnCount_ = n;
895  ArrayVector<int> indices(m);
896  for(unsigned int k=0; k<m; ++k)
897  indices[k] = k;
898 
899  if(options_.sample_classes_individually)
900  {
901  detail::RandomForestDeprecLabelSorter<ArrayVector<int> > sorter(intLabels);
902  std::sort(indices.begin(), indices.end(), sorter);
903  }
904 
905  ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m);
906 
907  UniformIntRandomFunctor<Random> randint(0, m-1, random);
908  //std::cerr << "Learning a RF \n";
909  for(unsigned int k=0; k<trees_.size(); ++k)
910  {
911  //std::cerr << "Learning tree " << k << " ...\n";
912 
913  ArrayVector<int> trainingSet;
914  usedIndices.init(0);
915 
916  if(options_.sample_classes_individually)
917  {
918  int first = 0;
919  for(unsigned int l=0; l<classCount; ++l)
920  {
921  int lc = classExampleCounts[l];
922  int lsamples = (msamples == 0)
923  ? int(std::ceil(options_.training_set_proportion*lc))
924  : msamples;
925 
926  if(options_.sample_with_replacement)
927  {
928  for(int ll=0; ll<lsamples; ++ll)
929  {
930  trainingSet.push_back(indices[first+randint(lc)]);
931  ++usedIndices[trainingSet.back()];
932  }
933  }
934  else
935  {
936  for(int ll=0; ll<lsamples; ++ll)
937  {
938  std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]);
939  trainingSet.push_back(indices[first+ll]);
940  ++usedIndices[trainingSet.back()];
941  }
942  //std::sort(indices.begin(), indices.begin()+lsamples);
943  }
944  first += lc;
945  }
946  }
947  else
948  {
949  if(msamples == 0)
950  msamples = int(std::ceil(options_.training_set_proportion*m));
951 
952  if(options_.sample_with_replacement)
953  {
954  for(int l=0; l<msamples; ++l)
955  {
956  trainingSet.push_back(indices[randint(m)]);
957  ++usedIndices[trainingSet.back()];
958  }
959  }
960  else
961  {
962  for(int l=0; l<msamples; ++l)
963  {
964  std::swap(indices[l], indices[l+randint(m-l)/*oikas*/]);
965  trainingSet.push_back(indices[l]);
966  ++usedIndices[trainingSet.back()];
967  }
968 
969 
970  }
971 
972  }
973  trees_[k].learn(features, intLabels,
974  trainingSet.begin(), trainingSet.size(),
975  options_.featuresPerNode(mtry), randint);
976 // for(unsigned int l=0; l<m; ++l)
977 // {
978 // if(!usedIndices[l])
979 // {
980 // ++oobCount[l];
981 // if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
982 // ++oobErrorCount[l];
983 // }
984 // }
985 
986  for(unsigned int l=0; l<m; ++l)
987  {
988  if(!usedIndices[l])
989  {
990  ++oobCount[l];
991  if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
992  {
993  ++oobErrorCount[l];
994  if(options_.oob_data.data() != 0)
995  options_.oob_data(l, k) = 2;
996  }
997  else if(options_.oob_data.data() != 0)
998  {
999  options_.oob_data(l, k) = 1;
1000  }
1001  }
1002  }
1003  // TODO: default value for oob_data
1004  // TODO: implement variable importance
1005  //if(!options_.sample_with_replacement){
1006  //std::cerr << "done\n";
1007  //trees_[k].print(std::cerr);
1008  #ifdef VIGRA_RF_VERBOSE
1009  trees_[k].printStatistics(std::cerr);
1010  #endif
1011  }
1012  double oobError = 0.0;
1013  int totalOobCount = 0;
1014  for(unsigned int l=0; l<m; ++l)
1015  if(oobCount[l])
1016  {
1017  oobError += double(oobErrorCount[l]) / oobCount[l];
1018  ++totalOobCount;
1019  }
1020  return oobError / totalOobCount;
1021 }
1022 
1023 template <class ClassLabelType>
1024 template <class U, class C>
1025 ClassLabelType
1026 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features) const
1027 {
1028  vigra_precondition(columnCount(features) >= featureCount(),
1029  "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
1030  vigra_precondition(rowCount(features) == 1,
1031  "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
1032  Matrix<double> prob(1, classes_.size());
1033  predictProbabilities(features, prob);
1034  return classes_[argMax(prob)];
1035 }
1036 
1037 
1038 //Same thing as above with priors for each label !!!
1039 template <class ClassLabelType>
1040 template <class U, class C, class Iterator>
1041 ClassLabelType
1042 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features,
1043  Iterator priors) const
1044 {
1045  using namespace functor;
1046  vigra_precondition(columnCount(features) >= featureCount(),
1047  "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
1048  vigra_precondition(rowCount(features) == 1,
1049  "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
1050  Matrix<double> prob(1,classes_.size());
1051  predictProbabilities(features, prob);
1052  std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2());
1053  return classes_[argMax(prob)];
1054 }
1055 
1056 template <class ClassLabelType>
1057 template <class U, class C1, class T, class C2>
1058 void
1059 RandomForestDeprec<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1> const & features,
1060  MultiArrayView<2, T, C2> & prob) const
1061 {
1062 
1063  //Features are n xp
1064  //prob is n x NumOfLabel probability for each feature in each class
1065 
1066  vigra_precondition(rowCount(features) == rowCount(prob),
1067  "RandomForestDeprec::predictProbabilities(): Feature matrix and probability matrix size mismatch.");
1068 
1069  // num of features must be bigger than num of features in Random forest training
1070  // but why bigger?
1071  vigra_precondition(columnCount(features) >= featureCount(),
1072  "RandomForestDeprec::predictProbabilities(): Too few columns in feature matrix.");
1073  vigra_precondition(columnCount(prob) == (MultiArrayIndex)labelCount(),
1074  "RandomForestDeprec::predictProbabilities(): Probability matrix must have as many columns as there are classes.");
1075 
1076  //Classify for each row.
1077  for(int row=0; row < rowCount(features); ++row)
1078  {
1079  //contains the weights returned by a single tree???
1080  //thought that one tree has only one vote???
1081  //Pruning???
1082  ArrayVector<double>::const_iterator weights;
1083 
1084  //totalWeight == totalVoteCount!
1085  double totalWeight = 0.0;
1086 
1087  //Set each VoteCount = 0 - prob(row,l) contains vote counts until
1088  //further normalisation
1089  for(unsigned int l=0; l<classes_.size(); ++l)
1090  prob(row, l) = 0.0;
1091 
1092  //Let each tree classify...
1093  for(unsigned int k=0; k<trees_.size(); ++k)
1094  {
1095  //get weights predicted by single tree
1096  weights = trees_[k].predict(rowVector(features, row));
1097 
1098  //update votecount.
1099  for(unsigned int l=0; l<classes_.size(); ++l)
1100  {
1101  prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]);
1102  //every weight in totalWeight.
1103  totalWeight += weights[l];
1104  }
1105  }
1106 
1107  //Normalise votes in each row by total VoteCount (totalWeight
1108  for(unsigned int l=0; l<classes_.size(); ++l)
1109  prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1110  }
1111 }
1112 
1113 
1114 template <class ClassLabelType>
1115 template <class U, class C1, class T, class C2>
1116 void
1117 RandomForestDeprec<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1> const & features,
1118  MultiArrayView<2, T, C2> & NodeIDs) const
1119 {
1120  vigra_precondition(columnCount(features) >= featureCount(),
1121  "RandomForestDeprec::getNodesRF(): Too few columns in feature matrix.");
1122  vigra_precondition(rowCount(features) <= rowCount(NodeIDs),
1123  "RandomForestDeprec::getNodesRF(): Too few rows in NodeIds matrix");
1124  vigra_precondition(columnCount(NodeIDs) >= treeCount(),
1125  "RandomForestDeprec::getNodesRF(): Too few columns in NodeIds matrix.");
1126  NodeIDs.init(0);
1127  for(unsigned int k=0; k<trees_.size(); ++k)
1128  {
1129  for(int row=0; row < rowCount(features); ++row)
1130  {
1131  NodeIDs(row,k) = trees_[k].leafID(rowVector(features, row));
1132  }
1133  }
1134 }
1135 
1136 //@}
1137 
1138 } // namespace vigra
1139 
1140 
1141 #endif // VIGRA_RANDOM_FOREST_HXX
1142 
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
Iterator argMin(Iterator first, Iterator last)
Find the minimum element in a sequence.
Definition: algorithm.hxx:68
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition: fixedpoint.hxx:667
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)