Machine Learning for the masses with AWS Deep Java Library and Spring Boot
Last year, as each year, I attended the SpringOne conference although this time it was virtual. In that conference I saw one of the best presentations I have ever seen on AI related to Spring and Java: Integrate Machine Learning into your Spring Application in less than an hour .
This post is based on that presentation. The basic idea is to easily integrate Spring Boot (the most popular Java framework) with Machine Learning natively in Java using Deep Java Library, a library created by AWS. There are few options in the market to do that.
Since this is one of the hardes topics in programming the idea is to put Machine Learning in the hands of all developers.
ML is a top priority for many organizations: in the past engineering & data science used to lead the way, they would produce this neural network, trying to pushing it to the business to find a use case (there was a limited amount of computing resources as well).
Nowadays the business drives the use cases, such as fraud detection, sentiment analysis, natural language processing, etc. and computing resources are almost unlimited.
What are the common challenges?
- Skills gap - not enough people can build ML models.
- ML model building is a time-consuming and complex process.
- Finding the right business use cases that could benefit from ML.
First we find the AI services like Amazon Recognition for vision or Amazon Translate for text. Anyone can use these services, they are trained models. Then we find the middle layer: it is for machine learning developers, data scientists, people who know what they're doing but it just makes it easier to run common activities (ie. in the SageMaker IDE you can debug, run experiments, etc.)
Then there is the bottom layer which is for experts: the ML frameworks and infrastructure is for people who want to have full control over the application code, their application logic as well as their infrastructure.
As you can see there are three Deep Learning frameworks which Amazon supports: TensorFlow, MXNet and PyTorch and contributes to the open source librairies. In the same spirit they created Deep Java Library (DJL).
What is Deep Java Library ? it is an open-source, high-level, engine-agnostic Java framework for deep learning. DJL is designed to be easy to get started with and simple to use for Java developers. You don't have to be machine learning/deep learning expert to get started.
DJL has most of the ML features including training and inference. It also has multithread support and memory control and of course supports the before mentioned frameworks (TensorFlow, MXNet and PyTorch).
DJL also supports ModelZoo : it comes with more than 70 pre-trained models out of the box. Some of these models were created by Keras and Gluon, another open source librairies which Amazon is already contributing.
Because it's engine-agnostic DJL can, for example, run an inference on a pre-trained model with PyTorch and then run exactly the same test using TensorFlow changing only the dependency in your project.
Why should you use DJL ?
- There are a lot of enterprises already using Java.
- Existing Java API for Deep Learning is painful
- Deploy ML in Java is challenging - Hard to maintain.
- Java community lacks Deep Learning standard.
- It is very easy to use, in 10 lines only you can setup and inference.
- There is a minimum dependency requirement.
- it runs fast: up to 2x performance boost on small model inference. It is used by Amazon Ads in query understanding and image classification.
- It can be used on a large scale: at Amazon it has been used on a 800M inference load on Apache Spark in 6 hours.
- It is stable: also at Amazon it reached the 100+ hour benchmark on continuous inference call.
We will run an image recognition microservice, demonstrating common deep learning use cases. You will find all the details for the implementation on the demo page. All the modules were built using Gradle, here are the main points:
On "djl-spring-boot-app" module you need to define which library you will be using for your test, in our case it will be Pytorch.
As you can see in the InferenceConfiguration we're telling the application to use "Object_Detection" as the inference type.
We can also see how the Predictor class is defined in the InferenceConfiguration
In the application.properties of the same module, you need to define the AWS S3 bucket that the app will be using to store the images, once processed.
You will need to change the same value on the application.properties of the "djl-spring-boot-web" module.
Then you can start both modules, the backend and the web (built on Kotlin), be careful to change the ports.
Once you deployed both apps you can access the web page to load your photos:
Let's try with the first example, let's load Diego Maradona winning the soccer World CupAs you can see the model is recognizing persons in the picture. Let's try another one.
The model is able to recognize our hero, Chuck, as a person but then he (she?) mades a mistake and thinks he has a baseball glove instead of a gun machine. Nice try. Another one:
Now the machine is able to distinguish my beloved Maica as a dog. Another one:
In this very interesting picture the machine is able to recognize a person (who is not standing), a dog (jumping in the air) and a frisbee !! . Quite an accomplishment !
The last one:
In this one the model recognizes a truck, various cars and even a person inside the truck, but mistakes some of the cars with an airplane :-)
That's it. Remember we used an image recognition microservice, that can be easily changed for another service. You have all these options:
- for computer vision
Application IMAGE_CLASSIFICATION = new Application("cv/image_classification");
Application OBJECT_DETECTION = new Application("cv/object_detection");
Application SEMANTIC_SEGMENTATION = new Application("cv/semantic_segmentation");
Application INSTANCE_SEGMENTATION = new Application("cv/instance_segmentation");
Application POSE_ESTIMATION = new Application("cv/pose_estimation");
Application ACTION_RECOGNITION= new Application("cv/action_recognition");
- for natural language processing
Application QUESTION_ANSWER = new Application("nlp/question_answer");
Application TEXT_CLASSIFICATION = new Application("nlp/text_classification");
Application SENTIMENT_ANALYSIS = new Application("nlp/sentiment_analysis");
Application WORD_EMBEDDING = new Application("nlp/word_embedding");
Application MACHINE_TRANSLATION = new Application("nlp/machine_translation");
Application MULTIPLE_CHOICE = new Application("nlp/multiple_choice");
Comments
Post a Comment