Main Page | Class List | File List | Class Members

tree.h

Go to the documentation of this file.
00001 00008 #ifndef _TREE_H_ 00009 #define _TREE_H_ 00010 #include "librf/types.h" 00011 #include "librf/tree_node.h" 00012 #include <iostream> 00013 #include <vector> 00014 #include <set> 00015 #include <map> 00016 #include <math.h> 00017 using namespace std; 00018 00019 namespace librf { 00020 00021 class InstanceSet; 00022 class weight_list; 00023 class DiscreteDist; 00024 class Instance; 00025 00045 class Tree { 00046 public: 00048 Tree(istream& in); 00050 Tree(const InstanceSet& set, weight_list* weights, 00051 int K, uchar max_depth=16, int min_size = 1, 00052 float min_gain = 0, unsigned int seed =0); 00053 ~Tree(); // clean up 00055 int predict(const Instance& c) const; 00057 int predict(const InstanceSet& set, int instance_no) const; 00058 // void write_dot(const string& s) const; 00060 float training_accuracy() const; 00061 // predict all the instances in testset and return the accuracy 00062 float testing_accuracy(const InstanceSet& testset) const; 00063 float oob_accuracy() const; 00064 void oob_cases(weight_list* counts, weight_list* correct) const; 00065 00066 void variable_importance(map<int, float>* scores, unsigned int* seed) const; 00067 void print() const; 00068 // do all the work -- separated this from constructor to 00069 // facilitate threading 00070 void grow(); 00071 void write(ostream& o) const; 00072 void read(istream& i); 00073 private: 00074 void copy_instances(); 00075 void move_data(tree_node* n, uint16 split_attr, uint16 split_idx); 00076 void find_best_split(tree_node* n, 00077 const vector<int>& attrs, 00078 int* split_attr, int* split_idx, 00079 float* split_point, float* split_gain); 00080 void find_best_split_for_attr(tree_node* n, 00081 int attr, 00082 float prior, 00083 int* split_idx, 00084 float *split_point, 00085 float* best_gain); 00086 00087 // Node marking 00088 void mark_build(tree_node* n, uint16 start, uint16 size, uchar depth); 00089 void mark_terminal(tree_node* n); 00090 void mark_split(tree_node* n, uint16 split_attr, float split_point); 00091 00092 void build_tree(int min_size); 00093 int build_node(uint16 node_num, uint16 min_size); 00094 void print_node(int n) const; 00095 00096 // inline convenience functions for 00097 // traversing nodes 00098 uint32 parent(uint32 node_num) const { 00099 return uint32(floor((node_num - 1) / 2.0)); 00100 } 00101 uint32 left_child(uint32 node_num) const { 00102 return 2 * node_num + 1; 00103 } 00104 uint32 right_child(uint32 node_num) const { 00105 return 2 * node_num + 2; 00106 } 00107 00108 void permuteOOB(int m, double *x); 00109 00110 vector<tree_node> nodes_; 00111 vector<uint16> active_nodes_; 00112 set<uint16> vars_used_; 00113 uint16 terminal_nodes_; 00114 uint16 split_nodes_; 00115 // get sorted indices 00116 // Const reference to instance set -- we don't get to delete it 00117 const InstanceSet& set_; 00118 // array of instance nums sorted by attributes 00119 // this is the block array that stores which instances belong to 00120 // which node 00121 // ex: sorted_inum_[attr*stride + start] 00122 // uint16 * sorted_inum_; 00123 // 00124 // Turns out there is not much of a gain in batch allocating the 00125 // 2d array (perhaps because, we only access a single column at a time) 00126 uint16** sorted_inum_; 00127 // label population 00128 // uchar * sorted_labels_; necessary? 00129 // A single weight list for all of the instances 00130 weight_list* weight_list_; 00131 // Depth of current tree 00132 uint16 max_depth_; 00133 uint16 K_; 00134 uint16 min_size_; 00135 float min_gain_; 00136 // scratch space 00137 int* temp; 00138 uchar* move_left; 00139 uint16 num_instances_; 00140 uint16 num_attributes_; 00141 uint16 stride_; 00142 unsigned int rand_seed_; 00143 // Constants 00144 static const int kLeft; 00145 static const int kRight; 00146 }; 00147 00148 } // namespace 00149 #endif

Generated on Mon Jan 8 23:19:06 2007 for librf by doxygen 1.3.7