Image Classification With Cubical Persistent Homology

In this example, we will show how to use Ripserer in an image classification context. Persistent homology is not a predictive algorithm, but it can be used to extract useful features from data.

Setting up

using Ripserer
using PersistenceDiagrams
using Images # also required: ImageIO to read .png files
using Plots
using ProgressMeter
using Random
Random.seed!(1337)

data_dir = joinpath(@__DIR__, "../assets/data/malaria") # replace with the correct path.

Let's load the data. We will use a a data set with microscope images of healthy cells and cells infected with malaria. The original data set is quite large, but we can pretend we were only given 200 images to work with. We have chosen the 200 images randomly.

uninfected = shuffle!(Images.load.(readdir(joinpath(data_dir, "uninfected"); join=true)))
infected = shuffle!(Images.load.(readdir(joinpath(data_dir, "infected"); join=true)))

images = [uninfected; infected]
classes = [fill(false, length(uninfected)); fill(true, length(infected))]

Let's see what the images look like.

plot(
    plot(uninfected[1]; title="Healthy"),
    plot(uninfected[2]; title="Healthy"),
    plot(infected[1]; title="Infected"),
    plot(infected[2]; title="Infected"),
)

To make the images work with Ripserer, we convert them to floating gray scale values. We do not have to resize the images. Maybe some additional preprocessing, such as normalization would help, but we'll skip it for this example.

inputs = [Gray.(image) for image in images]

Now we can compute persistence diagrams. Since we are working with images, we have to use the Cubical filtration type. Cubical persistent homology should detect the dark spots (local minima) in the images. It's pretty efficient, so this should only take a few seconds.

diagrams = @showprogress [ripserer(Cubical(i)) for i in inputs]
200-element Vector{Vector{PersistenceDiagram}}:
 [562-element 0-dimensional PersistenceDiagram, 323-element 1-dimensional PersistenceDiagram]
 [689-element 0-dimensional PersistenceDiagram, 436-element 1-dimensional PersistenceDiagram]
 [749-element 0-dimensional PersistenceDiagram, 565-element 1-dimensional PersistenceDiagram]
 [1336-element 0-dimensional PersistenceDiagram, 907-element 1-dimensional PersistenceDiagram]
 [678-element 0-dimensional PersistenceDiagram, 481-element 1-dimensional PersistenceDiagram]
 [1235-element 0-dimensional PersistenceDiagram, 818-element 1-dimensional PersistenceDiagram]
 [510-element 0-dimensional PersistenceDiagram, 360-element 1-dimensional PersistenceDiagram]
 [535-element 0-dimensional PersistenceDiagram, 384-element 1-dimensional PersistenceDiagram]
 [760-element 0-dimensional PersistenceDiagram, 487-element 1-dimensional PersistenceDiagram]
 [480-element 0-dimensional PersistenceDiagram, 360-element 1-dimensional PersistenceDiagram]
 ⋮
 [726-element 0-dimensional PersistenceDiagram, 572-element 1-dimensional PersistenceDiagram]
 [457-element 0-dimensional PersistenceDiagram, 318-element 1-dimensional PersistenceDiagram]
 [489-element 0-dimensional PersistenceDiagram, 342-element 1-dimensional PersistenceDiagram]
 [772-element 0-dimensional PersistenceDiagram, 510-element 1-dimensional PersistenceDiagram]
 [659-element 0-dimensional PersistenceDiagram, 452-element 1-dimensional PersistenceDiagram]
 [817-element 0-dimensional PersistenceDiagram, 529-element 1-dimensional PersistenceDiagram]
 [656-element 0-dimensional PersistenceDiagram, 495-element 1-dimensional PersistenceDiagram]
 [404-element 0-dimensional PersistenceDiagram, 292-element 1-dimensional PersistenceDiagram]
 [583-element 0-dimensional PersistenceDiagram, 323-element 1-dimensional PersistenceDiagram]

This is what some of the diagrams look like.

plot(plot(images[1]; title="Healthy"), plot(diagrams[1]))
plot(plot(images[end]; title="Infected"), plot(diagrams[end]))

Notice that there is a lot more going on in the middle of the infected diagram, especially in $H_0$.

The persistence diagrams might look nice, but are hard to use with machine learning algorithms. The number of points in the diagram may be different for every image, even when images are of the same size. We can solve this problem by using a vectorization method, such as converting all diagrams to persistence images.

Persistence images work by weighting each point in the diagram with a distribution. The distribution defaults to a Gaussian, but any function of two arguments can be used. Each point is also weighted by a weighting function that should be equal to zero along the $x$-axis. It defaults to a function that is zero on the $x$-axis and linearly increases to the maximum persistence in the diagram.

We start by splitting the diagrams into their 0 and 1 dimensional components.

dim_0 = first.(diagrams)
dim_1 = last.(diagrams)

We feed the diagram to the PersistenceImage constructor which will choose ranges that will fit all the diagrams. We set the sigma value to 0.1, since all persistence pairs are in the $[0,1]×[0,1]$ square and the default sigma of 1 would be too wide. We will use the default image size, which is 5×5.

image_0 = PersistenceImage(dim_0; sigma=0.1)
5×5 PersistenceImage(
  distribution = PersistenceDiagrams.Binormal(0.1),
  weight = PersistenceDiagrams.DefaultWeightingFunction(0.8713725490196078),
)
image_1 = PersistenceImage(dim_1; sigma=0.1)
5×5 PersistenceImage(
  distribution = PersistenceDiagrams.Binormal(0.1),
  weight = PersistenceDiagrams.DefaultWeightingFunction(0.33215686274509804),
)

Let's see how some of the images look like.

plot(plot(dim_0[end]; persistence=true), heatmap(image_0(dim_0[end]); aspect_ratio=1))
plot(plot(dim_1[end]; persistence=true), heatmap(image_1(dim_1[end]); aspect_ratio=1))

Next, we convert all diagrams to images and use vec to turn them into flat vectors. We then concatenate the zero and one-dimensional images. The result is a vector of length 50 for each diagram.

persims = [[vec(image_0(dim_0[i])); vec(image_1(dim_1[i]))] for i in 1:length(diagrams)]
200-element Vector{Vector{Float64}}:
 [0.18378693755398753, 0.3159004729179951, 5.999029413234287, 14.346932715741573, 9.378986400880102, 3.7126509074541363, 1.2687353172596347, 2.6063392762006763, 6.118152311228485, 3.999498467039263  …  98.30336347139273, 81.97262431935395, 47.38851078734253, 18.939880569781195, 5.205352268274971, 34.52727145077215, 30.960941273495163, 19.179257384160618, 8.171939411757268, 2.378733222128269]
 [1.9511623710037234e-5, 0.022454732389367932, 2.6403925109133373, 18.266469571432033, 20.470237313293467, 0.05882891628088876, 0.022428037414665782, 1.1261241616374436, 7.7894043751744775, 8.729160981979454  …  121.19337830392288, 95.15787179509874, 51.87211746903214, 19.69935649791533, 5.2255414595480945, 83.15803672587785, 66.52537034221021, 36.989335933280046, 14.331337013040272, 3.873603696973592]
 [4.0969323434235456e-5, 0.04941164432344015, 4.8685788695194985, 28.32881919811308, 29.2444819421783, 0.07454702167952737, 0.03764951279139233, 2.0763624888976264, 12.080310731982902, 12.470778271933096  …  121.24492950316424, 87.71810601544772, 42.84905690115792, 14.048194442945066, 3.0692816369294005, 71.74987308186884, 52.161808015463336, 25.616722298919722, 8.448240144338445, 1.8578508295110647]
 [1.1358310547259625e-5, 0.016330780450562958, 2.3204158407487245, 18.056539190821596, 20.971479841175995, 0.05507821154576636, 0.018680684848407014, 0.9896497904354546, 7.699883344920993, 8.942906755899642  …  216.05203014454096, 156.03519181071135, 75.93883236543232, 24.743908185463667, 5.356571081805438, 149.8258258914798, 108.3805698482355, 52.840827259901815, 17.252046091304948, 3.7430628716117873]
 [0.00038579052717306644, 0.14399218976386058, 5.575674385374099, 19.065622434440893, 16.05615228927796, 0.3366144055012067, 0.13051161908970613, 2.378490455051405, 8.130188984286345, 6.846854579135608  …  91.63560609151158, 65.5285446068929, 31.577785169105336, 10.189092827675092, 2.184814135485038, 33.1801287143104, 23.88293983464832, 11.59219880667366, 3.7703710459872, 0.8156616933451422]
 [0.00027989215927874146, 0.07185324652374847, 4.072959813027243, 14.742545822732895, 12.128864747380991, 0.4907591087008036, 0.15225517387146958, 1.7393242460889486, 6.286694798375448, 5.172134123124119  …  214.72211222157864, 157.04407802509175, 77.53238558812345, 25.680910523051367, 5.665231586912662, 91.53347285390332, 67.33317271996347, 33.458469000041234, 11.163730613962992, 2.4830887080287747]
 [0.00014359455559164084, 0.10781570930542317, 6.054832111634374, 24.049503888289816, 21.368744053512835, 0.10310307249265085, 0.06884073789876813, 2.5823159776614277, 10.2554745156537, 9.112312865317744  …  68.26648786567112, 49.25421386933707, 24.05524645132635, 7.916140717607677, 1.7463414263771178, 36.184011685591344, 26.282171993139606, 12.925115552848355, 4.283520513802259, 0.9515790847354259]
 [8.957070020939651e-6, 0.020690570741255573, 2.6244431758501197, 18.43068151507527, 20.47425611762805, 0.01169634137588437, 0.011269701477912068, 1.119177313815325, 7.859429468456025, 8.730874728105888  …  71.6857613506733, 51.23908218702332, 24.70540046252414, 7.9871207029966484, 1.719263546252137, 54.34247036979936, 38.886102303911834, 18.772534622487637, 6.077407336031359, 1.3101817142667864]
 [3.418035627997588e-5, 0.008201232314045518, 1.2633198475537804, 12.191223012193925, 16.591458724429042, 0.12788980935132543, 0.03201244754358883, 0.5391683226765487, 5.198726154375238, 7.075126287765207  …  119.67766646234183, 89.84253077793777, 46.00410973183364, 16.025162164299868, 3.785417644525939, 80.2518898037527, 61.7769860221528, 32.505698138761964, 11.657763255071288, 2.8388022827710095]
 [1.1138416574269145e-5, 0.021258614851079143, 2.637623626058409, 18.406105271406318, 20.457652650114788, 0.020749324003396154, 0.01387474435004396, 1.1248442848804756, 7.848949429335919, 8.723794480891634  …  73.6979868182655, 53.40205066502328, 26.18723354345536, 8.647559120100144, 1.9119060334827898, 50.76375978160655, 36.86524970376519, 18.121997768214687, 6.000240534178019, 1.3304383266385493]
 ⋮
 [0.038177456663513444, 0.1273561992211679, 3.334437739103469, 14.616363787150657, 13.318646746714192, 0.9820158489991613, 1.8304421374825757, 2.596770789623723, 6.330496959756984, 5.679918452468068  …  113.57601299800275, 81.82439390278343, 39.88091852901463, 13.088105258312863, 2.876928494278018, 69.38025714060686, 50.04572632805966, 24.426786406569917, 8.029606613412886, 1.768386406370124]
 [0.32751961619843006, 1.0122065498814636, 8.477445597608988, 13.801132908488386, 6.38815079289151, 2.633033344588887, 2.969248987895892, 5.277486108784153, 6.062748934622155, 2.725121871848892  …  72.67553475001748, 55.56407941887381, 29.569038954876984, 11.097347077627969, 2.9902910265119673, 20.527108265327467, 15.847710195424046, 8.502122713099654, 3.2034422094000736, 0.8602636789953471]
 [0.009456683592938811, 0.07470119635272596, 3.937031549689573, 14.728900367220167, 12.332233957477573, 0.8326203491546252, 0.5654626410134206, 1.7446128553774298, 6.2813709493068774, 5.258857414679481  …  86.68053600525214, 64.68461087358783, 32.8598143884274, 11.331603437311253, 2.644407363898276, 39.72667036569449, 29.891068156950208, 15.316296222685848, 5.329891948409149, 1.2556966278844643]
 [0.33239846035046156, 1.1909184567276347, 13.004437300132757, 33.98835667214518, 24.31553353535771, 3.789778641557496, 3.451385355352616, 7.479684635530825, 14.787511685707152, 10.371431938086113  …  131.54497947121772, 101.51571729145927, 54.600058976714394, 20.846878724942126, 5.817592756357402, 38.288267592608506, 30.927434426842932, 17.54803314099481, 7.11366139009819, 2.1078346858600696]
 [0.3012993138260101, 0.9354060061199079, 6.946137208570551, 29.324463053101958, 27.888865376420963, 1.7567912846338587, 2.7153606666228107, 5.6380369639932, 13.585757735908093, 11.925869848553996  …  100.99531244480603, 72.82337360855844, 35.415764430751786, 11.545844117152264, 2.505156782786974, 52.24886484484996, 37.82451871473657, 18.485789977509043, 6.063755953056579, 1.325890019919151]
 [0.08678504687089142, 0.46500183514744275, 4.671351795090407, 19.56989759108423, 18.047072448917966, 0.8733972420857041, 2.39151559200282, 4.504088207783342, 9.050993804021369, 7.708541574692296  …  126.75698993389202, 91.74217173591478, 44.77301068305383, 14.640748339771008, 3.1836984593920503, 63.74727109496232, 46.20220903368008, 22.580670371756923, 7.394898410664533, 1.6105147674436349]
 [0.46339279442550535, 0.7362981284144161, 5.4088714280585, 18.85554971815369, 16.740121595485657, 2.332249574144515, 2.3868719074147005, 3.1544849209586827, 8.075171805885665, 7.138591721519203  …  104.65009591073562, 77.34603054691404, 39.14538122393634, 13.58190042387118, 3.237511413725706, 50.96256322752152, 37.76330160940184, 19.1498462619718, 6.650413599524648, 1.5842931833817815]
 [0.17858708608462173, 0.34990516709475217, 6.605527398971073, 19.617225489489556, 14.570000486614715, 2.229673102375313, 2.248792771529487, 3.5727537571536816, 8.392010493038141, 6.213159624138093  …  62.279742966082615, 45.50716565043407, 22.66404917450508, 7.6910350189855885, 1.7799236868585542, 27.729472284585924, 20.21048855050869, 10.021262027633353, 3.3765978164182493, 0.7730659672754912]
 [0.3722450398330204, 1.5414294395175012, 13.204876309571432, 23.328335361957848, 11.998074595812863, 5.517177000761088, 3.81094665462445, 7.280186826407634, 10.133171819070103, 5.117489910516538  …  92.0543284425924, 78.55614960276354, 48.570553592838586, 22.96106594189336, 8.978970993511345, 36.78768670708403, 30.942749916870298, 18.241431541479812, 7.682390823048525, 2.409760878322222]

Fitting A Model

Now it's time to fit our model. We will use GLMNet.jl to fit a regularized linear model.

using GLMNet

Convert the image vectors to a matrix that will be understood by glmnet.

X = reduce(hcat, persims)'
y = classes

Start by randomly splitting the data into two sets, a training and a testing set.

perm = shuffle(1:200)
train_x = X[perm[1:100], :]
train_y = y[perm[1:100]]
test_x = X[perm[101:end], :]
test_y = y[perm[101:end]]

Fit the model and predict.

path = glmnet(train_x, train_y)
cv = glmnetcv(train_x, train_y)

λ = path.lambda[argmin(cv.meanloss)]
path = glmnet(train_x, train_y; lambda=[λ])

predictions = .!iszero.(round.(GLMNet.predict(path, test_x)))

Get the classification accuracy.

count(predictions .== test_y) / length(test_y)
0.95

Not half bad considering we haven't touched the images and we left pretty much all settings on default.

Now let's look at the misclassified examples.

missed = findall(predictions .!= test_y)
label = ("Healthy", "Infected")
plts = [plot(images[i]; title="$(label[test_y[i] + 1])", ticks=nothing) for i in missed]
plot(plts...)

Finally, let's look at which parts of the persistence images glmnet considered important.

plot(
    heatmap(reshape(path.betas[1:25], (5, 5)); title="H₀ coefficients"),
    heatmap(reshape(path.betas[26:50], (5, 5)); title="H₁ coefficients"),
)

These correspond to the area we identified at the beginning. Also note that in this case, the classifier does not care about $H_1$ at all.

Using MLJ

Another, more straightforward way to execute a similar pipeline is to use Ripserer's MLJ.jl integration. We will use a random forest classifier for this example.

We start by loading MLJ and the classifier. Not that MLJDecisionTreeInterface.jl needs to be installed for this to work.

using MLJ
tree = @load RandomForestClassifier pkg = "DecisionTree" verbosity = 0
MLJDecisionTreeInterface.RandomForestClassifier

We create a pipeline of CubicalPersistentHomology followed by the classifier. In this case, CubicalPersistentHomology takes care of both the homology computation and the conversion to persistence images.

pipe = @pipeline(CubicalPersistentHomology(), tree)

We train the pipeline the same way you would fit any other MLJ model. Remember, we need to use grayscale versions of images stored in inputs.

classes = coerce(classes, Binary)
train, test = partition(eachindex(classes), 0.7; shuffle=true, rng=1337)
mach = machine(pipe, inputs, classes)
fit!(mach; rows=train)

Next, we predict the classes on the test data and print out the classification accuracy.

yhat = predict_mode(mach, inputs[test])
accuracy(yhat, classes[test])

The result is quite a bit worse than before. We can try mitigating that by using a different vectorizer.

pipe.cubical_persistent_homology.vectorizer = PersistenceCurveVectorizer()
mach = machine(pipe, inputs, classes)
fit!(mach; rows=train)

yhat = predict_mode(mach, inputs[test])
accuracy(yhat, classes[test])

The result could be improved further by choosing a different model and vectorizer. However, this is just a short introduction. Please see the MLJ.jl documentation for more information on model tuning and selection, and the PersistenceDiagrams.jl documentation for a list of vectorizers and their options.


This page was generated using Literate.jl.