public class LogisticRegressionWithSGD extends GeneralizedLinearAlgorithm<LogisticRegressionModel> implements scala.Serializable
LogisticRegressionWithSGD.optimizer
.
NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
for k classes multi-label classification problem.
Using LogisticRegressionWithLBFGS
is recommended over this.Constructor and Description |
---|
LogisticRegressionWithSGD()
Deprecated.
Construct a LogisticRegression object with default parameters: {stepSize: 1.0,
numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}.
|
Modifier and Type | Method and Description |
---|---|
protected static void |
addIntercept_$eq(boolean x$1)
Deprecated.
|
protected static boolean |
addIntercept()
Deprecated.
|
protected LogisticRegressionModel |
createModel(Vector weights,
double intercept)
Deprecated.
Create a model given the weights and intercept
|
protected static Vector |
generateInitialWeights(RDD<LabeledPoint> input)
Deprecated.
|
static int |
getNumFeatures()
Deprecated.
|
protected static void |
initializeLogIfNecessary(boolean isInterpreter)
Deprecated.
|
static boolean |
isAddIntercept()
Deprecated.
|
protected static boolean |
isTraceEnabled()
Deprecated.
|
protected static org.slf4j.Logger |
log()
Deprecated.
|
protected static void |
logDebug(scala.Function0<java.lang.String> msg)
Deprecated.
|
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
Deprecated.
|
protected static void |
logError(scala.Function0<java.lang.String> msg)
Deprecated.
|
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
Deprecated.
|
protected static void |
logInfo(scala.Function0<java.lang.String> msg)
Deprecated.
|
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
Deprecated.
|
protected static java.lang.String |
logName()
Deprecated.
|
protected static void |
logTrace(scala.Function0<java.lang.String> msg)
Deprecated.
|
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
Deprecated.
|
protected static void |
logWarning(scala.Function0<java.lang.String> msg)
Deprecated.
|
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable)
Deprecated.
|
protected static void |
numFeatures_$eq(int x$1)
Deprecated.
|
protected static int |
numFeatures()
Deprecated.
|
protected static void |
numOfLinearPredictor_$eq(int x$1)
Deprecated.
|
protected static int |
numOfLinearPredictor()
Deprecated.
|
GradientDescent |
optimizer()
Deprecated.
The optimizer to solve the problem.
|
static M |
run(RDD<LabeledPoint> input)
Deprecated.
|
static M |
run(RDD<LabeledPoint> input,
Vector initialWeights)
Deprecated.
|
static GeneralizedLinearAlgorithm<M> |
setIntercept(boolean addIntercept)
Deprecated.
|
static GeneralizedLinearAlgorithm<M> |
setValidateData(boolean validateData)
Deprecated.
|
static LogisticRegressionModel |
train(RDD<LabeledPoint> input,
int numIterations)
Deprecated.
Train a logistic regression model given an RDD of (label, features) pairs.
|
static LogisticRegressionModel |
train(RDD<LabeledPoint> input,
int numIterations,
double stepSize)
Deprecated.
Train a logistic regression model given an RDD of (label, features) pairs.
|
static LogisticRegressionModel |
train(RDD<LabeledPoint> input,
int numIterations,
double stepSize,
double miniBatchFraction)
Deprecated.
Train a logistic regression model given an RDD of (label, features) pairs.
|
static LogisticRegressionModel |
train(RDD<LabeledPoint> input,
int numIterations,
double stepSize,
double miniBatchFraction,
Vector initialWeights)
Deprecated.
Train a logistic regression model given an RDD of (label, features) pairs.
|
protected static void |
validateData_$eq(boolean x$1)
Deprecated.
|
protected static boolean |
validateData()
Deprecated.
|
protected scala.collection.immutable.List<scala.Function1<RDD<LabeledPoint>,java.lang.Object>> |
validators()
Deprecated.
|
addIntercept, generateInitialWeights, getNumFeatures, isAddIntercept, numFeatures, numOfLinearPredictor, run, run, setIntercept, setValidateData, validateData
public LogisticRegressionWithSGD()
public static LogisticRegressionModel train(RDD<LabeledPoint> input, int numIterations, double stepSize, double miniBatchFraction, Vector initialWeights)
miniBatchFraction
fraction of the data to calculate the gradient. The weights used in
gradient descent are initialized using the initial weights provided.
NOTE: Labels used in Logistic Regression should be {0, 1}
input
- RDD of (label, array of features) pairs.numIterations
- Number of iterations of gradient descent to run.stepSize
- Step size to be used for each iteration of gradient descent.miniBatchFraction
- Fraction of data to be used per iteration.initialWeights
- Initial set of weights to be used. Array should be equal in size to
the number of features in the data.public static LogisticRegressionModel train(RDD<LabeledPoint> input, int numIterations, double stepSize, double miniBatchFraction)
miniBatchFraction
fraction of the data to calculate the gradient.
NOTE: Labels used in Logistic Regression should be {0, 1}
input
- RDD of (label, array of features) pairs.numIterations
- Number of iterations of gradient descent to run.stepSize
- Step size to be used for each iteration of gradient descent.
miniBatchFraction
- Fraction of data to be used per iteration.public static LogisticRegressionModel train(RDD<LabeledPoint> input, int numIterations, double stepSize)
input
- RDD of (label, array of features) pairs.stepSize
- Step size to be used for each iteration of Gradient Descent.
numIterations
- Number of iterations of gradient descent to run.public static LogisticRegressionModel train(RDD<LabeledPoint> input, int numIterations)
input
- RDD of (label, array of features) pairs.numIterations
- Number of iterations of gradient descent to run.protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
protected static boolean addIntercept()
protected static void addIntercept_$eq(boolean x$1)
protected static boolean validateData()
protected static void validateData_$eq(boolean x$1)
protected static int numOfLinearPredictor()
protected static void numOfLinearPredictor_$eq(int x$1)
public static int getNumFeatures()
protected static int numFeatures()
protected static void numFeatures_$eq(int x$1)
public static boolean isAddIntercept()
public static GeneralizedLinearAlgorithm<M> setIntercept(boolean addIntercept)
public static GeneralizedLinearAlgorithm<M> setValidateData(boolean validateData)
protected static Vector generateInitialWeights(RDD<LabeledPoint> input)
public static M run(RDD<LabeledPoint> input)
public static M run(RDD<LabeledPoint> input, Vector initialWeights)
public GradientDescent optimizer()
GeneralizedLinearAlgorithm
optimizer
in class GeneralizedLinearAlgorithm<LogisticRegressionModel>
protected scala.collection.immutable.List<scala.Function1<RDD<LabeledPoint>,java.lang.Object>> validators()
validators
in class GeneralizedLinearAlgorithm<LogisticRegressionModel>
protected LogisticRegressionModel createModel(Vector weights, double intercept)
GeneralizedLinearAlgorithm
createModel
in class GeneralizedLinearAlgorithm<LogisticRegressionModel>
weights
- (undocumented)intercept
- (undocumented)