public class GradientBoostedTreesModel extends Object implements Saveable
param: algo algorithm for the ensemble model, either Classification or Regression param: trees tree ensembles param: treeWeights tree ensemble weights
Constructor and Description |
---|
GradientBoostedTreesModel(scala.Enumeration.Value algo,
DecisionTreeModel[] trees,
double[] treeWeights) |
Modifier and Type | Method and Description |
---|---|
scala.Enumeration.Value |
algo() |
static RDD<scala.Tuple2<Object,Object>> |
computeInitialPredictionAndError(RDD<LabeledPoint> data,
double initTreeWeight,
DecisionTreeModel initTree,
Loss loss)
:: DeveloperApi ::
Compute the initial predictions and errors for a dataset for the first
iteration of gradient boosting.
|
double[] |
evaluateEachIteration(RDD<LabeledPoint> data,
Loss loss)
Method to compute error or loss for every iteration of gradient boosting.
|
static GradientBoostedTreesModel |
load(SparkContext sc,
String path) |
static int |
numTrees() |
int |
numTrees()
Get number of trees in ensemble.
|
static JavaRDD<Double> |
predict(JavaRDD<Vector> features) |
JavaRDD<Double> |
predict(JavaRDD<Vector> features)
Java-friendly version of
org.apache.spark.mllib.tree.model.TreeEnsembleModel.predict . |
static RDD<Object> |
predict(RDD<Vector> features) |
RDD<Object> |
predict(RDD<Vector> features)
Predict values for the given data set.
|
static double |
predict(Vector features) |
double |
predict(Vector features)
Predict values for a single data point using the model trained.
|
void |
save(SparkContext sc,
String path)
Save this model to the given path.
|
static String |
toDebugString() |
String |
toDebugString()
Print the full model to a string.
|
static String |
toString() |
String |
toString()
Print a summary of the model.
|
static int |
totalNumNodes() |
int |
totalNumNodes()
Get total number of nodes, summed over all trees in the ensemble.
|
DecisionTreeModel[] |
trees() |
double[] |
treeWeights() |
static RDD<scala.Tuple2<Object,Object>> |
updatePredictionError(RDD<LabeledPoint> data,
RDD<scala.Tuple2<Object,Object>> predictionAndError,
double treeWeight,
DecisionTreeModel tree,
Loss loss)
:: DeveloperApi ::
Update a zipped predictionError RDD
(as obtained with computeInitialPredictionAndError)
|
public GradientBoostedTreesModel(scala.Enumeration.Value algo, DecisionTreeModel[] trees, double[] treeWeights)
public static RDD<scala.Tuple2<Object,Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeModel initTree, Loss loss)
data:
- training data.initTreeWeight:
- learning rate assigned to the first tree.initTree:
- first DecisionTreeModel.loss:
- evaluation metric.public static RDD<scala.Tuple2<Object,Object>> updatePredictionError(RDD<LabeledPoint> data, RDD<scala.Tuple2<Object,Object>> predictionAndError, double treeWeight, DecisionTreeModel tree, Loss loss)
data:
- training data.predictionAndError:
- predictionError RDDtreeWeight:
- Learning rate.tree:
- Tree using which the prediction and error should be updated.loss:
- evaluation metric.public static GradientBoostedTreesModel load(SparkContext sc, String path)
sc
- Spark context used for loading model files.path
- Path specifying the directory to which the model was saved.public static double predict(Vector features)
public static String toString()
public static String toDebugString()
public static int numTrees()
public static int totalNumNodes()
public scala.Enumeration.Value algo()
public DecisionTreeModel[] trees()
public double[] treeWeights()
public void save(SparkContext sc, String path)
Saveable
This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using Loader.load
.
public double[] evaluateEachIteration(RDD<LabeledPoint> data, Loss loss)
data
- RDD of LabeledPoint
loss
- evaluation metric.public double predict(Vector features)
features
- array representing a single data pointpublic RDD<Object> predict(RDD<Vector> features)
features
- RDD representing data points to be predictedpublic JavaRDD<Double> predict(JavaRDD<Vector> features)
org.apache.spark.mllib.tree.model.TreeEnsembleModel.predict
.features
- (undocumented)public String toString()
toString
in class Object
public String toDebugString()
public int numTrees()
public int totalNumNodes()