public class GradientBoostedTrees
extends java.lang.Object
Constructor and Description |
---|
GradientBoostedTrees() |
Modifier and Type | Method and Description |
---|---|
static scala.Tuple2<DecisionTreeRegressionModel[],double[]> |
boost(RDD<LabeledPoint> input,
RDD<LabeledPoint> validationInput,
BoostingStrategy boostingStrategy,
boolean validate,
long seed)
Internal method for performing regression using trees as base learners.
|
static double |
computeError(RDD<LabeledPoint> data,
DecisionTreeRegressionModel[] trees,
double[] treeWeights,
Loss loss)
Method to calculate error of the base learner for the gradient boosting calculation.
|
static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> |
computeInitialPredictionAndError(RDD<LabeledPoint> data,
double initTreeWeight,
DecisionTreeRegressionModel initTree,
Loss loss)
Compute the initial predictions and errors for a dataset for the first
iteration of gradient boosting.
|
static double[] |
evaluateEachIteration(RDD<LabeledPoint> data,
DecisionTreeRegressionModel[] trees,
double[] treeWeights,
Loss loss,
scala.Enumeration.Value algo)
Method to compute error or loss for every iteration of gradient boosting.
|
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
protected static boolean |
isTraceEnabled() |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
static scala.Tuple2<DecisionTreeRegressionModel[],double[]> |
run(RDD<LabeledPoint> input,
BoostingStrategy boostingStrategy,
long seed)
Method to train a gradient boosting model
|
static scala.Tuple2<DecisionTreeRegressionModel[],double[]> |
runWithValidation(RDD<LabeledPoint> input,
RDD<LabeledPoint> validationInput,
BoostingStrategy boostingStrategy,
long seed)
Method to validate a gradient boosting model
|
static double |
updatePrediction(Vector features,
double prediction,
DecisionTreeRegressionModel tree,
double weight)
Add prediction from a new boosting iteration to an existing prediction.
|
static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> |
updatePredictionError(RDD<LabeledPoint> data,
RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> predictionAndError,
double treeWeight,
DecisionTreeRegressionModel tree,
Loss loss)
Update a zipped predictionError RDD
(as obtained with computeInitialPredictionAndError)
|
public static scala.Tuple2<DecisionTreeRegressionModel[],double[]> run(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy, long seed)
input
- Training dataset: RDD of LabeledPoint
.seed
- Random seed.boostingStrategy
- (undocumented)public static scala.Tuple2<DecisionTreeRegressionModel[],double[]> runWithValidation(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput, BoostingStrategy boostingStrategy, long seed)
input
- Training dataset: RDD of LabeledPoint
.validationInput
- Validation dataset.
This dataset should be different from the training dataset,
but it should follow the same distribution.
E.g., these two datasets could be created from an original dataset
by using org.apache.spark.rdd.RDD.randomSplit()
seed
- Random seed.boostingStrategy
- (undocumented)public static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeRegressionModel initTree, Loss loss)
data:
- training data.initTreeWeight:
- learning rate assigned to the first tree.initTree:
- first DecisionTreeModel.loss:
- evaluation metric.public static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> updatePredictionError(RDD<LabeledPoint> data, RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> predictionAndError, double treeWeight, DecisionTreeRegressionModel tree, Loss loss)
data:
- training data.predictionAndError:
- predictionError RDDtreeWeight:
- Learning rate.tree:
- Tree using which the prediction and error should be updated.loss:
- evaluation metric.public static double updatePrediction(Vector features, double prediction, DecisionTreeRegressionModel tree, double weight)
features
- Vector of features representing a single data point.prediction
- The existing prediction.tree
- New Decision Tree model.weight
- Tree weight.public static double computeError(RDD<LabeledPoint> data, DecisionTreeRegressionModel[] trees, double[] treeWeights, Loss loss)
data
- Training dataset: RDD of LabeledPoint
.trees
- Boosted Decision Tree modelstreeWeights
- Learning rates at each boosting iteration.loss
- evaluation metric.public static double[] evaluateEachIteration(RDD<LabeledPoint> data, DecisionTreeRegressionModel[] trees, double[] treeWeights, Loss loss, scala.Enumeration.Value algo)
data
- RDD of LabeledPoint
trees
- Boosted Decision Tree modelstreeWeights
- Learning rates at each boosting iteration.loss
- evaluation metric.algo
- algorithm for the ensemble, either Classification or Regressionpublic static scala.Tuple2<DecisionTreeRegressionModel[],double[]> boost(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput, BoostingStrategy boostingStrategy, boolean validate, long seed)
input
- training datasetvalidationInput
- validation dataset, ignored if validate is set to false.boostingStrategy
- boosting parametersvalidate
- whether or not to use the validation dataset.seed
- Random seed.protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)