L'algoritmo di clustering K-Means in Java

1. Panoramica

Il clustering è un termine generico per una classe di algoritmi non supervisionati per scoprire gruppi di cose, persone o idee strettamente correlate tra loro .

In questa definizione di una riga apparentemente semplice, abbiamo visto alcune parole d'ordine. Cos'è esattamente il clustering? Cos'è un algoritmo non supervisionato?

In questo tutorial, per prima cosa faremo luce su questi concetti. Quindi, vedremo come possono manifestarsi in Java.

2. Algoritmi non supervisionati

Prima di utilizzare la maggior parte degli algoritmi di apprendimento, dovremmo in qualche modo fornire loro alcuni dati di esempio e consentire all'algoritmo di apprendere da quei dati. Nella terminologia di Machine Learning, si chiamano dati di addestramento del set di dati di esempio. Inoltre, l'intero processo è noto come processo di formazione.

Ad ogni modo, possiamo classificare gli algoritmi di apprendimento in base alla quantità di supervisione di cui hanno bisogno durante il processo di formazione. I due principali tipi di algoritmi di apprendimento in questa categoria sono:

  • Apprendimento supervisionato : negli algoritmi supervisionati, i dati di addestramento dovrebbero includere la soluzione effettiva per ogni punto. Ad esempio, se stiamo per addestrare il nostro algoritmo di filtraggio dello spam, alimentiamo sia le email di esempio che la loro etichetta, ovvero spam o non spam, all'algoritmo. Matematicamente parlando, dedurremo f (x) da un set di addestramento che includa sia xs che ys.
  • Apprendimento non supervisionato : quando non ci sono etichette nei dati di addestramento, l'algoritmo è uno non supervisionato. Ad esempio, abbiamo molti dati sui musicisti e scopriremo gruppi di musicisti simili nei dati.

3. Clustering

Il clustering è un algoritmo non supervisionato per scoprire gruppi di cose, idee o persone simili. A differenza degli algoritmi supervisionati, non stiamo addestrando algoritmi di clustering con esempi di etichette note. Invece, il clustering cerca di trovare strutture all'interno di un set di addestramento in cui nessun punto dei dati è l'etichetta.

3.1. K-Means Clustering

K-Means è un algoritmo di clustering con una proprietà fondamentale: il numero di cluster è definito in anticipo . Oltre a K-Means, esistono altri tipi di algoritmi di clustering come Hierarchical Clustering, Affinity Propagation o Spectral Clustering.

3.2. Come funziona K-Means

Supponiamo che il nostro obiettivo sia trovare alcuni gruppi simili in un set di dati come:

K-Means inizia con k centroidi posizionati casualmente. I centroidi, come suggerisce il nome, sono i punti centrali dei cluster . Ad esempio, qui stiamo aggiungendo quattro centroidi casuali:

Quindi assegniamo ogni punto dati esistente al suo centroide più vicino:

Dopo l'assegnazione, spostiamo i centroidi nella posizione media dei punti ad esso assegnati. Ricorda, i centroidi dovrebbero essere i punti centrali dei cluster:

L'iterazione corrente si conclude ogni volta che abbiamo finito di riposizionare i centroidi. Ripetiamo queste iterazioni fino a quando l'assegnazione tra più iterazioni consecutive smette di cambiare:

Quando l'algoritmo termina, quei quattro cluster vengono trovati come previsto. Ora che sappiamo come funziona K-Means, implementiamolo in Java.

3.3. Rappresentazione delle caratteristiche

Quando si modellano diversi set di dati di addestramento, è necessaria una struttura dati per rappresentare gli attributi del modello e i valori corrispondenti. Ad esempio, un musicista può avere un attributo di genere con un valore come Rock . Di solito usiamo il termine caratteristica per fare riferimento alla combinazione di un attributo e del suo valore.

Per preparare un set di dati per un particolare algoritmo di apprendimento, di solito utilizziamo un insieme comune di attributi numerici che possono essere utilizzati per confrontare elementi diversi. Ad esempio, se permettiamo ai nostri utenti di taggare ogni artista con un genere, alla fine della giornata, possiamo contare quante volte ogni artista è taggato con un genere specifico:

Il vettore delle caratteristiche per un artista come i Linkin Park è [rock -> 7890, nu-metal -> 700, alternative -> 520, pop -> 3]. Quindi, se potessimo trovare un modo per rappresentare gli attributi come valori numerici, allora possiamo semplicemente confrontare due elementi diversi, ad esempio gli artisti, confrontando le loro voci vettoriali corrispondenti.

Poiché i vettori numerici sono strutture di dati così versatili, rappresenteremo le caratteristiche che li utilizzano . Ecco come implementiamo i vettori di funzionalità in Java:

public class Record { private final String description; private final Map features; // constructor, getter, toString, equals and hashcode }

3.4. Trovare oggetti simili

In ogni iterazione di K-Means, abbiamo bisogno di un modo per trovare il centroide più vicino a ciascun elemento nel set di dati. Uno dei modi più semplici per calcolare la distanza tra due vettori di caratteristiche è utilizzare la distanza euclidea. La distanza euclidea tra due vettori come [p1, q1] e [p2, q2] è uguale a:

Implementiamo questa funzione in Java. Innanzitutto, l'astrazione:

public interface Distance { double calculate(Map f1, Map f2); }

Oltre alla distanza euclidea, esistono altri approcci per calcolare la distanza o la somiglianza tra diversi elementi come il coefficiente di correlazione di Pearson . Questa astrazione rende facile passare da una metrica all'altra della distanza.

Vediamo l'implementazione per la distanza euclidea:

public class EuclideanDistance implements Distance { @Override public double calculate(Map f1, Map f2) { double sum = 0; for (String key : f1.keySet()) { Double v1 = f1.get(key); Double v2 = f2.get(key); if (v1 != null && v2 != null) { sum += Math.pow(v1 - v2, 2); } } return Math.sqrt(sum); } }

Innanzitutto, calcoliamo la somma delle differenze al quadrato tra le voci corrispondenti. Quindi, applicando la funzione sqrt , calcoliamo la distanza euclidea effettiva.

3.5. Rappresentazione del centroide

I centroidi si trovano nello stesso spazio degli elementi normali, quindi possiamo rappresentarli in modo simile agli elementi:

public class Centroid { private final Map coordinates; // constructors, getter, toString, equals and hashcode }

Ora che abbiamo alcune astrazioni necessarie in atto, è il momento di scrivere la nostra implementazione di K-Means. Ecco una rapida occhiata alla nostra firma del metodo:

public class KMeans { private static final Random random = new Random(); public static Map
    
      fit(List records, int k, Distance distance, int maxIterations) { // omitted } }
    

Analizziamo questa firma del metodo:

  • Il set di dati è un insieme di vettori di caratteristiche. Poiché ogni vettore di feature è un record, il tipo di dataset è List
  • Il parametro k determina il numero di cluster, che dovremmo fornire in anticipo
  • distance encapsulates the way we're going to calculate the difference between two features
  • K-Means terminates when the assignment stops changing for a few consecutive iterations. In addition to this termination condition, we can place an upper bound for the number of iterations, too. The maxIterations argument determines that upper bound
  • When K-Means terminates, each centroid should have a few assigned features, hence we're using a Map as the return type. Basically, each map entry corresponds to a cluster

3.6. Centroid Generation

The first step is to generate k randomly placed centroids.

Although each centroid can contain totally random coordinates, it's a good practice to generate random coordinates between the minimum and maximum possible values for each attribute. Generating random centroids without considering the range of possible values would cause the algorithm to converge more slowly.

First, we should compute the minimum and maximum value for each attribute, and then, generate the random values between each pair of them:

private static List randomCentroids(List records, int k) { List centroids = new ArrayList(); Map maxs = new HashMap(); Map mins = new HashMap(); for (Record record : records) { record.getFeatures().forEach((key, value) -> ); } Set attributes = records.stream() .flatMap(e -> e.getFeatures().keySet().stream()) .collect(toSet()); for (int i = 0; i < k; i++) { Map coordinates = new HashMap(); for (String attribute : attributes) { double max = maxs.get(attribute); double min = mins.get(attribute); coordinates.put(attribute, random.nextDouble() * (max - min) + min); } centroids.add(new Centroid(coordinates)); } return centroids; }

Now, we can assign each record to one of these random centroids.

3.7. Assignment

First off, given a Record, we should find the centroid nearest to it:

private static Centroid nearestCentroid(Record record, List centroids, Distance distance) { double minimumDistance = Double.MAX_VALUE; Centroid nearest = null; for (Centroid centroid : centroids) { double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates()); if (currentDistance < minimumDistance) { minimumDistance = currentDistance; nearest = centroid; } } return nearest; }

Each record belongs to its nearest centroid cluster:

private static void assignToCluster(Map
    
      clusters, Record record, Centroid centroid) { clusters.compute(centroid, (key, list) -> { if (list == null) { list = new ArrayList(); } list.add(record); return list; }); }
    

3.8. Centroid Relocation

If, after one iteration, a centroid does not contain any assignments, then we won't relocate it. Otherwise, we should relocate the centroid coordinate for each attribute to the average location of all assigned records:

private static Centroid average(Centroid centroid, List records) { if (records == null || records.isEmpty()) { return centroid; } Map average = centroid.getCoordinates(); records.stream().flatMap(e -> e.getFeatures().keySet().stream()) .forEach(k -> average.put(k, 0.0)); for (Record record : records) { record.getFeatures().forEach( (k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue) ); } average.forEach((k, v) -> average.put(k, v / records.size())); return new Centroid(average); }

Since we can relocate a single centroid, now it's possible to implement the relocateCentroids method:

private static List relocateCentroids(Map
    
      clusters) { return clusters.entrySet().stream().map(e -> average(e.getKey(), e.getValue())).collect(toList()); }
    

This simple one-liner iterates through all centroids, relocates them, and returns the new centroids.

3.9. Putting It All Together

In each iteration, after assigning all records to their nearest centroid, first, we should compare the current assignments with the last iteration.

If the assignments were identical, then the algorithm terminates. Otherwise, before jumping to the next iteration, we should relocate the centroids:

public static Map
    
      fit(List records, int k, Distance distance, int maxIterations) { List centroids = randomCentroids(records, k); Map
     
       clusters = new HashMap(); Map
      
        lastState = new HashMap(); // iterate for a pre-defined number of times for (int i = 0; i < maxIterations; i++) { boolean isLastIteration = i == maxIterations - 1; // in each iteration we should find the nearest centroid for each record for (Record record : records) { Centroid centroid = nearestCentroid(record, centroids, distance); assignToCluster(clusters, record, centroid); } // if the assignments do not change, then the algorithm terminates boolean shouldTerminate = isLastIteration || clusters.equals(lastState); lastState = clusters; if (shouldTerminate) { break; } // at the end of each iteration we should relocate the centroids centroids = relocateCentroids(clusters); clusters = new HashMap(); } return lastState; }
      
     
    

4. Example: Discovering Similar Artists on Last.fm

Last.fm builds a detailed profile of each user's musical taste by recording details of what the user listens to. In this section, we're going to find clusters of similar artists. To build a dataset appropriate for this task, we'll use three APIs from Last.fm:

  1. API to get a collection of top artists on Last.fm.
  2. Another API to find popular tags. Each user can tag an artist with something, e.g. rock. So, Last.fm maintains a database of those tags and their frequencies.
  3. And an API to get the top tags for an artist, ordered by popularity. Since there are many such tags, we'll only keep those tags that are among the top global tags.

4.1. Last.fm's API

To use these APIs, we should get an API Key from Last.fm and send it in every HTTP request. We're going to use the following Retrofit service for calling those APIs:

public interface LastFmService { @GET("/2.0/?method=chart.gettopartists&format=json&limit=50") Call topArtists(@Query("page") int page); @GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1") Call topTagsFor(@Query("artist") String artist); @GET("/2.0/?method=chart.gettoptags&format=json&limit=100") Call topTags(); // A few DTOs and one interceptor }

So, let's find the most popular artists on Last.fm:

// setting up the Retrofit service private static List getTop100Artists() throws IOException { List artists = new ArrayList(); // Fetching the first two pages, each containing 50 records. for (int i = 1; i <= 2; i++) { artists.addAll(lastFm.topArtists(i).execute().body().all()); } return artists; }

Similarly, we can fetch the top tags:

private static Set getTop100Tags() throws IOException { return lastFm.topTags().execute().body().all(); }

Finally, we can build a dataset of artists along with their tag frequencies:

private static List datasetWithTaggedArtists(List artists, Set topTags) throws IOException { List records = new ArrayList(); for (String artist : artists) { Map tags = lastFm.topTagsFor(artist).execute().body().all(); // Only keep popular tags. tags.entrySet().removeIf(e -> !topTags.contains(e.getKey())); records.add(new Record(artist, tags)); } return records; }

4.2. Forming Artist Clusters

Now, we can feed the prepared dataset to our K-Means implementation:

List artists = getTop100Artists(); Set topTags = getTop100Tags(); List records = datasetWithTaggedArtists(artists, topTags); Map
    
      clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000); // Printing the cluster configuration clusters.forEach((key, value) -> { System.out.println("-------------------------- CLUSTER ----------------------------"); // Sorting the coordinates to see the most significant tags first. System.out.println(sortedCentroid(key)); String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet())); System.out.print(members); System.out.println(); System.out.println(); });
    

If we run this code, then it would visualize the clusters as text output:

------------------------------ CLUSTER ----------------------------------- Centroid {classic rock=65.58333333333333, rock=64.41666666666667, british=20.333333333333332, ... } David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, blink-182, The Rolling Stones, Metallica, Fleetwood Mac, The Beatles, Elton John, The Clash ------------------------------ CLUSTER ----------------------------------- Centroid {Hip-Hop=97.21428571428571, rap=64.85714285714286, hip hop=29.285714285714285, ... } Kanye West, Post Malone, Childish Gambino, Lil Nas X, A$AP Rocky, Lizzo, xxxtentacion, Travi$ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj, Drake ------------------------------ CLUSTER ----------------------------------- Centroid {indie rock=54.0, rock=52.0, Psychedelic Rock=51.0, psychedelic=47.0, ... } Tame Impala, The Black Keys ------------------------------ CLUSTER ----------------------------------- Centroid {pop=81.96428571428571, female vocalists=41.285714285714285, indie=22.785714285714285, ... } Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, Lorde, Ellie Goulding, Bruno Mars, Katy Perry, Khalid, Ariana Grande, Bon Iver, Dua Lipa, Beyoncé, Sia, P!nk, Sam Smith, Shawn Mendes, Mark Ronson, Michael Jackson, Halsey, Lana Del Rey, Carly Rae Jepsen, Britney Spears, Madonna, Adele, Lady Gaga, Jonas Brothers ------------------------------ CLUSTER ----------------------------------- Centroid {indie=95.23076923076923, alternative=70.61538461538461, indie rock=64.46153846153847, ... } Twenty One Pilots, The Smiths, Florence + the Machine, Two Door Cinema Club, The 1975, Imagine Dragons, The Killers, Vampire Weekend, Foster the People, The Strokes, Cage the Elephant, Arcade Fire, Arctic Monkeys ------------------------------ CLUSTER ----------------------------------- Centroid {electronic=91.6923076923077, House=39.46153846153846, dance=38.0, ... } Charli XCX, The Weeknd, Daft Punk, Calvin Harris, MGMT, Martin Garrix, Depeche Mode, The Chainsmokers, Avicii, Kygo, Marshmello, David Guetta, Major Lazer ------------------------------ CLUSTER ----------------------------------- Centroid {rock=87.38888888888889, alternative=72.11111111111111, alternative rock=49.16666666, ... } Weezer, The White Stripes, Nirvana, Foo Fighters, Maroon 5, Oasis, Panic! at the Disco, Gorillaz, Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park, Red Hot Chili Peppers, Muse

Since centroid coordinations are sorted by the average tag frequency, we can easily spot the dominant genre in each cluster. For example, the last cluster is a cluster of a good old rock-bands, or the second one is filled with rap stars.

Although this clustering makes sense, for the most part, it's not perfect since the data is merely collected from user behavior.

5. Visualization

A few moments ago, our algorithm visualized the cluster of artists in a terminal-friendly way. If we convert our cluster configuration to JSON and feed it to D3.js, then with a few lines of JavaScript, we'll have a nice human-friendly Radial Tidy-Tree:

We have to convert our Map to a JSON with a similar schema like this d3.js example.

6. Number of Clusters

One of the fundamental properties of K-Means is the fact that we should define the number of clusters in advance. So far, we used a static value for k, but determining this value can be a challenging problem. There are two common ways to calculate the number of clusters:

  1. Domain Knowledge
  2. Mathematical Heuristics

If we're lucky enough that we know so much about the domain, then we might be able to simply guess the right number. Otherwise, we can apply a few heuristics like Elbow Method or Silhouette Method to get a sense on the number of clusters.

Before going any further, we should know that these heuristics, although useful, are just heuristics and may not provide clear-cut answers.

6.1. Elbow Method

To use the elbow method, we should first calculate the difference between each cluster centroid and all its members. As we group more unrelated members in a cluster, the distance between the centroid and its members goes up, hence the cluster quality decreases.

One way to perform this distance calculation is to use the Sum of Squared Errors. Sum of squared errors or SSE is equal to the sum of squared differences between a centroid and all its members:

public static double sse(Map
    
      clustered, Distance distance) { double sum = 0; for (Map.Entry
     
       entry : clustered.entrySet()) { Centroid centroid = entry.getKey(); for (Record record : entry.getValue()) { double d = distance.calculate(centroid.getCoordinates(), record.getFeatures()); sum += Math.pow(d, 2); } } return sum; }
     
    

Then, we can run the K-Means algorithm for different values of kand calculate the SSE for each of them:

List records = // the dataset; Distance distance = new EuclideanDistance(); List sumOfSquaredErrors = new ArrayList(); for (int k = 2; k <= 16; k++) { Map
    
      clusters = KMeans.fit(records, k, distance, 1000); double sse = Errors.sse(clusters, distance); sumOfSquaredErrors.add(sse); }
    

At the end of the day, it's possible to find an appropriate k by plotting the number of clusters against the SSE:

Usually, as the number of clusters increases, the distance between cluster members decreases. However, we can't choose any arbitrary large values for k, since having multiple clusters with just one member defeats the whole purpose of clustering.

L'idea alla base del metodo del gomito è trovare un valore appropriato per k in modo che l'ESS diminuisca drasticamente attorno a quel valore. Ad esempio, k = 9 può essere un buon candidato qui.

7. Conclusione

In questo tutorial, per prima cosa, abbiamo trattato alcuni importanti concetti di Machine Learning. Poi abbiamo acquisito la meccanica dell'algoritmo di clustering K-Means. Infine, abbiamo scritto una semplice implementazione per K-Means, testato il nostro algoritmo con un set di dati del mondo reale da Last.fm e visualizzato il risultato del clustering in un bel modo grafico.

Come al solito, il codice di esempio è disponibile nel nostro progetto GitHub, quindi assicurati di controllarlo!