相关#

import polars as pl
from plotnine import *

Regression Plot#

mpg = pl.read_csv("data/mpg.csv")
mpg.head()
shape: (5, 11)
manufacturermodeldisplyearcyltransdrvctyhwyflclass
strstrf64i64i64strstri64i64strstr
"audi""a4"1.819994"auto(l5)""f"1829"p""compact"
"audi""a4"1.819994"manual(m5)""f"2129"p""compact"
"audi""a4"2.020084"manual(m6)""f"2031"p""compact"
"audi""a4"2.020084"auto(av)""f"2130"p""compact"
"audi""a4"2.819996"auto(l5)""f"1626"p""compact"
mpg_select = mpg.filter((pl.col("cyl") == 4) | (pl.col("cyl") == 8))

(
    ggplot(mpg_select, aes(x="displ", y="hwy", color="cyl", fill="cyl"))
    + geom_point(size=2)
    + geom_smooth(method="lm", size=1)
    + facet_wrap("cyl")
    + lims(x=(0, 8), y=(0, 50))
    + labs(title="Scatterplot with line of best fit grouped by number of cylinders")
    + theme(plot_title=element_text(hjust=0.5), legend_position="none")
    + scale_color_gradient(low="#E69F00", high="#56B4E9")
    + scale_fill_gradient(low="#E69F00", high="#56B4E9")
)

Jittering#

g = ggplot(mpg, aes(x="cty", y="hwy"))

(
    g
    + geom_jitter(height=0.5, size=2, color="#06527f")
    + labs(title="Jittered Points")
    + theme(plot_title=element_text(hjust=0.5))
)
(
    g
    + geom_count(color="red")
    + labs(title="Counts Plot")
    + theme(plot_title=element_text(hjust=0.5), legend_position="none")
)

FacetGrid#

mtcars = pl.read_csv("data/mtcars.csv")
mtcars.head()
shape: (5, 13)
mpgcyldisphpdratwtqsecvsamgearcarbfastcars
f64i64f64i64f64f64f64i64i64i64i64i64str
4.5825766160.01103.92.6216.4601441"Mazda RX4"
4.5825766160.01103.92.87517.0201441"Mazda RX4 Wag"
4.7749354108.0933.852.3218.6111411"Datsun 710"
4.6260136258.01103.083.21519.4410311"Hornet 4 Drive"
4.324358360.01753.153.4417.0200321"Hornet Sportabout"
(
    ggplot(mtcars, aes("mpg", "wt"))
    + geom_point(color="#06527f")
    + facet_wrap(["am", "vs"])
)

Heatmap#

mtcars
shape: (32, 13)
mpgcyldisphpdratwtqsecvsamgearcarbfastcars
f64i64f64i64f64f64f64i64i64i64i64i64str
4.5825766160.01103.92.6216.4601441"Mazda RX4"
4.5825766160.01103.92.87517.0201441"Mazda RX4 Wag"
4.7749354108.0933.852.3218.6111411"Datsun 710"
4.6260136258.01103.083.21519.4410311"Hornet 4 Drive"
4.324358360.01753.153.4417.0200321"Hornet Sportabout"
5.51362495.11133.771.51316.911521"Lotus Europa"
3.9749218351.02644.223.1714.501540"Ford Pantera L"
4.4384686145.01753.622.7715.501561"Ferrari Dino"
3.8729838301.03353.543.5714.601580"Maserati Bora"
4.6260134121.01094.112.7818.611421"Volvo 142E"
corr_matrix = mtcars.select(pl.selectors.by_dtype([pl.Float64, pl.Int64])).corr()
corr_matrix.head()
shape: (5, 12)
mpgcyldisphpdratwtqsecvsamgearcarbfast
f64f64f64f64f64f64f64f64f64f64f64f64
1.0-0.858539-0.867536-0.7873090.680312-0.8834530.4203170.669260.5931530.487226-0.5537030.730748
-0.8585391.00.9020330.832447-0.6999380.782496-0.591242-0.810812-0.522607-0.4926870.526988-0.695182
-0.8675360.9020331.00.790949-0.7102140.88798-0.433698-0.710416-0.591227-0.5555690.394977-0.732073
-0.7873090.8324470.7909491.0-0.4487590.658748-0.708223-0.723097-0.243204-0.1257040.749812-0.751422
0.680312-0.699938-0.710214-0.4487591.0-0.7124410.0912050.4402780.7127110.69961-0.090790.40043
corr_matrix2 = (
    corr_matrix.with_row_index()
    .unpivot(index="index")
    .with_columns(pl.col("index").cast(pl.Float32) + 1)
    .sort(by="value", descending=True)
)
corr_matrix2.head()
shape: (5, 3)
indexvariablevalue
f32strf64
1.0"mpg"1.0
2.0"cyl"1.0
3.0"disp"1.0
4.0"hp"1.0
6.0"wt"1.0
(
    ggplot(corr_matrix2, aes("index", "variable", fill="value"))
    + geom_tile()
    + scale_fill_gradient(low="#0c71b0", high="#ff0000")
    + geom_text(aes(label="value"), size=8, format_string="{:.2f}", color="white")
    + labs(title="Correlogram of mtcars")
    + scale_x_discrete(limits=corr_matrix.columns)
    + scale_y_discrete(limits=corr_matrix.columns[::-1])
    + coord_fixed()
)

Marginal Plot#

g1 = ggplot(mpg) + geom_point(aes("displ", "hwy", color="manufacturer"))
g2 = ggplot(mpg) + geom_boxplot(aes("displ", "hwy"))
g3 = ggplot(mpg) + geom_boxplot(aes("hwy", "displ")) + coord_flip()
g1