36 #ifndef VIGRA_RANDOM_FOREST_DT_HXX
37 #define VIGRA_RANDOM_FOREST_DT_HXX
42 #include "vigra/multi_array.hxx"
43 #include "vigra/mathutil.hxx"
44 #include "vigra/metaprogramming.hxx"
45 #include "vigra/array_vector.hxx"
46 #include "vigra/sized_int.hxx"
47 #include "vigra/matrix.hxx"
48 #include "vigra/random.hxx"
49 #include "vigra/functorexpression.hxx"
52 #include "rf_common.hxx"
53 #include "rf_visitors.hxx"
54 #include "rf_nodeproxy.hxx"
86 typedef Int32 TreeInt;
88 ArrayVector<TreeInt> topology_;
89 ArrayVector<double> parameters_;
91 ProblemSpec<> ext_param_;
92 unsigned int classCount_;
98 DecisionTree(ProblemSpec<T> ext_param)
100 ext_param_(ext_param),
101 classCount_(ext_param.class_count_)
106 void reset(
unsigned int classCount = 0)
109 classCount_ = classCount;
122 template <
class U,
class C,
129 void learn( MultiArrayView<2, U, C>
const & features,
130 MultiArrayView<2, U2, C2>
const & labels,
131 StackEntry_t
const & stack_entry,
136 template <
class U,
class C,
143 void continueLearn( MultiArrayView<2, U, C>
const & features,
144 MultiArrayView<2, U2, C2>
const & labels,
145 StackEntry_t
const & stack_entry,
151 int garbaged_child=-1);
154 inline bool isLeafNode(TreeInt in)
const
156 return (in & LeafNodeTag) == LeafNodeTag;
164 template<
class U,
class C,
class Visitor_t>
165 TreeInt getToLeaf(MultiArrayView<2, U, C>
const & features,
166 Visitor_t & visitor)
const
169 while(!isLeafNode(topology_[index]))
171 visitor.visit_internal_node(*
this, index, topology_[index],features);
172 switch(topology_[index])
174 case i_ThresholdNode:
176 Node<i_ThresholdNode>
177 node(topology_, parameters_, index);
178 index = node.next(features);
181 case i_HyperplaneNode:
183 Node<i_HyperplaneNode>
184 node(topology_, parameters_, index);
185 index = node.next(features);
188 case i_HypersphereNode:
190 Node<i_HypersphereNode>
191 node(topology_, parameters_, index);
192 index = node.next(features);
200 node(topology_, parameters, index);
201 index = node.next(features);
205 vigra_fail(
"DecisionTree::getToLeaf():"
206 "encountered unknown internal Node Type");
209 visitor.visit_external_node(*
this, index, topology_[index],features);
217 template<
class Visitor_t>
218 void traverse_mem_order(Visitor_t visitor)
const
221 while(index < topology_.size())
223 if(isLeafNode(topology_[index]))
226 .visit_external_node(*
this, index, topology_[index]);
231 ._internal_node(*
this, index, topology_[index]);
236 template<
class Visitor_t>
237 void traverse_post_order(Visitor_t visitor, TreeInt = 2)
const
239 typedef TinyVector<double, 2> Entry;
240 std::vector<Entry > stack;
241 std::vector<double> result_stack;
242 stack.push_back(Entry(2, 0));
244 while(!stack.empty())
246 addr = stack.back()[0];
247 NodeBase node(topology_, parameters_, stack.back()[0]);
248 if(stack.back()[1] == 1)
251 double leftRes = result_stack.back();
252 double rightRes = result_stack.back();
253 result_stack.pop_back();
254 result_stack.pop_back();
255 result_stack.push_back(rightRes+ leftRes);
256 visitor.visit_internal_node(*
this,
263 if(isLeafNode(node.typeID()))
265 visitor.visit_external_node(*
this,
270 result_stack.push_back(node.weights());
275 stack.push_back(Entry(node.child(0), 0));
276 stack.push_back(Entry(node.child(1), 0));
284 template<
class U,
class C>
285 TreeInt getToLeaf(MultiArrayView<2, U, C>
const & features)
const
288 return getToLeaf(features, stop);
292 template <
class U,
class C>
293 ArrayVector<double>::iterator
294 predict(MultiArrayView<2, U, C>
const & features)
const
296 TreeInt nodeindex = getToLeaf(features);
297 switch(topology_[nodeindex])
299 case e_ConstProbNode:
300 return Node<e_ConstProbNode>(topology_,
302 nodeindex).prob_begin();
306 case e_LogRegProbNode:
307 return Node<e_LogRegProbNode>(topology_,
309 nodeindex).prob_begin();
312 vigra_fail(
"DecisionTree::predict() :"
313 " encountered unknown external Node Type");
315 return ArrayVector<double>::iterator();
320 template <
class U,
class C>
321 Int32 predictLabel(MultiArrayView<2, U, C>
const & features)
const
323 ArrayVector<double>::const_iterator weights = predict(features);
324 return argMax(weights, weights+classCount_) - weights;
330 template <
class U,
class C,
337 void DecisionTree::learn( MultiArrayView<2, U, C>
const & features,
338 MultiArrayView<2, U2, C2>
const & labels,
339 StackEntry_t
const & stack_entry,
346 topology_.reserve(256);
347 parameters_.reserve(256);
348 topology_.push_back(features.shape(1));
349 topology_.push_back(classCount_);
350 continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
353 template <
class U,
class C,
360 void DecisionTree::continueLearn( MultiArrayView<2, U, C>
const & features,
361 MultiArrayView<2, U2, C2>
const & labels,
362 StackEntry_t
const & stack_entry,
370 std::vector<StackEntry_t> stack;
372 ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry);
373 stack.push_back(stack_entry);
374 size_t last_node_pos = 0;
375 StackEntry_t top=stack.back();
377 while(!stack.empty())
385 child_stack_entry[0].reset();
386 child_stack_entry[1].reset();
396 NodeID = split.makeTerminalNode(features,
403 NodeID = split.findBestSplit(features,
413 visitor.visit_after_split(*
this, split, top,
414 child_stack_entry[0],
415 child_stack_entry[1],
423 last_node_pos = topology_.size();
424 if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
428 top.leftParent).child(0) = last_node_pos;
430 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
434 top.rightParent).child(1) = last_node_pos;
441 if(!isLeafNode(NodeID))
443 child_stack_entry[0].leftParent = topology_.size();
444 child_stack_entry[1].rightParent = topology_.size();
445 child_stack_entry[0].rightParent = -1;
446 child_stack_entry[1].leftParent = -1;
447 stack.push_back(child_stack_entry[0]);
448 stack.push_back(child_stack_entry[1]);
453 NodeBase node(split.createNode(), topology_, parameters_ );
454 ignore_argument(node);
456 if(garbaged_child!=-1)
458 Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
460 int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
461 topology_.resize(last_node_pos);
462 parameters_.resize(parameters_.size() - last_parameter_size);
464 if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
467 top.leftParent).child(0) = garbaged_child;
468 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
471 top.rightParent).child(1) = garbaged_child;
479 #endif //VIGRA_RANDOM_FOREST_DT_HXX
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
detail::SelectIntegerType< 32, detail::UnsignedIntTypes >::type UInt32
32-bit unsigned int
Definition: sized_int.hxx:183
Definition: rf_visitors.hxx:234