#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Example of RBF regression using built in k-means and cdist functions We then "manually" construct the necessary matrices, solving using pinv, and get the output on the test set. Key thing when using: Pay attention to dimensions of things! @author: drh """ import numpy as np from sklearn.cluster import KMeans from sklearn.model_selection import train_test_split from scipy.spatial.distance import cdist import matplotlib.pyplot as plt #%% # Step 1: Get or create the data. # numpts = 150 X = np.random.uniform(0., 1., numpts) X = np.sort(X, axis=0) noise = np.random.uniform(-0.1, 0.1, numpts) y = np.sin(3 * np.pi * X) + noise X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) #%% # Step 2: k-means on the data to get the centers (for now we won't worry about the std) # kmeans = KMeans(n_clusters=4, random_state=0).fit(X_train[:,np.newaxis]) C=kmeans.cluster_centers_ # C is num_centers x dim #%% # Step 3: Build the RBF matrices and solve for the weights and biases. # spr=np.sqrt(-np.log(0.5)) #Optional, this is for the spread of the Gaussian D=cdist(C,X_train[:,np.newaxis],'euclidean') # Each is numpts x dim Phi=np.exp(-spr*(D**2)) nc,nr=Phi.shape # Phi should be numcenters x num data # Add a row of zeros to get phi_hat Phi_hat = np.vstack((Phi,np.ones((1,nr)))) W_hat = y_train[:,np.newaxis].T @ np.linalg.pinv(Phi_hat) # W_hat is [ W b ] #%% # Step 4: Predict output using the RBF on the test set, and plot some results # # Basically just copy the previous steps ... spr=np.sqrt(-np.log(0.5)) #Optional, this is for the spread of the Gaussian D=cdist(C,X_test[:,np.newaxis],'euclidean') # Each is numpts x dim Phi=np.exp(-spr*(D**2)) nc,nr=Phi.shape # Phi should be numcenters x outdim # Add a row of zeros to get phi_hat Phi_hat = np.vstack((Phi,np.ones((1,nr)))) # W_hat is [ W b ] y_out = W_hat @ Phi_hat #%% # Step 5: Show results (by plotting if possible) plt.plot(X, y, '-o', label='Orig Data') plt.plot(X_test, y_out.T, 'o', label='RBF on Test') plt.legend() plt.tight_layout() plt.show()