Source code for datafusiontools.machine_learning.support_vector_machine

from dataclasses import dataclass
import matplotlib.pylab as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.svm import SVR
import numpy as np
import os

from .enumeration_classes import ActivationFunctionSVM, GammaList
from .baseclass import BaseClassMachineLearning



[docs]@dataclass class SVM(BaseClassMachineLearning): """ Class of the Support Vector Machine. :param kernel: Kernel type to be used in the algorithm :param gamma: Kernel coefficient for ‘rbf’, ‘poly’ and ‘sigmoid’ """ kernel: ActivationFunctionSVM = ActivationFunctionSVM.rbf gamma: GammaList = GammaList.scale
[docs] def train_svm(self) -> None: """ Trains the SVM with the data and multiple class target values """ # labels target data self.target_label = list(set(self.target)) # SVM model self.model = SVR(kernel=self.kernel.value, gamma=self.gamma.value) # Fit to data self.model.fit(self.training_data, self.target) # compute accuracy self.check_score()
[docs] def train_classification(self) -> None: self.train_svm()
[docs] def train_regression(self) -> None: self.train_svm()
[docs] def check_score(self) -> None: """ Computes score of training """ # compute prediction of the training data target_predict = self.model.predict(self.training_data) if self.classification: self.accuracy = len(np.where(target_predict == self.target)[0]) / len( self.target ) print(f"Accuracy of training: {round(self.accuracy * 100, 2)} %") else: self.accuracy = np.sqrt(((target_predict - self.target) ** 2).mean()) print(f"RMSE of training: {round(self.accuracy, 2)}") return
[docs] def predict(self, data: np.ndarray) -> None: """ Predict the values at the data points :param data: dataset with features for prediction """ self.prediction = self.model.predict(data) return
[docs] def plot_confusion(self, validation: np.ndarray, output_folder: str = "./") -> None: """ Plots the confusion matrix for the validation dataset :param validation: Validation data at the predicted points :param output_folder: location where the plot is saved """ if not os.path.isdir(output_folder): os.makedirs(output_folder) confusion = confusion_matrix( validation, self.prediction, labels=self.target_label ) # , normalize="true") print(f"Confusion matrix:\n {confusion}") disp = ConfusionMatrixDisplay( confusion_matrix=confusion, display_labels=self.target_label ) fig, ax = plt.subplots(figsize=(6, 4)) ax.set_position([0.15, 0.15, 0.8, 0.8]) disp.plot(cmap="binary", ax=ax) # disp.im_.set_clim(0, 1) plt.savefig(os.path.join(output_folder, "confusion_matrix")) plt.close() return