機械学習済みモデルのスマートフォンでの利用(H2O→android)

機械学習は多量のデータと計算を必要とするので、これをPCやクラウドで済ませ、学修済みのモデルをスマートフォンで活用したいというケースがあります。

普段、よく用いているとH2O(R言語上)からjavaへのexportが可能でしたのでメモします。javaなのでandroidで利用できるというわけです。H2Oのドキュメントを参照しています。

step1. R + H2Oでの学習(deep learning)

ここでは典型的な学習例として「iris」データを用いた例を示します。

library(h2o)
localH2O <- h2o.init()

datanum <- nrow(iris)
trainnum <- datanum / 2

trainid <- sample(1:datanum, trainnum)

iris.train <- as.h2o(iris[trainid,])
iris.test <- as.h2o(iris[-trainid,])

res.dl <- h2o.deeplearning(x = 1:4, y = 5, training_frame = iris.train, validation = iris.test, activation = "TanhWithDropout")

pred.dl <- predict(res.dl, iris.test)
correct.dl <- ifelse(pred.dl[,1]==iris.test[,5], 1, 0)
rate.dl <- sum(correct.dl) / nrow(correct.dl)

rate.dl

h2o.download_pojo(res.dl)

最後に表示される内容をファイルに保存するか、http://localhost:54321にアクセスし、該当モデルをexport pojoします。

次に、「http://localhost:54321/3/h2o-genmodel.jar」にアクセスし、ファイルへ保存します。

step2. androidアプリでの利用

上でexportしたjavaファイルを、android studioで該当プロジェクトのjavaフォルダー内のパッケージへdropします。そしてandroid studio上でこのファイルを開き、先頭にpackageを追加します。これは生成済みの別ファイル(例えばMainActivity)と同じ内容にすればOKです。

次に、上でダウンロードしたjarファイルをapp/libsフォルダへコピーしておきます。それから、Build→Edit Libraries & ..より、jar dependencyとして追加します。

最後の利用例を示します。この例ではGUIを無視して、起動後に処理した結果をconsole出力しています。

package ...;

import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.*;

public class MainActivity extends AppCompatActivity {

    private static String modelClassName = "DeepLearning_model_R_1507166311509_2";

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        hex.genmodel.GenModel rawModel;
        try {
            rawModel = new DeepLearning_model_R_1507166311509_2();

            // rawModel = (hex.genmodel.GenModel) Class.forName("DeepLearning_model_R_1507166311509_2").newInstance();
            EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel);

            RowData row = new RowData();
            row.put("Sepal.Length", 5.1);
            row.put("Sepal.Width", 3.5);
            row.put("Petal.Length", 1.4);
            row.put("Petal.Width", 0.2);

            // BinomialModelPrediction p = model.predictBinomial(row);
            MultinomialModelPrediction p = model.predictMultinomial(row);
            System.out.println("Label (aka prediction) is flight departure delayed: " + p.label);
            System.out.print("Class probabilities: ");
            for (int i = 0; i < p.classProbabilities.length; i++) {
                if (i > 0) {
                    System.out.print(",");
                }
                System.out.print(p.classProbabilities[i]);
            }
            System.out.println("");
        } catch (PredictException e) {
            e.printStackTrace();
        }

    }
}

 

基本的にはサンプルをそのまま使っていますが、2点ほど変更しています。コメントアウトしている部分です。

インスタンスの取得部分はどうしても実行時エラーになるため、上のように通常のインスタンス生成の手順に入れ替えました。

この例ではBinominalだとうまくいかないので、ここも変更しています。

 

まとめ

普段、データ処理や他の手法との比較が容易であることから、R上からH2Oを使っているので、その延長線上でそのまま利用できるのはとてもメリットが大きいです。

 

タグ: , ,

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

*