web-dev-qa-db-de.com

Speichern Sie den Klassifikator in scikit-learn auf der Festplatte

Wie speichere ich einen trainierten Naive Bayes-Klassifikator bis Datenträger und verwende ihn für Vorhersage Daten?

Ich habe das folgende Beispielprogramm von der Scikit-Learn-Website:

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()
173
garak

Klassifikatoren sind nur Objekte, die wie jedes andere auch eingelegt und entsorgt werden können. So setzen Sie Ihr Beispiel fort:

import cPickle
# save the classifier
with open('my_dumped_classifier.pkl', 'wb') as fid:
    cPickle.dump(gnb, fid)    

# load it again
with open('my_dumped_classifier.pkl', 'rb') as fid:
    gnb_loaded = cPickle.load(fid)
182
mwv

Sie können auch joblib.dump und joblib.load verwenden, was beim Umgang mit numerischen Arrays wesentlich effizienter ist als der Standard-Pickler python pickler).

Joblib ist in scikit-learn enthalten:

>>> from sklearn.externals import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482
196
ogrisel

Was Sie suchen, heißt Modellpersistenz in sklearn Wörtern und ist in Einführung und in Modell) dokumentiert Persistenz Abschnitte.

Sie haben also Ihren Klassifikator initialisiert und lange mit trainiert

clf = some.classifier()
clf.fit(X, y)

Danach haben Sie zwei Möglichkeiten:

1) Mit Gurke

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2) Verwenden von Joblib

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

Ein weiteres Mal ist es hilfreich, die oben genannten Links zu lesen

93
Salvador Dali

In vielen Fällen, insbesondere bei der Textklassifizierung, reicht es nicht aus, nur den Klassifizierer zu speichern, sondern Sie müssen auch den Vektorisierer speichern, damit Sie Ihre Eingabe in Zukunft vektorisieren können.

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

zukünftiger Anwendungsfall:

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

Bevor Sie den Vectorizer sichern, können Sie die Eigenschaft stop_words_ des Vectorizers löschen, indem Sie:

vectorizer.stop_words_ = None

um das Dumping effizienter zu gestalten. Auch wenn Ihre Klassifizierer-Parameter spärlich sind (wie in den meisten Beispielen für die Textklassifizierung), können Sie die Parameter von dicht in dünn konvertieren, was einen großen Unterschied in Bezug auf den Speicherverbrauch, das Laden und das Ausgeben von Daten darstellt. Sparsify das Modell durch:

clf.sparsify()

Was automatisch für SGDClassifier funktioniert, aber wenn Sie wissen, dass Ihr Modell dünn ist (viele Nullen in clf.coef_), können Sie clf.coef_ manuell konvertieren. in eine csr scipy sparse Matrix durch:

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

und dann können Sie es effizienter speichern.

27
Ash

sklearn Schätzer implementieren Methoden, mit denen Sie auf einfache Weise relevante geschulte Eigenschaften eines Schätzers speichern können. Einige Schätzer implementieren __getstate__ Methoden selbst, aber andere, wie die GMM verwenden einfach die Basisimplementierung , die einfach das innere Objektwörterbuch speichert:

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

Die empfohlene Methode zum Speichern Ihres Modells auf einem Datenträger ist die Verwendung des Moduls pickle :

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

Sie sollten jedoch zusätzliche Daten speichern, damit Sie Ihr Modell in Zukunft neu trainieren können, oder schwerwiegende Konsequenzen haben (z. B. das Einschließen in eine alte Version von sklearn) .

Aus der Dokumentation :

Um ein ähnliches Modell mit zukünftigen Versionen von scikit-learn neu zu erstellen, sollten zusätzliche Metadaten entlang des ausgewählten Modells gespeichert werden:

Die Trainingsdaten, z.B. ein Verweis auf einen unveränderlichen Schnappschuss

Der python Quellcode, der zum Generieren des Modells verwendet wird

Die Versionen von scikit-learn und ihre Abhängigkeiten

Die Kreuzvalidierungsbewertung, die anhand der Trainingsdaten erhalten wurde

Dies gilt insbesondere für Ensemble-Schätzer , die auf tree.pyx Modul, das in Cython geschrieben wurde (z. B. IsolationForest), da es eine Kopplung mit der Implementierung herstellt, die nicht garantiert zwischen den Versionen von sklearn stabil ist. In der Vergangenheit gab es inkompatible Änderungen.

Wenn Ihre Modelle sehr groß werden und das Laden zu einem Ärgernis wird, können Sie auch das effizientere joblib verwenden. Aus der Dokumentation:

Im speziellen Fall des Scikits kann es interessanter sein, joblibs Ersatz für pickle (joblib.dump & joblib.load), der effizienter für Objekte ist, die intern große Numpy-Arrays enthalten, wie dies häufig bei angepassten Scikit-Learn-Schätzern der Fall ist, die jedoch nur auf die Festplatte und nicht auf einen String zugreifen können:

4
Sebastian Wozny