[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_decisionTree.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 #ifndef VIGRA_RANDOM_FOREST_DT_HXX
37 #define VIGRA_RANDOM_FOREST_DT_HXX
38 
39 #include <algorithm>
40 #include <map>
41 #include <numeric>
42 #include "vigra/multi_array.hxx"
43 #include "vigra/mathutil.hxx"
44 #include "vigra/metaprogramming.hxx"
45 #include "vigra/array_vector.hxx"
46 #include "vigra/sized_int.hxx"
47 #include "vigra/matrix.hxx"
48 #include "vigra/random.hxx"
49 #include "vigra/functorexpression.hxx"
50 #include <vector>
51 
52 #include "rf_common.hxx"
53 #include "rf_visitors.hxx"
54 #include "rf_nodeproxy.hxx"
55 namespace vigra
56 {
57 
58 namespace detail
59 {
60  // todo FINALLY DECIDE TO USE CAMEL CASE OR UNDERSCORES !!!!!!
61 /* decisiontree classifier.
62  *
63  * This class is actually meant to be used in conjunction with the
64  * Random Forest Classifier
65  * - My suggestion would be to use the RandomForest classifier with
66  * following parameters instead of directly using this
67  * class (Preprocessing default values etc is handled in there):
68  *
69  * \code
70  * RandomForest decisionTree(RF_Traits::Options_t()
71  * .features_per_node(RF_ALL)
72  * .tree_count(1) );
73  * \endcode
74  *
75  * \todo remove the classCount and featurecount from the topology
76  * array. Pass ext_param_ to the nodes!
77  * \todo Use relative addressing of nodes?
78  */
79 class DecisionTree
80 {
81  /* \todo make private?*/
82  public:
83 
84  /* value type of container array. use whenever referencing it
85  */
86  typedef Int32 TreeInt;
87 
88  ArrayVector<TreeInt> topology_;
89  ArrayVector<double> parameters_;
90 
91  ProblemSpec<> ext_param_;
92  unsigned int classCount_;
93 
94 
95  public:
96  /* \brief Create tree with parameters */
97  template<class T>
98  DecisionTree(ProblemSpec<T> ext_param)
99  :
100  ext_param_(ext_param),
101  classCount_(ext_param.class_count_)
102  {}
103 
104  /* clears all memory used.
105  */
106  void reset(unsigned int classCount = 0)
107  {
108  if(classCount)
109  classCount_ = classCount;
110  topology_.clear();
111  parameters_.clear();
112  }
113 
114 
115  /* learn a Tree
116  *
117  * \tparam StackEntry_t The Stackentry containing Node/StackEntry_t
118  * Information used during learning. Each Split functor has a
119  * Stack entry associated with it (Split_t::StackEntry_t)
120  * \sa RandomForest::learn()
121  */
122  template < class U, class C,
123  class U2, class C2,
124  class StackEntry_t,
125  class Stop_t,
126  class Split_t,
127  class Visitor_t,
128  class Random_t >
129  void learn( MultiArrayView<2, U, C> const & features,
130  MultiArrayView<2, U2, C2> const & labels,
131  StackEntry_t const & stack_entry,
132  Split_t split,
133  Stop_t stop,
134  Visitor_t & visitor,
135  Random_t & randint);
136  template < class U, class C,
137  class U2, class C2,
138  class StackEntry_t,
139  class Stop_t,
140  class Split_t,
141  class Visitor_t,
142  class Random_t>
143  void continueLearn( MultiArrayView<2, U, C> const & features,
144  MultiArrayView<2, U2, C2> const & labels,
145  StackEntry_t const & stack_entry,
146  Split_t split,
147  Stop_t stop,
148  Visitor_t & visitor,
149  Random_t & randint,
150  //an index to which the last created exterior node will be moved (because it is not used anymore)
151  int garbaged_child=-1);
152 
153  /* is a node a Leaf Node? */
154  inline bool isLeafNode(TreeInt in) const
155  {
156  return (in & LeafNodeTag) == LeafNodeTag;
157  }
158 
159  /* data driven traversal from root to leaf
160  *
161  * traverse through tree with data given in features. Use Visitors to
162  * collect statistics along the way.
163  */
164  template<class U, class C, class Visitor_t>
165  TreeInt getToLeaf(MultiArrayView<2, U, C> const & features,
166  Visitor_t & visitor) const
167  {
168  TreeInt index = 2;
169  while(!isLeafNode(topology_[index]))
170  {
171  visitor.visit_internal_node(*this, index, topology_[index],features);
172  switch(topology_[index])
173  {
174  case i_ThresholdNode:
175  {
176  Node<i_ThresholdNode>
177  node(topology_, parameters_, index);
178  index = node.next(features);
179  break;
180  }
181  case i_HyperplaneNode:
182  {
183  Node<i_HyperplaneNode>
184  node(topology_, parameters_, index);
185  index = node.next(features);
186  break;
187  }
188  case i_HypersphereNode:
189  {
190  Node<i_HypersphereNode>
191  node(topology_, parameters_, index);
192  index = node.next(features);
193  break;
194  }
195 #if 0
196  // for quick prototyping! has to be implemented.
197  case i_VirtualNode:
198  {
199  Node<i_VirtualNode>
200  node(topology_, parameters, index);
201  index = node.next(features);
202  }
203 #endif
204  default:
205  vigra_fail("DecisionTree::getToLeaf():"
206  "encountered unknown internal Node Type");
207  }
208  }
209  visitor.visit_external_node(*this, index, topology_[index],features);
210  return index;
211  }
212  /* traverse tree to get statistics
213  *
214  * Tree is traversed in order the Nodes are in memory (i.e. if no
215  * relearning//pruning scheme is utilized this will be pre order)
216  */
217  template<class Visitor_t>
218  void traverse_mem_order(Visitor_t visitor) const
219  {
220  UInt32 index = 2;
221  while(index < topology_.size())
222  {
223  if(isLeafNode(topology_[index]))
224  {
225  visitor
226  .visit_external_node(*this, index, topology_[index]);
227  }
228  else
229  {
230  visitor
231  ._internal_node(*this, index, topology_[index]);
232  }
233  }
234  }
235 
236  template<class Visitor_t>
237  void traverse_post_order(Visitor_t visitor, TreeInt /*start*/ = 2) const
238  {
239  typedef TinyVector<double, 2> Entry;
240  std::vector<Entry > stack;
241  std::vector<double> result_stack;
242  stack.push_back(Entry(2, 0));
243  int addr;
244  while(!stack.empty())
245  {
246  addr = stack.back()[0];
247  NodeBase node(topology_, parameters_, stack.back()[0]);
248  if(stack.back()[1] == 1)
249  {
250  stack.pop_back();
251  double leftRes = result_stack.back();
252  double rightRes = result_stack.back();
253  result_stack.pop_back();
254  result_stack.pop_back();
255  result_stack.push_back(rightRes+ leftRes);
256  visitor.visit_internal_node(*this,
257  addr,
258  node.typeID(),
259  rightRes+leftRes);
260  }
261  else
262  {
263  if(isLeafNode(node.typeID()))
264  {
265  visitor.visit_external_node(*this,
266  addr,
267  node.typeID(),
268  node.weights());
269  stack.pop_back();
270  result_stack.push_back(node.weights());
271  }
272  else
273  {
274  stack.back()[1] = 1;
275  stack.push_back(Entry(node.child(0), 0));
276  stack.push_back(Entry(node.child(1), 0));
277  }
278 
279  }
280  }
281  }
282 
283  /* same thing as above, without any visitors */
284  template<class U, class C>
285  TreeInt getToLeaf(MultiArrayView<2, U, C> const & features) const
286  {
288  return getToLeaf(features, stop);
289  }
290 
291 
292  template <class U, class C>
293  ArrayVector<double>::iterator
294  predict(MultiArrayView<2, U, C> const & features) const
295  {
296  TreeInt nodeindex = getToLeaf(features);
297  switch(topology_[nodeindex])
298  {
299  case e_ConstProbNode:
300  return Node<e_ConstProbNode>(topology_,
301  parameters_,
302  nodeindex).prob_begin();
303  break;
304 #if 0
305  //first make the Logistic regression stuff...
306  case e_LogRegProbNode:
307  return Node<e_LogRegProbNode>(topology_,
308  parameters_,
309  nodeindex).prob_begin();
310 #endif
311  default:
312  vigra_fail("DecisionTree::predict() :"
313  " encountered unknown external Node Type");
314  }
315  return ArrayVector<double>::iterator();
316  }
317 
318 
319 
320  template <class U, class C>
321  Int32 predictLabel(MultiArrayView<2, U, C> const & features) const
322  {
323  ArrayVector<double>::const_iterator weights = predict(features);
324  return argMax(weights, weights+classCount_) - weights;
325  }
326 
327 };
328 
329 
330 template < class U, class C,
331  class U2, class C2,
332  class StackEntry_t,
333  class Stop_t,
334  class Split_t,
335  class Visitor_t,
336  class Random_t>
337 void DecisionTree::learn( MultiArrayView<2, U, C> const & features,
338  MultiArrayView<2, U2, C2> const & labels,
339  StackEntry_t const & stack_entry,
340  Split_t split,
341  Stop_t stop,
342  Visitor_t & visitor,
343  Random_t & randint)
344 {
345  this->reset();
346  topology_.reserve(256);
347  parameters_.reserve(256);
348  topology_.push_back(features.shape(1));
349  topology_.push_back(classCount_);
350  continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
351 }
352 
353 template < class U, class C,
354  class U2, class C2,
355  class StackEntry_t,
356  class Stop_t,
357  class Split_t,
358  class Visitor_t,
359  class Random_t>
360 void DecisionTree::continueLearn( MultiArrayView<2, U, C> const & features,
361  MultiArrayView<2, U2, C2> const & labels,
362  StackEntry_t const & stack_entry,
363  Split_t split,
364  Stop_t stop,
365  Visitor_t & visitor,
366  Random_t & randint,
367  //an index to which the last created exterior node will be moved (because it is not used anymore)
368  int garbaged_child)
369 {
370  std::vector<StackEntry_t> stack;
371  stack.reserve(128);
372  ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry);
373  stack.push_back(stack_entry);
374  size_t last_node_pos = 0;
375  StackEntry_t top=stack.back();
376 
377  while(!stack.empty())
378  {
379 
380  // Take an element of the stack. Obvious ain't it?
381  top = stack.back();
382  stack.pop_back();
383 
384  // Make sure no data from the last round has remained in Pipeline;
385  child_stack_entry[0].reset();
386  child_stack_entry[1].reset();
387  split.reset();
388 
389 
390  //Either the Stopping criterion decides that the split should
391  //produce a Terminal Node or the Split itself decides what
392  //kind of node to make
393  TreeInt NodeID;
394 
395  if(stop(top))
396  NodeID = split.makeTerminalNode(features,
397  labels,
398  top,
399  randint);
400  else
401  {
402  //TIC;
403  NodeID = split.findBestSplit(features,
404  labels,
405  top,
406  child_stack_entry,
407  randint);
408  //std::cerr << TOC <<" " << NodeID << ";" <<std::endl;
409  }
410 
411  // do some visiting yawn - just added this comment as eye candy
412  // (looks odd otherwise with my syntax highlighting....
413  visitor.visit_after_split(*this, split, top,
414  child_stack_entry[0],
415  child_stack_entry[1],
416  features,
417  labels);
418 
419 
420  // Update the Child entries of the parent
421  // Using InteriorNodeBase because exact parameter form not needed.
422  // look at the Node base before getting scared.
423  last_node_pos = topology_.size();
424  if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
425  {
426  NodeBase(topology_,
427  parameters_,
428  top.leftParent).child(0) = last_node_pos;
429  }
430  else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
431  {
432  NodeBase(topology_,
433  parameters_,
434  top.rightParent).child(1) = last_node_pos;
435  }
436 
437 
438  // Supply the split functor with the Node type it requires.
439  // set the address to which the children of this node should point
440  // to and push back children onto stack
441  if(!isLeafNode(NodeID))
442  {
443  child_stack_entry[0].leftParent = topology_.size();
444  child_stack_entry[1].rightParent = topology_.size();
445  child_stack_entry[0].rightParent = -1;
446  child_stack_entry[1].leftParent = -1;
447  stack.push_back(child_stack_entry[0]);
448  stack.push_back(child_stack_entry[1]);
449  }
450 
451  //copy the newly created node form the split functor to the
452  //decision tree.
453  NodeBase node(split.createNode(), topology_, parameters_ );
454  ignore_argument(node);
455  }
456  if(garbaged_child!=-1)
457  {
458  Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
459 
460  int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
461  topology_.resize(last_node_pos);
462  parameters_.resize(parameters_.size() - last_parameter_size);
463 
464  if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
465  NodeBase(topology_,
466  parameters_,
467  top.leftParent).child(0) = garbaged_child;
468  else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
469  NodeBase(topology_,
470  parameters_,
471  top.rightParent).child(1) = garbaged_child;
472  }
473 }
474 
475 } //namespace detail
476 
477 } //namespace vigra
478 
479 #endif //VIGRA_RANDOM_FOREST_DT_HXX
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
detail::SelectIntegerType< 32, detail::UnsignedIntTypes >::type UInt32
32-bit unsigned int
Definition: sized_int.hxx:183
Definition: rf_visitors.hxx:234

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)