public class GeneralizedLinearRegression extends Predictor<FeaturesType,Learner,M>
Fit a Generalized Linear Model (https://en.wikipedia.org/wiki/Generalized_linear_model
)
specified by giving a symbolic description of the linear predictor (link function) and
a description of the error distribution (family).
It supports "gaussian", "binomial", "poisson" and "gamma" as family.
Valid link functions for each family is listed below. The first link function of each family
is the default one.
- "gaussian" -> "identity", "log", "inverse"
- "binomial" -> "logit", "probit", "cloglog"
- "poisson" -> "log", "identity", "sqrt"
- "gamma" -> "inverse", "identity", "log"
Modifier and Type | Class and Description |
---|---|
static class |
GeneralizedLinearRegression.Binomial$
Binomial exponential family distribution.
|
static class |
GeneralizedLinearRegression.CLogLog$ |
static class |
GeneralizedLinearRegression.Family$ |
static class |
GeneralizedLinearRegression.Gamma$
Gamma exponential family distribution.
|
static class |
GeneralizedLinearRegression.Gaussian$
Gaussian exponential family distribution.
|
static class |
GeneralizedLinearRegression.Identity$ |
static class |
GeneralizedLinearRegression.Inverse$ |
static class |
GeneralizedLinearRegression.Link$ |
static class |
GeneralizedLinearRegression.Log$ |
static class |
GeneralizedLinearRegression.Logit$ |
static class |
GeneralizedLinearRegression.Poisson$
Poisson exponential family distribution.
|
static class |
GeneralizedLinearRegression.Probit$ |
static class |
GeneralizedLinearRegression.Sqrt$ |
Constructor and Description |
---|
GeneralizedLinearRegression() |
GeneralizedLinearRegression(java.lang.String uid) |
Modifier and Type | Method and Description |
---|---|
protected static <T> T |
$(Param<T> param) |
static Params |
clear(Param<?> param) |
GeneralizedLinearRegression |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
protected static <T extends Params> |
copyValues(T to,
ParamMap extra) |
protected static <T extends Params> |
copyValues$default$2() |
protected static <T extends Params> |
defaultCopy(ParamMap extra) |
static java.lang.String |
explainParam(Param<?> param) |
static java.lang.String |
explainParams() |
protected static RDD<LabeledPoint> |
extractLabeledPoints(Dataset<?> dataset) |
static ParamMap |
extractParamMap() |
static ParamMap |
extractParamMap(ParamMap extra) |
static Param<java.lang.String> |
family() |
Param<java.lang.String> |
family()
Param for the name of family which is a description of the error distribution
to be used in the model.
|
static Param<java.lang.String> |
featuresCol() |
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
static M |
fit(Dataset<?> dataset) |
static M |
fit(Dataset<?> dataset,
ParamMap paramMap) |
static scala.collection.Seq<M> |
fit(Dataset<?> dataset,
ParamMap[] paramMaps) |
static M |
fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
ParamPair<?>... otherParamPairs) |
static M |
fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
scala.collection.Seq<ParamPair<?>> otherParamPairs) |
static BooleanParam |
fitIntercept() |
static <T> scala.Option<T> |
get(Param<T> param) |
static <T> scala.Option<T> |
getDefault(Param<T> param) |
static java.lang.String |
getFamily() |
java.lang.String |
getFamily() |
static java.lang.String |
getFeaturesCol() |
java.lang.String |
getFeaturesCol() |
static boolean |
getFitIntercept() |
static java.lang.String |
getLabelCol() |
java.lang.String |
getLabelCol() |
static java.lang.String |
getLink() |
java.lang.String |
getLink() |
static java.lang.String |
getLinkPredictionCol() |
java.lang.String |
getLinkPredictionCol() |
static int |
getMaxIter() |
static <T> T |
getOrDefault(Param<T> param) |
static Param<java.lang.Object> |
getParam(java.lang.String paramName) |
static java.lang.String |
getPredictionCol() |
java.lang.String |
getPredictionCol() |
static double |
getRegParam() |
static java.lang.String |
getSolver() |
static double |
getTol() |
static java.lang.String |
getWeightCol() |
static <T> boolean |
hasDefault(Param<T> param) |
static boolean |
hasParam(java.lang.String paramName) |
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
static boolean |
isDefined(Param<?> param) |
static boolean |
isSet(Param<?> param) |
protected static boolean |
isTraceEnabled() |
static Param<java.lang.String> |
labelCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
static Param<java.lang.String> |
link() |
Param<java.lang.String> |
link()
Param for the name of link function which provides the relationship
between the linear predictor and the mean of the distribution function.
|
static Param<java.lang.String> |
linkPredictionCol() |
Param<java.lang.String> |
linkPredictionCol()
Param for link prediction (linear predictor) column name.
|
static GeneralizedLinearRegression |
load(java.lang.String path) |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
static IntParam |
maxIter() |
static Param<?>[] |
params() |
static Param<java.lang.String> |
predictionCol() |
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
static DoubleParam |
regParam() |
static void |
save(java.lang.String path) |
static <T> Params |
set(Param<T> param,
T value) |
protected static Params |
set(ParamPair<?> paramPair) |
protected static Params |
set(java.lang.String param,
java.lang.Object value) |
protected static <T> Params |
setDefault(Param<T> param,
T value) |
protected static Params |
setDefault(scala.collection.Seq<ParamPair<?>> paramPairs) |
GeneralizedLinearRegression |
setFamily(java.lang.String value)
Sets the value of param
family . |
static Learner |
setFeaturesCol(java.lang.String value) |
GeneralizedLinearRegression |
setFitIntercept(boolean value)
Sets if we should fit the intercept.
|
static Learner |
setLabelCol(java.lang.String value) |
GeneralizedLinearRegression |
setLink(java.lang.String value)
Sets the value of param
link . |
GeneralizedLinearRegression |
setLinkPredictionCol(java.lang.String value)
Sets the link prediction (linear predictor) column name.
|
GeneralizedLinearRegression |
setMaxIter(int value)
Sets the maximum number of iterations (applicable for solver "irls").
|
static Learner |
setPredictionCol(java.lang.String value) |
GeneralizedLinearRegression |
setRegParam(double value)
Sets the regularization parameter for L2 regularization.
|
GeneralizedLinearRegression |
setSolver(java.lang.String value)
Sets the solver algorithm used for optimization.
|
GeneralizedLinearRegression |
setTol(double value)
Sets the convergence tolerance of iterations.
|
GeneralizedLinearRegression |
setWeightCol(java.lang.String value)
Sets the value of param
weightCol . |
static Param<java.lang.String> |
solver() |
static DoubleParam |
tol() |
static java.lang.String |
toString() |
protected GeneralizedLinearRegressionModel |
train(Dataset<?> dataset)
Train a model using the given dataset and parameters.
|
static StructType |
transformSchema(StructType schema) |
protected static StructType |
transformSchema(StructType schema,
boolean logging) |
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
static StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
static void |
validateParams() |
static Param<java.lang.String> |
weightCol() |
static MLWriter |
write() |
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
extractLabeledPoints, fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
save
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString
public GeneralizedLinearRegression(java.lang.String uid)
public GeneralizedLinearRegression()
public static GeneralizedLinearRegression load(java.lang.String path)
public static java.lang.String toString()
public static Param<?>[] params()
public static void validateParams()
public static java.lang.String explainParam(Param<?> param)
public static java.lang.String explainParams()
public static final boolean isSet(Param<?> param)
public static final boolean isDefined(Param<?> param)
public static boolean hasParam(java.lang.String paramName)
public static Param<java.lang.Object> getParam(java.lang.String paramName)
protected static final Params set(java.lang.String param, java.lang.Object value)
public static final <T> scala.Option<T> get(Param<T> param)
public static final <T> T getOrDefault(Param<T> param)
protected static final <T> T $(Param<T> param)
public static final <T> scala.Option<T> getDefault(Param<T> param)
public static final <T> boolean hasDefault(Param<T> param)
public static final ParamMap extractParamMap()
protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
protected static StructType transformSchema(StructType schema, boolean logging)
public static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.Seq<ParamPair<?>> otherParamPairs)
public static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
public static final Param<java.lang.String> labelCol()
public static final java.lang.String getLabelCol()
public static final Param<java.lang.String> featuresCol()
public static final java.lang.String getFeaturesCol()
public static final Param<java.lang.String> predictionCol()
public static final java.lang.String getPredictionCol()
public static Learner setLabelCol(java.lang.String value)
public static Learner setFeaturesCol(java.lang.String value)
public static Learner setPredictionCol(java.lang.String value)
public static M fit(Dataset<?> dataset)
public static StructType transformSchema(StructType schema)
protected static RDD<LabeledPoint> extractLabeledPoints(Dataset<?> dataset)
public static final BooleanParam fitIntercept()
public static final boolean getFitIntercept()
public static final IntParam maxIter()
public static final int getMaxIter()
public static final DoubleParam tol()
public static final double getTol()
public static final DoubleParam regParam()
public static final double getRegParam()
public static final Param<java.lang.String> weightCol()
public static final java.lang.String getWeightCol()
public static final Param<java.lang.String> solver()
public static final java.lang.String getSolver()
public static final Param<java.lang.String> family()
public static java.lang.String getFamily()
public static final Param<java.lang.String> link()
public static java.lang.String getLink()
public static final Param<java.lang.String> linkPredictionCol()
public static java.lang.String getLinkPredictionCol()
public static StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public static void save(java.lang.String path) throws java.io.IOException
java.io.IOException
public static MLWriter write()
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
public GeneralizedLinearRegression setFamily(java.lang.String value)
family
.
Default is "gaussian".
value
- (undocumented)public GeneralizedLinearRegression setLink(java.lang.String value)
link
.
value
- (undocumented)public GeneralizedLinearRegression setFitIntercept(boolean value)
value
- (undocumented)public GeneralizedLinearRegression setMaxIter(int value)
value
- (undocumented)public GeneralizedLinearRegression setTol(double value)
value
- (undocumented)public GeneralizedLinearRegression setRegParam(double value)
0.5 * regParam * L2norm(coefficients)^2
Default is 0.0.
value
- (undocumented)public GeneralizedLinearRegression setWeightCol(java.lang.String value)
weightCol
.
If this is not set or empty, we treat all instance weights as 1.0.
Default is empty, so all instances have weight one.
value
- (undocumented)public GeneralizedLinearRegression setSolver(java.lang.String value)
value
- (undocumented)public GeneralizedLinearRegression setLinkPredictionCol(java.lang.String value)
value
- (undocumented)protected GeneralizedLinearRegressionModel train(Dataset<?> dataset)
Predictor
fit()
to avoid dealing with schema validation
and copying parameters into the model.
train
in class Predictor<Vector,GeneralizedLinearRegression,GeneralizedLinearRegressionModel>
dataset
- Training datasetpublic GeneralizedLinearRegression copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Predictor<Vector,GeneralizedLinearRegression,GeneralizedLinearRegressionModel>
extra
- (undocumented)defaultCopy()
public Param<java.lang.String> family()
public java.lang.String getFamily()
public Param<java.lang.String> link()
public java.lang.String getLink()
public Param<java.lang.String> linkPredictionCol()
public java.lang.String getLinkPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public MLWriter write()
MLWritable
MLWriter
instance for this ML instance.write
in interface MLWritable
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()