Source code for datafusiontools.machine_learning.baseclass

from dataclasses import dataclass
from typing import List, Union
from abc import abstractmethod
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datafusiontools._core.base_class import BaseClass


[docs]@dataclass class BaseClassMachineLearning(BaseClass): classification: bool training_data: Union[List, None, np.array] = None target: Union[List, None, np.array] = None target_label: Union[List, None, np.array] = None prediction: Union[List, None, np.array] = None
[docs] @abstractmethod def train_classification(self): raise NotImplementedError( "The method should be implemented in concrete classes." )
[docs] @abstractmethod def train_regression(self): raise NotImplementedError( "The method should be implemented in concrete classes." )
[docs] def train(self, data: np.ndarray, target: np.ndarray) -> None: """ Trains the NN with the data and multiple class target values based on the model selected (classification or regression). :param data: data features :param target: multiple class target value """ # variables self.training_data = data self.target = target if self.classification: self.train_classification() else: self.train_regression() return
[docs] @abstractmethod def predict(self): raise NotImplementedError( "The method should be implemented in concrete classes." )
[docs] def plot_fitted_line( self, validation_target: np.ndarray, output_folder: Path = Path("./") ) -> None: """ Plots fitted line of prediction :param output_folder: location where the plot is saved """ output_folder.mkdir(parents=True, exist_ok=True) # max and mins max_all = max( np.amax(self.prediction.flatten()), np.amax(validation_target.flatten()) ) min_all = min( np.amin(self.prediction.flatten()), np.amin(validation_target.flatten()) ) fig, ax = plt.subplots(figsize=(6, 4)) ax.set_position([0.15, 0.15, 0.8, 0.8]) ax.plot(self.prediction.flatten(), validation_target.flatten(), "ro") ax.plot([min_all - 0.2, max_all + 0.2], [min_all - 0.2, max_all + 0.2], "-") ax.set_xlim([min_all - 0.1, max_all + 0.1]) ax.set_ylim([min_all - 0.1, max_all + 0.1]) ax.set_xlabel("prediction") ax.set_ylabel("actual data") ax.grid() figname = str(Path(output_folder, "fitted_line.png")) plt.savefig(figname) plt.close() return None