36 #ifndef VIGRA_RANDOM_FOREST_NP_HXX
37 #define VIGRA_RANDOM_FOREST_NP_HXX
42 #include "vigra/mathutil.hxx"
43 #include "vigra/array_vector.hxx"
44 #include "vigra/sized_int.hxx"
45 #include "vigra/matrix.hxx"
46 #include "vigra/random.hxx"
47 #include "vigra/functorexpression.hxx"
58 AllColumns = 0x00000000,
59 ToBePrunedTag = 0x80000000,
60 LeafNodeTag = 0x40000000,
64 i_HypersphereNode = 2,
65 e_ConstProbNode = 0 | LeafNodeTag,
66 e_LogRegProbNode = 1 | LeafNodeTag
93 typedef T_Container_type::iterator Topology_type;
94 typedef P_Container_type::iterator Parameter_type;
97 mutable Topology_type topology_;
100 mutable Parameter_type parameters_;
101 int parameter_size_ ;
141 INT
const &
typeID()
const
161 return topology_ + 4 ;
177 return featureCount_;
197 Topology_type topology_end()
const
201 int topology_size()
const
203 return topology_size_;
211 Parameter_type parameters_end()
const
216 int parameters_size()
const
218 return parameter_size_;
243 vigra_precondition(topology_size_==o.topology_size_,
"Cannot copy nodes of different sizes");
244 vigra_precondition(featureCount_==o.featureCount_,
"Cannot copy nodes with different feature count");
245 vigra_precondition(classCount_==o.classCount_,
"Cannot copy nodes with different class counts");
246 vigra_precondition(parameters_size() ==o.parameters_size(),
"Cannot copy nodes with different parameter sizes");
258 topology_ (const_cast<Topology_type>(topology.begin()+ n)),
260 parameters_ (const_cast<Parameter_type>(parameter.begin() +
parameter_addr())),
262 featureCount_(topology[0]),
263 classCount_(topology[1]),
278 topology_ (const_cast<Topology_type>(topology.begin()+ n)),
279 topology_size_(tLen),
280 parameters_ (const_cast<Parameter_type>(parameter.begin() +
parameter_addr())),
281 parameter_size_(pLen),
282 featureCount_(topology[0]),
283 classCount_(topology[1]),
296 topology_ (node.topology_),
297 topology_size_(tLen),
298 parameters_ (node.parameters_),
299 parameter_size_(pLen),
300 featureCount_(node.featureCount_),
301 classCount_(node.classCount_),
321 topology_size_(tLen),
322 parameter_size_(pLen),
323 featureCount_(topology[0]),
324 classCount_(topology[1]),
330 size_t n = topology.
size();
331 for(
int ii = 0; ii < tLen; ++ii)
332 topology.push_back(0);
335 topology_ = topology.
begin()+ n;
341 for(
int ii = 0; ii < pLen; ++ii)
342 parameter.push_back(0);
360 topology_size_(toCopy.topology_size()),
361 parameter_size_(toCopy.parameters_size()),
362 featureCount_(topology[0]),
363 classCount_(topology[1]),
369 size_t n = topology.
size();
370 for(
int ii = 0; ii < toCopy.topology_size(); ++ii)
373 topology_ = topology.
begin()+ n;
375 for(
int ii = 0; ii < toCopy.parameters_size(); ++ii)
383 template<NodeTags NodeType>
387 class Node<i_ThresholdNode>
397 Node( BT::T_Container_type & topology,
398 BT::P_Container_type & param)
399 : BT(5,2,topology, param)
401 BT::typeID() = i_ThresholdNode;
404 Node( BT::T_Container_type
const & topology,
405 BT::P_Container_type
const & param,
407 : BT(5,2,topology, param, n)
416 return BT::parameters_begin()[1];
419 double const & threshold()
const
421 return BT::parameters_begin()[1];
426 return BT::column_data()[0];
428 BT::INT
const & column()
const
430 return BT::column_data()[0];
433 template<
class U,
class C>
434 BT::INT next(MultiArrayView<2,U,C>
const & feature)
const
436 return (feature(0, column()) < threshold())? child(0):child(1);
442 class Node<i_HyperplaneNode>
452 BT::T_Container_type & topology,
453 BT::P_Container_type & split_param)
454 : BT(nCol + 5,nCol + 2,topology, split_param)
456 BT::typeID() = i_HyperplaneNode;
459 Node( BT::T_Container_type
const & topology,
460 BT::P_Container_type
const & split_param,
462 : NodeBase(5 , 2,topology, split_param, n)
465 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
467 : BT::column_data()[0];
468 BT::parameter_size_ += BT::columns_size();
475 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
477 : BT::column_data()[0];
478 BT::parameter_size_ += BT::columns_size();
482 double const & intercept()
const
484 return BT::parameters_begin()[1];
488 return BT::parameters_begin()[1];
491 BT::Parameter_type weights()
const
493 return BT::parameters_begin()+2;
496 BT::Parameter_type weights()
498 return BT::parameters_begin()+2;
502 template<
class U,
class C>
503 BT::INT next(MultiArrayView<2,U,C>
const & feature)
const
505 double result = -1 * intercept();
506 if(*(BT::column_data()) == AllColumns)
508 for(
int ii = 0; ii < BT::columns_size(); ++ii)
510 result +=feature[ii] * weights()[ii];
515 for(
int ii = 0; ii < BT::columns_size(); ++ii)
517 result +=feature[BT::columns_begin()[ii]] * weights()[ii];
520 return result < 0 ? BT::child(0)
528 class Node<i_HypersphereNode>
538 BT::T_Container_type & topology,
539 BT::P_Container_type & param)
540 : NodeBase(nCol + 5,nCol + 1,topology, param)
542 BT::typeID() = i_HypersphereNode;
545 Node( BT::T_Container_type
const & topology,
546 BT::P_Container_type
const & param,
548 : NodeBase(5, 1,topology, param, n)
550 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
552 : BT::column_data()[0];
553 BT::parameter_size_ += BT::columns_size();
559 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
561 : BT::column_data()[0];
562 BT::parameter_size_ += BT::columns_size();
566 double const & squaredRadius()
const
568 return BT::parameters_begin()[1];
571 double& squaredRadius()
573 return BT::parameters_begin()[1];
576 BT::Parameter_type center()
const
578 return BT::parameters_begin()+2;
581 BT::Parameter_type center()
583 return BT::parameters_begin()+2;
586 template<
class U,
class C>
587 BT::INT next(MultiArrayView<2,U,C>
const & feature)
const
589 double result = -1 * squaredRadius();
590 if(*(BT::column_data()) == AllColumns)
592 for(
int ii = 0; ii < BT::columns_size(); ++ii)
594 result += (feature[ii] - center()[ii])*
595 (feature[ii] - center()[ii]);
600 for(
int ii = 0; ii < BT::columns_size(); ++ii)
602 result += (feature[BT::columns_begin()[ii]] - center()[ii])*
603 (feature[BT::columns_begin()[ii]] - center()[ii]);
606 return result < 0 ? BT::child(0)
626 class Node<e_ConstProbNode>
636 BT(2,topology[1]+1, topology, param)
639 BT::typeID() = e_ConstProbNode;
646 :
BT(2, topology[1]+1,topology, param, n)
651 :
BT(2, node_.classCount_ +1, node_)
653 BT::Parameter_type prob_begin()
const
655 return BT::parameters_begin()+1;
657 BT::Parameter_type prob_end()
const
659 return prob_begin() + prob_size();
661 int prob_size()
const
663 return BT::classCount_;
668 class Node<e_LogRegProbNode>;
672 #endif //RF_nodeproxy
Topology_type column_data() const
Definition: rf_nodeproxy.hxx:159
const_iterator begin() const
Definition: array_vector.hxx:223
NodeBase()
Definition: rf_nodeproxy.hxx:237
NodeBase(T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition: rf_nodeproxy.hxx:254
Topology_type columns_begin() const
Definition: rf_nodeproxy.hxx:167
INT & child(Int32 l)
Definition: rf_nodeproxy.hxx:224
int columns_size() const
Definition: rf_nodeproxy.hxx:174
NodeBase(int tLen, int pLen, NodeBase &node)
Definition: rf_nodeproxy.hxx:292
Topology_type columns_end() const
Definition: rf_nodeproxy.hxx:184
NodeBase(int tLen, int pLen, T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition: rf_nodeproxy.hxx:272
Definition: rf_nodeproxy.hxx:87
bool data() const
Definition: rf_nodeproxy.hxx:128
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Parameter_type parameters_begin() const
Definition: rf_nodeproxy.hxx:207
INT & typeID()
Definition: rf_nodeproxy.hxx:136
NodeBase(int tLen, int pLen, T_Container_type &topology, P_Container_type ¶meter)
Definition: rf_nodeproxy.hxx:316
INT & parameter_addr()
Definition: rf_nodeproxy.hxx:148
size_type size() const
Definition: array_vector.hxx:358
INT const & child(Int32 l) const
Definition: rf_nodeproxy.hxx:231
Topology_type topology_begin() const
Definition: rf_nodeproxy.hxx:193
double & weights()
Definition: rf_nodeproxy.hxx:115
NodeBase(NodeBase const &toCopy, T_Container_type &topology, P_Container_type ¶meter)
Definition: rf_nodeproxy.hxx:356