seqdata.get_torch_dataloader¶
- seqdata.get_torch_dataloader(sdata, sample_dims, variables, transforms=None, dtypes=torch.float32, *, return_tuples=False, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, pin_memory=False, drop_last=False, timeout=0.0, worker_init_fn=None, multiprocessing_context=None, generator=None, prefetch_factor=None, persistent_workers=False)¶
Get a PyTorch DataLoader for this SeqData.
- Parameters:
sample_dims (str or list[str]) – Sample dimensions that will be indexed over when fetching batches. For example, if
sample_dims = ['_sequence', 'sample']for a variable with dimensions['_sequence', 'length', 'sample']then a batch of data will have dimensions['batch', 'length'].variables (list[str]) – Which variables to sample from.
transforms (Dict[str | tuple[str], (ndarray | tuple[ndarray]) -> ndarray], optional) – Transforms to apply to each variable. Will be applied in order and keys that are tuples of strings will pass the corresponding variables to the transform in the order that the variable names appear. See examples for details.
dtypes (torch.dtype, Dict[str, torch.dtype]) – Data type to convert each variable to after applying all transforms.
parameters (For other) –
[DataLoader](https (see documentation for) –
- Return type:
DataLoader that returns dictionaries or tuples of tensors.