Ben Chuanlong Du's Blog

It is never too late to learn.

Split a Dataset into Train and Test Datasets in Python

Scikit-learn Compatible Packages

sklearn.model_selection.train_test_split is the best way to split a dataset into train and test subset for scikit-learn compatible packages (scikit-learn, XGBoost, LightGBM, etc.). It supports splitting both iterable objects (numpy array, list, pandas Series) and pandas DataFrames. When splitting an iterable object, it returns (train, test) where train and test are lists. When splitting a pandas DataFrame, it returns (train, test) where train and test are pandas DataFrames.

In [5]:
import pandas as pd
In [8]:
df = pd.read_csv("http://www.legendu.net/media/data/iris.csv")
df.head()
Out[8]:
id sepal_length_cm sepal_width_cm petal_length_cm petal_width_cm species
0 1 5.1 3.5 1.4 0.2 Iris-setosa
1 2 4.9 3.0 1.4 0.2 Iris-setosa
2 3 4.7 3.2 1.3 0.2 Iris-setosa
3 4 4.6 3.1 1.5 0.2 Iris-setosa
4 5 5.0 3.6 1.4 0.2 Iris-setosa
In [9]:
df.shape
Out[9]:
(150, 6)
In [10]:
from sklearn.model_selection import train_test_split
In [11]:
train, test = train_test_split(df, test_size=0.2, random_state=119)

Notice that an integer value 119 is passed to the parameter random_state. This is STRONGLY suggested as it enables to reproduce your work later. It is generally a good idea to set a seed for the random number generator when you build a model.

In [12]:
train.head()
Out[12]:
id sepal_length_cm sepal_width_cm petal_length_cm petal_width_cm species
122 123 7.7 2.8 6.7 2.0 Iris-virginica
109 110 7.2 3.6 6.1 2.5 Iris-virginica
34 35 4.9 3.1 1.5 0.1 Iris-setosa
123 124 6.3 2.7 4.9 1.8 Iris-virginica
69 70 5.6 2.5 3.9 1.1 Iris-versicolor
In [13]:
train.shape
Out[13]:
(120, 6)
In [14]:
test.head()
Out[14]:
id sepal_length_cm sepal_width_cm petal_length_cm petal_width_cm species
37 38 4.9 3.1 1.5 0.1 Iris-setosa
63 64 6.1 2.9 4.7 1.4 Iris-versicolor
31 32 5.4 3.4 1.5 0.4 Iris-setosa
102 103 7.1 3.0 5.9 2.1 Iris-virginica
126 127 6.2 2.8 4.8 1.8 Iris-virginica
In [15]:
test.shape
Out[15]:
(30, 6)

More Flexible Splitting of Arrays and DataFrames

If you are not building a model and want to split a pandas DataFrame into many pieces, numpy.array_split comes very convenient. For example, the code below splits a pandas DataFrame into 4 parts. Numpy arrays are also supported of course.

In [ ]:
import numpy as np

dfs = np.split(df, 4)

PyTorch

The best way to split a PyTorch Dataset is to use the function torch.utils.data.random_split which returns (train, test) where train and test are of the type torch.utils.data.dataset.Subset.

In [ ]:
train, test = torch.utils.data.random_split(dataset, [6000, 2055])

References

Comments