public class GaussianMixtureModel extends java.lang.Object implements scala.Serializable, Saveable
param: weights Weights for each Gaussian distribution in the mixture, where weights(i) is the weight for Gaussian i, and weights.sum == 1 param: gaussians Array of MultivariateGaussian where gaussians(i) represents the Multivariate Gaussian (Normal) Distribution for Gaussian i
Constructor and Description |
---|
GaussianMixtureModel(double[] weights,
MultivariateGaussian[] gaussians) |
Modifier and Type | Method and Description |
---|---|
protected java.lang.String |
formatVersion()
Current version of model save/load format.
|
MultivariateGaussian[] |
gaussians() |
int |
k()
Number of gaussians in mixture
|
static GaussianMixtureModel |
load(SparkContext sc,
java.lang.String path) |
JavaRDD<java.lang.Integer> |
predict(JavaRDD<Vector> points)
Java-friendly version of
predict() |
RDD<java.lang.Object> |
predict(RDD<Vector> points)
Maps given points to their cluster indices.
|
int |
predict(Vector point)
Maps given point to its cluster index.
|
RDD<double[]> |
predictSoft(RDD<Vector> points)
Given the input vectors, return the membership value of each vector
to all mixture components.
|
double[] |
predictSoft(Vector point)
Given the input vector, return the membership values to all mixture components.
|
void |
save(SparkContext sc,
java.lang.String path)
Save this model to the given path.
|
double[] |
weights() |
public GaussianMixtureModel(double[] weights, MultivariateGaussian[] gaussians)
public static GaussianMixtureModel load(SparkContext sc, java.lang.String path)
public double[] weights()
public MultivariateGaussian[] gaussians()
protected java.lang.String formatVersion()
Saveable
formatVersion
in interface Saveable
public void save(SparkContext sc, java.lang.String path)
Saveable
This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using Loader.load
.
public int k()
public RDD<java.lang.Object> predict(RDD<Vector> points)
points
- (undocumented)public int predict(Vector point)
point
- (undocumented)public JavaRDD<java.lang.Integer> predict(JavaRDD<Vector> points)
predict()
points
- (undocumented)public RDD<double[]> predictSoft(RDD<Vector> points)
points
- (undocumented)public double[] predictSoft(Vector point)
point
- (undocumented)