Note that the goal here isn’t necessarily to fit the best model; rather it’s just to demonstrate an sklearn pipeline. Also note that I wouldn’t call myself an expert python programmer, so there may be better/more efficient ways to do this.
import polars as plimport numpy as npimport mathfrom sklearn.model_selection import train_test_splitfrom sklearn.pipeline import Pipelinefrom sklearn.preprocessing import StandardScaler, OneHotEncoderfrom sklearn.impute import SimpleImputerfrom sklearn.linear_model import LinearRegressionfrom sklearn.compose import ColumnTransformerfrom sklearn.metrics import mean_squared_errorpenguins = pl.read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-07-28/penguins.csv", null_values="NA",)# filtering to only rows with available body mass datapenguins_complete = penguins.filter(pl.col("body_mass_g").is_not_null())# coercing null to nanpenguins_complete = penguins_complete.with_columns(pl.all().fill_null(np.nan))# separate into X and yy = penguins_complete["body_mass_g"]X = penguins_complete.select(pl.exclude("body_mass_g"))# train test splitX_trn, X_tst, y_trn, y_tst = train_test_split(X, y, random_state=408)# create pipeline for categorical featurescat_feats = ["species", "island", "sex"]cat_transform = Pipeline( [ ("cat_imputer", SimpleImputer(strategy="most_frequent")), ("oh_encoder", OneHotEncoder(drop="first")), ])# create pipeline for numerical featurescont_feats = ["bill_length_mm", "bill_depth_mm", "flipper_length_mm", "year"]cont_transform = Pipeline( [ ("cont_imputer", SimpleImputer(strategy="mean")), ("standardizer", StandardScaler()), ])# create a preprocessing pipelinepreprocessor = ColumnTransformer( [("cats", cat_transform, cat_feats), ("conts", cont_transform, cont_feats)])# make full pipelinepipe = Pipeline([("preprocess", preprocessor), ("lin_reg", LinearRegression())])# fit pipelinepipe.fit(X_trn, y_trn)# predict training y'sy_hat = pipe.predict(X_trn)# evaluate modely_hat_tst = pipe.predict(X_tst)math.sqrt(mean_squared_error(y_tst, y_hat_tst))