public class LinearRegression extends Predictor<FeaturesType,Learner,M> implements Logging
The learning objective is to minimize the squared error, with regularization. The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^
This support multiple types of regularization: - none (a.k.a. ordinary least squares) - L2 (ridge regression) - L1 (Lasso) - L2 + L1 (elastic net)
Constructor and Description |
---|
LinearRegression() |
LinearRegression(java.lang.String uid) |
Modifier and Type | Method and Description |
---|---|
LinearRegression |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
java.lang.String |
getPredictionCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
static LinearRegression |
load(java.lang.String path) |
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
LinearRegression |
setElasticNetParam(double value)
Set the ElasticNet mixing parameter.
|
LinearRegression |
setFitIntercept(boolean value)
Set if we should fit the intercept
Default is true.
|
LinearRegression |
setMaxIter(int value)
Set the maximum number of iterations.
|
LinearRegression |
setRegParam(double value)
Set the regularization parameter.
|
LinearRegression |
setSolver(java.lang.String value)
Set the solver algorithm used for optimization.
|
LinearRegression |
setStandardization(boolean value)
Whether to standardize the training features before fitting the model.
|
LinearRegression |
setTol(double value)
Set the convergence tolerance of iterations.
|
LinearRegression |
setWeightCol(java.lang.String value)
Whether to over-/under-sample training instances according to the given weights in weightCol.
|
protected LinearRegressionModel |
train(DataFrame dataset)
Train a model using the given dataset and parameters.
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
extractLabeledPoints, fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString
public LinearRegression(java.lang.String uid)
public LinearRegression()
public static LinearRegression load(java.lang.String path)
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
public LinearRegression setRegParam(double value)
value
- (undocumented)public LinearRegression setFitIntercept(boolean value)
value
- (undocumented)public LinearRegression setStandardization(boolean value)
value
- (undocumented)public LinearRegression setElasticNetParam(double value)
value
- (undocumented)public LinearRegression setMaxIter(int value)
value
- (undocumented)public LinearRegression setTol(double value)
value
- (undocumented)public LinearRegression setWeightCol(java.lang.String value)
value
- (undocumented)public LinearRegression setSolver(java.lang.String value)
value
- (undocumented)protected LinearRegressionModel train(DataFrame dataset)
Predictor
fit()
to avoid dealing with schema validation
and copying parameters into the model.
train
in class Predictor<Vector,LinearRegression,LinearRegressionModel>
dataset
- Training datasetpublic LinearRegression copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Predictor<Vector,LinearRegression,LinearRegressionModel>
extra
- (undocumented)defaultCopy()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()