public abstract class ClassificationModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>> extends PredictionModel<FeaturesType,M>
Model produced by a Classifier
.
Classes are indexed {0, 1, ..., numClasses - 1}.
Constructor and Description |
---|
ClassificationModel() |
Modifier and Type | Method and Description |
---|---|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
java.lang.String |
getPredictionCol() |
java.lang.String |
getRawPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
abstract int |
numClasses()
Number of classes (values which the label can take).
|
protected double |
predict(FeaturesType features)
Predict label for the given features.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
protected abstract Vector |
predictRaw(FeaturesType features)
Raw prediction for each possible label.
|
protected double |
raw2prediction(Vector rawPrediction)
Given a vector of raw predictions, select the predicted label.
|
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
M |
setRawPredictionCol(java.lang.String value) |
DataFrame |
transform(DataFrame dataset)
Transforms dataset by reading from
featuresCol , and appending new columns as specified by
parameters:
- predicted labels as predictionCol of type Double
- raw predictions (confidences) as rawPredictionCol of type Vector . |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
featuresDataType, numFeatures, setFeaturesCol, setPredictionCol, transformImpl, transformSchema
transform, transform, transform
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copy, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString, uid
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public M setRawPredictionCol(java.lang.String value)
public abstract int numClasses()
public DataFrame transform(DataFrame dataset)
featuresCol
, and appending new columns as specified by
parameters:
- predicted labels as predictionCol
of type Double
- raw predictions (confidences) as rawPredictionCol
of type Vector
.
transform
in class PredictionModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>>
dataset
- input datasetprotected double predict(FeaturesType features)
transform()
and output predictionCol
.
This default implementation for classification predicts the index of the maximum value
from predictRaw()
.
predict
in class PredictionModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>>
features
- (undocumented)protected abstract Vector predictRaw(FeaturesType features)
transform()
and output rawPredictionCol
.
features
- (undocumented)protected double raw2prediction(Vector rawPrediction)
rawPrediction
- (undocumented)public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()