public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel> implements MLWritable
LogisticRegression
.Modifier and Type | Method and Description |
---|---|
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
Vector |
coefficients() |
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
java.lang.String |
getPredictionCol() |
java.lang.String |
getRawPredictionCol() |
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
double |
intercept() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
static LogisticRegressionModel |
load(java.lang.String path) |
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
protected double |
predict(Vector features)
Predict label for the given feature vector.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
protected Vector |
predictRaw(Vector features)
Raw prediction for each possible label.
|
protected double |
probability2prediction(Vector probability)
Given a vector of class conditional probabilities, select the predicted label.
|
protected double |
raw2prediction(Vector rawPrediction)
Given a vector of raw predictions, select the predicted label.
|
protected Vector |
raw2probabilityInPlace(Vector rawPrediction)
Estimate the probability of each class given the raw prediction,
doing the computation in-place.
|
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
static MLReader<LogisticRegressionModel> |
read() |
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
void |
validateParams() |
Vector |
weights() |
MLWriter |
write()
Returns a
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, predictProbability, raw2probability, setProbabilityCol, transform
setRawPredictionCol
featuresDataType, setFeaturesCol, setPredictionCol, transformImpl, transformSchema
transform, transform, transform
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
save
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public static MLReader<LogisticRegressionModel> read()
public static LogisticRegressionModel load(java.lang.String path)
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
public Vector coefficients()
public double intercept()
public Vector weights()
public LogisticRegressionModel setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p))
.
When setThreshold()
is called, any user-set value for thresholds
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value
- (undocumented)public double getThreshold()
If threshold
is set, returns that value.
Otherwise, if thresholds
is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1))
.
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegressionModel setThresholds(double[] value)
Note: When setThresholds()
is called, any user-set value for threshold
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
setThresholds
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
value
- (undocumented)public double[] getThresholds()
If thresholds
is set, return its value.
Otherwise, if threshold
is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,LogisticRegressionModel>
public int numClasses()
ClassificationModel
numClasses
in class ClassificationModel<Vector,LogisticRegressionModel>
public LogisticRegressionTrainingSummary summary()
trainingSummary == None
.public boolean hasSummary()
protected double predict(Vector features)
thresholds
.predict
in class ClassificationModel<Vector,LogisticRegressionModel>
features
- (undocumented)protected Vector raw2probabilityInPlace(Vector rawPrediction)
ProbabilisticClassificationModel
This internal method is used to implement transform()
and output probabilityCol
.
raw2probabilityInPlace
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
rawPrediction
- (undocumented)protected Vector predictRaw(Vector features)
ClassificationModel
transform()
and output rawPredictionCol
.
predictRaw
in class ClassificationModel<Vector,LogisticRegressionModel>
features
- (undocumented)public LogisticRegressionModel copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Model<LogisticRegressionModel>
extra
- (undocumented)defaultCopy()
protected double raw2prediction(Vector rawPrediction)
ClassificationModel
raw2prediction
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
rawPrediction
- (undocumented)protected double probability2prediction(Vector probability)
ProbabilisticClassificationModel
probability2prediction
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
probability
- (undocumented)public MLWriter write()
MLWriter
instance for this ML instance.
For LogisticRegressionModel
, this does NOT currently save the training summary
.
An option to save summary
may be added in the future.
This also does not save the parent
currently.
write
in interface MLWritable
public void checkThresholdConsistency()
threshold
and thresholds
are both set, ensures they are consistent.java.lang.IllegalArgumentException
- if threshold
and thresholds
are not equivalentpublic void validateParams()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()