public class DecisionTreeRegressionModel extends PredictionModel<Vector,DecisionTreeRegressionModel> implements DecisionTreeModel, DecisionTreeRegressorParams, MLWritable, scala.Serializable
Modifier and Type | Method and Description |
---|---|
DecisionTreeRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Vector |
featureImportances()
Estimate of the importance of each feature.
|
static DecisionTreeRegressionModel |
load(String path) |
int |
numFeatures()
Returns the number of features the model was trained on.
|
double |
predict(Vector features)
Predict label for the given features.
|
static MLReader<DecisionTreeRegressionModel> |
read() |
Node |
rootNode()
Root of the decision tree
|
DecisionTreeRegressionModel |
setVarianceCol(String value) |
String |
toString()
Summary of the model
|
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol , calling predict , and storing
the predictions as a new column predictionCol . |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
setFeaturesCol, setPredictionCol, transformSchema
transform, transform, transform
depth, maxSplitFeatureIndex, numNodes, toDebugString
validateAndTransformSchema
cacheNodeIds, getCacheNodeIds, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getOldStrategy, maxBins, maxDepth, maxMemoryInMB, minInfoGain, minInstancesPerNode, setCacheNodeIds, setCheckpointInterval, setMaxBins, setMaxDepth, setMaxMemoryInMB, setMinInfoGain, setMinInstancesPerNode, setSeed
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
checkpointInterval, getCheckpointInterval
getImpurity, getOldImpurity, impurity, setImpurity
getVarianceCol, varianceCol
save
initializeLogging, initializeLogIfNecessary, initializeLogIfNecessary, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public static MLReader<DecisionTreeRegressionModel> read()
public static DecisionTreeRegressionModel load(String path)
public String uid()
Identifiable
uid
in interface Identifiable
public Node rootNode()
DecisionTreeModel
rootNode
in interface DecisionTreeModel
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,DecisionTreeRegressionModel>
public DecisionTreeRegressionModel setVarianceCol(String value)
public double predict(Vector features)
PredictionModel
transform()
and output predictionCol
.predict
in class PredictionModel<Vector,DecisionTreeRegressionModel>
features
- (undocumented)public Dataset<Row> transform(Dataset<?> dataset)
PredictionModel
featuresCol
, calling predict
, and storing
the predictions as a new column predictionCol
.
transform
in class PredictionModel<Vector,DecisionTreeRegressionModel>
dataset
- input datasetpredictionCol
of type Double
public DecisionTreeRegressionModel copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Model<DecisionTreeRegressionModel>
extra
- (undocumented)public String toString()
DecisionTreeModel
toString
in interface DecisionTreeModel
toString
in interface Identifiable
toString
in class Object
public Vector featureImportances()
This generalizes the idea of "Gini" importance to other losses, following the explanation of Gini importance from "Random Forests" documentation by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
This feature importance is calculated as follows: - importance(feature j) = sum (over nodes which split on feature j) of the gain, where gain is scaled by the number of instances passing through node - Normalize importances for tree to sum to 1.
RandomForestRegressor
to determine feature importance instead.public MLWriter write()
MLWritable
MLWriter
instance for this ML instance.write
in interface MLWritable