TorchDistributor#
- class pyspark.ml.torch.distributor.TorchDistributor(num_processes=1, local_mode=True, use_gpu=True, _ssl_conf='pytorch.spark.distributor.ignoreSsl')[source]#
A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
New in version 3.4.0.
Changed in version 3.5.0: Supports Spark Connect.
- Parameters
- num_processesint, optional
An integer that determines how many different concurrent tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default should be 1; we don’t want to invoke multiple cores/gpus without explicit mention.
- local_modebool, optional
A boolean that determines whether we are using the driver node for training. Default should be false; we don’t want to invoke executors without explicit mention.
- use_gpubool, optional
A boolean that indicates whether or not we are doing training on the GPU. Note that there are differences in how GPU-enabled code looks like and how CPU-specific code looks like.
Examples
Run PyTorch Training locally on GPU (using a PyTorch native function)
>>> def train(learning_rate): ... import torch.distributed ... torch.distributed.init_process_group(backend="nccl") ... # ... ... torch.destroy_process_group() ... return model # or anything else ... >>> distributor = TorchDistributor( ... num_processes=2, ... local_mode=True, ... use_gpu=True) >>> model = distributor.run(train, 1e-3)
Run PyTorch Training on GPU (using a file with PyTorch code)
>>> distributor = TorchDistributor( ... num_processes=2, ... local_mode=False, ... use_gpu=True) >>> distributor.run("/path/to/train.py", "--learning-rate=1e-3")
Run PyTorch Lightning Training on GPU
>>> num_proc = 2 >>> def train(): ... from pytorch_lightning import Trainer ... # ... ... # required to set devices = 1 and num_nodes = num_processes for multi node ... # required to set devices = num_processes and num_nodes = 1 for single node multi GPU ... trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp") ... trainer.fit() ... # ... ... return trainer ... >>> distributor = TorchDistributor( ... num_processes=num_proc, ... local_mode=True, ... use_gpu=True) >>> trainer = distributor.run(train)
Methods
run
(train_object, *args, **kwargs)Runs distributed training.
Methods Documentation
- run(train_object, *args, **kwargs)[source]#
Runs distributed training.
- Parameters
- train_objectcallable object or str
Either a PyTorch function, PyTorch Lightning function, or the path to a python file that launches distributed training.
- args
If train_object is a python function and not a path to a python file, args need to be the input parameters to that function. It would look like
>>> model = distributor.run(train, 1e-3, 64)
where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
If train_object is a python file, then args would be the command-line arguments for that python file which are all in the form of strings. An example would be
>>> distributor.run("/path/to/train.py", "--learning-rate=1e-3", "--batch-size=64")
where since the input is a path, all of the parameters are strings that can be handled by argparse in that python file.
- kwargs
If train_object is a python function and not a path to a python file, kwargs need to be the key-word input parameters to that function. It would look like
>>> model = distributor.run(train, tol=1e-3, max_iter=64)
where train is a function of 2 arguments tol and max_iter.
If train_object is a python file, then you should not set kwargs arguments.
- Returns
- Returns the output of train_object called with args inside spark rank 0 task if the
- train_object is a Callable with an expected output. Returns None if train_object is
- a file.