public class DecisionTree extends Object implements scala.Serializable, Logging
Constructor and Description |
---|
DecisionTree(Strategy strategy) |
Modifier and Type | Method and Description |
---|---|
static scala.collection.immutable.List<Object> |
extractMultiClassCategories(int input,
int maxFeatureValue)
Nested method to extract list of eligible categories given an index.
|
static void |
findBestSplits(RDD<BaggedPoint<TreePoint>> input,
DecisionTreeMetadata metadata,
Node[] topNodes,
scala.collection.immutable.Map<Object,Node[]> nodesForGroup,
scala.collection.immutable.Map<Object,scala.collection.immutable.Map<Object,RandomForest.NodeIndexInfo>> treeToNodeToIndexInfo,
Split[][] splits,
Bin[][] bins,
scala.collection.mutable.Queue<scala.Tuple2<Object,Node>> nodeQueue,
TimeTracker timer,
scala.Option<NodeIdCache> nodeIdCache)
Given a group of nodes, this finds the best split for each node.
|
static double[] |
findSplitsForContinuousFeature(double[] featureSamples,
DecisionTreeMetadata metadata,
int featureIndex)
Find splits for a continuous feature
NOTE: Returned number of splits is set based on
featureSamples and
could be different from the specified numSplits . |
DecisionTreeModel |
run(RDD<LabeledPoint> input)
Method to train a decision tree model over an RDD
|
DecisionTreeModel |
train(RDD<LabeledPoint> input)
Trains a decision tree model over an RDD.
|
static DecisionTreeModel |
trainClassifier(JavaRDD<LabeledPoint> input,
int numClasses,
java.util.Map<Integer,Integer> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Java-friendly API for
DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int) |
static DecisionTreeModel |
trainClassifier(RDD<LabeledPoint> input,
int numClasses,
scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Method to train a decision tree model for binary or multiclass classification.
|
static DecisionTreeModel |
trainRegressor(JavaRDD<LabeledPoint> input,
java.util.Map<Integer,Integer> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Java-friendly API for
DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int) |
static DecisionTreeModel |
trainRegressor(RDD<LabeledPoint> input,
scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Method to train a decision tree model for regression.
|
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public DecisionTree(Strategy strategy)
public static DecisionTreeModel trainClassifier(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
input
- Training dataset: RDD of LabeledPoint
.
Labels should take values {0, 1, ..., numClasses-1}.numClasses
- number of classes for classification.categoricalFeaturesInfo
- Map storing arity of categorical features.
E.g., an entry (n -> k) indicates that feature n is categorical
with k categories indexed from 0: {0, 1, ..., k-1}.impurity
- Criterion used for information gain calculation.
Supported values: "gini" (recommended) or "entropy".maxDepth
- Maximum depth of the tree.
E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
(suggested value: 5)maxBins
- maximum number of bins used for splitting features
(suggested value: 32)public static DecisionTreeModel trainClassifier(JavaRDD<LabeledPoint> input, int numClasses, java.util.Map<Integer,Integer> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
public static DecisionTreeModel trainRegressor(RDD<LabeledPoint> input, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
input
- Training dataset: RDD of LabeledPoint
.
Labels are real numbers.categoricalFeaturesInfo
- Map storing arity of categorical features.
E.g., an entry (n -> k) indicates that feature n is categorical
with k categories indexed from 0: {0, 1, ..., k-1}.impurity
- Criterion used for information gain calculation.
Supported values: "variance".maxDepth
- Maximum depth of the tree.
E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
(suggested value: 5)maxBins
- maximum number of bins used for splitting features
(suggested value: 32)public static DecisionTreeModel trainRegressor(JavaRDD<LabeledPoint> input, java.util.Map<Integer,Integer> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
public static void findBestSplits(RDD<BaggedPoint<TreePoint>> input, DecisionTreeMetadata metadata, Node[] topNodes, scala.collection.immutable.Map<Object,Node[]> nodesForGroup, scala.collection.immutable.Map<Object,scala.collection.immutable.Map<Object,RandomForest.NodeIndexInfo>> treeToNodeToIndexInfo, Split[][] splits, Bin[][] bins, scala.collection.mutable.Queue<scala.Tuple2<Object,Node>> nodeQueue, TimeTracker timer, scala.Option<NodeIdCache> nodeIdCache)
input
- Training data: RDD of TreePoint
metadata
- Learning and dataset metadatatopNodes
- Root node for each tree. Used for matching instances with nodes.nodesForGroup
- Mapping: treeIndex --> nodes to be split in treetreeToNodeToIndexInfo
- Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
where nodeIndexInfo stores the index in the group and the
feature subsets (if using feature subsets).splits
- possible splits for all features, indexed (numFeatures)(numSplits)bins
- possible bins for all features, indexed (numFeatures)(numBins)nodeQueue
- Queue of nodes to split, with values (treeIndex, node).
Updated with new non-leaf nodes which are created.nodeIdCache
- Node Id cache containing an RDD of Array[Int] where
each value in the array is the data point's node Id
for a corresponding tree. This is used to prevent the need
to pass the entire tree to the executors during
the node stat aggregation phase.public static scala.collection.immutable.List<Object> extractMultiClassCategories(int input, int maxFeatureValue)
public static double[] findSplitsForContinuousFeature(double[] featureSamples, DecisionTreeMetadata metadata, int featureIndex)
featureSamples
and
could be different from the specified numSplits
.
The numSplits
attribute in the DecisionTreeMetadata
class will be set accordingly.featureSamples
- feature values of each samplemetadata
- decision tree metadata
NOTE: metadata.numbins
will be changed accordingly
if there are not enough splits to be foundfeatureIndex
- feature index to find splitspublic DecisionTreeModel run(RDD<LabeledPoint> input)
input
- Training data: RDD of LabeledPoint
public DecisionTreeModel train(RDD<LabeledPoint> input)