import numpy as np
import xarray as xr


class FeatsModel:

    def __init__(self, model=None):

        if model is None:
            from sklearn.neural_network import MLPRegressor
            model = MLPRegressor()
        self.model = model

    def _to_vec(self, feats):
        if isinstance(feats, xr.DataArray):
            feats = feats.data
        if feats.ndim == 2:
            feats = feats[np.newaxis, ...]
        nbands, ny, nx = feats.shape
        X = feats.reshape(nbands, -1).T
        X = np.squeeze(X)
        return X

    def _from_vec(self, X, shape):
        feats = X.T.reshape(*shape)
        return feats

    def fit(self, feats_x, feats_y):
        X = self._to_vec(feats_x)
        y = self._to_vec(feats_y)
        self.model.fit(X, y)
        return self

    def fit_transform(self, feats_x, normalize=True):
        X = self._to_vec(feats_x)
        preds = self.model.fit_transform(X)

        if normalize:
            # normalize on each channel
            preds = preds - preds.min(axis=0)
            preds = preds / preds.max(axis=0)

        _, ny, nx = feats_x.shape
        nbands = int(preds.size / (ny * nx))
        preds = self._from_vec(preds,
                               (nbands, ny, nx))
        return preds

    def predict(self, feats_x):
        X = self._to_vec(feats_x)
        probs = None
        if hasattr(self.model, 'transform'):
            preds = self.model.transform(X)
        elif hasattr(self.model, 'predict'):
            preds = self.model.predict(X)
            try:
                probs = self.model.predict_proba(X)
            except:
                probs = None
        else:
            raise ValueError("Model has no transform"
                             " or predict method.")

        _, ny, nx = feats_x.shape
        nbands = int(preds.size / (ny * nx))
        preds = self._from_vec(preds,
                               (nbands, ny, nx))
        if probs is not None:
            nbands = int(probs.size / (ny * nx))
            probs = self._from_vec(probs,
                                   (nbands, ny, nx))
            return preds, probs
        else:
            return preds
