web-dev-qa-db-fra.com

Comment puis-je appeler des classificateurs scikit-learn depuis Java?

J'ai un classificateur que j'ai formé à l'aide de scikit-learn de Python. Comment puis-je utiliser le classificateur à partir d'un programme Java? Puis-je utiliser Jython? Existe-t-il un moyen d'enregistrer le classificateur dans Python et de le charger en Java?) Existe-t-il une autre manière de l'utiliser?

30
Thomas Johnson

Vous ne pouvez pas utiliser jython car scikit-learn s'appuie fortement sur numpy et scipy qui ont de nombreuses extensions C et Fortran compilées et ne peuvent donc pas fonctionner en jython.

La façon la plus simple d'utiliser scikit-learn dans un environnement Java serait de:

  • exposer le classificateur en tant que service HTTP/Json, par exemple en utilisant un microframework tel que flask ou bottle ou cornice et l'appeler depuis Java utilisant une bibliothèque client HTTP

  • écrire une application d'encapsuleur de ligne de commande dans python qui lit les données sur stdin et les prédictions de sortie sur stdout en utilisant un format tel que CSV ou JSON (ou une représentation binaire de niveau inférieur) et appelez le python programme de Java par exemple en utilisant Apache Commons Exec .

  • faire sortir le programme python les paramètres numériques bruts appris au moment de l'ajustement (généralement sous forme de tableau de valeurs à virgule flottante) et réimplémenter la fonction de prédiction dans Java (ce est généralement facile pour les modèles linéaires prédictifs où la prédiction est souvent juste un produit scalaire seuil).

La dernière approche nécessitera beaucoup plus de travail si vous devez réimplémenter l'extraction de fonctionnalités dans Java également.

Enfin, vous pouvez utiliser une bibliothèque Java telle que Weka ou Mahout qui implémente les algorithmes dont vous avez besoin au lieu d'essayer d'utiliser scikit-learn à partir de Java.

45
ogrisel

Il y a JPMML projet à cet effet.

Tout d'abord, vous pouvez sérialiser le modèle scikit-learn en PMML (qui est XML en interne) à l'aide de la bibliothèque sklearn2pmml directement depuis python ou le vider dans python d'abord et convertir en utilisant jpmml-sklearn dans Java ou à partir d'une ligne de commande fournie par cette bibliothèque. Ensuite, vous pouvez charger le fichier pmml, désérialiser et exécuter chargé modèle utilisant jpmml-evaluator dans votre Java.

De cette façon, cela ne fonctionne pas avec tous les modèles scikit-learn, mais avec many d'entre eux.

16
Dmitry Spikhalskiy

Vous pouvez soit utiliser un porteur, j'ai testé le sklearn-porter ( https://github.com/nok/sklearn-porter ), et cela fonctionne bien pour Java.

Mon code est le suivant:

import pandas as pd
from sklearn import tree
from sklearn_porter import Porter

train_dataset = pd.read_csv('./result2.csv').as_matrix()

X_train = train_dataset[:90, :8]
Y_train = train_dataset[:90, 8:]

X_test = train_dataset[90:, :8]
Y_test = train_dataset[90:, 8:]

print X_train.shape
print Y_train.shape


clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

porter = Porter(clf, language='Java')
output = porter.export(embed_data=True)
print(output)

Dans mon cas, j'utilise un DecisionTreeClassifier, et la sortie de

impression (sortie)

est le code suivant sous forme de texte dans la console:

class DecisionTreeClassifier {

  private static int findMax(int[] nums) {
    int index = 0;
    for (int i = 0; i < nums.length; i++) {
        index = nums[i] > nums[index] ? i : index;
    }
    return index;
  }


  public static int predict(double[] features) {
    int[] classes = new int[2];

    if (features[5] <= 51.5) {
        if (features[6] <= 21.0) {

            // HUGE amount of ifs..........

        }
    }

    return findMax(classes);
  }

  public static void main(String[] args) {
    if (args.length == 8) {

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Prediction:
        int prediction = DecisionTreeClassifier.predict(features);
        System.out.println(prediction);

    }
  }
}
4
gustavoresque

Voici du code pour la solution JPMML:

--PARTIE PYTHON--

# helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
def determine_categorical_columns(df):
    categorical_columns = []
    x = 0
    for col in df.dtypes:
        if col == 'object':
            val = df[df.columns[x]].iloc[0]
            if not isinstance(val,Decimal):
                categorical_columns.append(df.columns[x])
        x += 1
    return categorical_columns

categorical_columns = determine_categorical_columns(df)
other_columns = list(set(df.columns).difference(categorical_columns))


#construction of transformators for our example
labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
nones = [(d, None) for d in other_columns]
transformators = labelBinarizers+nones

mapper = DataFrameMapper(transformators,df_out=True)
gbc = GradientBoostingClassifier()

#construction of the pipeline
lm = PMMLPipeline([
    ("mapper", mapper),
    ("estimator", gbc)
])

--Pièce Java -

//Initialisation.
String pmmlFile = "ScikitLearnNew.pmml";
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);

//Determine which features are required as input
HashMap<String, Field>() inputFieldMap = new HashMap<String, Field>();
for (int i = 0; i < evaluator.getInputFields().size();i++) {
  InputField curInputField = evaluator.getInputFields().get(i);
  String fieldName = curInputField.getName().getValue();
  inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
}


//prediction

HashMap<String,String> argsMap = new HashMap<String,String>();
//... fill argsMap with input

Map<FieldName, ?> res;
// here we keep only features that are required by the model
Map<FieldName,String> args = new HashMap<FieldName, String>();
Iterator<String> iter = argsMap.keySet().iterator();
while (iter.hasNext()) {
  String key = iter.next();
  Field f = inputFieldMap.get(key);
  if (f != null) {
    FieldName name =f.getName();
    String value = argsMap.get(key);
    args.put(name, value);
  }
}
//the model is applied to input, a probability distribution is obtained
res = evaluator.evaluate(args);
SegmentResult segmentResult = (SegmentResult) res;
Object targetValue = segmentResult.getTargetValue();
ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;
2
Volokh

Je me suis retrouvé dans une situation similaire. Je recommanderai de créer un microservice de classificateur. Vous pourriez avoir un microservice de classificateur qui s'exécute en python puis exposer les appels à ce service sur une API RESTFul produisant un format d'échange de données JSON/XML. Je pense que c'est une approche plus propre.

1
theyCallMeJun

Alternativement, vous pouvez simplement générer un code Python à partir d'un modèle formé. Voici un outil qui peut vous aider avec cela https://github.com/BayesWitnesses/m2cgen

0
Viktor Ershov